105 lines
2.7 KiB
Julia

# functions for "cleaning" nodes, i.e. regenerating the possible operations for a node
# function to find node fusions involving the given node if it's a data node
# pushes the found fusion everywhere it needs to be and returns nothing
function find_fusions!(graph::DAG, node::DataTaskNode)
# if there is already a fusion here, skip
for op in node.operations
if typeof(op) <: NodeFusion
return nothing
end
end
if length(node.parents) != 1 || length(node.children) != 1
return nothing
end
child_node = first(node.children)
parent_node = first(node.parents)
#=if !(child_node in graph) || !(parent_node in graph)
error("Parents/Children that are not in the graph!!!")
end=#
if length(child_node.parents) != 1
return nothing
end
nf = NodeFusion((child_node, node, parent_node))
push!(graph.possibleOperations.nodeFusions, nf)
push!(child_node.operations, nf)
push!(node.operations, nf)
push!(parent_node.operations, nf)
return nothing
end
function find_fusions!(graph::DAG, node::ComputeTaskNode)
# just find fusions in neighbouring DataTaskNodes
for child in node.children
find_fusions!(graph, child)
end
for parent in node.parents
find_fusions!(graph, parent)
end
return nothing
end
function find_reductions!(graph::DAG, node::Node)
# there can only be one reduction per node, avoid adding duplicates
for op in node.operations
if typeof(op) <: NodeReduction
return nothing
end
end
reductionVector = nothing
# possible reductions are with nodes that are partners, i.e. parents of children
partners_ = partners(node)
delete!(partners_, node)
for partner in partners_
if can_reduce(node, partner)
if reductionVector === nothing
# only when there's at least one reduction partner, insert the vector
reductionVector = Vector{Node}()
push!(reductionVector, node)
end
push!(reductionVector, partner)
end
end
if reductionVector !== nothing
nr = NodeReduction(reductionVector)
push!(graph.possibleOperations.nodeReductions, nr)
for node in reductionVector
push!(node.operations, nr)
end
end
return nothing
end
function find_splits!(graph::DAG, node::Node)
if (can_split(node))
ns = NodeSplit(node)
push!(graph.possibleOperations.nodeSplits, ns)
push!(node.operations, ns)
end
return nothing
end
# "clean" the operations on a dirty node
function clean_node!(graph::DAG, node::Node)
sort_node!(node)
find_fusions!(graph, node)
find_reductions!(graph, node)
find_splits!(graph, node)
end