2023-06-13 00:19:12 +02:00

149 lines
4.4 KiB
Julia

# Fuse nodes n1 -> n2 -> n3 together into one node, return the resulting new node
function node_fusion(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
if !(n1 in graph) || !(n2 in graph) || !(n3 in graph)
error("[Node Fusion] The given nodes are not part of the given graph")
end
required_edge1 = make_edge(n1, n2)
required_edge2 = make_edge(n2, n3)
if !(required_edge1 in graph) || !(required_edge2 in graph)
error("[Node Fusion] The given nodes are not connected by edges which is required for node fusion")
end
# save children and parents
n1_children = children(graph, n1)
n2_parents = parents(graph, n2)
n3_parents = parents(graph, n3)
if length(n2_parents) > 1
error("[Node Fusion] The given data node has more than one parent")
end
# remove the edges and nodes that will be replaced by the fused node
remove_edge(graph, required_edge1)
remove_edge(graph, required_edge2)
remove_node(graph, n1)
remove_node(graph, n2)
# get n3's children now so it automatically excludes n2
n3_children = children(graph, n3)
remove_node(graph, n3)
# create new node with the fused compute task
new_node = ComputeTaskNode(FusedComputeTask{typeof(n1.task), typeof(n3.task)}())
insert_node(graph, new_node)
# "repoint" children of n1 to the new node
for child in n1_children
remove_edge(graph, make_edge(child, n1))
insert_edge(graph, make_edge(child, new_node))
end
# "repoint" children of n3 to the new node
for child in n3_children
remove_edge(graph, make_edge(child, n3))
insert_edge(graph, make_edge(child, new_node))
end
# "repoint" parents of n3 from new node
for parent in n3_parents
remove_edge(graph, make_edge(n3, parent))
insert_edge(graph, make_edge(new_node, parent))
end
return new_node
end
function node_reduction(graph::DAG, n1::Node, n2::Node)
if !(n1 in graph) || !(n2 in graph)
error("[Node Reduction] The given nodes are not part of the given graph")
end
if typeof(n1) != typeof(n2)
error("[Node Reduction] The given nodes are not of the same type")
end
# save n2 parents and children
n2_children = children(graph, n2)
n2_parents = parents(graph, n2)
if n2_children != children(graph, n1)
error("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction")
end
# remove n2 and all its parents and children
for child in n2_children
remove_edge(graph, make_edge(child, n2))
end
for parent in n2_parents
remove_edge(graph, make_edge(n2, parent))
# add parents of n2 to n1
insert_edge(graph, make_edge(n1, parent))
end
remove_node(graph, n2)
return n1
end
function node_split(graph::DAG, n1::Node)
if !(n1 in graph)
error("[Node Split] The given node is not part of the given graph")
end
n1_parents = parents(graph, n1)
n1_children = children(graph, n1)
if length(n1_parents) <= 1
error("[Node Split] The given node does not have multiple parents which is required for node split")
end
for parent in n1_parents
n_copy = copy(n1)
insert_node(graph, n_copy)
insert_edge(graph, make_edge(n_copy, parent))
remove_edge(graph, make_edge(n1, parent))
for child in n1_children
insert_edge(graph, make_edge(child, n_copy))
end
end
return nothing
end
# function to generate all possible optmizations on the graph
function generate_options(graph::DAG)
options = (fusions = Vector{Tuple{ComputeTaskNode, DataTaskNode, ComputeTaskNode}}(),
reductions = Vector{Vector{Node}}(),
splits = Vector{Tuple{Node}}())
# find possible node fusions
for node in graph.nodes
if (typeof(node) <: DataTaskNode)
node_parents = parents(graph, node)
if length(node_parents) != 1
# data node can only have a single parent
continue
end
parent_node = node_parents[1]
node_children = children(graph, node)
if length(node_children) != 1
# this node is an entry node or has multiple children which should not be possible
continue
end
child_node = node_children[1]
push!(options.fusions, (child_node, node, parent_node))
end
end
# find possible node reductions
# find possible node splits
return options
end