Add node reduction unit test and fix bugs
This commit is contained in:
@ -25,7 +25,7 @@ function apply_operation!(graph::DAG, operation::NodeFusion)
|
||||
end
|
||||
|
||||
function apply_operation!(graph::DAG, operation::NodeReduction)
|
||||
diff = node_reduction!(graph, operation.input[1], operation.input[2])
|
||||
diff = node_reduction!(graph, operation.input)
|
||||
return AppliedNodeReduction(operation, diff)
|
||||
end
|
||||
|
||||
@ -147,48 +147,57 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com
|
||||
return get_snapshot_diff(graph)
|
||||
end
|
||||
|
||||
function node_reduction!(graph::DAG, n1::Node, n2::Node)
|
||||
function node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
#=if !(n1 in graph) || !(n2 in graph)
|
||||
error("[Node Reduction] The given nodes are not part of the given graph")
|
||||
end=#
|
||||
t = typeof(nodes[1].task)
|
||||
for n in nodes
|
||||
if n ∉ graph
|
||||
error("[Node Reduction] The given nodes are not part of the given graph")
|
||||
end
|
||||
|
||||
#=if typeof(n1) != typeof(n2)
|
||||
error("[Node Reduction] The given nodes are not of the same type")
|
||||
end=#
|
||||
|
||||
# save n2 parents and children
|
||||
n2_children = children(n2)
|
||||
n2_parents = Set(n2.parents)
|
||||
|
||||
#=if Set(n2_children) != Set(n1.children)
|
||||
error("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction")
|
||||
end=#
|
||||
|
||||
# remove n2 and all its parents and children
|
||||
for child in n2_children
|
||||
remove_edge!(graph, make_edge(child, n2))
|
||||
if typeof(n.task) != t
|
||||
error("[Node Reduction] The given nodes are not of the same type")
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
for parent in n2_parents
|
||||
remove_edge!(graph, make_edge(n2, parent))
|
||||
n1 = nodes[1]
|
||||
n1_children = children(n1)
|
||||
for n in nodes
|
||||
if Set(n1_children) != Set(n.children)
|
||||
error("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction")
|
||||
end
|
||||
end
|
||||
|
||||
for parent in n1.parents
|
||||
# delete parents in n1 that already exist in n2
|
||||
delete!(n2_parents, parent)
|
||||
n1_parents = Set(n1.parents)
|
||||
new_parents = Set{Node}()
|
||||
|
||||
# remove all of the nodes' parents and children and the nodes themselves (except for first node)
|
||||
for i in 2:length(nodes)
|
||||
n = nodes[i]
|
||||
for child in n1_children
|
||||
remove_edge!(graph, make_edge(child, n))
|
||||
end
|
||||
|
||||
for parent in parents(n)
|
||||
remove_edge!(graph, make_edge(n, parent))
|
||||
|
||||
# collect all parents
|
||||
push!(new_parents, parent)
|
||||
end
|
||||
|
||||
remove_node!(graph, n)
|
||||
end
|
||||
|
||||
for parent in n2_parents
|
||||
setdiff!(new_parents, n1_parents)
|
||||
|
||||
for parent in new_parents
|
||||
# now add parents of n2 to n1 without duplicates
|
||||
insert_edge!(graph, make_edge(n1, parent))
|
||||
end
|
||||
|
||||
remove_node!(graph, n2)
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
end
|
||||
|
||||
|
@ -3,6 +3,13 @@
|
||||
# function to find node fusions involving the given node if it's a data node
|
||||
# pushes the found fusion everywhere it needs to be and returns nothing
|
||||
function find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
# if there is already a fusion here, skip
|
||||
for op in node.operations
|
||||
if typeof(op) <: NodeFusion
|
||||
return nothing
|
||||
end
|
||||
end
|
||||
|
||||
if length(node.parents) != 1 || length(node.children) != 1
|
||||
return nothing
|
||||
end
|
||||
@ -27,61 +34,29 @@ function find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
return nothing
|
||||
end
|
||||
|
||||
# function to find node fusions involving the given node if it's a compute node
|
||||
# pushes the found fusion(s) everywhere it needs to be and returns nothing
|
||||
|
||||
function find_fusions!(graph::DAG, node::ComputeTaskNode)
|
||||
# for loop that always runs once for a scoped block we can break out of
|
||||
for _ in 1:1
|
||||
# assume this node as child of the chain
|
||||
if length(node.parents) != 1
|
||||
break
|
||||
end
|
||||
node2 = first(node.parents)
|
||||
if length(node2.parents) != 1 || length(node2.children) != 1
|
||||
break
|
||||
end
|
||||
node3 = first(node2.parents)
|
||||
# just find fusions in neighbouring DataTaskNodes
|
||||
|
||||
#=if !(node2 in graph) || !(node3 in graph)
|
||||
error("Parents/Children that are not in the graph!!!")
|
||||
end=#
|
||||
|
||||
nf = NodeFusion((node, node2, node3))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(node.operations, nf)
|
||||
push!(node2.operations, nf)
|
||||
push!(node3.operations, nf)
|
||||
for child in node.children
|
||||
find_fusions!(graph, child)
|
||||
end
|
||||
|
||||
for _ in 1:1
|
||||
# assume this node as parent of the chain
|
||||
if length(node.children) < 1
|
||||
break
|
||||
end
|
||||
node2 = first(node.children)
|
||||
if length(node2.parents) != 1 || length(node2.children) != 1
|
||||
break
|
||||
end
|
||||
node1 = first(node2.children)
|
||||
if (length(node1.parents) > 1)
|
||||
break
|
||||
end
|
||||
|
||||
#=if !(node2 in graph) || !(node1 in graph)
|
||||
error("Parents/Children that are not in the graph!!!")
|
||||
end=#
|
||||
|
||||
nf = NodeFusion((node1, node2, node))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(node1.operations, nf)
|
||||
push!(node2.operations, nf)
|
||||
push!(node.operations, nf)
|
||||
for parent in node.parents
|
||||
find_fusions!(graph, parent)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function find_reductions!(graph::DAG, node::Node)
|
||||
# there can only be one reduction per node, avoid adding duplicates
|
||||
for op in node.operations
|
||||
if typeof(op) <: NodeReduction
|
||||
return nothing
|
||||
end
|
||||
end
|
||||
|
||||
reductionVector = nothing
|
||||
# possible reductions are with nodes that are partners, i.e. parents of children
|
||||
partners_ = partners(node)
|
||||
@ -121,6 +96,8 @@ end
|
||||
|
||||
# "clean" the operations on a dirty node
|
||||
function clean_node!(graph::DAG, node::Node)
|
||||
sort_node!(node)
|
||||
|
||||
find_fusions!(graph, node)
|
||||
find_reductions!(graph, node)
|
||||
find_splits!(graph, node)
|
||||
|
38
src/operations/print.jl
Normal file
38
src/operations/print.jl
Normal file
@ -0,0 +1,38 @@
|
||||
function show(io::IO, ops::PossibleOperations)
|
||||
print(io, length(ops.nodeFusions))
|
||||
println(io, " Node Fusions: ")
|
||||
for nf in ops.nodeFusions
|
||||
println(io, " - ", nf)
|
||||
end
|
||||
print(io, length(ops.nodeReductions))
|
||||
println(io, " Node Reductions: ")
|
||||
for nr in ops.nodeReductions
|
||||
println(io, " - ", nr)
|
||||
end
|
||||
print(io, length(ops.nodeSplits))
|
||||
println(io, " Node Splits: ")
|
||||
for ns in ops.nodeSplits
|
||||
println(io, " - ", ns)
|
||||
end
|
||||
end
|
||||
|
||||
function show(io::IO, op::NodeReduction)
|
||||
print(io, "NR: ")
|
||||
print(io, length(op.input))
|
||||
print(io, "x")
|
||||
print(io, op.input[1].task)
|
||||
end
|
||||
|
||||
function show(io::IO, op::NodeSplit)
|
||||
print(io, "NS: ")
|
||||
print(io, op.input.task)
|
||||
end
|
||||
|
||||
function show(io::IO, op::NodeFusion)
|
||||
print(io, "NF: ")
|
||||
print(io, op.input[1].task)
|
||||
print(io, "->")
|
||||
print(io, op.input[2].task)
|
||||
print(io, "->")
|
||||
print(io, op.input[3].task)
|
||||
end
|
@ -96,14 +96,12 @@ function ==(op1::NodeFusion, op2::NodeFusion)
|
||||
end
|
||||
|
||||
function ==(op1::NodeReduction, op2::NodeReduction)
|
||||
# only test the ids against each other
|
||||
return op1.id == op2.id
|
||||
# node reductions are equal exactly if their first input is the same
|
||||
return op1.input[1].id == op2.input[1].id
|
||||
end
|
||||
|
||||
function ==(op1::NodeSplit, op2::NodeSplit)
|
||||
return op1.input == op2.input
|
||||
end
|
||||
|
||||
NodeReduction(input::Vector{Node}) = NodeReduction(input, UUIDs.uuid1(rng[threadid()]))
|
||||
|
||||
copy(id::UUID) = UUID(id.value)
|
||||
|
Reference in New Issue
Block a user