Add node reduction unit test and fix bugs
This commit is contained in:
@ -3,6 +3,13 @@
|
||||
# function to find node fusions involving the given node if it's a data node
|
||||
# pushes the found fusion everywhere it needs to be and returns nothing
|
||||
function find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
# if there is already a fusion here, skip
|
||||
for op in node.operations
|
||||
if typeof(op) <: NodeFusion
|
||||
return nothing
|
||||
end
|
||||
end
|
||||
|
||||
if length(node.parents) != 1 || length(node.children) != 1
|
||||
return nothing
|
||||
end
|
||||
@ -27,61 +34,29 @@ function find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
return nothing
|
||||
end
|
||||
|
||||
# function to find node fusions involving the given node if it's a compute node
|
||||
# pushes the found fusion(s) everywhere it needs to be and returns nothing
|
||||
|
||||
function find_fusions!(graph::DAG, node::ComputeTaskNode)
|
||||
# for loop that always runs once for a scoped block we can break out of
|
||||
for _ in 1:1
|
||||
# assume this node as child of the chain
|
||||
if length(node.parents) != 1
|
||||
break
|
||||
end
|
||||
node2 = first(node.parents)
|
||||
if length(node2.parents) != 1 || length(node2.children) != 1
|
||||
break
|
||||
end
|
||||
node3 = first(node2.parents)
|
||||
# just find fusions in neighbouring DataTaskNodes
|
||||
|
||||
#=if !(node2 in graph) || !(node3 in graph)
|
||||
error("Parents/Children that are not in the graph!!!")
|
||||
end=#
|
||||
|
||||
nf = NodeFusion((node, node2, node3))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(node.operations, nf)
|
||||
push!(node2.operations, nf)
|
||||
push!(node3.operations, nf)
|
||||
for child in node.children
|
||||
find_fusions!(graph, child)
|
||||
end
|
||||
|
||||
for _ in 1:1
|
||||
# assume this node as parent of the chain
|
||||
if length(node.children) < 1
|
||||
break
|
||||
end
|
||||
node2 = first(node.children)
|
||||
if length(node2.parents) != 1 || length(node2.children) != 1
|
||||
break
|
||||
end
|
||||
node1 = first(node2.children)
|
||||
if (length(node1.parents) > 1)
|
||||
break
|
||||
end
|
||||
|
||||
#=if !(node2 in graph) || !(node1 in graph)
|
||||
error("Parents/Children that are not in the graph!!!")
|
||||
end=#
|
||||
|
||||
nf = NodeFusion((node1, node2, node))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(node1.operations, nf)
|
||||
push!(node2.operations, nf)
|
||||
push!(node.operations, nf)
|
||||
for parent in node.parents
|
||||
find_fusions!(graph, parent)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function find_reductions!(graph::DAG, node::Node)
|
||||
# there can only be one reduction per node, avoid adding duplicates
|
||||
for op in node.operations
|
||||
if typeof(op) <: NodeReduction
|
||||
return nothing
|
||||
end
|
||||
end
|
||||
|
||||
reductionVector = nothing
|
||||
# possible reductions are with nodes that are partners, i.e. parents of children
|
||||
partners_ = partners(node)
|
||||
@ -121,6 +96,8 @@ end
|
||||
|
||||
# "clean" the operations on a dirty node
|
||||
function clean_node!(graph::DAG, node::Node)
|
||||
sort_node!(node)
|
||||
|
||||
find_fusions!(graph, node)
|
||||
find_reductions!(graph, node)
|
||||
find_splits!(graph, node)
|
||||
|
Reference in New Issue
Block a user