174 lines
5.2 KiB
Julia
174 lines
5.2 KiB
Julia
|
# functions that find operations on the inital graph
|
||
|
|
||
|
using Base.Threads
|
||
|
|
||
|
function insert_operation!(operations::PossibleOperations, nf::NodeFusion, locks::Dict{Node, SpinLock})
|
||
|
push!(operations.nodeFusions, nf)
|
||
|
n1 = nf.input[1]; n2 = nf.input[2]; n3 = nf.input[3]
|
||
|
|
||
|
lock(locks[n1]) do; push!(nf.input[1].operations, nf); end
|
||
|
lock(locks[n2]) do; push!(nf.input[2].operations, nf); end
|
||
|
lock(locks[n3]) do; push!(nf.input[3].operations, nf); end
|
||
|
end
|
||
|
|
||
|
function insert_operation!(operations::PossibleOperations, nr::NodeReduction, locks::Dict{Node, SpinLock})
|
||
|
push!(operations.nodeReductions, nr)
|
||
|
for n in nr.input
|
||
|
lock(locks[n]) do; push!(n.operations, nr); end
|
||
|
end
|
||
|
end
|
||
|
|
||
|
function insert_operation!(operations::PossibleOperations, ns::NodeSplit, locks::Dict{Node, SpinLock})
|
||
|
push!(operations.nodeSplits, ns)
|
||
|
lock(locks[ns.input]) do; push!(ns.input.operations, ns); end
|
||
|
end
|
||
|
|
||
|
function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}}, locks::Dict{Node, SpinLock})
|
||
|
for vec in nodeReductions
|
||
|
for op in vec
|
||
|
insert_operation!(operations, op, locks)
|
||
|
end
|
||
|
end
|
||
|
end
|
||
|
|
||
|
function nf_insertion!(operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}}, locks::Dict{Node, SpinLock})
|
||
|
for vec in nodeFusions
|
||
|
for op in vec
|
||
|
insert_operation!(operations, op, locks)
|
||
|
end
|
||
|
end
|
||
|
end
|
||
|
|
||
|
function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}}, locks::Dict{Node, SpinLock})
|
||
|
for vec in nodeSplits
|
||
|
for op in vec
|
||
|
insert_operation!(operations, op, locks)
|
||
|
end
|
||
|
end
|
||
|
end
|
||
|
|
||
|
# function to generate all possible operations on the graph
|
||
|
function generate_options(graph::DAG)
|
||
|
locks = Dict{Node, SpinLock}()
|
||
|
for n in graph.nodes
|
||
|
locks[n] = SpinLock()
|
||
|
end
|
||
|
|
||
|
generatedFusions = [Vector{NodeFusion}() for _ in 1:nthreads()]
|
||
|
generatedReductions = [Vector{NodeReduction}() for _ in 1:nthreads()]
|
||
|
generatedSplits = [Vector{NodeSplit}() for _ in 1:nthreads()]
|
||
|
|
||
|
# make sure the graph is fully generated through
|
||
|
apply_all!(graph)
|
||
|
|
||
|
# --- find possible node reductions ---
|
||
|
|
||
|
# find some useful partition of nodes without generating duplicate node reductions
|
||
|
nodePartitions = [Vector{Set{Node}}() for _ in 1:nthreads()]
|
||
|
avgNodes = 0. # the average number of nodes across all the node partitions
|
||
|
nodeSet = copy(graph.nodes)
|
||
|
|
||
|
partitionPointer = 1
|
||
|
rotatePointer(i) = (i % nthreads()) + 1
|
||
|
|
||
|
while !isempty(nodeSet)
|
||
|
# cycle partition pointer to a set with fewer than average nodes
|
||
|
nodes = partners(first(nodeSet))
|
||
|
setdiff!(nodeSet, nodes)
|
||
|
|
||
|
if length(nodes) == 1
|
||
|
# nothing to reduce here anyways
|
||
|
continue
|
||
|
end
|
||
|
|
||
|
partitionPointer = rotatePointer(partitionPointer)
|
||
|
|
||
|
push!(nodePartitions[partitionPointer], nodes)
|
||
|
avgNodes = avgNodes + length(nodes) / nthreads()
|
||
|
end
|
||
|
|
||
|
@threads for partition in nodePartitions
|
||
|
for partners_ in partition
|
||
|
reductionVector = nothing
|
||
|
|
||
|
node = pop!(partners_)
|
||
|
|
||
|
t = typeof(node)
|
||
|
|
||
|
# possible reductions are with nodes that are partners, i.e. parents of children
|
||
|
for partner in partners_
|
||
|
if (t != typeof(partner))
|
||
|
continue
|
||
|
end
|
||
|
|
||
|
if !can_reduce(node, partner)
|
||
|
continue
|
||
|
end
|
||
|
|
||
|
if reductionVector === nothing
|
||
|
# only when there's at least one reduction partner, insert the vector
|
||
|
reductionVector = Vector{Node}()
|
||
|
push!(reductionVector, node)
|
||
|
end
|
||
|
|
||
|
push!(reductionVector, partner)
|
||
|
end
|
||
|
|
||
|
if reductionVector !== nothing
|
||
|
push!(generatedReductions[threadid()], NodeReduction(reductionVector))
|
||
|
end
|
||
|
end
|
||
|
end
|
||
|
|
||
|
# launch thread for node reduction insertion
|
||
|
nr_task = @task nr_insertion!(graph.possibleOperations, generatedReductions, locks)
|
||
|
schedule(nr_task)
|
||
|
|
||
|
# --- find possible node fusions ---
|
||
|
nodeArray = collect(graph.nodes)
|
||
|
|
||
|
@threads for node in nodeArray
|
||
|
if (typeof(node) <: DataTaskNode)
|
||
|
if length(node.parents) != 1
|
||
|
# data node can only have a single parent
|
||
|
continue
|
||
|
end
|
||
|
parent_node = first(node.parents)
|
||
|
|
||
|
if length(node.children) != 1
|
||
|
# this node is an entry node or has multiple children which should not be possible
|
||
|
continue
|
||
|
end
|
||
|
child_node = first(node.children)
|
||
|
if (length(child_node.parents) != 1)
|
||
|
continue
|
||
|
end
|
||
|
|
||
|
push!(generatedFusions[threadid()], NodeFusion((child_node, node, parent_node)))
|
||
|
end
|
||
|
end
|
||
|
|
||
|
# launch thread for node fusion insertion
|
||
|
nf_task = @task nf_insertion!(graph.possibleOperations, generatedFusions, locks)
|
||
|
schedule(nf_task)
|
||
|
|
||
|
# find possible node splits
|
||
|
@threads for node in nodeArray
|
||
|
if (can_split(node))
|
||
|
push!(generatedSplits[threadid()], NodeSplit(node))
|
||
|
end
|
||
|
end
|
||
|
|
||
|
# launch thread for node split insertion
|
||
|
ns_task = @task ns_insertion!(graph.possibleOperations, generatedSplits, locks)
|
||
|
schedule(ns_task)
|
||
|
|
||
|
empty!(graph.dirtyNodes)
|
||
|
|
||
|
wait(nr_task)
|
||
|
wait(nf_task)
|
||
|
wait(ns_task)
|
||
|
|
||
|
return nothing
|
||
|
end
|