diff --git a/src/operation/apply.jl b/src/operation/apply.jl index f61953a..71bf183 100644 --- a/src/operation/apply.jl +++ b/src/operation/apply.jl @@ -1,6 +1,8 @@ -# functions that apply graph operations +""" + apply_all!(graph::DAG) -# applies all unapplied operations in the DAG +Apply all unapplied operations in the DAG. Is automatically called in all functions that require the latest state of the [`DAG`](@ref). +""" function apply_all!(graph::DAG) while !isempty(graph.operationsToApply) # get next operation to apply from front of the deque @@ -15,10 +17,22 @@ function apply_all!(graph::DAG) return nothing end +""" + apply_operation!(graph::DAG, operation::Operation) + +Fallback implementation of apply_operation! for unimplemented operation types, throwing an error. +""" function apply_operation!(graph::DAG, operation::Operation) return error("Unknown operation type!") end +""" + apply_operation!(graph::DAG, operation::NodeFusion) + +Apply the given [`NodeFusion`](@ref) to the graph. Generic wrapper around [`node_fusion!`](@ref). + +Return an [`AppliedNodeFusion`](@ref) object generated from the graph's [`Diff`](@ref). +""" function apply_operation!(graph::DAG, operation::NodeFusion) diff = node_fusion!( graph, @@ -29,37 +43,74 @@ function apply_operation!(graph::DAG, operation::NodeFusion) return AppliedNodeFusion(operation, diff) end +""" + apply_operation!(graph::DAG, operation::NodeReduction) + +Apply the given [`NodeReduction`](@ref) to the graph. Generic wrapper around [`node_reduction!`](@ref). + +Return an [`AppliedNodeReduction`](@ref) object generated from the graph's [`Diff`](@ref). +""" function apply_operation!(graph::DAG, operation::NodeReduction) diff = node_reduction!(graph, operation.input) return AppliedNodeReduction(operation, diff) end +""" + apply_operation!(graph::DAG, operation::NodeSplit) + +Apply the given [`NodeSplit`](@ref) to the graph. Generic wrapper around [`node_split!`](@ref). + +Return an [`AppliedNodeSplit`](@ref) object generated from the graph's [`Diff`](@ref). +""" function apply_operation!(graph::DAG, operation::NodeSplit) diff = node_split!(graph, operation.input) return AppliedNodeSplit(operation, diff) end +""" + revert_operation!(graph::DAG, operation::AppliedOperation) +Fallback implementation of operation reversion for unimplemented operation types, throwing an error. +""" function revert_operation!(graph::DAG, operation::AppliedOperation) return error("Unknown operation type!") end +""" + revert_operation!(graph::DAG, operation::AppliedNodeFusion) + +Revert the applied node fusion on the graph. Return the original [`NodeFusion`](@ref) operation. +""" function revert_operation!(graph::DAG, operation::AppliedNodeFusion) revert_diff!(graph, operation.diff) return operation.operation end +""" + revert_operation!(graph::DAG, operation::AppliedNodeReduction) + +Revert the applied node fusion on the graph. Return the original [`NodeReduction`](@ref) operation. +""" function revert_operation!(graph::DAG, operation::AppliedNodeReduction) revert_diff!(graph, operation.diff) return operation.operation end +""" + revert_operation!(graph::DAG, operation::AppliedNodeSplit) + +Revert the applied node fusion on the graph. Return the original [`NodeSplit`](@ref) operation. +""" function revert_operation!(graph::DAG, operation::AppliedNodeSplit) revert_diff!(graph, operation.diff) return operation.operation end +""" + revert_diff!(graph::DAG, diff::Diff) +Revert the given diff on the graph. Used to revert the individual [`ApplieOperation`](@ref)s with [`revert_operation`](@ref). +""" function revert_diff!(graph::DAG, diff::Diff) # add removed nodes, remove added nodes, same for edges # note the order @@ -76,9 +127,16 @@ function revert_diff!(graph::DAG, diff::Diff) for edge in diff.removedEdges insert_edge!(graph, edge.edge[1], edge.edge[2], false) end + return nothing end -# Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph +""" + node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode) + +Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph. + +For details see [`NodeFusion`](@ref). +""" function node_fusion!( graph::DAG, n1::ComputeTaskNode, @@ -139,6 +197,13 @@ function node_fusion!( return get_snapshot_diff(graph) end +""" + node_reduction!(graph::DAG, nodes::Vector{Node}) + +Reduce the given nodes together into one node, return the applied difference to the graph. + +For details see [`NodeReduction`](@ref). +""" function node_reduction!(graph::DAG, nodes::Vector{Node}) # @assert is_valid_node_reduction_input(graph, nodes) @@ -178,6 +243,13 @@ function node_reduction!(graph::DAG, nodes::Vector{Node}) return get_snapshot_diff(graph) end +""" + node_split!(graph::DAG, n1::Node) + +Split the given node into one node per parent, return the applied difference to the graph. + +For details see [`NodeSplit`](@ref). +""" function node_split!(graph::DAG, n1::Node) # @assert is_valid_node_split_input(graph, n1) diff --git a/src/operation/clean.jl b/src/operation/clean.jl index 142e21a..e42a2a3 100644 --- a/src/operation/clean.jl +++ b/src/operation/clean.jl @@ -1,9 +1,14 @@ -# functions for "cleaning" nodes, i.e. regenerating the possible operations for a node +# These are functions for "cleaning" nodes, i.e. regenerating the possible operations for a node -# function to find node fusions involving the given node if it's a data node -# pushes the found fusion everywhere it needs to be and returns nothing +""" + find_fusions!(graph::DAG, node::DataTaskNode) + +Find node fusions involving the given data node. The function pushes the found [`NodeFusion`](@ref) (if any) everywhere it needs to be and returns nothing. + +Does nothing if the node already has a node fusion set. Since it's a data node, only one node fusion can be possible with it. +""" function find_fusions!(graph::DAG, node::DataTaskNode) - # if there is already a fusion here, skip + # if there is already a fusion here, skip to avoid duplicates if !ismissing(node.nodeFusion) return nothing end @@ -32,7 +37,11 @@ function find_fusions!(graph::DAG, node::DataTaskNode) return nothing end +""" + find_fusions!(graph::DAG, node::ComputeTaskNode) +Find node fusions involving the given compute node. The function pushes the found [`NodeFusion`](@ref)s (if any) everywhere they need to be and returns nothing. +""" function find_fusions!(graph::DAG, node::ComputeTaskNode) # just find fusions in neighbouring DataTaskNodes for child in node.children @@ -46,6 +55,11 @@ function find_fusions!(graph::DAG, node::ComputeTaskNode) return nothing end +""" + find_reductions!(graph::DAG, node::Node) + +Find node reductions involving the given node. The function pushes the found [`NodeReduction`](@ref) (if any) everywhere it needs to be and returns nothing. +""" function find_reductions!(graph::DAG, node::Node) # there can only be one reduction per node, avoid adding duplicates if !ismissing(node.nodeReduction) @@ -91,6 +105,11 @@ function find_reductions!(graph::DAG, node::Node) return nothing end +""" + find_splits!(graph::DAG, node::Node) + +Find the node split of the given node. The function pushes the found [`NodeSplit`](@ref) (if any) everywhere it needs to be and returns nothing. +""" function find_splits!(graph::DAG, node::Node) if !ismissing(node.nodeSplit) return nothing @@ -105,11 +124,17 @@ function find_splits!(graph::DAG, node::Node) return nothing end -# "clean" the operations on a dirty node +""" + clean_node!(graph::DAG, node::Node) + +Sort this node's parent and child sets, then find fusions, reductions and splits involving it. Needs to be called after the node was changed in some way. +""" function clean_node!(graph::DAG, node::Node) sort_node!(node) find_fusions!(graph, node) find_reductions!(graph, node) - return find_splits!(graph, node) + find_splits!(graph, node) + + return nothing end diff --git a/src/operation/find.jl b/src/operation/find.jl index e851e21..89acc3a 100644 --- a/src/operation/find.jl +++ b/src/operation/find.jl @@ -2,6 +2,11 @@ using Base.Threads +""" + insert_operation!(nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock}) + +Insert the given node fusion into its input nodes' operation caches. For the compute nodes, locking via the given `locks` is employed to have safe multi-threading. For a large set of nodes, contention on the locks should be very small. +""" function insert_operation!( nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock}, @@ -20,6 +25,11 @@ function insert_operation!( return nothing end +""" + insert_operation!(nf::NodeReduction) + +Insert the given node reduction into its input nodes' operation caches. This is thread-safe. +""" function insert_operation!(nr::NodeReduction) for n in nr.input n.nodeReduction = nr @@ -27,11 +37,21 @@ function insert_operation!(nr::NodeReduction) return nothing end +""" + insert_operation!(nf::NodeSplit) + +Insert the given node split into its input node's operation cache. This is thread-safe. +""" function insert_operation!(ns::NodeSplit) ns.input.nodeSplit = ns return nothing end +""" + nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}}) + +Insert the node reductions into the graph and the nodes' caches. Employs multithreading for speedup. +""" function nr_insertion!( operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}}, @@ -58,6 +78,11 @@ function nr_insertion!( return nothing end +""" + nf_insertion!(graph::DAG, operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}}) + +Insert the node fusions into the graph and the nodes' caches. Employs multithreading for speedup. +""" function nf_insertion!( graph::DAG, operations::PossibleOperations, @@ -92,6 +117,11 @@ function nf_insertion!( return nothing end +""" + ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplits}}) + +Insert the node splits into the graph and the nodes' caches. Employs multithreading for speedup. +""" function ns_insertion!( operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}}, @@ -118,8 +148,14 @@ function ns_insertion!( return nothing end -# function to generate all possible operations on the graph -function generate_options(graph::DAG) +""" + generate_operations(graph::DAG) + +Generate all possible operations on the graph. Used initially when the graph is freshly assembled or parsed. Uses multithreading for speedup. + +Safely inserts all the found operations into the graph and its nodes. +""" +function generate_operations(graph::DAG) generatedFusions = [Vector{NodeFusion}() for _ in 1:nthreads()] generatedReductions = [Vector{NodeReduction}() for _ in 1:nthreads()] generatedSplits = [Vector{NodeSplit}() for _ in 1:nthreads()] diff --git a/src/operation/get.jl b/src/operation/get.jl index b527fb1..764bb7f 100644 --- a/src/operation/get.jl +++ b/src/operation/get.jl @@ -6,7 +6,7 @@ function get_operations(graph::DAG) apply_all!(graph) if isempty(graph.possibleOperations) - generate_options(graph) + generate_operations(graph) end for node in graph.dirtyNodes diff --git a/src/operation/print.jl b/src/operation/print.jl index 39f5e8c..61239be 100644 --- a/src/operation/print.jl +++ b/src/operation/print.jl @@ -1,3 +1,8 @@ +""" + show(io::IO, ops::PossibleOperations) + +Print a string representation of the set of possible operations to io. +""" function show(io::IO, ops::PossibleOperations) print(io, length(ops.nodeFusions)) println(io, " Node Fusions: ") @@ -16,6 +21,11 @@ function show(io::IO, ops::PossibleOperations) end end +""" + show(io::IO, op::NodeReduction) + +Print a string representation of the node reduction to io. +""" function show(io::IO, op::NodeReduction) print(io, "NR: ") print(io, length(op.input)) @@ -23,11 +33,21 @@ function show(io::IO, op::NodeReduction) return print(io, op.input[1].task) end +""" + show(io::IO, op::NodeSplit) + +Print a string representation of the node split to io. +""" function show(io::IO, op::NodeSplit) print(io, "NS: ") return print(io, op.input.task) end +""" + show(io::IO, op::NodeFusion) + +Print a string representation of the node fusion to io. +""" function show(io::IO, op::NodeFusion) print(io, "NF: ") print(io, op.input[1].task) diff --git a/src/operation/type.jl b/src/operation/type.jl index c9c68b7..c609b3b 100644 --- a/src/operation/type.jl +++ b/src/operation/type.jl @@ -1,33 +1,116 @@ -# An abstract base class for operations -# an operation can be applied to a DAG +""" + Operation + +An abstract base class for operations. An operation can be applied to a [`DAG`](@ref), changing its nodes and edges. + +Possible operations on a [`DAG`](@ref) can be retrieved using [`get_operations`](@ref). + +See also: [`push_operation`](@ref), [`pop_operation`](@ref) +""" abstract type Operation end -# An abstract base class for already applied operations -# an applied operation can be reversed iff it is the last applied operation on the DAG +""" + AppliedOperation + +An abstract base class for already applied operations. +An applied operation can be reversed iff it is the last applied operation on the DAG. +Every applied operation stores a [`Diff`](@ref) from when it was initially applied to be able to revert the operation. + +See also: [`revert_operation`](@ref). +""" abstract type AppliedOperation end +""" + NodeFusion <: Operation + +The NodeFusion operation. Represents the fusing of a chain of compute node -> data node -> compute node. + +After the node fusion is applied, the graph has 2 fewer nodes and edges, and a new [`FusedComputeTask`](@ref) with the two input compute nodes as parts. + +# Requirements for successful application + +A chain of (n1, n2, n3) can be fused if: +- All nodes are in the graph. +- (n1, n2) is an edge in the graph. +- (n2, n3) is an edge in the graph. +- n2 has exactly one parent (n3) and exactly one child (n1). +- n1 has exactly one parent (n2). + +[`is_valid_node_fusion_input`](@ref) can be used to `@assert` these requirements. + +See also: [`can_fuse`](@ref) +""" struct NodeFusion <: Operation input::Tuple{ComputeTaskNode, DataTaskNode, ComputeTaskNode} end +""" + AppliedNodeFusion <: AppliedOperation + +The applied version of the [`NodeFusion`](@ref). +""" struct AppliedNodeFusion <: AppliedOperation operation::NodeFusion diff::Diff end +""" + NodeReduction <: Operation + +The NodeReduction operation. Represents the reduction of two or more nodes with one another. +Only one of the input nodes is kept, while all others are deleted and their parents are accumulated in the kept node's parents instead. + +After the node reduction is applied, the graph has `length(nr.input) - 1` fewer nodes. + +# Requirements for successful application + +A vector of nodes can be reduced if: +- All nodes are in the graph. +- All nodes have the same task type. +- All nodes have the same set of children. + +[`is_valid_node_reduction_input`](@ref) can be used to `@assert` these requirements. + +See also: [`can_reduce`](@ref) +""" struct NodeReduction <: Operation input::Vector{Node} end +""" + AppliedNodeReduction <: AppliedOperation + +The applied version of the [`NodeReduction`](@ref). +""" struct AppliedNodeReduction <: AppliedOperation operation::NodeReduction diff::Diff end +""" + NodeSplit <: Operation + +The NodeSplit operation. Represents the split of its input node into one node for each of its parents. It is the reverse operation to the [`NodeReduction`](@ref). + +# Requirements for successful application + +A node can be split if: +- It is in the graph. +- It has at least 2 parents. + +[`is_valid_node_split_input`](@ref) can be used to `@assert` these requirements. + +See also: [`can_split`](@ref) +""" struct NodeSplit <: Operation input::Node end +""" + AppliedNodeSplit <: AppliedOperation + +The applied version of the [`NodeSplit`](@ref). +""" struct AppliedNodeSplit <: AppliedOperation operation::NodeSplit diff::Diff diff --git a/src/operation/utility.jl b/src/operation/utility.jl index 97297b0..2c1bae5 100644 --- a/src/operation/utility.jl +++ b/src/operation/utility.jl @@ -1,10 +1,19 @@ +""" + isempty(operations::PossibleOperations) +Return whether `operations` is empty, i.e. all of its fields are empty. +""" function isempty(operations::PossibleOperations) return isempty(operations.nodeFusions) && isempty(operations.nodeReductions) && isempty(operations.nodeSplits) end +""" + length(operations::PossibleOperations) + +Return a named tuple with the number of each of the operation types as a named tuple. The fields are named the same as the [`PossibleOperations`](@ref)'. +""" function length(operations::PossibleOperations) return ( nodeFusions = length(operations.nodeFusions), @@ -13,22 +22,41 @@ function length(operations::PossibleOperations) ) end +""" + delete!(operations::PossibleOperations, op::NodeFusion) + +Delete the given node fusion from the possible operations. +""" function delete!(operations::PossibleOperations, op::NodeFusion) delete!(operations.nodeFusions, op) return operations end +""" + delete!(operations::PossibleOperations, op::NodeReduction) + +Delete the given node reduction from the possible operations. +""" function delete!(operations::PossibleOperations, op::NodeReduction) delete!(operations.nodeReductions, op) return operations end +""" + delete!(operations::PossibleOperations, op::NodeSplit) + +Delete the given node split from the possible operations. +""" function delete!(operations::PossibleOperations, op::NodeSplit) delete!(operations.nodeSplits, op) return operations end +""" + can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode) +Return whether the given nodes can be fused. See [`NodeFusion`](@ref) for the requirements. +""" function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode) if !is_child(n1, n2) || !is_child(n2, n3) # the checks are redundant but maybe a good sanity check @@ -44,6 +72,11 @@ function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode) return true end +""" + can_reduce(n1::Node, n2::Node) + +Return whether the given two nodes can be reduced. See [`NodeReduction`](@ref) for the requirements. +""" function can_reduce(n1::Node, n2::Node) if (n1.task != n2.task) return false @@ -86,26 +119,49 @@ function can_reduce(n1::Node, n2::Node) return Set(n1.children) == Set(n2.children) end +""" + can_split(n1::Node) + +Return whether the given node can be split. See [`NodeSplit`](@ref) for the requirements. +""" function can_split(n::Node) return length(parents(n)) > 1 end +""" + ==(op1::Operation, op2::Operation) + +Fallback implementation of operation equality. Return false. Actual comparisons are done by the overloads of same type operation comparisons. +""" function ==(op1::Operation, op2::Operation) return false end +""" + ==(op1::NodeFusion, op2::NodeFusion) + +Equality comparison between two node fusions. Two node fusions are considered equal if they have the same inputs. +""" function ==(op1::NodeFusion, op2::NodeFusion) # there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same return op1.input[2] == op2.input[2] end +""" + ==(op1::NodeReduction, op2::NodeReduction) + +Equality comparison between two node reductions. Two node reductions are considered equal when they have the same inputs. +""" function ==(op1::NodeReduction, op2::NodeReduction) # node reductions are equal exactly if their first input is the same return op1.input[1].id == op2.input[1].id end +""" + ==(op1::NodeSplit, op2::NodeSplit) + +Equality comparison between two node splits. Two node splits are considered equal if they have the same input node. +""" function ==(op1::NodeSplit, op2::NodeSplit) return op1.input == op2.input end - -copy(id::UUID) = UUID(id.value) diff --git a/src/operation/validate.jl b/src/operation/validate.jl index 5b8dca5..5d41e87 100644 --- a/src/operation/validate.jl +++ b/src/operation/validate.jl @@ -2,6 +2,13 @@ # should be called with @assert # the functions throw their own errors though, to still have helpful error messages +""" + is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode) + +Assert for a gven node fusion input whether the nodes can be fused. For the requirements of a node fusion see [`NodeFusion`](@ref). + +Intended for use with `@assert` or `@test`. +""" function is_valid_node_fusion_input( graph::DAG, n1::ComputeTaskNode, @@ -52,6 +59,13 @@ function is_valid_node_fusion_input( return true end +""" + is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node}) + +Assert for a gven node reduction input whether the nodes can be reduced. For the requirements of a node reduction see [`NodeReduction`](@ref). + +Intended for use with `@assert` or `@test`. +""" function is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node}) for n in nodes if n ∉ graph @@ -88,6 +102,13 @@ function is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node}) return true end +""" + is_valid_node_split_input(graph::DAG, n1::Node) + +Assert for a gven node split input whether the node can be split. For the requirements of a node split see [`NodeSplit`](@ref). + +Intended for use with `@assert` or `@test`. +""" function is_valid_node_split_input(graph::DAG, n1::Node) if n1 ∉ graph throw( @@ -108,18 +129,39 @@ function is_valid_node_split_input(graph::DAG, n1::Node) return true end +""" + is_valid(graph::DAG, nr::NodeReduction) + +Assert for a given [`NodeReduction`](@ref) whether it is a valid operation in the graph. + +Intended for use with `@assert` or `@test`. +""" function is_valid(graph::DAG, nr::NodeReduction) @assert is_valid_node_reduction_input(graph, nr.input) @assert nr in graph.possibleOperations.nodeReductions "NodeReduction is not part of the graph's possible operations!" return true end +""" + is_valid(graph::DAG, nr::NodeSplit) + +Assert for a given [`NodeSplit`](@ref) whether it is a valid operation in the graph. + +Intended for use with `@assert` or `@test`. +""" function is_valid(graph::DAG, ns::NodeSplit) @assert is_valid_node_split_input(graph, ns.input) @assert ns in graph.possibleOperations.nodeSplits "NodeSplit is not part of the graph's possible operations!" return true end +""" + is_valid(graph::DAG, nr::NodeFusion) + +Assert for a given [`NodeFusion`](@ref) whether it is a valid operation in the graph. + +Intended for use with `@assert` or `@test`. +""" function is_valid(graph::DAG, nf::NodeFusion) @assert is_valid_node_fusion_input( graph,