Add node reduction unit test and fix bugs

This commit is contained in:
2023-08-23 12:51:25 +02:00
parent 569949d5c7
commit 92f59110ed
13 changed files with 453 additions and 323 deletions

View File

@ -25,7 +25,7 @@ function apply_operation!(graph::DAG, operation::NodeFusion)
end
function apply_operation!(graph::DAG, operation::NodeReduction)
diff = node_reduction!(graph, operation.input[1], operation.input[2])
diff = node_reduction!(graph, operation.input)
return AppliedNodeReduction(operation, diff)
end
@ -147,48 +147,57 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com
return get_snapshot_diff(graph)
end
function node_reduction!(graph::DAG, n1::Node, n2::Node)
function node_reduction!(graph::DAG, nodes::Vector{Node})
# clear snapshot
get_snapshot_diff(graph)
#=if !(n1 in graph) || !(n2 in graph)
error("[Node Reduction] The given nodes are not part of the given graph")
end=#
t = typeof(nodes[1].task)
for n in nodes
if n graph
error("[Node Reduction] The given nodes are not part of the given graph")
end
#=if typeof(n1) != typeof(n2)
error("[Node Reduction] The given nodes are not of the same type")
end=#
# save n2 parents and children
n2_children = children(n2)
n2_parents = Set(n2.parents)
#=if Set(n2_children) != Set(n1.children)
error("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction")
end=#
# remove n2 and all its parents and children
for child in n2_children
remove_edge!(graph, make_edge(child, n2))
if typeof(n.task) != t
error("[Node Reduction] The given nodes are not of the same type")
end
end
for parent in n2_parents
remove_edge!(graph, make_edge(n2, parent))
n1 = nodes[1]
n1_children = children(n1)
for n in nodes
if Set(n1_children) != Set(n.children)
error("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction")
end
end
for parent in n1.parents
# delete parents in n1 that already exist in n2
delete!(n2_parents, parent)
n1_parents = Set(n1.parents)
new_parents = Set{Node}()
# remove all of the nodes' parents and children and the nodes themselves (except for first node)
for i in 2:length(nodes)
n = nodes[i]
for child in n1_children
remove_edge!(graph, make_edge(child, n))
end
for parent in parents(n)
remove_edge!(graph, make_edge(n, parent))
# collect all parents
push!(new_parents, parent)
end
remove_node!(graph, n)
end
for parent in n2_parents
setdiff!(new_parents, n1_parents)
for parent in new_parents
# now add parents of n2 to n1 without duplicates
insert_edge!(graph, make_edge(n1, parent))
end
remove_node!(graph, n2)
return get_snapshot_diff(graph)
end

View File

@ -3,6 +3,13 @@
# function to find node fusions involving the given node if it's a data node
# pushes the found fusion everywhere it needs to be and returns nothing
function find_fusions!(graph::DAG, node::DataTaskNode)
# if there is already a fusion here, skip
for op in node.operations
if typeof(op) <: NodeFusion
return nothing
end
end
if length(node.parents) != 1 || length(node.children) != 1
return nothing
end
@ -27,61 +34,29 @@ function find_fusions!(graph::DAG, node::DataTaskNode)
return nothing
end
# function to find node fusions involving the given node if it's a compute node
# pushes the found fusion(s) everywhere it needs to be and returns nothing
function find_fusions!(graph::DAG, node::ComputeTaskNode)
# for loop that always runs once for a scoped block we can break out of
for _ in 1:1
# assume this node as child of the chain
if length(node.parents) != 1
break
end
node2 = first(node.parents)
if length(node2.parents) != 1 || length(node2.children) != 1
break
end
node3 = first(node2.parents)
# just find fusions in neighbouring DataTaskNodes
#=if !(node2 in graph) || !(node3 in graph)
error("Parents/Children that are not in the graph!!!")
end=#
nf = NodeFusion((node, node2, node3))
push!(graph.possibleOperations.nodeFusions, nf)
push!(node.operations, nf)
push!(node2.operations, nf)
push!(node3.operations, nf)
for child in node.children
find_fusions!(graph, child)
end
for _ in 1:1
# assume this node as parent of the chain
if length(node.children) < 1
break
end
node2 = first(node.children)
if length(node2.parents) != 1 || length(node2.children) != 1
break
end
node1 = first(node2.children)
if (length(node1.parents) > 1)
break
end
#=if !(node2 in graph) || !(node1 in graph)
error("Parents/Children that are not in the graph!!!")
end=#
nf = NodeFusion((node1, node2, node))
push!(graph.possibleOperations.nodeFusions, nf)
push!(node1.operations, nf)
push!(node2.operations, nf)
push!(node.operations, nf)
for parent in node.parents
find_fusions!(graph, parent)
end
return nothing
end
function find_reductions!(graph::DAG, node::Node)
# there can only be one reduction per node, avoid adding duplicates
for op in node.operations
if typeof(op) <: NodeReduction
return nothing
end
end
reductionVector = nothing
# possible reductions are with nodes that are partners, i.e. parents of children
partners_ = partners(node)
@ -121,6 +96,8 @@ end
# "clean" the operations on a dirty node
function clean_node!(graph::DAG, node::Node)
sort_node!(node)
find_fusions!(graph, node)
find_reductions!(graph, node)
find_splits!(graph, node)

38
src/operations/print.jl Normal file
View File

@ -0,0 +1,38 @@
function show(io::IO, ops::PossibleOperations)
print(io, length(ops.nodeFusions))
println(io, " Node Fusions: ")
for nf in ops.nodeFusions
println(io, " - ", nf)
end
print(io, length(ops.nodeReductions))
println(io, " Node Reductions: ")
for nr in ops.nodeReductions
println(io, " - ", nr)
end
print(io, length(ops.nodeSplits))
println(io, " Node Splits: ")
for ns in ops.nodeSplits
println(io, " - ", ns)
end
end
function show(io::IO, op::NodeReduction)
print(io, "NR: ")
print(io, length(op.input))
print(io, "x")
print(io, op.input[1].task)
end
function show(io::IO, op::NodeSplit)
print(io, "NS: ")
print(io, op.input.task)
end
function show(io::IO, op::NodeFusion)
print(io, "NF: ")
print(io, op.input[1].task)
print(io, "->")
print(io, op.input[2].task)
print(io, "->")
print(io, op.input[3].task)
end

View File

@ -96,14 +96,12 @@ function ==(op1::NodeFusion, op2::NodeFusion)
end
function ==(op1::NodeReduction, op2::NodeReduction)
# only test the ids against each other
return op1.id == op2.id
# node reductions are equal exactly if their first input is the same
return op1.input[1].id == op2.input[1].id
end
function ==(op1::NodeSplit, op2::NodeSplit)
return op1.input == op2.input
end
NodeReduction(input::Vector{Node}) = NodeReduction(input, UUIDs.uuid1(rng[threadid()]))
copy(id::UUID) = UUID(id.value)