229 lines
5.5 KiB
Julia
Raw Normal View History

# functions that find operations on the inital graph
using Base.Threads
2023-08-25 10:48:22 +02:00
function insert_operation!(
nf::NodeFusion,
locks::Dict{ComputeTaskNode, SpinLock},
)
n1 = nf.input[1]
n2 = nf.input[2]
n3 = nf.input[3]
lock(locks[n1]) do
return push!(nf.input[1].nodeFusions, nf)
end
n2.nodeFusion = nf
lock(locks[n3]) do
return push!(nf.input[3].nodeFusions, nf)
end
return nothing
end
function insert_operation!(nr::NodeReduction)
2023-08-25 10:48:22 +02:00
for n in nr.input
n.nodeReduction = nr
end
return nothing
end
function insert_operation!(ns::NodeSplit)
2023-08-25 10:48:22 +02:00
ns.input.nodeSplit = ns
return nothing
end
2023-08-25 10:48:22 +02:00
function nr_insertion!(
operations::PossibleOperations,
nodeReductions::Vector{Vector{NodeReduction}},
)
total_len = 0
for vec in nodeReductions
total_len += length(vec)
end
sizehint!(operations.nodeReductions, total_len)
t = @task for vec in nodeReductions
union!(operations.nodeReductions, Set(vec))
end
schedule(t)
@threads for vec in nodeReductions
for op in vec
insert_operation!(op)
end
end
wait(t)
return nothing
end
2023-08-25 10:48:22 +02:00
function nf_insertion!(
graph::DAG,
operations::PossibleOperations,
nodeFusions::Vector{Vector{NodeFusion}},
)
total_len = 0
for vec in nodeFusions
total_len += length(vec)
end
sizehint!(operations.nodeFusions, total_len)
t = @task for vec in nodeFusions
union!(operations.nodeFusions, Set(vec))
end
schedule(t)
locks = Dict{ComputeTaskNode, SpinLock}()
for n in graph.nodes
if (typeof(n) <: ComputeTaskNode)
locks[n] = SpinLock()
end
end
@threads for vec in nodeFusions
for op in vec
insert_operation!(op, locks)
end
end
wait(t)
return nothing
end
2023-08-25 10:48:22 +02:00
function ns_insertion!(
operations::PossibleOperations,
nodeSplits::Vector{Vector{NodeSplit}},
)
total_len = 0
for vec in nodeSplits
total_len += length(vec)
end
sizehint!(operations.nodeSplits, total_len)
t = @task for vec in nodeSplits
union!(operations.nodeSplits, Set(vec))
end
schedule(t)
@threads for vec in nodeSplits
for op in vec
insert_operation!(op)
end
end
wait(t)
return nothing
end
# function to generate all possible operations on the graph
function generate_options(graph::DAG)
2023-08-25 10:48:22 +02:00
generatedFusions = [Vector{NodeFusion}() for _ in 1:nthreads()]
generatedReductions = [Vector{NodeReduction}() for _ in 1:nthreads()]
generatedSplits = [Vector{NodeSplit}() for _ in 1:nthreads()]
# make sure the graph is fully generated through
apply_all!(graph)
nodeArray = collect(graph.nodes)
# sort all nodes
@threads for node in nodeArray
sort_node!(node)
end
checkedNodes = Set{Node}()
checkedNodesLock = SpinLock()
# --- find possible node reductions ---
@threads for node in nodeArray
# we're looking for nodes with multiple parents, those parents can then potentially reduce with one another
if (length(node.parents) <= 1)
continue
2023-08-25 10:48:22 +02:00
end
candidates = node.parents
# sort into equivalence classes
trie = NodeTrie()
for candidate in candidates
# insert into trie
insert!(trie, candidate)
end
nodeReductions = collect(trie)
for nrVec in nodeReductions
# parent sets are ordered and any node can only be part of one nodeReduction, so a NodeReduction is uniquely identifiable by its first element
# this prevents duplicate nodeReductions being generated
lock(checkedNodesLock)
if (nrVec[1] in checkedNodes)
unlock(checkedNodesLock)
continue
else
push!(checkedNodes, nrVec[1])
end
unlock(checkedNodesLock)
2023-08-25 10:48:22 +02:00
push!(generatedReductions[threadid()], NodeReduction(nrVec))
end
end
# launch thread for node reduction insertion
# remove duplicates
nr_task = @task nr_insertion!(graph.possibleOperations, generatedReductions)
schedule(nr_task)
# --- find possible node fusions ---
@threads for node in nodeArray
if (typeof(node) <: DataTaskNode)
if length(node.parents) != 1
# data node can only have a single parent
continue
end
parent_node = first(node.parents)
if length(node.children) != 1
# this node is an entry node or has multiple children which should not be possible
continue
end
child_node = first(node.children)
if (length(child_node.parents) != 1)
continue
end
push!(
generatedFusions[threadid()],
NodeFusion((child_node, node, parent_node)),
)
end
end
# launch thread for node fusion insertion
nf_task =
@task nf_insertion!(graph, graph.possibleOperations, generatedFusions)
schedule(nf_task)
# find possible node splits
@threads for node in nodeArray
if (can_split(node))
push!(generatedSplits[threadid()], NodeSplit(node))
end
end
# launch thread for node split insertion
ns_task = @task ns_insertion!(graph.possibleOperations, generatedSplits)
schedule(ns_task)
empty!(graph.dirtyNodes)
wait(nr_task)
wait(nf_task)
wait(ns_task)
return nothing
end