Add formatter
Some checks failed
Test / test (push) Has been cancelled

This commit is contained in:
2023-08-25 10:48:22 +02:00
parent dbcd569967
commit ae1345d547
44 changed files with 1191 additions and 865 deletions

View File

@@ -16,11 +16,16 @@ function apply_all!(graph::DAG)
end
function apply_operation!(graph::DAG, operation::Operation)
error("Unknown operation type!")
return error("Unknown operation type!")
end
function apply_operation!(graph::DAG, operation::NodeFusion)
diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3])
diff = node_fusion!(
graph,
operation.input[1],
operation.input[2],
operation.input[3],
)
return AppliedNodeFusion(operation, diff)
end
@@ -36,7 +41,7 @@ end
function revert_operation!(graph::DAG, operation::AppliedOperation)
error("Unknown operation type!")
return error("Unknown operation type!")
end
function revert_operation!(graph::DAG, operation::AppliedNodeFusion)
@@ -74,7 +79,12 @@ function revert_diff!(graph::DAG, diff::Diff)
end
# Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph
function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
function node_fusion!(
graph::DAG,
n1::ComputeTaskNode,
n2::DataTaskNode,
n3::ComputeTaskNode,
)
# @assert is_valid_node_fusion_input(graph, n1, n2, n3)
# clear snapshot
@@ -97,7 +107,8 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com
remove_node!(graph, n3)
# create new node with the fused compute task
new_node = ComputeTaskNode(FusedComputeTask{typeof(n1.task),typeof(n3.task)}())
new_node =
ComputeTaskNode(FusedComputeTask{typeof(n1.task), typeof(n3.task)}())
insert_node!(graph, new_node)
# use a set for combined children of n1 and n3 to not get duplicates
@@ -136,7 +147,7 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
n1 = nodes[1]
n1_children = children(n1)
n1_parents = Set(n1.parents)
new_parents = Set{Node}()

View File

@@ -3,113 +3,113 @@
# function to find node fusions involving the given node if it's a data node
# pushes the found fusion everywhere it needs to be and returns nothing
function find_fusions!(graph::DAG, node::DataTaskNode)
# if there is already a fusion here, skip
if !ismissing(node.nodeFusion)
return nothing
end
# if there is already a fusion here, skip
if !ismissing(node.nodeFusion)
return nothing
end
if length(node.parents) != 1 || length(node.children) != 1
return nothing
end
if length(node.parents) != 1 || length(node.children) != 1
return nothing
end
child_node = first(node.children)
parent_node = first(node.parents)
child_node = first(node.children)
parent_node = first(node.parents)
if !(child_node in graph) || !(parent_node in graph)
error("Parents/Children that are not in the graph!!!")
end
if !(child_node in graph) || !(parent_node in graph)
error("Parents/Children that are not in the graph!!!")
end
if length(child_node.parents) != 1
return nothing
end
if length(child_node.parents) != 1
return nothing
end
nf = NodeFusion((child_node, node, parent_node))
push!(graph.possibleOperations.nodeFusions, nf)
push!(child_node.nodeFusions, nf)
node.nodeFusion = nf
push!(parent_node.nodeFusions, nf)
nf = NodeFusion((child_node, node, parent_node))
push!(graph.possibleOperations.nodeFusions, nf)
push!(child_node.nodeFusions, nf)
node.nodeFusion = nf
push!(parent_node.nodeFusions, nf)
return nothing
return nothing
end
function find_fusions!(graph::DAG, node::ComputeTaskNode)
# just find fusions in neighbouring DataTaskNodes
for child in node.children
find_fusions!(graph, child)
end
# just find fusions in neighbouring DataTaskNodes
for child in node.children
find_fusions!(graph, child)
end
for parent in node.parents
find_fusions!(graph, parent)
end
for parent in node.parents
find_fusions!(graph, parent)
end
return nothing
return nothing
end
function find_reductions!(graph::DAG, node::Node)
# there can only be one reduction per node, avoid adding duplicates
if !ismissing(node.nodeReduction)
return nothing
end
# there can only be one reduction per node, avoid adding duplicates
if !ismissing(node.nodeReduction)
return nothing
end
reductionVector = nothing
# possible reductions are with nodes that are partners, i.e. parents of children
partners_ = partners(node)
delete!(partners_, node)
for partner in partners_
if partner graph.nodes
error("Partner is not part of the graph")
end
reductionVector = nothing
# possible reductions are with nodes that are partners, i.e. parents of children
partners_ = partners(node)
delete!(partners_, node)
for partner in partners_
if partner graph.nodes
error("Partner is not part of the graph")
end
if can_reduce(node, partner)
if Set(node.children) != Set(partner.children)
error("Not equal children")
end
if reductionVector === nothing
# only when there's at least one reduction partner, insert the vector
reductionVector = Vector{Node}()
push!(reductionVector, node)
end
if can_reduce(node, partner)
if Set(node.children) != Set(partner.children)
error("Not equal children")
end
if reductionVector === nothing
# only when there's at least one reduction partner, insert the vector
reductionVector = Vector{Node}()
push!(reductionVector, node)
end
push!(reductionVector, partner)
end
end
push!(reductionVector, partner)
end
end
if reductionVector !== nothing
nr = NodeReduction(reductionVector)
push!(graph.possibleOperations.nodeReductions, nr)
for node in reductionVector
if !ismissing(node.nodeReduction)
# it can happen that the dirty node becomes part of an existing NodeReduction and overrides those ones now
# this is only a problem insofar the existing NodeReduction has to be deleted and replaced also in the possibleOperations
invalidate_caches!(graph, node.nodeReduction)
end
node.nodeReduction = nr
end
end
if reductionVector !== nothing
nr = NodeReduction(reductionVector)
push!(graph.possibleOperations.nodeReductions, nr)
for node in reductionVector
if !ismissing(node.nodeReduction)
# it can happen that the dirty node becomes part of an existing NodeReduction and overrides those ones now
# this is only a problem insofar the existing NodeReduction has to be deleted and replaced also in the possibleOperations
invalidate_caches!(graph, node.nodeReduction)
end
node.nodeReduction = nr
end
end
return nothing
return nothing
end
function find_splits!(graph::DAG, node::Node)
if !ismissing(node.nodeSplit)
return nothing
end
if !ismissing(node.nodeSplit)
return nothing
end
if (can_split(node))
ns = NodeSplit(node)
push!(graph.possibleOperations.nodeSplits, ns)
node.nodeSplit = ns
end
if (can_split(node))
ns = NodeSplit(node)
push!(graph.possibleOperations.nodeSplits, ns)
node.nodeSplit = ns
end
return nothing
return nothing
end
# "clean" the operations on a dirty node
function clean_node!(graph::DAG, node::Node)
sort_node!(node)
find_fusions!(graph, node)
find_reductions!(graph, node)
find_splits!(graph, node)
sort_node!(node)
find_fusions!(graph, node)
find_reductions!(graph, node)
return find_splits!(graph, node)
end

View File

@@ -2,204 +2,227 @@
using Base.Threads
function insert_operation!(nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock})
n1 = nf.input[1]; n2 = nf.input[2]; n3 = nf.input[3]
function insert_operation!(
nf::NodeFusion,
locks::Dict{ComputeTaskNode, SpinLock},
)
n1 = nf.input[1]
n2 = nf.input[2]
n3 = nf.input[3]
lock(locks[n1]) do; push!(nf.input[1].nodeFusions, nf); end
nf.input[2].nodeFusion = nf
lock(locks[n3]) do; push!(nf.input[3].nodeFusions, nf); end
return nothing
lock(locks[n1]) do
return push!(nf.input[1].nodeFusions, nf)
end
n2.nodeFusion = nf
lock(locks[n3]) do
return push!(nf.input[3].nodeFusions, nf)
end
return nothing
end
function insert_operation!(nr::NodeReduction)
for n in nr.input
n.nodeReduction = nr
end
return nothing
for n in nr.input
n.nodeReduction = nr
end
return nothing
end
function insert_operation!(ns::NodeSplit)
ns.input.nodeSplit = ns
return nothing
ns.input.nodeSplit = ns
return nothing
end
function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}})
total_len = 0
for vec in nodeReductions
total_len += length(vec)
end
sizehint!(operations.nodeReductions, total_len)
function nr_insertion!(
operations::PossibleOperations,
nodeReductions::Vector{Vector{NodeReduction}},
)
total_len = 0
for vec in nodeReductions
total_len += length(vec)
end
sizehint!(operations.nodeReductions, total_len)
t = @task for vec in nodeReductions
union!(operations.nodeReductions, Set(vec))
end
schedule(t)
t = @task for vec in nodeReductions
union!(operations.nodeReductions, Set(vec))
end
schedule(t)
@threads for vec in nodeReductions
for op in vec
insert_operation!(op)
end
end
@threads for vec in nodeReductions
for op in vec
insert_operation!(op)
end
end
wait(t)
wait(t)
return nothing
return nothing
end
function nf_insertion!(graph::DAG, operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}})
total_len = 0
for vec in nodeFusions
total_len += length(vec)
end
sizehint!(operations.nodeFusions, total_len)
t = @task for vec in nodeFusions
union!(operations.nodeFusions, Set(vec))
end
schedule(t)
function nf_insertion!(
graph::DAG,
operations::PossibleOperations,
nodeFusions::Vector{Vector{NodeFusion}},
)
total_len = 0
for vec in nodeFusions
total_len += length(vec)
end
sizehint!(operations.nodeFusions, total_len)
locks = Dict{ComputeTaskNode, SpinLock}()
for n in graph.nodes
if (typeof(n) <: ComputeTaskNode)
locks[n] = SpinLock()
end
end
t = @task for vec in nodeFusions
union!(operations.nodeFusions, Set(vec))
end
schedule(t)
@threads for vec in nodeFusions
for op in vec
insert_operation!(op, locks)
end
end
locks = Dict{ComputeTaskNode, SpinLock}()
for n in graph.nodes
if (typeof(n) <: ComputeTaskNode)
locks[n] = SpinLock()
end
end
wait(t)
@threads for vec in nodeFusions
for op in vec
insert_operation!(op, locks)
end
end
return nothing
wait(t)
return nothing
end
function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}})
total_len = 0
for vec in nodeSplits
total_len += length(vec)
end
sizehint!(operations.nodeSplits, total_len)
function ns_insertion!(
operations::PossibleOperations,
nodeSplits::Vector{Vector{NodeSplit}},
)
total_len = 0
for vec in nodeSplits
total_len += length(vec)
end
sizehint!(operations.nodeSplits, total_len)
t = @task for vec in nodeSplits
union!(operations.nodeSplits, Set(vec))
end
schedule(t)
t = @task for vec in nodeSplits
union!(operations.nodeSplits, Set(vec))
end
schedule(t)
@threads for vec in nodeSplits
for op in vec
insert_operation!(op)
end
end
@threads for vec in nodeSplits
for op in vec
insert_operation!(op)
end
end
wait(t)
wait(t)
return nothing
return nothing
end
# function to generate all possible operations on the graph
function generate_options(graph::DAG)
generatedFusions = [Vector{NodeFusion}() for _ in 1:nthreads()]
generatedReductions = [Vector{NodeReduction}() for _ in 1:nthreads()]
generatedSplits = [Vector{NodeSplit}() for _ in 1:nthreads()]
generatedFusions = [Vector{NodeFusion}() for _ in 1:nthreads()]
generatedReductions = [Vector{NodeReduction}() for _ in 1:nthreads()]
generatedSplits = [Vector{NodeSplit}() for _ in 1:nthreads()]
# make sure the graph is fully generated through
apply_all!(graph)
# make sure the graph is fully generated through
apply_all!(graph)
nodeArray = collect(graph.nodes)
nodeArray = collect(graph.nodes)
# sort all nodes
@threads for node in nodeArray
sort_node!(node)
end
# sort all nodes
@threads for node in nodeArray
sort_node!(node)
end
checkedNodes = Set{Node}()
checkedNodesLock = SpinLock()
# --- find possible node reductions ---
@threads for node in nodeArray
# we're looking for nodes with multiple parents, those parents can then potentially reduce with one another
if (length(node.parents) <= 1)
continue
end
checkedNodes = Set{Node}()
checkedNodesLock = SpinLock()
# --- find possible node reductions ---
@threads for node in nodeArray
# we're looking for nodes with multiple parents, those parents can then potentially reduce with one another
if (length(node.parents) <= 1)
continue
end
candidates = node.parents
candidates = node.parents
# sort into equivalence classes
trie = NodeTrie()
# sort into equivalence classes
trie = NodeTrie()
for candidate in candidates
# insert into trie
insert!(trie, candidate)
end
for candidate in candidates
# insert into trie
insert!(trie, candidate)
end
nodeReductions = collect(trie)
nodeReductions = collect(trie)
for nrVec in nodeReductions
# parent sets are ordered and any node can only be part of one nodeReduction, so a NodeReduction is uniquely identifiable by its first element
# this prevents duplicate nodeReductions being generated
lock(checkedNodesLock)
if (nrVec[1] in checkedNodes)
for nrVec in nodeReductions
# parent sets are ordered and any node can only be part of one nodeReduction, so a NodeReduction is uniquely identifiable by its first element
# this prevents duplicate nodeReductions being generated
lock(checkedNodesLock)
if (nrVec[1] in checkedNodes)
unlock(checkedNodesLock)
continue
else
push!(checkedNodes, nrVec[1])
end
unlock(checkedNodesLock)
continue
else
push!(checkedNodes, nrVec[1])
end
unlock(checkedNodesLock)
push!(generatedReductions[threadid()], NodeReduction(nrVec))
end
end
push!(generatedReductions[threadid()], NodeReduction(nrVec))
end
end
# launch thread for node reduction insertion
# remove duplicates
nr_task = @task nr_insertion!(graph.possibleOperations, generatedReductions)
schedule(nr_task)
# --- find possible node fusions ---
@threads for node in nodeArray
if (typeof(node) <: DataTaskNode)
if length(node.parents) != 1
# data node can only have a single parent
continue
end
parent_node = first(node.parents)
# launch thread for node reduction insertion
# remove duplicates
nr_task = @task nr_insertion!(graph.possibleOperations, generatedReductions)
schedule(nr_task)
if length(node.children) != 1
# this node is an entry node or has multiple children which should not be possible
continue
end
child_node = first(node.children)
if (length(child_node.parents) != 1)
continue
end
# --- find possible node fusions ---
@threads for node in nodeArray
if (typeof(node) <: DataTaskNode)
if length(node.parents) != 1
# data node can only have a single parent
continue
end
parent_node = first(node.parents)
push!(generatedFusions[threadid()], NodeFusion((child_node, node, parent_node)))
end
end
if length(node.children) != 1
# this node is an entry node or has multiple children which should not be possible
continue
end
child_node = first(node.children)
if (length(child_node.parents) != 1)
continue
end
# launch thread for node fusion insertion
nf_task = @task nf_insertion!(graph, graph.possibleOperations, generatedFusions)
schedule(nf_task)
push!(
generatedFusions[threadid()],
NodeFusion((child_node, node, parent_node)),
)
end
end
# find possible node splits
@threads for node in nodeArray
if (can_split(node))
push!(generatedSplits[threadid()], NodeSplit(node))
end
end
# launch thread for node fusion insertion
nf_task =
@task nf_insertion!(graph, graph.possibleOperations, generatedFusions)
schedule(nf_task)
# launch thread for node split insertion
ns_task = @task ns_insertion!(graph.possibleOperations, generatedSplits)
schedule(ns_task)
# find possible node splits
@threads for node in nodeArray
if (can_split(node))
push!(generatedSplits[threadid()], NodeSplit(node))
end
end
empty!(graph.dirtyNodes)
# launch thread for node split insertion
ns_task = @task ns_insertion!(graph.possibleOperations, generatedSplits)
schedule(ns_task)
wait(nr_task)
wait(nf_task)
wait(ns_task)
empty!(graph.dirtyNodes)
return nothing
wait(nr_task)
wait(nf_task)
wait(ns_task)
return nothing
end

View File

@@ -3,16 +3,16 @@
using Base.Threads
function get_operations(graph::DAG)
apply_all!(graph)
apply_all!(graph)
if isempty(graph.possibleOperations)
generate_options(graph)
end
if isempty(graph.possibleOperations)
generate_options(graph)
end
for node in graph.dirtyNodes
clean_node!(graph, node)
end
empty!(graph.dirtyNodes)
for node in graph.dirtyNodes
clean_node!(graph, node)
end
empty!(graph.dirtyNodes)
return graph.possibleOperations
return graph.possibleOperations
end

View File

@@ -20,12 +20,12 @@ function show(io::IO, op::NodeReduction)
print(io, "NR: ")
print(io, length(op.input))
print(io, "x")
print(io, op.input[1].task)
return print(io, op.input[1].task)
end
function show(io::IO, op::NodeSplit)
print(io, "NS: ")
print(io, op.input.task)
return print(io, op.input.task)
end
function show(io::IO, op::NodeFusion)
@@ -34,5 +34,5 @@ function show(io::IO, op::NodeFusion)
print(io, "->")
print(io, op.input[2].task)
print(io, "->")
print(io, op.input[3].task)
return print(io, op.input[3].task)
end

View File

@@ -7,28 +7,28 @@ abstract type Operation end
abstract type AppliedOperation end
struct NodeFusion <: Operation
input::Tuple{ComputeTaskNode, DataTaskNode, ComputeTaskNode}
input::Tuple{ComputeTaskNode, DataTaskNode, ComputeTaskNode}
end
struct AppliedNodeFusion <: AppliedOperation
operation::NodeFusion
diff::Diff
operation::NodeFusion
diff::Diff
end
struct NodeReduction <: Operation
input::Vector{Node}
input::Vector{Node}
end
struct AppliedNodeReduction <: AppliedOperation
operation::NodeReduction
diff::Diff
operation::NodeReduction
diff::Diff
end
struct NodeSplit <: Operation
input::Node
input::Node
end
struct AppliedNodeSplit <: AppliedOperation
operation::NodeSplit
diff::Diff
operation::NodeSplit
diff::Diff
end

View File

@@ -1,107 +1,111 @@
function isempty(operations::PossibleOperations)
return isempty(operations.nodeFusions) &&
isempty(operations.nodeReductions) &&
isempty(operations.nodeSplits)
return isempty(operations.nodeFusions) &&
isempty(operations.nodeReductions) &&
isempty(operations.nodeSplits)
end
function length(operations::PossibleOperations)
return (nodeFusions = length(operations.nodeFusions),
nodeReductions = length(operations.nodeReductions),
nodeSplits = length(operations.nodeSplits))
return (
nodeFusions = length(operations.nodeFusions),
nodeReductions = length(operations.nodeReductions),
nodeSplits = length(operations.nodeSplits),
)
end
function delete!(operations::PossibleOperations, op::NodeFusion)
delete!(operations.nodeFusions, op)
return operations
delete!(operations.nodeFusions, op)
return operations
end
function delete!(operations::PossibleOperations, op::NodeReduction)
delete!(operations.nodeReductions, op)
return operations
delete!(operations.nodeReductions, op)
return operations
end
function delete!(operations::PossibleOperations, op::NodeSplit)
delete!(operations.nodeSplits, op)
return operations
delete!(operations.nodeSplits, op)
return operations
end
function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
if !is_child(n1, n2) || !is_child(n2, n3)
# the checks are redundant but maybe a good sanity check
return false
end
if !is_child(n1, n2) || !is_child(n2, n3)
# the checks are redundant but maybe a good sanity check
return false
end
if length(n2.parents) != 1 || length(n2.children) != 1 || length(n1.parents) != 1
return false
end
if length(n2.parents) != 1 ||
length(n2.children) != 1 ||
length(n1.parents) != 1
return false
end
return true
return true
end
function can_reduce(n1::Node, n2::Node)
if (n1.task != n2.task)
return false
end
n1_length = length(n1.children)
n2_length = length(n2.children)
if (n1.task != n2.task)
return false
end
if (n1_length != n2_length)
return false
end
n1_length = length(n1.children)
n2_length = length(n2.children)
# this seems to be the most common case so do this first
# doing it manually is a lot faster than using the sets for a general solution
if (n1_length == 2)
if (n1.children[1] != n2.children[1])
if (n1.children[1] != n2.children[2])
if (n1_length != n2_length)
return false
end
# this seems to be the most common case so do this first
# doing it manually is a lot faster than using the sets for a general solution
if (n1_length == 2)
if (n1.children[1] != n2.children[1])
if (n1.children[1] != n2.children[2])
return false
end
# 1_1 == 2_2
if (n1.children[2] != n2.children[1])
return false
end
return true
end
# 1_1 == 2_1
if (n1.children[2] != n2.children[2])
return false
end
# 1_1 == 2_2
if (n1.children[2] != n2.children[1])
return false
end
return true
end
end
return true
end
# 1_1 == 2_1
if (n1.children[2] != n2.children[2])
return false
end
return true
end
# this is simple
if (n1_length == 1)
return n1.children[1] == n2.children[1]
end
# this is simple
if (n1_length == 1)
return n1.children[1] == n2.children[1]
end
# this takes a long time
return Set(n1.children) == Set(n2.children)
# this takes a long time
return Set(n1.children) == Set(n2.children)
end
function can_split(n::Node)
return length(parents(n)) > 1
return length(parents(n)) > 1
end
function ==(op1::Operation, op2::Operation)
return false
return false
end
function ==(op1::NodeFusion, op2::NodeFusion)
# there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same
return op1.input[2] == op2.input[2]
# there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same
return op1.input[2] == op2.input[2]
end
function ==(op1::NodeReduction, op2::NodeReduction)
# node reductions are equal exactly if their first input is the same
return op1.input[1].id == op2.input[1].id
# node reductions are equal exactly if their first input is the same
return op1.input[1].id == op2.input[1].id
end
function ==(op1::NodeSplit, op2::NodeSplit)
return op1.input == op2.input
return op1.input == op2.input
end
copy(id::UUID) = UUID(id.value)

View File

@@ -2,23 +2,51 @@
# should be called with @assert
# the functions throw their own errors though, to still have helpful error messages
function is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
function is_valid_node_fusion_input(
graph::DAG,
n1::ComputeTaskNode,
n2::DataTaskNode,
n3::ComputeTaskNode,
)
if !(n1 in graph) || !(n2 in graph) || !(n3 in graph)
throw(AssertionError("[Node Fusion] The given nodes are not part of the given graph"))
throw(
AssertionError(
"[Node Fusion] The given nodes are not part of the given graph",
),
)
end
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
throw(AssertionError("[Node Fusion] The given nodes are not connected by edges which is required for node fusion"))
if !is_child(n1, n2) ||
!is_child(n2, n3) ||
!is_parent(n3, n2) ||
!is_parent(n2, n1)
throw(
AssertionError(
"[Node Fusion] The given nodes are not connected by edges which is required for node fusion",
),
)
end
if length(n2.parents) > 1
throw(AssertionError("[Node Fusion] The given data node has more than one parent"))
throw(
AssertionError(
"[Node Fusion] The given data node has more than one parent",
),
)
end
if length(n2.children) > 1
throw(AssertionError("[Node Fusion] The given data node has more than one child"))
throw(
AssertionError(
"[Node Fusion] The given data node has more than one child",
),
)
end
if length(n1.parents) > 1
throw(AssertionError("[Node Fusion] The given n1 has more than one parent"))
throw(
AssertionError(
"[Node Fusion] The given n1 has more than one parent",
),
)
end
return true
@@ -27,21 +55,33 @@ end
function is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node})
for n in nodes
if n graph
throw(AssertionError("[Node Reduction] The given nodes are not part of the given graph"))
throw(
AssertionError(
"[Node Reduction] The given nodes are not part of the given graph",
),
)
end
end
t = typeof(nodes[1].task)
for n in nodes
if typeof(n.task) != t
throw(AssertionError("[Node Reduction] The given nodes are not of the same type"))
throw(
AssertionError(
"[Node Reduction] The given nodes are not of the same type",
),
)
end
end
n1_children = nodes[1].children
for n in nodes
if Set(n1_children) != Set(n.children)
throw(AssertionError("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction"))
throw(
AssertionError(
"[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction",
),
)
end
end
@@ -50,11 +90,19 @@ end
function is_valid_node_split_input(graph::DAG, n1::Node)
if n1 graph
throw(AssertionError("[Node Split] The given node is not part of the given graph"))
throw(
AssertionError(
"[Node Split] The given node is not part of the given graph",
),
)
end
if length(n1.parents) <= 1
throw(AssertionError("[Node Split] The given node does not have multiple parents which is required for node split"))
throw(
AssertionError(
"[Node Split] The given node does not have multiple parents which is required for node split",
),
)
end
return true
@@ -73,7 +121,12 @@ function is_valid(graph::DAG, ns::NodeSplit)
end
function is_valid(graph::DAG, nf::NodeFusion)
@assert is_valid_node_fusion_input(graph, nf.input[1], nf.input[2], nf.input[3])
@assert is_valid_node_fusion_input(
graph,
nf.input[1],
nf.input[2],
nf.input[3],
)
@assert nf in graph.possibleOperations.nodeFusions "NodeFusion is not part of the graph's possible operations!"
return true
end