@@ -16,11 +16,16 @@ function apply_all!(graph::DAG)
|
||||
end
|
||||
|
||||
function apply_operation!(graph::DAG, operation::Operation)
|
||||
error("Unknown operation type!")
|
||||
return error("Unknown operation type!")
|
||||
end
|
||||
|
||||
function apply_operation!(graph::DAG, operation::NodeFusion)
|
||||
diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3])
|
||||
diff = node_fusion!(
|
||||
graph,
|
||||
operation.input[1],
|
||||
operation.input[2],
|
||||
operation.input[3],
|
||||
)
|
||||
return AppliedNodeFusion(operation, diff)
|
||||
end
|
||||
|
||||
@@ -36,7 +41,7 @@ end
|
||||
|
||||
|
||||
function revert_operation!(graph::DAG, operation::AppliedOperation)
|
||||
error("Unknown operation type!")
|
||||
return error("Unknown operation type!")
|
||||
end
|
||||
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeFusion)
|
||||
@@ -74,7 +79,12 @@ function revert_diff!(graph::DAG, diff::Diff)
|
||||
end
|
||||
|
||||
# Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph
|
||||
function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
function node_fusion!(
|
||||
graph::DAG,
|
||||
n1::ComputeTaskNode,
|
||||
n2::DataTaskNode,
|
||||
n3::ComputeTaskNode,
|
||||
)
|
||||
# @assert is_valid_node_fusion_input(graph, n1, n2, n3)
|
||||
|
||||
# clear snapshot
|
||||
@@ -97,7 +107,8 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com
|
||||
remove_node!(graph, n3)
|
||||
|
||||
# create new node with the fused compute task
|
||||
new_node = ComputeTaskNode(FusedComputeTask{typeof(n1.task),typeof(n3.task)}())
|
||||
new_node =
|
||||
ComputeTaskNode(FusedComputeTask{typeof(n1.task), typeof(n3.task)}())
|
||||
insert_node!(graph, new_node)
|
||||
|
||||
# use a set for combined children of n1 and n3 to not get duplicates
|
||||
@@ -136,7 +147,7 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
|
||||
n1 = nodes[1]
|
||||
n1_children = children(n1)
|
||||
|
||||
|
||||
n1_parents = Set(n1.parents)
|
||||
new_parents = Set{Node}()
|
||||
|
||||
|
@@ -3,113 +3,113 @@
|
||||
# function to find node fusions involving the given node if it's a data node
|
||||
# pushes the found fusion everywhere it needs to be and returns nothing
|
||||
function find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
# if there is already a fusion here, skip
|
||||
if !ismissing(node.nodeFusion)
|
||||
return nothing
|
||||
end
|
||||
# if there is already a fusion here, skip
|
||||
if !ismissing(node.nodeFusion)
|
||||
return nothing
|
||||
end
|
||||
|
||||
if length(node.parents) != 1 || length(node.children) != 1
|
||||
return nothing
|
||||
end
|
||||
if length(node.parents) != 1 || length(node.children) != 1
|
||||
return nothing
|
||||
end
|
||||
|
||||
child_node = first(node.children)
|
||||
parent_node = first(node.parents)
|
||||
child_node = first(node.children)
|
||||
parent_node = first(node.parents)
|
||||
|
||||
if !(child_node in graph) || !(parent_node in graph)
|
||||
error("Parents/Children that are not in the graph!!!")
|
||||
end
|
||||
if !(child_node in graph) || !(parent_node in graph)
|
||||
error("Parents/Children that are not in the graph!!!")
|
||||
end
|
||||
|
||||
if length(child_node.parents) != 1
|
||||
return nothing
|
||||
end
|
||||
if length(child_node.parents) != 1
|
||||
return nothing
|
||||
end
|
||||
|
||||
nf = NodeFusion((child_node, node, parent_node))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(child_node.nodeFusions, nf)
|
||||
node.nodeFusion = nf
|
||||
push!(parent_node.nodeFusions, nf)
|
||||
nf = NodeFusion((child_node, node, parent_node))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(child_node.nodeFusions, nf)
|
||||
node.nodeFusion = nf
|
||||
push!(parent_node.nodeFusions, nf)
|
||||
|
||||
return nothing
|
||||
return nothing
|
||||
end
|
||||
|
||||
|
||||
function find_fusions!(graph::DAG, node::ComputeTaskNode)
|
||||
# just find fusions in neighbouring DataTaskNodes
|
||||
for child in node.children
|
||||
find_fusions!(graph, child)
|
||||
end
|
||||
# just find fusions in neighbouring DataTaskNodes
|
||||
for child in node.children
|
||||
find_fusions!(graph, child)
|
||||
end
|
||||
|
||||
for parent in node.parents
|
||||
find_fusions!(graph, parent)
|
||||
end
|
||||
for parent in node.parents
|
||||
find_fusions!(graph, parent)
|
||||
end
|
||||
|
||||
return nothing
|
||||
return nothing
|
||||
end
|
||||
|
||||
function find_reductions!(graph::DAG, node::Node)
|
||||
# there can only be one reduction per node, avoid adding duplicates
|
||||
if !ismissing(node.nodeReduction)
|
||||
return nothing
|
||||
end
|
||||
# there can only be one reduction per node, avoid adding duplicates
|
||||
if !ismissing(node.nodeReduction)
|
||||
return nothing
|
||||
end
|
||||
|
||||
reductionVector = nothing
|
||||
# possible reductions are with nodes that are partners, i.e. parents of children
|
||||
partners_ = partners(node)
|
||||
delete!(partners_, node)
|
||||
for partner in partners_
|
||||
if partner ∉ graph.nodes
|
||||
error("Partner is not part of the graph")
|
||||
end
|
||||
reductionVector = nothing
|
||||
# possible reductions are with nodes that are partners, i.e. parents of children
|
||||
partners_ = partners(node)
|
||||
delete!(partners_, node)
|
||||
for partner in partners_
|
||||
if partner ∉ graph.nodes
|
||||
error("Partner is not part of the graph")
|
||||
end
|
||||
|
||||
if can_reduce(node, partner)
|
||||
if Set(node.children) != Set(partner.children)
|
||||
error("Not equal children")
|
||||
end
|
||||
if reductionVector === nothing
|
||||
# only when there's at least one reduction partner, insert the vector
|
||||
reductionVector = Vector{Node}()
|
||||
push!(reductionVector, node)
|
||||
end
|
||||
if can_reduce(node, partner)
|
||||
if Set(node.children) != Set(partner.children)
|
||||
error("Not equal children")
|
||||
end
|
||||
if reductionVector === nothing
|
||||
# only when there's at least one reduction partner, insert the vector
|
||||
reductionVector = Vector{Node}()
|
||||
push!(reductionVector, node)
|
||||
end
|
||||
|
||||
push!(reductionVector, partner)
|
||||
end
|
||||
end
|
||||
push!(reductionVector, partner)
|
||||
end
|
||||
end
|
||||
|
||||
if reductionVector !== nothing
|
||||
nr = NodeReduction(reductionVector)
|
||||
push!(graph.possibleOperations.nodeReductions, nr)
|
||||
for node in reductionVector
|
||||
if !ismissing(node.nodeReduction)
|
||||
# it can happen that the dirty node becomes part of an existing NodeReduction and overrides those ones now
|
||||
# this is only a problem insofar the existing NodeReduction has to be deleted and replaced also in the possibleOperations
|
||||
invalidate_caches!(graph, node.nodeReduction)
|
||||
end
|
||||
node.nodeReduction = nr
|
||||
end
|
||||
end
|
||||
if reductionVector !== nothing
|
||||
nr = NodeReduction(reductionVector)
|
||||
push!(graph.possibleOperations.nodeReductions, nr)
|
||||
for node in reductionVector
|
||||
if !ismissing(node.nodeReduction)
|
||||
# it can happen that the dirty node becomes part of an existing NodeReduction and overrides those ones now
|
||||
# this is only a problem insofar the existing NodeReduction has to be deleted and replaced also in the possibleOperations
|
||||
invalidate_caches!(graph, node.nodeReduction)
|
||||
end
|
||||
node.nodeReduction = nr
|
||||
end
|
||||
end
|
||||
|
||||
return nothing
|
||||
return nothing
|
||||
end
|
||||
|
||||
function find_splits!(graph::DAG, node::Node)
|
||||
if !ismissing(node.nodeSplit)
|
||||
return nothing
|
||||
end
|
||||
if !ismissing(node.nodeSplit)
|
||||
return nothing
|
||||
end
|
||||
|
||||
if (can_split(node))
|
||||
ns = NodeSplit(node)
|
||||
push!(graph.possibleOperations.nodeSplits, ns)
|
||||
node.nodeSplit = ns
|
||||
end
|
||||
if (can_split(node))
|
||||
ns = NodeSplit(node)
|
||||
push!(graph.possibleOperations.nodeSplits, ns)
|
||||
node.nodeSplit = ns
|
||||
end
|
||||
|
||||
return nothing
|
||||
return nothing
|
||||
end
|
||||
|
||||
# "clean" the operations on a dirty node
|
||||
function clean_node!(graph::DAG, node::Node)
|
||||
sort_node!(node)
|
||||
|
||||
find_fusions!(graph, node)
|
||||
find_reductions!(graph, node)
|
||||
find_splits!(graph, node)
|
||||
sort_node!(node)
|
||||
|
||||
find_fusions!(graph, node)
|
||||
find_reductions!(graph, node)
|
||||
return find_splits!(graph, node)
|
||||
end
|
||||
|
@@ -2,204 +2,227 @@
|
||||
|
||||
using Base.Threads
|
||||
|
||||
function insert_operation!(nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock})
|
||||
n1 = nf.input[1]; n2 = nf.input[2]; n3 = nf.input[3]
|
||||
function insert_operation!(
|
||||
nf::NodeFusion,
|
||||
locks::Dict{ComputeTaskNode, SpinLock},
|
||||
)
|
||||
n1 = nf.input[1]
|
||||
n2 = nf.input[2]
|
||||
n3 = nf.input[3]
|
||||
|
||||
lock(locks[n1]) do; push!(nf.input[1].nodeFusions, nf); end
|
||||
nf.input[2].nodeFusion = nf
|
||||
lock(locks[n3]) do; push!(nf.input[3].nodeFusions, nf); end
|
||||
return nothing
|
||||
lock(locks[n1]) do
|
||||
return push!(nf.input[1].nodeFusions, nf)
|
||||
end
|
||||
n2.nodeFusion = nf
|
||||
lock(locks[n3]) do
|
||||
return push!(nf.input[3].nodeFusions, nf)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function insert_operation!(nr::NodeReduction)
|
||||
for n in nr.input
|
||||
n.nodeReduction = nr
|
||||
end
|
||||
return nothing
|
||||
for n in nr.input
|
||||
n.nodeReduction = nr
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function insert_operation!(ns::NodeSplit)
|
||||
ns.input.nodeSplit = ns
|
||||
return nothing
|
||||
ns.input.nodeSplit = ns
|
||||
return nothing
|
||||
end
|
||||
|
||||
function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}})
|
||||
total_len = 0
|
||||
for vec in nodeReductions
|
||||
total_len += length(vec)
|
||||
end
|
||||
sizehint!(operations.nodeReductions, total_len)
|
||||
function nr_insertion!(
|
||||
operations::PossibleOperations,
|
||||
nodeReductions::Vector{Vector{NodeReduction}},
|
||||
)
|
||||
total_len = 0
|
||||
for vec in nodeReductions
|
||||
total_len += length(vec)
|
||||
end
|
||||
sizehint!(operations.nodeReductions, total_len)
|
||||
|
||||
t = @task for vec in nodeReductions
|
||||
union!(operations.nodeReductions, Set(vec))
|
||||
end
|
||||
schedule(t)
|
||||
t = @task for vec in nodeReductions
|
||||
union!(operations.nodeReductions, Set(vec))
|
||||
end
|
||||
schedule(t)
|
||||
|
||||
@threads for vec in nodeReductions
|
||||
for op in vec
|
||||
insert_operation!(op)
|
||||
end
|
||||
end
|
||||
@threads for vec in nodeReductions
|
||||
for op in vec
|
||||
insert_operation!(op)
|
||||
end
|
||||
end
|
||||
|
||||
wait(t)
|
||||
wait(t)
|
||||
|
||||
return nothing
|
||||
return nothing
|
||||
end
|
||||
|
||||
function nf_insertion!(graph::DAG, operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}})
|
||||
total_len = 0
|
||||
for vec in nodeFusions
|
||||
total_len += length(vec)
|
||||
end
|
||||
sizehint!(operations.nodeFusions, total_len)
|
||||
|
||||
t = @task for vec in nodeFusions
|
||||
union!(operations.nodeFusions, Set(vec))
|
||||
end
|
||||
schedule(t)
|
||||
function nf_insertion!(
|
||||
graph::DAG,
|
||||
operations::PossibleOperations,
|
||||
nodeFusions::Vector{Vector{NodeFusion}},
|
||||
)
|
||||
total_len = 0
|
||||
for vec in nodeFusions
|
||||
total_len += length(vec)
|
||||
end
|
||||
sizehint!(operations.nodeFusions, total_len)
|
||||
|
||||
locks = Dict{ComputeTaskNode, SpinLock}()
|
||||
for n in graph.nodes
|
||||
if (typeof(n) <: ComputeTaskNode)
|
||||
locks[n] = SpinLock()
|
||||
end
|
||||
end
|
||||
t = @task for vec in nodeFusions
|
||||
union!(operations.nodeFusions, Set(vec))
|
||||
end
|
||||
schedule(t)
|
||||
|
||||
@threads for vec in nodeFusions
|
||||
for op in vec
|
||||
insert_operation!(op, locks)
|
||||
end
|
||||
end
|
||||
locks = Dict{ComputeTaskNode, SpinLock}()
|
||||
for n in graph.nodes
|
||||
if (typeof(n) <: ComputeTaskNode)
|
||||
locks[n] = SpinLock()
|
||||
end
|
||||
end
|
||||
|
||||
wait(t)
|
||||
@threads for vec in nodeFusions
|
||||
for op in vec
|
||||
insert_operation!(op, locks)
|
||||
end
|
||||
end
|
||||
|
||||
return nothing
|
||||
wait(t)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}})
|
||||
total_len = 0
|
||||
for vec in nodeSplits
|
||||
total_len += length(vec)
|
||||
end
|
||||
sizehint!(operations.nodeSplits, total_len)
|
||||
function ns_insertion!(
|
||||
operations::PossibleOperations,
|
||||
nodeSplits::Vector{Vector{NodeSplit}},
|
||||
)
|
||||
total_len = 0
|
||||
for vec in nodeSplits
|
||||
total_len += length(vec)
|
||||
end
|
||||
sizehint!(operations.nodeSplits, total_len)
|
||||
|
||||
t = @task for vec in nodeSplits
|
||||
union!(operations.nodeSplits, Set(vec))
|
||||
end
|
||||
schedule(t)
|
||||
t = @task for vec in nodeSplits
|
||||
union!(operations.nodeSplits, Set(vec))
|
||||
end
|
||||
schedule(t)
|
||||
|
||||
@threads for vec in nodeSplits
|
||||
for op in vec
|
||||
insert_operation!(op)
|
||||
end
|
||||
end
|
||||
@threads for vec in nodeSplits
|
||||
for op in vec
|
||||
insert_operation!(op)
|
||||
end
|
||||
end
|
||||
|
||||
wait(t)
|
||||
wait(t)
|
||||
|
||||
return nothing
|
||||
return nothing
|
||||
end
|
||||
|
||||
# function to generate all possible operations on the graph
|
||||
function generate_options(graph::DAG)
|
||||
generatedFusions = [Vector{NodeFusion}() for _ in 1:nthreads()]
|
||||
generatedReductions = [Vector{NodeReduction}() for _ in 1:nthreads()]
|
||||
generatedSplits = [Vector{NodeSplit}() for _ in 1:nthreads()]
|
||||
generatedFusions = [Vector{NodeFusion}() for _ in 1:nthreads()]
|
||||
generatedReductions = [Vector{NodeReduction}() for _ in 1:nthreads()]
|
||||
generatedSplits = [Vector{NodeSplit}() for _ in 1:nthreads()]
|
||||
|
||||
# make sure the graph is fully generated through
|
||||
apply_all!(graph)
|
||||
# make sure the graph is fully generated through
|
||||
apply_all!(graph)
|
||||
|
||||
nodeArray = collect(graph.nodes)
|
||||
nodeArray = collect(graph.nodes)
|
||||
|
||||
# sort all nodes
|
||||
@threads for node in nodeArray
|
||||
sort_node!(node)
|
||||
end
|
||||
# sort all nodes
|
||||
@threads for node in nodeArray
|
||||
sort_node!(node)
|
||||
end
|
||||
|
||||
checkedNodes = Set{Node}()
|
||||
checkedNodesLock = SpinLock()
|
||||
# --- find possible node reductions ---
|
||||
@threads for node in nodeArray
|
||||
# we're looking for nodes with multiple parents, those parents can then potentially reduce with one another
|
||||
if (length(node.parents) <= 1)
|
||||
continue
|
||||
end
|
||||
checkedNodes = Set{Node}()
|
||||
checkedNodesLock = SpinLock()
|
||||
# --- find possible node reductions ---
|
||||
@threads for node in nodeArray
|
||||
# we're looking for nodes with multiple parents, those parents can then potentially reduce with one another
|
||||
if (length(node.parents) <= 1)
|
||||
continue
|
||||
end
|
||||
|
||||
candidates = node.parents
|
||||
candidates = node.parents
|
||||
|
||||
# sort into equivalence classes
|
||||
trie = NodeTrie()
|
||||
# sort into equivalence classes
|
||||
trie = NodeTrie()
|
||||
|
||||
for candidate in candidates
|
||||
# insert into trie
|
||||
insert!(trie, candidate)
|
||||
end
|
||||
for candidate in candidates
|
||||
# insert into trie
|
||||
insert!(trie, candidate)
|
||||
end
|
||||
|
||||
nodeReductions = collect(trie)
|
||||
nodeReductions = collect(trie)
|
||||
|
||||
for nrVec in nodeReductions
|
||||
# parent sets are ordered and any node can only be part of one nodeReduction, so a NodeReduction is uniquely identifiable by its first element
|
||||
# this prevents duplicate nodeReductions being generated
|
||||
lock(checkedNodesLock)
|
||||
if (nrVec[1] in checkedNodes)
|
||||
for nrVec in nodeReductions
|
||||
# parent sets are ordered and any node can only be part of one nodeReduction, so a NodeReduction is uniquely identifiable by its first element
|
||||
# this prevents duplicate nodeReductions being generated
|
||||
lock(checkedNodesLock)
|
||||
if (nrVec[1] in checkedNodes)
|
||||
unlock(checkedNodesLock)
|
||||
continue
|
||||
else
|
||||
push!(checkedNodes, nrVec[1])
|
||||
end
|
||||
unlock(checkedNodesLock)
|
||||
continue
|
||||
else
|
||||
push!(checkedNodes, nrVec[1])
|
||||
end
|
||||
unlock(checkedNodesLock)
|
||||
|
||||
push!(generatedReductions[threadid()], NodeReduction(nrVec))
|
||||
end
|
||||
end
|
||||
push!(generatedReductions[threadid()], NodeReduction(nrVec))
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
# launch thread for node reduction insertion
|
||||
# remove duplicates
|
||||
nr_task = @task nr_insertion!(graph.possibleOperations, generatedReductions)
|
||||
schedule(nr_task)
|
||||
|
||||
# --- find possible node fusions ---
|
||||
@threads for node in nodeArray
|
||||
if (typeof(node) <: DataTaskNode)
|
||||
if length(node.parents) != 1
|
||||
# data node can only have a single parent
|
||||
continue
|
||||
end
|
||||
parent_node = first(node.parents)
|
||||
# launch thread for node reduction insertion
|
||||
# remove duplicates
|
||||
nr_task = @task nr_insertion!(graph.possibleOperations, generatedReductions)
|
||||
schedule(nr_task)
|
||||
|
||||
if length(node.children) != 1
|
||||
# this node is an entry node or has multiple children which should not be possible
|
||||
continue
|
||||
end
|
||||
child_node = first(node.children)
|
||||
if (length(child_node.parents) != 1)
|
||||
continue
|
||||
end
|
||||
# --- find possible node fusions ---
|
||||
@threads for node in nodeArray
|
||||
if (typeof(node) <: DataTaskNode)
|
||||
if length(node.parents) != 1
|
||||
# data node can only have a single parent
|
||||
continue
|
||||
end
|
||||
parent_node = first(node.parents)
|
||||
|
||||
push!(generatedFusions[threadid()], NodeFusion((child_node, node, parent_node)))
|
||||
end
|
||||
end
|
||||
if length(node.children) != 1
|
||||
# this node is an entry node or has multiple children which should not be possible
|
||||
continue
|
||||
end
|
||||
child_node = first(node.children)
|
||||
if (length(child_node.parents) != 1)
|
||||
continue
|
||||
end
|
||||
|
||||
# launch thread for node fusion insertion
|
||||
nf_task = @task nf_insertion!(graph, graph.possibleOperations, generatedFusions)
|
||||
schedule(nf_task)
|
||||
push!(
|
||||
generatedFusions[threadid()],
|
||||
NodeFusion((child_node, node, parent_node)),
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
# find possible node splits
|
||||
@threads for node in nodeArray
|
||||
if (can_split(node))
|
||||
push!(generatedSplits[threadid()], NodeSplit(node))
|
||||
end
|
||||
end
|
||||
# launch thread for node fusion insertion
|
||||
nf_task =
|
||||
@task nf_insertion!(graph, graph.possibleOperations, generatedFusions)
|
||||
schedule(nf_task)
|
||||
|
||||
# launch thread for node split insertion
|
||||
ns_task = @task ns_insertion!(graph.possibleOperations, generatedSplits)
|
||||
schedule(ns_task)
|
||||
# find possible node splits
|
||||
@threads for node in nodeArray
|
||||
if (can_split(node))
|
||||
push!(generatedSplits[threadid()], NodeSplit(node))
|
||||
end
|
||||
end
|
||||
|
||||
empty!(graph.dirtyNodes)
|
||||
# launch thread for node split insertion
|
||||
ns_task = @task ns_insertion!(graph.possibleOperations, generatedSplits)
|
||||
schedule(ns_task)
|
||||
|
||||
wait(nr_task)
|
||||
wait(nf_task)
|
||||
wait(ns_task)
|
||||
empty!(graph.dirtyNodes)
|
||||
|
||||
return nothing
|
||||
wait(nr_task)
|
||||
wait(nf_task)
|
||||
wait(ns_task)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
@@ -3,16 +3,16 @@
|
||||
using Base.Threads
|
||||
|
||||
function get_operations(graph::DAG)
|
||||
apply_all!(graph)
|
||||
apply_all!(graph)
|
||||
|
||||
if isempty(graph.possibleOperations)
|
||||
generate_options(graph)
|
||||
end
|
||||
if isempty(graph.possibleOperations)
|
||||
generate_options(graph)
|
||||
end
|
||||
|
||||
for node in graph.dirtyNodes
|
||||
clean_node!(graph, node)
|
||||
end
|
||||
empty!(graph.dirtyNodes)
|
||||
for node in graph.dirtyNodes
|
||||
clean_node!(graph, node)
|
||||
end
|
||||
empty!(graph.dirtyNodes)
|
||||
|
||||
return graph.possibleOperations
|
||||
return graph.possibleOperations
|
||||
end
|
||||
|
@@ -20,12 +20,12 @@ function show(io::IO, op::NodeReduction)
|
||||
print(io, "NR: ")
|
||||
print(io, length(op.input))
|
||||
print(io, "x")
|
||||
print(io, op.input[1].task)
|
||||
return print(io, op.input[1].task)
|
||||
end
|
||||
|
||||
function show(io::IO, op::NodeSplit)
|
||||
print(io, "NS: ")
|
||||
print(io, op.input.task)
|
||||
return print(io, op.input.task)
|
||||
end
|
||||
|
||||
function show(io::IO, op::NodeFusion)
|
||||
@@ -34,5 +34,5 @@ function show(io::IO, op::NodeFusion)
|
||||
print(io, "->")
|
||||
print(io, op.input[2].task)
|
||||
print(io, "->")
|
||||
print(io, op.input[3].task)
|
||||
return print(io, op.input[3].task)
|
||||
end
|
||||
|
@@ -7,28 +7,28 @@ abstract type Operation end
|
||||
abstract type AppliedOperation end
|
||||
|
||||
struct NodeFusion <: Operation
|
||||
input::Tuple{ComputeTaskNode, DataTaskNode, ComputeTaskNode}
|
||||
input::Tuple{ComputeTaskNode, DataTaskNode, ComputeTaskNode}
|
||||
end
|
||||
|
||||
struct AppliedNodeFusion <: AppliedOperation
|
||||
operation::NodeFusion
|
||||
diff::Diff
|
||||
operation::NodeFusion
|
||||
diff::Diff
|
||||
end
|
||||
|
||||
struct NodeReduction <: Operation
|
||||
input::Vector{Node}
|
||||
input::Vector{Node}
|
||||
end
|
||||
|
||||
struct AppliedNodeReduction <: AppliedOperation
|
||||
operation::NodeReduction
|
||||
diff::Diff
|
||||
operation::NodeReduction
|
||||
diff::Diff
|
||||
end
|
||||
|
||||
struct NodeSplit <: Operation
|
||||
input::Node
|
||||
input::Node
|
||||
end
|
||||
|
||||
struct AppliedNodeSplit <: AppliedOperation
|
||||
operation::NodeSplit
|
||||
diff::Diff
|
||||
operation::NodeSplit
|
||||
diff::Diff
|
||||
end
|
||||
|
@@ -1,107 +1,111 @@
|
||||
|
||||
function isempty(operations::PossibleOperations)
|
||||
return isempty(operations.nodeFusions) &&
|
||||
isempty(operations.nodeReductions) &&
|
||||
isempty(operations.nodeSplits)
|
||||
return isempty(operations.nodeFusions) &&
|
||||
isempty(operations.nodeReductions) &&
|
||||
isempty(operations.nodeSplits)
|
||||
end
|
||||
|
||||
function length(operations::PossibleOperations)
|
||||
return (nodeFusions = length(operations.nodeFusions),
|
||||
nodeReductions = length(operations.nodeReductions),
|
||||
nodeSplits = length(operations.nodeSplits))
|
||||
return (
|
||||
nodeFusions = length(operations.nodeFusions),
|
||||
nodeReductions = length(operations.nodeReductions),
|
||||
nodeSplits = length(operations.nodeSplits),
|
||||
)
|
||||
end
|
||||
|
||||
function delete!(operations::PossibleOperations, op::NodeFusion)
|
||||
delete!(operations.nodeFusions, op)
|
||||
return operations
|
||||
delete!(operations.nodeFusions, op)
|
||||
return operations
|
||||
end
|
||||
|
||||
function delete!(operations::PossibleOperations, op::NodeReduction)
|
||||
delete!(operations.nodeReductions, op)
|
||||
return operations
|
||||
delete!(operations.nodeReductions, op)
|
||||
return operations
|
||||
end
|
||||
|
||||
function delete!(operations::PossibleOperations, op::NodeSplit)
|
||||
delete!(operations.nodeSplits, op)
|
||||
return operations
|
||||
delete!(operations.nodeSplits, op)
|
||||
return operations
|
||||
end
|
||||
|
||||
|
||||
function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
if !is_child(n1, n2) || !is_child(n2, n3)
|
||||
# the checks are redundant but maybe a good sanity check
|
||||
return false
|
||||
end
|
||||
if !is_child(n1, n2) || !is_child(n2, n3)
|
||||
# the checks are redundant but maybe a good sanity check
|
||||
return false
|
||||
end
|
||||
|
||||
if length(n2.parents) != 1 || length(n2.children) != 1 || length(n1.parents) != 1
|
||||
return false
|
||||
end
|
||||
if length(n2.parents) != 1 ||
|
||||
length(n2.children) != 1 ||
|
||||
length(n1.parents) != 1
|
||||
return false
|
||||
end
|
||||
|
||||
return true
|
||||
return true
|
||||
end
|
||||
|
||||
function can_reduce(n1::Node, n2::Node)
|
||||
if (n1.task != n2.task)
|
||||
return false
|
||||
end
|
||||
|
||||
n1_length = length(n1.children)
|
||||
n2_length = length(n2.children)
|
||||
if (n1.task != n2.task)
|
||||
return false
|
||||
end
|
||||
|
||||
if (n1_length != n2_length)
|
||||
return false
|
||||
end
|
||||
n1_length = length(n1.children)
|
||||
n2_length = length(n2.children)
|
||||
|
||||
# this seems to be the most common case so do this first
|
||||
# doing it manually is a lot faster than using the sets for a general solution
|
||||
if (n1_length == 2)
|
||||
if (n1.children[1] != n2.children[1])
|
||||
if (n1.children[1] != n2.children[2])
|
||||
if (n1_length != n2_length)
|
||||
return false
|
||||
end
|
||||
|
||||
# this seems to be the most common case so do this first
|
||||
# doing it manually is a lot faster than using the sets for a general solution
|
||||
if (n1_length == 2)
|
||||
if (n1.children[1] != n2.children[1])
|
||||
if (n1.children[1] != n2.children[2])
|
||||
return false
|
||||
end
|
||||
# 1_1 == 2_2
|
||||
if (n1.children[2] != n2.children[1])
|
||||
return false
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
# 1_1 == 2_1
|
||||
if (n1.children[2] != n2.children[2])
|
||||
return false
|
||||
end
|
||||
# 1_1 == 2_2
|
||||
if (n1.children[2] != n2.children[1])
|
||||
return false
|
||||
end
|
||||
return true
|
||||
end
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
# 1_1 == 2_1
|
||||
if (n1.children[2] != n2.children[2])
|
||||
return false
|
||||
end
|
||||
return true
|
||||
end
|
||||
# this is simple
|
||||
if (n1_length == 1)
|
||||
return n1.children[1] == n2.children[1]
|
||||
end
|
||||
|
||||
# this is simple
|
||||
if (n1_length == 1)
|
||||
return n1.children[1] == n2.children[1]
|
||||
end
|
||||
|
||||
# this takes a long time
|
||||
return Set(n1.children) == Set(n2.children)
|
||||
# this takes a long time
|
||||
return Set(n1.children) == Set(n2.children)
|
||||
end
|
||||
|
||||
function can_split(n::Node)
|
||||
return length(parents(n)) > 1
|
||||
return length(parents(n)) > 1
|
||||
end
|
||||
|
||||
function ==(op1::Operation, op2::Operation)
|
||||
return false
|
||||
return false
|
||||
end
|
||||
|
||||
function ==(op1::NodeFusion, op2::NodeFusion)
|
||||
# there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same
|
||||
return op1.input[2] == op2.input[2]
|
||||
# there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same
|
||||
return op1.input[2] == op2.input[2]
|
||||
end
|
||||
|
||||
function ==(op1::NodeReduction, op2::NodeReduction)
|
||||
# node reductions are equal exactly if their first input is the same
|
||||
return op1.input[1].id == op2.input[1].id
|
||||
# node reductions are equal exactly if their first input is the same
|
||||
return op1.input[1].id == op2.input[1].id
|
||||
end
|
||||
|
||||
function ==(op1::NodeSplit, op2::NodeSplit)
|
||||
return op1.input == op2.input
|
||||
return op1.input == op2.input
|
||||
end
|
||||
|
||||
copy(id::UUID) = UUID(id.value)
|
||||
|
@@ -2,23 +2,51 @@
|
||||
# should be called with @assert
|
||||
# the functions throw their own errors though, to still have helpful error messages
|
||||
|
||||
function is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
function is_valid_node_fusion_input(
|
||||
graph::DAG,
|
||||
n1::ComputeTaskNode,
|
||||
n2::DataTaskNode,
|
||||
n3::ComputeTaskNode,
|
||||
)
|
||||
if !(n1 in graph) || !(n2 in graph) || !(n3 in graph)
|
||||
throw(AssertionError("[Node Fusion] The given nodes are not part of the given graph"))
|
||||
throw(
|
||||
AssertionError(
|
||||
"[Node Fusion] The given nodes are not part of the given graph",
|
||||
),
|
||||
)
|
||||
end
|
||||
|
||||
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
|
||||
throw(AssertionError("[Node Fusion] The given nodes are not connected by edges which is required for node fusion"))
|
||||
if !is_child(n1, n2) ||
|
||||
!is_child(n2, n3) ||
|
||||
!is_parent(n3, n2) ||
|
||||
!is_parent(n2, n1)
|
||||
throw(
|
||||
AssertionError(
|
||||
"[Node Fusion] The given nodes are not connected by edges which is required for node fusion",
|
||||
),
|
||||
)
|
||||
end
|
||||
|
||||
if length(n2.parents) > 1
|
||||
throw(AssertionError("[Node Fusion] The given data node has more than one parent"))
|
||||
throw(
|
||||
AssertionError(
|
||||
"[Node Fusion] The given data node has more than one parent",
|
||||
),
|
||||
)
|
||||
end
|
||||
if length(n2.children) > 1
|
||||
throw(AssertionError("[Node Fusion] The given data node has more than one child"))
|
||||
throw(
|
||||
AssertionError(
|
||||
"[Node Fusion] The given data node has more than one child",
|
||||
),
|
||||
)
|
||||
end
|
||||
if length(n1.parents) > 1
|
||||
throw(AssertionError("[Node Fusion] The given n1 has more than one parent"))
|
||||
throw(
|
||||
AssertionError(
|
||||
"[Node Fusion] The given n1 has more than one parent",
|
||||
),
|
||||
)
|
||||
end
|
||||
|
||||
return true
|
||||
@@ -27,21 +55,33 @@ end
|
||||
function is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node})
|
||||
for n in nodes
|
||||
if n ∉ graph
|
||||
throw(AssertionError("[Node Reduction] The given nodes are not part of the given graph"))
|
||||
throw(
|
||||
AssertionError(
|
||||
"[Node Reduction] The given nodes are not part of the given graph",
|
||||
),
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
t = typeof(nodes[1].task)
|
||||
for n in nodes
|
||||
if typeof(n.task) != t
|
||||
throw(AssertionError("[Node Reduction] The given nodes are not of the same type"))
|
||||
throw(
|
||||
AssertionError(
|
||||
"[Node Reduction] The given nodes are not of the same type",
|
||||
),
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
n1_children = nodes[1].children
|
||||
for n in nodes
|
||||
if Set(n1_children) != Set(n.children)
|
||||
throw(AssertionError("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction"))
|
||||
throw(
|
||||
AssertionError(
|
||||
"[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction",
|
||||
),
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
@@ -50,11 +90,19 @@ end
|
||||
|
||||
function is_valid_node_split_input(graph::DAG, n1::Node)
|
||||
if n1 ∉ graph
|
||||
throw(AssertionError("[Node Split] The given node is not part of the given graph"))
|
||||
throw(
|
||||
AssertionError(
|
||||
"[Node Split] The given node is not part of the given graph",
|
||||
),
|
||||
)
|
||||
end
|
||||
|
||||
if length(n1.parents) <= 1
|
||||
throw(AssertionError("[Node Split] The given node does not have multiple parents which is required for node split"))
|
||||
throw(
|
||||
AssertionError(
|
||||
"[Node Split] The given node does not have multiple parents which is required for node split",
|
||||
),
|
||||
)
|
||||
end
|
||||
|
||||
return true
|
||||
@@ -73,7 +121,12 @@ function is_valid(graph::DAG, ns::NodeSplit)
|
||||
end
|
||||
|
||||
function is_valid(graph::DAG, nf::NodeFusion)
|
||||
@assert is_valid_node_fusion_input(graph, nf.input[1], nf.input[2], nf.input[3])
|
||||
@assert is_valid_node_fusion_input(
|
||||
graph,
|
||||
nf.input[1],
|
||||
nf.input[2],
|
||||
nf.input[3],
|
||||
)
|
||||
@assert nf in graph.possibleOperations.nodeFusions "NodeFusion is not part of the graph's possible operations!"
|
||||
return true
|
||||
end
|
||||
|
Reference in New Issue
Block a user