Multithreading for Node Reductions
This commit is contained in:
parent
2e96e6520e
commit
a7fb15c95b
@ -4,6 +4,7 @@ authors = ["Anton Reinhard <anton.reinhard@proton.me>"]
|
||||
version = "0.1.0"
|
||||
|
||||
[deps]
|
||||
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
|
||||
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
|
||||
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
|
||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
|
@ -17,6 +17,8 @@ import Base.in
|
||||
import Base.copy
|
||||
import Base.isempty
|
||||
import Base.delete!
|
||||
import Base.insert!
|
||||
import Base.collect
|
||||
|
||||
|
||||
include("tasks.jl")
|
||||
|
@ -54,6 +54,16 @@ function partners(node::Node)
|
||||
return result
|
||||
end
|
||||
|
||||
# alternative version to partners(Node), avoiding allocation of a new set
|
||||
# works on the given set and returns nothing
|
||||
function partners(node::Node, set::Set{Node})
|
||||
push!(set, node)
|
||||
for child in node.children
|
||||
union!(set, child.parents)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
is_entry_node(node::Node) = length(node.children) == 0
|
||||
is_exit_node(node::Node) = length(node.parents) == 0
|
||||
|
||||
|
@ -9,18 +9,40 @@ function insert_operation!(operations::PossibleOperations, nf::NodeFusion, locks
|
||||
lock(locks[n1]) do; push!(nf.input[1].operations, nf); end
|
||||
lock(locks[n2]) do; push!(nf.input[2].operations, nf); end
|
||||
lock(locks[n3]) do; push!(nf.input[3].operations, nf); end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function insert_operation!(operations::PossibleOperations, nr::NodeReduction, locks::Dict{Node, SpinLock})
|
||||
push!(operations.nodeReductions, nr)
|
||||
first = true
|
||||
for n in nr.input
|
||||
lock(locks[n]) do; push!(n.operations, nr); end
|
||||
skip_duplicate = false
|
||||
# careful here, this is a manual lock
|
||||
lock(locks[n])
|
||||
if first
|
||||
first = false
|
||||
for op in n.operations
|
||||
if typeof(op) <: NodeReduction
|
||||
skip_duplicate = true
|
||||
break
|
||||
end
|
||||
end
|
||||
if skip_duplicate
|
||||
unlock(locks[n])
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
push!(n.operations, nr)
|
||||
unlock(locks[n])
|
||||
end
|
||||
push!(operations.nodeReductions, nr)
|
||||
return nothing
|
||||
end
|
||||
|
||||
function insert_operation!(operations::PossibleOperations, ns::NodeSplit, locks::Dict{Node, SpinLock})
|
||||
push!(operations.nodeSplits, ns)
|
||||
lock(locks[ns.input]) do; push!(ns.input.operations, ns); end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}}, locks::Dict{Node, SpinLock})
|
||||
@ -29,6 +51,7 @@ function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Ve
|
||||
insert_operation!(operations, op, locks)
|
||||
end
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function nf_insertion!(operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}}, locks::Dict{Node, SpinLock})
|
||||
@ -37,6 +60,7 @@ function nf_insertion!(operations::PossibleOperations, nodeFusions::Vector{Vecto
|
||||
insert_operation!(operations, op, locks)
|
||||
end
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}}, locks::Dict{Node, SpinLock})
|
||||
@ -45,6 +69,7 @@ function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector
|
||||
insert_operation!(operations, op, locks)
|
||||
end
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
# function to generate all possible operations on the graph
|
||||
@ -61,73 +86,51 @@ function generate_options(graph::DAG)
|
||||
# make sure the graph is fully generated through
|
||||
apply_all!(graph)
|
||||
|
||||
nodeArray = collect(graph.nodes)
|
||||
|
||||
# sort all nodes
|
||||
println("Sorting...")
|
||||
@time @threads for node in nodeArray
|
||||
sort_node!(node)
|
||||
end
|
||||
|
||||
# --- find possible node reductions ---
|
||||
|
||||
# find some useful partition of nodes without generating duplicate node reductions
|
||||
nodePartitions = [Vector{Set{Node}}() for _ in 1:nthreads()]
|
||||
avgNodes = 0. # the average number of nodes across all the node partitions
|
||||
nodeSet = copy(graph.nodes)
|
||||
|
||||
partitionPointer = 1
|
||||
rotatePointer(i) = (i % nthreads()) + 1
|
||||
|
||||
while !isempty(nodeSet)
|
||||
# cycle partition pointer to a set with fewer than average nodes
|
||||
nodes = partners(first(nodeSet))
|
||||
setdiff!(nodeSet, nodes)
|
||||
|
||||
if length(nodes) == 1
|
||||
# nothing to reduce here anyways
|
||||
println("Node Reductions...")
|
||||
@time @threads for node in nodeArray
|
||||
# we're looking for nodes with multiple parents, those parents can then potentially reduce with one another
|
||||
if (length(node.parents) <= 1)
|
||||
continue
|
||||
end
|
||||
|
||||
partitionPointer = rotatePointer(partitionPointer)
|
||||
candidates = parents(node)
|
||||
|
||||
push!(nodePartitions[partitionPointer], nodes)
|
||||
avgNodes = avgNodes + length(nodes) / nthreads()
|
||||
end
|
||||
nodeReductions = Set{Set{Node}}()
|
||||
|
||||
@threads for partition in nodePartitions
|
||||
for partners_ in partition
|
||||
reductionVector = nothing
|
||||
|
||||
node = pop!(partners_)
|
||||
# sort into equivalence classes
|
||||
#**TODO** check that only same types can reduce
|
||||
trie = NodeTrie()
|
||||
|
||||
t = typeof(node)
|
||||
for candidate in candidates
|
||||
# insert into trie
|
||||
insert!(trie, candidate)
|
||||
end
|
||||
|
||||
# possible reductions are with nodes that are partners, i.e. parents of children
|
||||
for partner in partners_
|
||||
if (t != typeof(partner))
|
||||
continue
|
||||
end
|
||||
nodeReductions = collect(trie)
|
||||
|
||||
if !can_reduce(node, partner)
|
||||
continue
|
||||
end
|
||||
|
||||
if reductionVector === nothing
|
||||
# only when there's at least one reduction partner, insert the vector
|
||||
reductionVector = Vector{Node}()
|
||||
push!(reductionVector, node)
|
||||
end
|
||||
|
||||
push!(reductionVector, partner)
|
||||
end
|
||||
|
||||
if reductionVector !== nothing
|
||||
push!(generatedReductions[threadid()], NodeReduction(reductionVector))
|
||||
end
|
||||
for nrSet in nodeReductions
|
||||
push!(generatedReductions[threadid()], NodeReduction(collect(nrSet)))
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
# launch thread for node reduction insertion
|
||||
# removeduplicates
|
||||
nr_task = @task nr_insertion!(graph.possibleOperations, generatedReductions, locks)
|
||||
schedule(nr_task)
|
||||
|
||||
# --- find possible node fusions ---
|
||||
nodeArray = collect(graph.nodes)
|
||||
|
||||
@threads for node in nodeArray
|
||||
println("Node Fusions...")
|
||||
@time @threads for node in nodeArray
|
||||
if (typeof(node) <: DataTaskNode)
|
||||
if length(node.parents) != 1
|
||||
# data node can only have a single parent
|
||||
@ -153,7 +156,8 @@ function generate_options(graph::DAG)
|
||||
schedule(nf_task)
|
||||
|
||||
# find possible node splits
|
||||
@threads for node in nodeArray
|
||||
println("Node Splits...")
|
||||
@time @threads for node in nodeArray
|
||||
if (can_split(node))
|
||||
push!(generatedSplits[threadid()], NodeSplit(node))
|
||||
end
|
||||
@ -165,9 +169,12 @@ function generate_options(graph::DAG)
|
||||
|
||||
empty!(graph.dirtyNodes)
|
||||
|
||||
wait(nr_task)
|
||||
wait(nf_task)
|
||||
wait(ns_task)
|
||||
println("Waiting...")
|
||||
@time begin
|
||||
wait(nr_task)
|
||||
wait(nf_task)
|
||||
wait(ns_task)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
@ -7,3 +7,56 @@ function bytes_to_human_readable(bytes::Int64)
|
||||
end
|
||||
return string(round(bytes, sigdigits=4), " ", units[unit_index])
|
||||
end
|
||||
|
||||
# Trie data structure for node reduction, inserts nodes by children
|
||||
# Assumes that given nodes have ordered sets of children (see sort_node)
|
||||
mutable struct NodeTrie
|
||||
value::Set{Node}
|
||||
children::Dict{UUID, NodeTrie}
|
||||
end
|
||||
|
||||
NodeTrie() = NodeTrie(Set{Node}(), Dict{UUID, NodeTrie}())
|
||||
|
||||
function insert_helper!(trie::NodeTrie, node::Node, depth::Int)
|
||||
if (length(node.children) == depth)
|
||||
push!(trie.value, node)
|
||||
return nothing
|
||||
end
|
||||
|
||||
depth = depth + 1
|
||||
id = node.children[depth].id
|
||||
if (!haskey(trie.children, id))
|
||||
trie.children[id] = NodeTrie()
|
||||
end
|
||||
insert_helper!(trie.children[id], node, depth)
|
||||
end
|
||||
|
||||
function insert!(trie::NodeTrie, node::Node)
|
||||
insert_helper!(trie, node, 0)
|
||||
end
|
||||
|
||||
function collect_helper(trie::NodeTrie, acc::Set{Set{Node}})
|
||||
if (length(trie.value) >= 2)
|
||||
push!(acc, trie.value)
|
||||
end
|
||||
|
||||
for (id,child) in trie.children
|
||||
collect_helper(child, acc)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
# returns all sets of multiple nodes that have accumulated in leaves
|
||||
function collect(trie::NodeTrie)
|
||||
acc = Set{Set{Node}}()
|
||||
collect_helper(trie, acc)
|
||||
return acc
|
||||
end
|
||||
|
||||
function lt_nodes(n1::Node, n2::Node)
|
||||
return n1.id < n2.id
|
||||
end
|
||||
|
||||
function sort_node!(node::Node)
|
||||
sort!(node.children, lt=lt_nodes)
|
||||
end
|
||||
|
Loading…
x
Reference in New Issue
Block a user