""" apply_all!(graph::DAG) Apply all unapplied operations in the DAG. Is automatically called in all functions that require the latest state of the [`DAG`](@ref). """ function apply_all!(graph::DAG) while !isempty(graph.operationsToApply) # get next operation to apply from front of the deque op = popfirst!(graph.operationsToApply) # apply it appliedOp = apply_operation!(graph, op) # push to the end of the appliedOperations deque push!(graph.appliedOperations, appliedOp) end return nothing end """ apply_operation!(graph::DAG, operation::Operation) Fallback implementation of apply_operation! for unimplemented operation types, throwing an error. """ function apply_operation!(graph::DAG, operation::Operation) return error("Unknown operation type!") end """ apply_operation!(graph::DAG, operation::NodeFusion) Apply the given [`NodeFusion`](@ref) to the graph. Generic wrapper around [`node_fusion!`](@ref). Return an [`AppliedNodeFusion`](@ref) object generated from the graph's [`Diff`](@ref). """ function apply_operation!(graph::DAG, operation::NodeFusion) diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3]) graph.properties += GraphProperties(diff) return AppliedNodeFusion(operation, diff) end """ apply_operation!(graph::DAG, operation::NodeReduction) Apply the given [`NodeReduction`](@ref) to the graph. Generic wrapper around [`node_reduction!`](@ref). Return an [`AppliedNodeReduction`](@ref) object generated from the graph's [`Diff`](@ref). """ function apply_operation!(graph::DAG, operation::NodeReduction) diff = node_reduction!(graph, operation.input) graph.properties += GraphProperties(diff) return AppliedNodeReduction(operation, diff) end """ apply_operation!(graph::DAG, operation::NodeSplit) Apply the given [`NodeSplit`](@ref) to the graph. Generic wrapper around [`node_split!`](@ref). Return an [`AppliedNodeSplit`](@ref) object generated from the graph's [`Diff`](@ref). """ function apply_operation!(graph::DAG, operation::NodeSplit) diff = node_split!(graph, operation.input) graph.properties += GraphProperties(diff) return AppliedNodeSplit(operation, diff) end """ revert_operation!(graph::DAG, operation::AppliedOperation) Fallback implementation of operation reversion for unimplemented operation types, throwing an error. """ function revert_operation!(graph::DAG, operation::AppliedOperation) return error("Unknown operation type!") end """ revert_operation!(graph::DAG, operation::AppliedNodeFusion) Revert the applied node fusion on the graph. Return the original [`NodeFusion`](@ref) operation. """ function revert_operation!(graph::DAG, operation::AppliedNodeFusion) revert_diff!(graph, operation.diff) return operation.operation end """ revert_operation!(graph::DAG, operation::AppliedNodeReduction) Revert the applied node fusion on the graph. Return the original [`NodeReduction`](@ref) operation. """ function revert_operation!(graph::DAG, operation::AppliedNodeReduction) revert_diff!(graph, operation.diff) return operation.operation end """ revert_operation!(graph::DAG, operation::AppliedNodeSplit) Revert the applied node fusion on the graph. Return the original [`NodeSplit`](@ref) operation. """ function revert_operation!(graph::DAG, operation::AppliedNodeSplit) revert_diff!(graph, operation.diff) return operation.operation end """ revert_diff!(graph::DAG, diff::Diff) Revert the given diff on the graph. Used to revert the individual [`AppliedOperation`](@ref)s with [`revert_operation!`](@ref). """ function revert_diff!(graph::DAG, diff::Diff) # add removed nodes, remove added nodes, same for edges # note the order for edge in diff.addedEdges remove_edge!(graph, edge.edge[1], edge.edge[2], track = false) end for node in diff.addedNodes remove_node!(graph, node, track = false) end for node in diff.removedNodes insert_node!(graph, node, track = false) end for edge in diff.removedEdges insert_edge!(graph, edge.edge[1], edge.edge[2], track = false) end for (node, task) in diff.updatedChildren # node must be fused compute task at this point @assert typeof(node.task) <: FusedComputeTask node.task = task end graph.properties -= GraphProperties(diff) return nothing end """ node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode) Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph. For details see [`NodeFusion`](@ref). """ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode) @assert is_valid_node_fusion_input(graph, n1, n2, n3) # clear snapshot get_snapshot_diff(graph) # save children and parents n1Children = children(n1) n3Parents = parents(n3) n1Task = copy(n1.task) n3Task = copy(n3.task) # assemble the input node vectors of n1 and n3 to save into the FusedComputeTask n1Inputs = Vector{Symbol}() for child in n1Children push!(n1Inputs, Symbol(to_var_name(child.id))) end # remove the edges and nodes that will be replaced by the fused node remove_edge!(graph, n1, n2) remove_edge!(graph, n2, n3) remove_node!(graph, n1) remove_node!(graph, n2) # get n3's children now so it automatically excludes n2 n3Children = children(n3) n3Inputs = Vector{Symbol}() for child in n3Children push!(n3Inputs, Symbol(to_var_name(child.id))) end remove_node!(graph, n3) # create new node with the fused compute task newNode = ComputeTaskNode(FusedComputeTask(n1Task, n3Task, n1Inputs, Symbol(to_var_name(n2.id)), n3Inputs)) insert_node!(graph, newNode) for child in n1Children remove_edge!(graph, child, n1) insert_edge!(graph, child, newNode) end for child in n3Children remove_edge!(graph, child, n3) if !(child in n1Children) insert_edge!(graph, child, newNode) end end for parent in n3Parents remove_edge!(graph, n3, parent) insert_edge!(graph, newNode, parent) # important! update the parent node's child names in case they are fused compute tasks # needed for compute generation so the fused compute task can correctly match inputs to its component tasks update_child!(graph, parent, Symbol(to_var_name(n3.id)), Symbol(to_var_name(newNode.id))) end return get_snapshot_diff(graph) end """ node_reduction!(graph::DAG, nodes::Vector{Node}) Reduce the given nodes together into one node, return the applied difference to the graph. For details see [`NodeReduction`](@ref). """ function node_reduction!(graph::DAG, nodes::Vector{Node}) @assert is_valid_node_reduction_input(graph, nodes) # clear snapshot get_snapshot_diff(graph) n1 = nodes[1] n1Children = children(n1) n1Parents = Set(n1.parents) # set of the new parents of n1 newParents = Set{Node}() # names of the previous children that n1 now replaces per parent newParentsChildNames = Dict{Node, Symbol}() # remove all of the nodes' parents and children and the nodes themselves (except for first node) for i in 2:length(nodes) n = nodes[i] for child in n1Children remove_edge!(graph, child, n) end for parent in parents(n) remove_edge!(graph, n, parent) # collect all parents push!(newParents, parent) newParentsChildNames[parent] = Symbol(to_var_name(n.id)) end remove_node!(graph, n) end for parent in newParents # now add parents of all input nodes to n1 without duplicates if !(parent in n1Parents) # don't double insert edges insert_edge!(graph, n1, parent) end # this has to be done for all parents, even the ones of n1 because they can be duplicate prevChild = newParentsChildNames[parent] update_child!(graph, parent, prevChild, Symbol(to_var_name(n1.id))) end return get_snapshot_diff(graph) end """ node_split!(graph::DAG, n1::Node) Split the given node into one node per parent, return the applied difference to the graph. For details see [`NodeSplit`](@ref). """ function node_split!(graph::DAG, n1::Node) @assert is_valid_node_split_input(graph, n1) # clear snapshot get_snapshot_diff(graph) n1Parents = parents(n1) n1Children = children(n1) for parent in n1Parents remove_edge!(graph, n1, parent) end for child in n1Children remove_edge!(graph, child, n1) end remove_node!(graph, n1) for parent in n1Parents nCopy = copy(n1) insert_node!(graph, nCopy) insert_edge!(graph, nCopy, parent) for child in n1Children insert_edge!(graph, child, nCopy) end update_child!(graph, parent, Symbol(to_var_name(n1.id)), Symbol(to_var_name(nCopy.id))) end return get_snapshot_diff(graph) end