2023-08-29 12:57:46 +02:00
|
|
|
"""
|
|
|
|
apply_all!(graph::DAG)
|
2023-08-21 12:54:45 +02:00
|
|
|
|
2023-08-29 12:57:46 +02:00
|
|
|
Apply all unapplied operations in the DAG. Is automatically called in all functions that require the latest state of the [`DAG`](@ref).
|
|
|
|
"""
|
2023-08-21 12:54:45 +02:00
|
|
|
function apply_all!(graph::DAG)
|
2023-08-23 19:28:45 +02:00
|
|
|
while !isempty(graph.operationsToApply)
|
|
|
|
# get next operation to apply from front of the deque
|
|
|
|
op = popfirst!(graph.operationsToApply)
|
2023-08-21 12:54:45 +02:00
|
|
|
|
2023-08-23 19:28:45 +02:00
|
|
|
# apply it
|
|
|
|
appliedOp = apply_operation!(graph, op)
|
2023-08-21 12:54:45 +02:00
|
|
|
|
2023-08-23 19:28:45 +02:00
|
|
|
# push to the end of the appliedOperations deque
|
|
|
|
push!(graph.appliedOperations, appliedOp)
|
|
|
|
end
|
|
|
|
return nothing
|
2023-08-21 12:54:45 +02:00
|
|
|
end
|
|
|
|
|
2023-08-29 12:57:46 +02:00
|
|
|
"""
|
|
|
|
apply_operation!(graph::DAG, operation::Operation)
|
|
|
|
|
|
|
|
Fallback implementation of apply_operation! for unimplemented operation types, throwing an error.
|
|
|
|
"""
|
2023-08-21 12:54:45 +02:00
|
|
|
function apply_operation!(graph::DAG, operation::Operation)
|
2023-08-25 10:48:22 +02:00
|
|
|
return error("Unknown operation type!")
|
2023-08-21 12:54:45 +02:00
|
|
|
end
|
|
|
|
|
2023-08-29 12:57:46 +02:00
|
|
|
"""
|
|
|
|
apply_operation!(graph::DAG, operation::NodeFusion)
|
|
|
|
|
|
|
|
Apply the given [`NodeFusion`](@ref) to the graph. Generic wrapper around [`node_fusion!`](@ref).
|
|
|
|
|
|
|
|
Return an [`AppliedNodeFusion`](@ref) object generated from the graph's [`Diff`](@ref).
|
|
|
|
"""
|
2023-08-21 12:54:45 +02:00
|
|
|
function apply_operation!(graph::DAG, operation::NodeFusion)
|
2023-10-12 17:51:03 +02:00
|
|
|
diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3])
|
2023-08-28 13:32:22 +02:00
|
|
|
|
|
|
|
graph.properties += GraphProperties(diff)
|
|
|
|
|
2023-08-23 19:28:45 +02:00
|
|
|
return AppliedNodeFusion(operation, diff)
|
2023-08-21 12:54:45 +02:00
|
|
|
end
|
|
|
|
|
2023-08-29 12:57:46 +02:00
|
|
|
"""
|
|
|
|
apply_operation!(graph::DAG, operation::NodeReduction)
|
|
|
|
|
|
|
|
Apply the given [`NodeReduction`](@ref) to the graph. Generic wrapper around [`node_reduction!`](@ref).
|
|
|
|
|
|
|
|
Return an [`AppliedNodeReduction`](@ref) object generated from the graph's [`Diff`](@ref).
|
|
|
|
"""
|
2023-08-21 12:54:45 +02:00
|
|
|
function apply_operation!(graph::DAG, operation::NodeReduction)
|
2023-08-23 19:28:45 +02:00
|
|
|
diff = node_reduction!(graph, operation.input)
|
2023-08-28 13:32:22 +02:00
|
|
|
|
|
|
|
graph.properties += GraphProperties(diff)
|
|
|
|
|
2023-08-23 19:28:45 +02:00
|
|
|
return AppliedNodeReduction(operation, diff)
|
2023-08-21 12:54:45 +02:00
|
|
|
end
|
|
|
|
|
2023-08-29 12:57:46 +02:00
|
|
|
"""
|
|
|
|
apply_operation!(graph::DAG, operation::NodeSplit)
|
|
|
|
|
|
|
|
Apply the given [`NodeSplit`](@ref) to the graph. Generic wrapper around [`node_split!`](@ref).
|
|
|
|
|
|
|
|
Return an [`AppliedNodeSplit`](@ref) object generated from the graph's [`Diff`](@ref).
|
|
|
|
"""
|
2023-08-21 12:54:45 +02:00
|
|
|
function apply_operation!(graph::DAG, operation::NodeSplit)
|
2023-08-23 19:28:45 +02:00
|
|
|
diff = node_split!(graph, operation.input)
|
2023-08-28 13:32:22 +02:00
|
|
|
|
|
|
|
graph.properties += GraphProperties(diff)
|
|
|
|
|
2023-08-23 19:28:45 +02:00
|
|
|
return AppliedNodeSplit(operation, diff)
|
2023-08-21 12:54:45 +02:00
|
|
|
end
|
|
|
|
|
2023-08-29 12:57:46 +02:00
|
|
|
"""
|
|
|
|
revert_operation!(graph::DAG, operation::AppliedOperation)
|
2023-08-21 12:54:45 +02:00
|
|
|
|
2023-08-29 12:57:46 +02:00
|
|
|
Fallback implementation of operation reversion for unimplemented operation types, throwing an error.
|
|
|
|
"""
|
2023-08-21 12:54:45 +02:00
|
|
|
function revert_operation!(graph::DAG, operation::AppliedOperation)
|
2023-08-25 10:48:22 +02:00
|
|
|
return error("Unknown operation type!")
|
2023-08-21 12:54:45 +02:00
|
|
|
end
|
|
|
|
|
2023-08-29 12:57:46 +02:00
|
|
|
"""
|
|
|
|
revert_operation!(graph::DAG, operation::AppliedNodeFusion)
|
|
|
|
|
|
|
|
Revert the applied node fusion on the graph. Return the original [`NodeFusion`](@ref) operation.
|
|
|
|
"""
|
2023-08-21 12:54:45 +02:00
|
|
|
function revert_operation!(graph::DAG, operation::AppliedNodeFusion)
|
2023-08-23 19:28:45 +02:00
|
|
|
revert_diff!(graph, operation.diff)
|
|
|
|
return operation.operation
|
2023-08-21 12:54:45 +02:00
|
|
|
end
|
|
|
|
|
2023-08-29 12:57:46 +02:00
|
|
|
"""
|
|
|
|
revert_operation!(graph::DAG, operation::AppliedNodeReduction)
|
|
|
|
|
|
|
|
Revert the applied node fusion on the graph. Return the original [`NodeReduction`](@ref) operation.
|
|
|
|
"""
|
2023-08-21 12:54:45 +02:00
|
|
|
function revert_operation!(graph::DAG, operation::AppliedNodeReduction)
|
2023-08-23 19:28:45 +02:00
|
|
|
revert_diff!(graph, operation.diff)
|
|
|
|
return operation.operation
|
2023-08-21 12:54:45 +02:00
|
|
|
end
|
|
|
|
|
2023-08-29 12:57:46 +02:00
|
|
|
"""
|
|
|
|
revert_operation!(graph::DAG, operation::AppliedNodeSplit)
|
|
|
|
|
|
|
|
Revert the applied node fusion on the graph. Return the original [`NodeSplit`](@ref) operation.
|
|
|
|
"""
|
2023-08-21 12:54:45 +02:00
|
|
|
function revert_operation!(graph::DAG, operation::AppliedNodeSplit)
|
2023-08-23 19:28:45 +02:00
|
|
|
revert_diff!(graph, operation.diff)
|
|
|
|
return operation.operation
|
2023-08-21 12:54:45 +02:00
|
|
|
end
|
|
|
|
|
2023-08-29 12:57:46 +02:00
|
|
|
"""
|
|
|
|
revert_diff!(graph::DAG, diff::Diff)
|
2023-08-21 12:54:45 +02:00
|
|
|
|
2023-08-29 12:57:46 +02:00
|
|
|
Revert the given diff on the graph. Used to revert the individual [`AppliedOperation`](@ref)s with [`revert_operation!`](@ref).
|
|
|
|
"""
|
2023-08-23 19:28:45 +02:00
|
|
|
function revert_diff!(graph::DAG, diff::Diff)
|
|
|
|
# add removed nodes, remove added nodes, same for edges
|
|
|
|
# note the order
|
|
|
|
for edge in diff.addedEdges
|
2023-10-12 17:51:03 +02:00
|
|
|
remove_edge!(graph, edge.edge[1], edge.edge[2], track = false)
|
2023-08-23 19:28:45 +02:00
|
|
|
end
|
|
|
|
for node in diff.addedNodes
|
2023-10-12 17:51:03 +02:00
|
|
|
remove_node!(graph, node, track = false)
|
2023-08-23 19:28:45 +02:00
|
|
|
end
|
|
|
|
|
|
|
|
for node in diff.removedNodes
|
2023-10-12 17:51:03 +02:00
|
|
|
insert_node!(graph, node, track = false)
|
2023-08-23 19:28:45 +02:00
|
|
|
end
|
|
|
|
for edge in diff.removedEdges
|
2023-10-12 17:51:03 +02:00
|
|
|
insert_edge!(graph, edge.edge[1], edge.edge[2], track = false)
|
|
|
|
end
|
|
|
|
|
|
|
|
for (node, task) in diff.updatedChildren
|
|
|
|
# node must be fused compute task at this point
|
|
|
|
@assert typeof(node.task) <: FusedComputeTask
|
|
|
|
|
|
|
|
node.task = task
|
2023-08-23 19:28:45 +02:00
|
|
|
end
|
2023-08-28 13:32:22 +02:00
|
|
|
|
|
|
|
graph.properties -= GraphProperties(diff)
|
|
|
|
|
2023-08-29 12:57:46 +02:00
|
|
|
return nothing
|
2023-08-21 12:54:45 +02:00
|
|
|
end
|
|
|
|
|
2023-08-29 12:57:46 +02:00
|
|
|
"""
|
|
|
|
node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
|
|
|
|
|
|
|
Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph.
|
|
|
|
|
|
|
|
For details see [`NodeFusion`](@ref).
|
|
|
|
"""
|
2023-10-12 17:51:03 +02:00
|
|
|
function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
|
|
|
@assert is_valid_node_fusion_input(graph, n1, n2, n3)
|
2023-08-21 12:54:45 +02:00
|
|
|
|
2023-08-23 19:28:45 +02:00
|
|
|
# clear snapshot
|
|
|
|
get_snapshot_diff(graph)
|
|
|
|
|
|
|
|
# save children and parents
|
2023-10-12 17:51:03 +02:00
|
|
|
n1Children = children(n1)
|
|
|
|
n3Parents = parents(n3)
|
|
|
|
|
|
|
|
n1Task = copy(n1.task)
|
|
|
|
n3Task = copy(n3.task)
|
|
|
|
|
|
|
|
# assemble the input node vectors of n1 and n3 to save into the FusedComputeTask
|
|
|
|
n1Inputs = Vector{Symbol}()
|
|
|
|
for child in n1Children
|
|
|
|
push!(n1Inputs, Symbol(to_var_name(child.id)))
|
|
|
|
end
|
2023-08-23 19:28:45 +02:00
|
|
|
|
|
|
|
# remove the edges and nodes that will be replaced by the fused node
|
|
|
|
remove_edge!(graph, n1, n2)
|
|
|
|
remove_edge!(graph, n2, n3)
|
|
|
|
remove_node!(graph, n1)
|
|
|
|
remove_node!(graph, n2)
|
2023-08-21 12:54:45 +02:00
|
|
|
|
2023-08-23 19:28:45 +02:00
|
|
|
# get n3's children now so it automatically excludes n2
|
2023-10-12 17:51:03 +02:00
|
|
|
n3Children = children(n3)
|
|
|
|
|
|
|
|
n3Inputs = Vector{Symbol}()
|
|
|
|
for child in n3Children
|
|
|
|
push!(n3Inputs, Symbol(to_var_name(child.id)))
|
|
|
|
end
|
|
|
|
|
2023-08-23 19:28:45 +02:00
|
|
|
remove_node!(graph, n3)
|
2023-08-21 12:54:45 +02:00
|
|
|
|
2023-08-23 19:28:45 +02:00
|
|
|
# create new node with the fused compute task
|
2023-10-12 17:51:03 +02:00
|
|
|
newNode = ComputeTaskNode(FusedComputeTask(n1Task, n3Task, n1Inputs, Symbol(to_var_name(n2.id)), n3Inputs))
|
|
|
|
insert_node!(graph, newNode)
|
2023-08-21 12:54:45 +02:00
|
|
|
|
2023-10-12 17:51:03 +02:00
|
|
|
for child in n1Children
|
2023-08-23 19:28:45 +02:00
|
|
|
remove_edge!(graph, child, n1)
|
2023-10-12 17:51:03 +02:00
|
|
|
insert_edge!(graph, child, newNode)
|
2023-08-23 19:28:45 +02:00
|
|
|
end
|
2023-08-21 12:54:45 +02:00
|
|
|
|
2023-10-12 17:51:03 +02:00
|
|
|
for child in n3Children
|
2023-08-23 19:28:45 +02:00
|
|
|
remove_edge!(graph, child, n3)
|
2023-10-12 17:51:03 +02:00
|
|
|
if !(child in n1Children)
|
|
|
|
insert_edge!(graph, child, newNode)
|
2023-09-07 15:15:21 +02:00
|
|
|
end
|
2023-08-23 19:28:45 +02:00
|
|
|
end
|
|
|
|
|
2023-10-12 17:51:03 +02:00
|
|
|
for parent in n3Parents
|
2023-08-23 19:28:45 +02:00
|
|
|
remove_edge!(graph, n3, parent)
|
2023-10-12 17:51:03 +02:00
|
|
|
insert_edge!(graph, newNode, parent)
|
|
|
|
|
|
|
|
# important! update the parent node's child names in case they are fused compute tasks
|
|
|
|
# needed for compute generation so the fused compute task can correctly match inputs to its component tasks
|
|
|
|
update_child!(graph, parent, Symbol(to_var_name(n3.id)), Symbol(to_var_name(newNode.id)))
|
2023-08-23 19:28:45 +02:00
|
|
|
end
|
|
|
|
|
|
|
|
return get_snapshot_diff(graph)
|
|
|
|
end
|
|
|
|
|
2023-08-29 12:57:46 +02:00
|
|
|
"""
|
|
|
|
node_reduction!(graph::DAG, nodes::Vector{Node})
|
|
|
|
|
|
|
|
Reduce the given nodes together into one node, return the applied difference to the graph.
|
|
|
|
|
|
|
|
For details see [`NodeReduction`](@ref).
|
|
|
|
"""
|
2023-08-23 19:28:45 +02:00
|
|
|
function node_reduction!(graph::DAG, nodes::Vector{Node})
|
2023-10-12 17:51:03 +02:00
|
|
|
@assert is_valid_node_reduction_input(graph, nodes)
|
2023-08-21 12:54:45 +02:00
|
|
|
|
2023-08-23 19:28:45 +02:00
|
|
|
# clear snapshot
|
|
|
|
get_snapshot_diff(graph)
|
2023-08-23 12:51:25 +02:00
|
|
|
|
2023-08-23 19:28:45 +02:00
|
|
|
n1 = nodes[1]
|
2023-10-12 17:51:03 +02:00
|
|
|
n1Children = children(n1)
|
|
|
|
|
|
|
|
n1Parents = Set(n1.parents)
|
2023-08-25 10:48:22 +02:00
|
|
|
|
2023-10-12 17:51:03 +02:00
|
|
|
# set of the new parents of n1
|
|
|
|
newParents = Set{Node}()
|
|
|
|
|
|
|
|
# names of the previous children that n1 now replaces per parent
|
|
|
|
newParentsChildNames = Dict{Node, Symbol}()
|
2023-08-23 12:51:25 +02:00
|
|
|
|
2023-08-23 19:28:45 +02:00
|
|
|
# remove all of the nodes' parents and children and the nodes themselves (except for first node)
|
|
|
|
for i in 2:length(nodes)
|
|
|
|
n = nodes[i]
|
2023-10-12 17:51:03 +02:00
|
|
|
for child in n1Children
|
2023-08-23 19:28:45 +02:00
|
|
|
remove_edge!(graph, child, n)
|
|
|
|
end
|
2023-08-21 12:54:45 +02:00
|
|
|
|
2023-08-23 19:28:45 +02:00
|
|
|
for parent in parents(n)
|
|
|
|
remove_edge!(graph, n, parent)
|
2023-08-23 12:51:25 +02:00
|
|
|
|
2023-08-23 19:28:45 +02:00
|
|
|
# collect all parents
|
2023-10-12 17:51:03 +02:00
|
|
|
push!(newParents, parent)
|
|
|
|
newParentsChildNames[parent] = Symbol(to_var_name(n.id))
|
2023-08-23 19:28:45 +02:00
|
|
|
end
|
2023-08-21 12:54:45 +02:00
|
|
|
|
2023-08-23 19:28:45 +02:00
|
|
|
remove_node!(graph, n)
|
|
|
|
end
|
|
|
|
|
2023-10-12 17:51:03 +02:00
|
|
|
for parent in newParents
|
2023-08-23 19:28:45 +02:00
|
|
|
# now add parents of all input nodes to n1 without duplicates
|
2023-10-12 17:51:03 +02:00
|
|
|
if !(parent in n1Parents)
|
|
|
|
# don't double insert edges
|
|
|
|
insert_edge!(graph, n1, parent)
|
|
|
|
end
|
|
|
|
|
|
|
|
# this has to be done for all parents, even the ones of n1 because they can be duplicate
|
|
|
|
prevChild = newParentsChildNames[parent]
|
|
|
|
update_child!(graph, parent, prevChild, Symbol(to_var_name(n1.id)))
|
2023-08-23 19:28:45 +02:00
|
|
|
end
|
|
|
|
|
|
|
|
return get_snapshot_diff(graph)
|
2023-08-21 12:54:45 +02:00
|
|
|
end
|
|
|
|
|
2023-08-29 12:57:46 +02:00
|
|
|
"""
|
|
|
|
node_split!(graph::DAG, n1::Node)
|
|
|
|
|
|
|
|
Split the given node into one node per parent, return the applied difference to the graph.
|
|
|
|
|
|
|
|
For details see [`NodeSplit`](@ref).
|
|
|
|
"""
|
2023-08-21 12:54:45 +02:00
|
|
|
function node_split!(graph::DAG, n1::Node)
|
2023-10-12 17:51:03 +02:00
|
|
|
@assert is_valid_node_split_input(graph, n1)
|
2023-08-23 19:28:45 +02:00
|
|
|
|
|
|
|
# clear snapshot
|
|
|
|
get_snapshot_diff(graph)
|
|
|
|
|
2023-10-12 17:51:03 +02:00
|
|
|
n1Parents = parents(n1)
|
|
|
|
n1Children = children(n1)
|
2023-08-23 19:28:45 +02:00
|
|
|
|
2023-10-12 17:51:03 +02:00
|
|
|
for parent in n1Parents
|
2023-08-23 19:28:45 +02:00
|
|
|
remove_edge!(graph, n1, parent)
|
|
|
|
end
|
2023-10-12 17:51:03 +02:00
|
|
|
for child in n1Children
|
2023-08-23 19:28:45 +02:00
|
|
|
remove_edge!(graph, child, n1)
|
|
|
|
end
|
|
|
|
remove_node!(graph, n1)
|
|
|
|
|
2023-10-12 17:51:03 +02:00
|
|
|
for parent in n1Parents
|
|
|
|
nCopy = copy(n1)
|
2023-08-23 19:28:45 +02:00
|
|
|
|
2023-10-12 17:51:03 +02:00
|
|
|
insert_node!(graph, nCopy)
|
|
|
|
insert_edge!(graph, nCopy, parent)
|
|
|
|
|
|
|
|
for child in n1Children
|
|
|
|
insert_edge!(graph, child, nCopy)
|
2023-08-23 19:28:45 +02:00
|
|
|
end
|
2023-10-12 17:51:03 +02:00
|
|
|
|
|
|
|
update_child!(graph, parent, Symbol(to_var_name(n1.id)), Symbol(to_var_name(nCopy.id)))
|
2023-08-23 19:28:45 +02:00
|
|
|
end
|
|
|
|
|
|
|
|
return get_snapshot_diff(graph)
|
2023-08-21 12:54:45 +02:00
|
|
|
end
|