Multithreading for Node Reductions

This commit is contained in:
2023-08-21 13:29:55 +02:00
parent 2e96e6520e
commit a7fb15c95b
5 changed files with 129 additions and 56 deletions

View File

@ -9,18 +9,40 @@ function insert_operation!(operations::PossibleOperations, nf::NodeFusion, locks
lock(locks[n1]) do; push!(nf.input[1].operations, nf); end
lock(locks[n2]) do; push!(nf.input[2].operations, nf); end
lock(locks[n3]) do; push!(nf.input[3].operations, nf); end
return nothing
end
function insert_operation!(operations::PossibleOperations, nr::NodeReduction, locks::Dict{Node, SpinLock})
push!(operations.nodeReductions, nr)
first = true
for n in nr.input
lock(locks[n]) do; push!(n.operations, nr); end
skip_duplicate = false
# careful here, this is a manual lock
lock(locks[n])
if first
first = false
for op in n.operations
if typeof(op) <: NodeReduction
skip_duplicate = true
break
end
end
if skip_duplicate
unlock(locks[n])
break
end
end
push!(n.operations, nr)
unlock(locks[n])
end
push!(operations.nodeReductions, nr)
return nothing
end
function insert_operation!(operations::PossibleOperations, ns::NodeSplit, locks::Dict{Node, SpinLock})
push!(operations.nodeSplits, ns)
lock(locks[ns.input]) do; push!(ns.input.operations, ns); end
return nothing
end
function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}}, locks::Dict{Node, SpinLock})
@ -29,6 +51,7 @@ function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Ve
insert_operation!(operations, op, locks)
end
end
return nothing
end
function nf_insertion!(operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}}, locks::Dict{Node, SpinLock})
@ -37,6 +60,7 @@ function nf_insertion!(operations::PossibleOperations, nodeFusions::Vector{Vecto
insert_operation!(operations, op, locks)
end
end
return nothing
end
function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}}, locks::Dict{Node, SpinLock})
@ -45,6 +69,7 @@ function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector
insert_operation!(operations, op, locks)
end
end
return nothing
end
# function to generate all possible operations on the graph
@ -61,73 +86,51 @@ function generate_options(graph::DAG)
# make sure the graph is fully generated through
apply_all!(graph)
nodeArray = collect(graph.nodes)
# sort all nodes
println("Sorting...")
@time @threads for node in nodeArray
sort_node!(node)
end
# --- find possible node reductions ---
# find some useful partition of nodes without generating duplicate node reductions
nodePartitions = [Vector{Set{Node}}() for _ in 1:nthreads()]
avgNodes = 0. # the average number of nodes across all the node partitions
nodeSet = copy(graph.nodes)
partitionPointer = 1
rotatePointer(i) = (i % nthreads()) + 1
while !isempty(nodeSet)
# cycle partition pointer to a set with fewer than average nodes
nodes = partners(first(nodeSet))
setdiff!(nodeSet, nodes)
if length(nodes) == 1
# nothing to reduce here anyways
println("Node Reductions...")
@time @threads for node in nodeArray
# we're looking for nodes with multiple parents, those parents can then potentially reduce with one another
if (length(node.parents) <= 1)
continue
end
partitionPointer = rotatePointer(partitionPointer)
candidates = parents(node)
push!(nodePartitions[partitionPointer], nodes)
avgNodes = avgNodes + length(nodes) / nthreads()
end
nodeReductions = Set{Set{Node}}()
@threads for partition in nodePartitions
for partners_ in partition
reductionVector = nothing
node = pop!(partners_)
# sort into equivalence classes
#**TODO** check that only same types can reduce
trie = NodeTrie()
t = typeof(node)
for candidate in candidates
# insert into trie
insert!(trie, candidate)
end
# possible reductions are with nodes that are partners, i.e. parents of children
for partner in partners_
if (t != typeof(partner))
continue
end
nodeReductions = collect(trie)
if !can_reduce(node, partner)
continue
end
if reductionVector === nothing
# only when there's at least one reduction partner, insert the vector
reductionVector = Vector{Node}()
push!(reductionVector, node)
end
push!(reductionVector, partner)
end
if reductionVector !== nothing
push!(generatedReductions[threadid()], NodeReduction(reductionVector))
end
for nrSet in nodeReductions
push!(generatedReductions[threadid()], NodeReduction(collect(nrSet)))
end
end
# launch thread for node reduction insertion
# removeduplicates
nr_task = @task nr_insertion!(graph.possibleOperations, generatedReductions, locks)
schedule(nr_task)
# --- find possible node fusions ---
nodeArray = collect(graph.nodes)
@threads for node in nodeArray
println("Node Fusions...")
@time @threads for node in nodeArray
if (typeof(node) <: DataTaskNode)
if length(node.parents) != 1
# data node can only have a single parent
@ -153,7 +156,8 @@ function generate_options(graph::DAG)
schedule(nf_task)
# find possible node splits
@threads for node in nodeArray
println("Node Splits...")
@time @threads for node in nodeArray
if (can_split(node))
push!(generatedSplits[threadid()], NodeSplit(node))
end
@ -165,9 +169,12 @@ function generate_options(graph::DAG)
empty!(graph.dirtyNodes)
wait(nr_task)
wait(nf_task)
wait(ns_task)
println("Waiting...")
@time begin
wait(nr_task)
wait(nf_task)
wait(ns_task)
end
return nothing
end