diff --git a/src/MetagraphOptimization.jl b/src/MetagraphOptimization.jl index 61432d4..6d7553c 100644 --- a/src/MetagraphOptimization.jl +++ b/src/MetagraphOptimization.jl @@ -26,7 +26,15 @@ include("graph.jl") include("task_functions.jl") include("node_functions.jl") include("graph_functions.jl") -include("graph_operations.jl") + +include("operations/utility.jl") +include("operations/apply.jl") +include("operations/clean.jl") +include("operations/find.jl") +include("operations/get.jl") + +include("graph_interface.jl") + include("utility.jl") include("abc_model/tasks.jl") diff --git a/src/graph_functions.jl b/src/graph_functions.jl index 8ff437e..3d1f3c3 100644 --- a/src/graph_functions.jl +++ b/src/graph_functions.jl @@ -3,31 +3,6 @@ using DataStructures in(node::Node, graph::DAG) = node in graph.nodes in(edge::Edge, graph::DAG) = edge in graph.edges -function isempty(operations::PossibleOperations) - return isempty(operations.nodeFusions) && - isempty(operations.nodeReductions) && - isempty(operations.nodeSplits) -end - -function length(operations::PossibleOperations) - return (nodeFusions = length(operations.nodeFusions), - nodeReductions = length(operations.nodeReductions), - nodeSplits = length(operations.nodeSplits)) -end - -function delete!(operations::PossibleOperations, op::NodeFusion) - delete!(operations.nodeFusions, op) - return operations -end -function delete!(operations::PossibleOperations, op::NodeReduction) - delete!(operations.nodeReductions, op) - return operations -end -function delete!(operations::PossibleOperations, op::NodeSplit) - delete!(operations.nodeSplits, op) - return operations -end - function is_parent(potential_parent, node) return potential_parent in node.parents end @@ -57,13 +32,13 @@ function parents(node::Node) return copy(node.parents) end -# siblings = all children of any parents, no duplicates, does not include the node itself +# siblings = all children of any parents, no duplicates, includes the node itself function siblings(node::Node) result = Set{Node}() + push!(result, node) for parent in node.parents union!(result, parent.children) end - delete!(result, node) return result end @@ -71,10 +46,10 @@ end # partners = all parents of any children, no duplicates, includes the node itself function partners(node::Node) result = Set{Node}() + push!(result, node) for child in node.children union!(result, child.parents) end - delete!(result, node) return result end @@ -259,65 +234,6 @@ function get_exit_node(graph::DAG) error("The given graph has no exit node! It is either empty or not acyclic!") end -function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode) - if !is_child(n1, n2) || !is_child(n2, n3) - # the checks are redundant but maybe a good sanity check - return false - end - - if length(n2.parents) != 1 || length(n2.children) != 1 || length(n1.parents) != 1 - return false - end - - return true -end - -function can_reduce(n1::Node, n2::Node) - if (n1.task != n2.task) - return false - end - - n1_length = length(n1.children) - n2_length = length(n2.children) - - if (n1_length != n2_length) - return false - end - - # this seems to be the most common case so do this first - # doing it manually is a lot faster than using the sets for a general solution - if (n1_length == 2) - if (n1.children[1] != n2.children[1]) - if (n1.children[1] != n2.children[2]) - return false - end - # 1_1 == 2_2 - if (n1.children[2] != n2.children[1]) - return false - end - return true - end - - # 1_1 == 2_1 - if (n1.children[2] != n2.children[2]) - return false - end - return true - end - - # this is simple - if (n1_length == 1) - return n1.children[1] == n2.children[1] - end - - # this takes a long time - return Set(n1.children) == Set(n2.children) -end - -function can_split(n::Node) - return length(parents(n)) > 1 -end - # check whether the given graph is connected function is_valid(graph::DAG) nodeQueue = Deque{Node}() @@ -408,25 +324,3 @@ function length(diff::Diff) removedEdges = length(diff.removedEdges) ) end - -function ==(op1::Operation, op2::Operation) - return false -end - -function ==(op1::NodeFusion, op2::NodeFusion) - # there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same - return op1.input[2] == op2.input[2] -end - -function ==(op1::NodeReduction, op2::NodeReduction) - # only test the ids against each other - return op1.id == op2.id -end - -function ==(op1::NodeSplit, op2::NodeSplit) - return op1.input == op2.input -end - -NodeReduction(input::Vector{Node}) = NodeReduction(input, UUIDs.uuid1(rng[threadid()])) - -copy(id::UUID) = UUID(id.value) diff --git a/src/graph_interface.jl b/src/graph_interface.jl new file mode 100644 index 0000000..7cc414f --- /dev/null +++ b/src/graph_interface.jl @@ -0,0 +1,34 @@ +# user interface on the DAG + +# applies a new operation to the end of the graph +function push_operation!(graph::DAG, operation::Operation) + # 1.: Add the operation to the DAG + push!(graph.operationsToApply, operation) + + return nothing +end + +# reverts the latest applied operation, essentially like a ctrl+z for +function pop_operation!(graph::DAG) + # 1.: Remove the operation from the appliedChain of the DAG + if !isempty(graph.operationsToApply) + pop!(graph.operationsToApply) + elseif !isempty(graph.appliedOperations) + appliedOp = pop!(graph.appliedOperations) + revert_operation!(graph, appliedOp) + else + error("No more operations to pop!") + end + return nothing +end + +can_pop(graph::DAG) = !isempty(graph.operationsToApply) || !isempty(graph.appliedOperations) + +# reset the graph to its initial state with no operations applied +function reset_graph!(graph::DAG) + while (can_pop(graph)) + pop_operation!(graph) + end + + return nothing +end diff --git a/src/graph_operations.jl b/src/graph_operations.jl deleted file mode 100644 index 3aa4466..0000000 --- a/src/graph_operations.jl +++ /dev/null @@ -1,515 +0,0 @@ -using Base.Threads - -# outside interface - -# applies a new operation to the end of the graph -function push_operation!(graph::DAG, operation::Operation) - # 1.: Add the operation to the DAG - push!(graph.operationsToApply, operation) - - return nothing -end - -# reverts the latest applied operation, essentially like a ctrl+z for -function pop_operation!(graph::DAG) - # 1.: Remove the operation from the appliedChain of the DAG - if !isempty(graph.operationsToApply) - pop!(graph.operationsToApply) - elseif !isempty(graph.appliedOperations) - appliedOp = pop!(graph.appliedOperations) - revert_operation!(graph, appliedOp) - else - error("No more operations to pop!") - end - return nothing -end - -can_pop(graph::DAG) = !isempty(graph.operationsToApply) || !isempty(graph.appliedOperations) - -# reset the graph to its initial state with no operations applied -function reset_graph!(graph::DAG) - while (can_pop(graph)) - pop_operation!(graph) - end - - return nothing -end - -# implementation detail functions, don't export - -# applies all unapplied operations in the DAG -function apply_all!(graph::DAG) - while !isempty(graph.operationsToApply) - # get next operation to apply from front of the deque - op = popfirst!(graph.operationsToApply) - - # apply it - appliedOp = apply_operation!(graph, op) - - # push to the end of the appliedOperations deque - push!(graph.appliedOperations, appliedOp) - end - return nothing -end - - -function apply_operation!(graph::DAG, operation::Operation) - error("Unknown operation type!") -end - -function apply_operation!(graph::DAG, operation::NodeFusion) - diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3]) - return AppliedNodeFusion(operation, diff) -end - -function apply_operation!(graph::DAG, operation::NodeReduction) - diff = node_reduction!(graph, operation.input[1], operation.input[2]) - return AppliedNodeReduction(operation, diff) -end - -function apply_operation!(graph::DAG, operation::NodeSplit) - diff = node_split!(graph, operation.input) - return AppliedNodeSplit(operation, diff) -end - - -function revert_operation!(graph::DAG, operation::AppliedOperation) - error("Unknown operation type!") -end - -function revert_operation!(graph::DAG, operation::AppliedNodeFusion) - revert_diff!(graph, operation.diff) - return operation.operation -end - -function revert_operation!(graph::DAG, operation::AppliedNodeReduction) - revert_diff!(graph, operation.diff) - return operation.operation -end - -function revert_operation!(graph::DAG, operation::AppliedNodeSplit) - revert_diff!(graph, operation.diff) - return operation.operation -end - - -function revert_diff!(graph::DAG, diff) - # add removed nodes, remove added nodes, same for edges - # note the order - for edge in diff.addedEdges - remove_edge!(graph, edge, false) - end - for node in diff.addedNodes - remove_node!(graph, node, false) - end - - for node in diff.removedNodes - insert_node!(graph, node, false) - end - for edge in diff.removedEdges - insert_edge!(graph, edge, false) - end -end - -# Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph -function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode) - # clear snapshot - get_snapshot_diff(graph) - - if !(n1 in graph) || !(n2 in graph) || !(n3 in graph) - error("[Node Fusion] The given nodes are not part of the given graph") - end - - if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1) - # the checks are redundant but maybe a good sanity check - error("[Node Fusion] The given nodes are not connected by edges which is required for node fusion") - end - - # save children and parents - n1_children = children(n1) - n3_parents = parents(n3) - n3_children = children(n3) - - if length(n2.parents) > 1 - error("[Node Fusion] The given data node has more than one parent") - end - if length(n2.children) > 1 - error("[Node Fusion] The given data node has more than one child") - end - if length(n1.parents) > 1 - error("[Node Fusion] The given n1 has more than one parent") - end - - required_edge1 = make_edge(n1, n2) - required_edge2 = make_edge(n2, n3) - - # remove the edges and nodes that will be replaced by the fused node - remove_edge!(graph, required_edge1) - remove_edge!(graph, required_edge2) - remove_node!(graph, n1) - remove_node!(graph, n2) - - # get n3's children now so it automatically excludes n2 - n3_children = children(n3) - remove_node!(graph, n3) - - # create new node with the fused compute task - new_node = ComputeTaskNode(FusedComputeTask{typeof(n1.task),typeof(n3.task)}()) - insert_node!(graph, new_node) - - # use a set for combined children of n1 and n3 to not get duplicates - n1and3_children = Set{Node}() - - # remove edges from n1 children to n1 - for child in n1_children - remove_edge!(graph, make_edge(child, n1)) - push!(n1and3_children, child) - end - - # remove edges from n3 children to n3 - for child in n3_children - remove_edge!(graph, make_edge(child, n3)) - push!(n1and3_children, child) - end - - for child in n1and3_children - insert_edge!(graph, make_edge(child, new_node)) - end - - # "repoint" parents of n3 from new node - for parent in n3_parents - remove_edge!(graph, make_edge(n3, parent)) - insert_edge!(graph, make_edge(new_node, parent)) - end - - return get_snapshot_diff(graph) -end - -function node_reduction!(graph::DAG, n1::Node, n2::Node) - # clear snapshot - get_snapshot_diff(graph) - - #=if !(n1 in graph) || !(n2 in graph) - error("[Node Reduction] The given nodes are not part of the given graph") - end=# - - #=if typeof(n1) != typeof(n2) - error("[Node Reduction] The given nodes are not of the same type") - end=# - - # save n2 parents and children - n2_children = children(n2) - n2_parents = Set(n2.parents) - - #=if Set(n2_children) != Set(n1.children) - error("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction") - end=# - - # remove n2 and all its parents and children - for child in n2_children - remove_edge!(graph, make_edge(child, n2)) - end - - - for parent in n2_parents - remove_edge!(graph, make_edge(n2, parent)) - end - - for parent in n1.parents - # delete parents in n1 that already exist in n2 - delete!(n2_parents, parent) - end - - for parent in n2_parents - # now add parents of n2 to n1 without duplicates - insert_edge!(graph, make_edge(n1, parent)) - end - - remove_node!(graph, n2) - - return get_snapshot_diff(graph) -end - -function node_split!(graph::DAG, n1::Node) - # clear snapshot - get_snapshot_diff(graph) - - #=if !(n1 in graph) - error("[Node Split] The given node is not part of the given graph") - end=# - - n1_parents = parents(n1) - n1_children = children(n1) - - #=if length(n1_parents) <= 1 - error("[Node Split] The given node does not have multiple parents which is required for node split") - end=# - - for parent in n1_parents - remove_edge!(graph, make_edge(n1, parent)) - end - for child in n1_children - remove_edge!(graph, make_edge(child, n1)) - end - remove_node!(graph, n1) - - for parent in n1_parents - n_copy = copy(n1) - insert_node!(graph, n_copy) - insert_edge!(graph, make_edge(n_copy, parent)) - - for child in n1_children - insert_edge!(graph, make_edge(child, n_copy)) - end - end - - return get_snapshot_diff(graph) -end - -# function to find node fusions involving the given node if it's a data node -# pushes the found fusion everywhere it needs to be and returns nothing -function find_fusions!(graph::DAG, node::DataTaskNode) - if length(node.parents) != 1 || length(node.children) != 1 - return nothing - end - - child_node = first(node.children) - parent_node = first(node.parents) - - #=if !(child_node in graph) || !(parent_node in graph) - error("Parents/Children that are not in the graph!!!") - end=# - - if length(child_node.parents) != 1 - return nothing - end - - nf = NodeFusion((child_node, node, parent_node)) - push!(graph.possibleOperations.nodeFusions, nf) - push!(child_node.operations, nf) - push!(node.operations, nf) - push!(parent_node.operations, nf) - - return nothing -end - -# function to find node fusions involving the given node if it's a compute node -# pushes the found fusion(s) everywhere it needs to be and returns nothing -function find_fusions!(graph::DAG, node::ComputeTaskNode) - # for loop that always runs once for a scoped block we can break out of - for _ in 1:1 - # assume this node as child of the chain - if length(node.parents) != 1 - break - end - node2 = first(node.parents) - if length(node2.parents) != 1 || length(node2.children) != 1 - break - end - node3 = first(node2.parents) - - #=if !(node2 in graph) || !(node3 in graph) - error("Parents/Children that are not in the graph!!!") - end=# - - nf = NodeFusion((node, node2, node3)) - push!(graph.possibleOperations.nodeFusions, nf) - push!(node.operations, nf) - push!(node2.operations, nf) - push!(node3.operations, nf) - end - - for _ in 1:1 - # assume this node as parent of the chain - if length(node.children) < 1 - break - end - node2 = first(node.children) - if length(node2.parents) != 1 || length(node2.children) != 1 - break - end - node1 = first(node2.children) - if (length(node1.parents) > 1) - break - end - - #=if !(node2 in graph) || !(node1 in graph) - error("Parents/Children that are not in the graph!!!") - end=# - - nf = NodeFusion((node1, node2, node)) - push!(graph.possibleOperations.nodeFusions, nf) - push!(node1.operations, nf) - push!(node2.operations, nf) - push!(node.operations, nf) - end - - return nothing -end - -function find_reductions!(graph::DAG, node::Node) - reductionVector = nothing - # possible reductions are with nodes that are partners, i.e. parents of children - for partner in partners(node) - if can_reduce(node, partner) - if reductionVector === nothing - # only when there's at least one reduction partner, insert the vector - reductionVector = Vector{Node}() - push!(reductionVector, node) - end - - push!(reductionVector, partner) - end - end - - if reductionVector !== nothing - nr = NodeReduction(reductionVector) - push!(graph.possibleOperations.nodeReductions, nr) - for node in reductionVector - push!(node.operations, nr) - end - end - - return nothing -end - -function find_splits!(graph::DAG, node::Node) - if (can_split(node)) - ns = NodeSplit(node) - push!(graph.possibleOperations.nodeSplits, ns) - push!(node.operations, ns) - end - - return nothing -end - -# "clean" the operations on a dirty node -function clean_node!(graph::DAG, node::Node) - find_fusions!(graph, node) - find_reductions!(graph, node) - find_splits!(graph, node) -end - -# function to generate all possible optmizations on the graph -function generate_options(graph::DAG) - generatedOperations = [Vector{Operation}() for _ in 1:nthreads()] - - # make sure the graph is fully generated through - apply_all!(graph) - - nodeArray = collect(graph.nodes) - - # find possible node fusions - @threads for node in nodeArray - if (typeof(node) <: DataTaskNode) - if length(node.parents) != 1 - # data node can only have a single parent - continue - end - parent_node = first(node.parents) - - if length(node.children) != 1 - # this node is an entry node or has multiple children which should not be possible - continue - end - child_node = first(node.children) - if (length(child_node.parents) != 1) - continue - end - - push!(generatedOperations[threadid()], NodeFusion((child_node, node, parent_node))) - end - end - - # TODO figure out how to parallelize this - # find possible node reductions - visitedNodes = Set{Node}() - - for node in graph.nodes - if (node in visitedNodes) - continue - end - - push!(visitedNodes, node) - - reductionVector = nothing - partners_ = partners(node) - - t = typeof(node) - - # possible reductions are with nodes that are partners, i.e. parents of children - for partner in partners_ - # see proof Node Reduction 1 - if (t != typeof(partner)) - continue - end - push!(visitedNodes, partner) - - - if can_reduce(node, partner) - if reductionVector === nothing - # only when there's at least one reduction partner, insert the vector - reductionVector = Vector{Node}() - push!(reductionVector, node) - end - - push!(reductionVector, partner) - end - end - - if reductionVector !== nothing - push!(generatedOperations[threadid()], NodeReduction(reductionVector)) - end - end - - # find possible node splits - @threads for node in nodeArray - if (can_split(node)) - push!(generatedOperations[threadid()], NodeSplit(node)) - end - end - - # TODO figure out how to parallelize this - # insert generated operations from every thread into the final result - for genOps in generatedOperations - for op in genOps - insert_operation!(graph.possibleOperations, op) - end - end - - empty!(graph.dirtyNodes) -end - -function get_operations(graph::DAG) - apply_all!(graph) - - if isempty(graph.possibleOperations) - generate_options(graph) - end - - for node in graph.dirtyNodes - clean_node!(graph, node) - end - empty!(graph.dirtyNodes) - - return graph.possibleOperations -end - -function insert_operation!(operations::PossibleOperations, nf::NodeFusion) - push!(operations.nodeFusions, nf) - push!(nf.input[1].operations, nf) - push!(nf.input[2].operations, nf) - push!(nf.input[3].operations, nf) -end - -function insert_operation!(operations::PossibleOperations, nr::NodeReduction) - push!(operations.nodeReductions, nr) - for n in nr.input - push!(n.operations, nr) - end -end - -function insert_operation!(operations::PossibleOperations, ns::NodeSplit) - push!(operations.nodeSplits, ns) - push!(ns.input.operations, ns) -end diff --git a/src/nodes.jl b/src/nodes.jl index 6edb697..6e92f26 100644 --- a/src/nodes.jl +++ b/src/nodes.jl @@ -2,7 +2,7 @@ using Random using UUIDs using Base.Threads -rng = [Random.MersenneTwister(0) for _ in 1:nthreads()] +rng = [Random.MersenneTwister(0) for _ in 1:32] abstract type Node end diff --git a/src/operations/apply.jl b/src/operations/apply.jl new file mode 100644 index 0000000..ac12cd9 --- /dev/null +++ b/src/operations/apply.jl @@ -0,0 +1,229 @@ +# functions that apply graph operations + +# applies all unapplied operations in the DAG +function apply_all!(graph::DAG) + while !isempty(graph.operationsToApply) + # get next operation to apply from front of the deque + op = popfirst!(graph.operationsToApply) + + # apply it + appliedOp = apply_operation!(graph, op) + + # push to the end of the appliedOperations deque + push!(graph.appliedOperations, appliedOp) + end + return nothing +end + +function apply_operation!(graph::DAG, operation::Operation) + error("Unknown operation type!") +end + +function apply_operation!(graph::DAG, operation::NodeFusion) + diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3]) + return AppliedNodeFusion(operation, diff) +end + +function apply_operation!(graph::DAG, operation::NodeReduction) + diff = node_reduction!(graph, operation.input[1], operation.input[2]) + return AppliedNodeReduction(operation, diff) +end + +function apply_operation!(graph::DAG, operation::NodeSplit) + diff = node_split!(graph, operation.input) + return AppliedNodeSplit(operation, diff) +end + + +function revert_operation!(graph::DAG, operation::AppliedOperation) + error("Unknown operation type!") +end + +function revert_operation!(graph::DAG, operation::AppliedNodeFusion) + revert_diff!(graph, operation.diff) + return operation.operation +end + +function revert_operation!(graph::DAG, operation::AppliedNodeReduction) + revert_diff!(graph, operation.diff) + return operation.operation +end + +function revert_operation!(graph::DAG, operation::AppliedNodeSplit) + revert_diff!(graph, operation.diff) + return operation.operation +end + + +function revert_diff!(graph::DAG, diff) + # add removed nodes, remove added nodes, same for edges + # note the order + for edge in diff.addedEdges + remove_edge!(graph, edge, false) + end + for node in diff.addedNodes + remove_node!(graph, node, false) + end + + for node in diff.removedNodes + insert_node!(graph, node, false) + end + for edge in diff.removedEdges + insert_edge!(graph, edge, false) + end +end + +# Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph +function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode) + # clear snapshot + get_snapshot_diff(graph) + + if !(n1 in graph) || !(n2 in graph) || !(n3 in graph) + error("[Node Fusion] The given nodes are not part of the given graph") + end + + if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1) + # the checks are redundant but maybe a good sanity check + error("[Node Fusion] The given nodes are not connected by edges which is required for node fusion") + end + + # save children and parents + n1_children = children(n1) + n3_parents = parents(n3) + n3_children = children(n3) + + if length(n2.parents) > 1 + error("[Node Fusion] The given data node has more than one parent") + end + if length(n2.children) > 1 + error("[Node Fusion] The given data node has more than one child") + end + if length(n1.parents) > 1 + error("[Node Fusion] The given n1 has more than one parent") + end + + required_edge1 = make_edge(n1, n2) + required_edge2 = make_edge(n2, n3) + + # remove the edges and nodes that will be replaced by the fused node + remove_edge!(graph, required_edge1) + remove_edge!(graph, required_edge2) + remove_node!(graph, n1) + remove_node!(graph, n2) + + # get n3's children now so it automatically excludes n2 + n3_children = children(n3) + remove_node!(graph, n3) + + # create new node with the fused compute task + new_node = ComputeTaskNode(FusedComputeTask{typeof(n1.task),typeof(n3.task)}()) + insert_node!(graph, new_node) + + # use a set for combined children of n1 and n3 to not get duplicates + n1and3_children = Set{Node}() + + # remove edges from n1 children to n1 + for child in n1_children + remove_edge!(graph, make_edge(child, n1)) + push!(n1and3_children, child) + end + + # remove edges from n3 children to n3 + for child in n3_children + remove_edge!(graph, make_edge(child, n3)) + push!(n1and3_children, child) + end + + for child in n1and3_children + insert_edge!(graph, make_edge(child, new_node)) + end + + # "repoint" parents of n3 from new node + for parent in n3_parents + remove_edge!(graph, make_edge(n3, parent)) + insert_edge!(graph, make_edge(new_node, parent)) + end + + return get_snapshot_diff(graph) +end + +function node_reduction!(graph::DAG, n1::Node, n2::Node) + # clear snapshot + get_snapshot_diff(graph) + + #=if !(n1 in graph) || !(n2 in graph) + error("[Node Reduction] The given nodes are not part of the given graph") + end=# + + #=if typeof(n1) != typeof(n2) + error("[Node Reduction] The given nodes are not of the same type") + end=# + + # save n2 parents and children + n2_children = children(n2) + n2_parents = Set(n2.parents) + + #=if Set(n2_children) != Set(n1.children) + error("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction") + end=# + + # remove n2 and all its parents and children + for child in n2_children + remove_edge!(graph, make_edge(child, n2)) + end + + + for parent in n2_parents + remove_edge!(graph, make_edge(n2, parent)) + end + + for parent in n1.parents + # delete parents in n1 that already exist in n2 + delete!(n2_parents, parent) + end + + for parent in n2_parents + # now add parents of n2 to n1 without duplicates + insert_edge!(graph, make_edge(n1, parent)) + end + + remove_node!(graph, n2) + + return get_snapshot_diff(graph) +end + +function node_split!(graph::DAG, n1::Node) + # clear snapshot + get_snapshot_diff(graph) + + #=if !(n1 in graph) + error("[Node Split] The given node is not part of the given graph") + end=# + + n1_parents = parents(n1) + n1_children = children(n1) + + #=if length(n1_parents) <= 1 + error("[Node Split] The given node does not have multiple parents which is required for node split") + end=# + + for parent in n1_parents + remove_edge!(graph, make_edge(n1, parent)) + end + for child in n1_children + remove_edge!(graph, make_edge(child, n1)) + end + remove_node!(graph, n1) + + for parent in n1_parents + n_copy = copy(n1) + insert_node!(graph, n_copy) + insert_edge!(graph, make_edge(n_copy, parent)) + + for child in n1_children + insert_edge!(graph, make_edge(child, n_copy)) + end + end + + return get_snapshot_diff(graph) +end diff --git a/src/operations/clean.jl b/src/operations/clean.jl new file mode 100644 index 0000000..ea85ce1 --- /dev/null +++ b/src/operations/clean.jl @@ -0,0 +1,127 @@ +# functions for "cleaning" nodes, i.e. regenerating the possible operations for a node + +# function to find node fusions involving the given node if it's a data node +# pushes the found fusion everywhere it needs to be and returns nothing +function find_fusions!(graph::DAG, node::DataTaskNode) + if length(node.parents) != 1 || length(node.children) != 1 + return nothing + end + + child_node = first(node.children) + parent_node = first(node.parents) + + #=if !(child_node in graph) || !(parent_node in graph) + error("Parents/Children that are not in the graph!!!") + end=# + + if length(child_node.parents) != 1 + return nothing + end + + nf = NodeFusion((child_node, node, parent_node)) + push!(graph.possibleOperations.nodeFusions, nf) + push!(child_node.operations, nf) + push!(node.operations, nf) + push!(parent_node.operations, nf) + + return nothing +end + +# function to find node fusions involving the given node if it's a compute node +# pushes the found fusion(s) everywhere it needs to be and returns nothing +function find_fusions!(graph::DAG, node::ComputeTaskNode) + # for loop that always runs once for a scoped block we can break out of + for _ in 1:1 + # assume this node as child of the chain + if length(node.parents) != 1 + break + end + node2 = first(node.parents) + if length(node2.parents) != 1 || length(node2.children) != 1 + break + end + node3 = first(node2.parents) + + #=if !(node2 in graph) || !(node3 in graph) + error("Parents/Children that are not in the graph!!!") + end=# + + nf = NodeFusion((node, node2, node3)) + push!(graph.possibleOperations.nodeFusions, nf) + push!(node.operations, nf) + push!(node2.operations, nf) + push!(node3.operations, nf) + end + + for _ in 1:1 + # assume this node as parent of the chain + if length(node.children) < 1 + break + end + node2 = first(node.children) + if length(node2.parents) != 1 || length(node2.children) != 1 + break + end + node1 = first(node2.children) + if (length(node1.parents) > 1) + break + end + + #=if !(node2 in graph) || !(node1 in graph) + error("Parents/Children that are not in the graph!!!") + end=# + + nf = NodeFusion((node1, node2, node)) + push!(graph.possibleOperations.nodeFusions, nf) + push!(node1.operations, nf) + push!(node2.operations, nf) + push!(node.operations, nf) + end + + return nothing +end + +function find_reductions!(graph::DAG, node::Node) + reductionVector = nothing + # possible reductions are with nodes that are partners, i.e. parents of children + partners_ = partners(node) + delete!(partners_, node) + for partner in partners_ + if can_reduce(node, partner) + if reductionVector === nothing + # only when there's at least one reduction partner, insert the vector + reductionVector = Vector{Node}() + push!(reductionVector, node) + end + + push!(reductionVector, partner) + end + end + + if reductionVector !== nothing + nr = NodeReduction(reductionVector) + push!(graph.possibleOperations.nodeReductions, nr) + for node in reductionVector + push!(node.operations, nr) + end + end + + return nothing +end + +function find_splits!(graph::DAG, node::Node) + if (can_split(node)) + ns = NodeSplit(node) + push!(graph.possibleOperations.nodeSplits, ns) + push!(node.operations, ns) + end + + return nothing +end + +# "clean" the operations on a dirty node +function clean_node!(graph::DAG, node::Node) + find_fusions!(graph, node) + find_reductions!(graph, node) + find_splits!(graph, node) +end diff --git a/src/operations/find.jl b/src/operations/find.jl new file mode 100644 index 0000000..c825249 --- /dev/null +++ b/src/operations/find.jl @@ -0,0 +1,173 @@ +# functions that find operations on the inital graph + +using Base.Threads + +function insert_operation!(operations::PossibleOperations, nf::NodeFusion, locks::Dict{Node, SpinLock}) + push!(operations.nodeFusions, nf) + n1 = nf.input[1]; n2 = nf.input[2]; n3 = nf.input[3] + + lock(locks[n1]) do; push!(nf.input[1].operations, nf); end + lock(locks[n2]) do; push!(nf.input[2].operations, nf); end + lock(locks[n3]) do; push!(nf.input[3].operations, nf); end +end + +function insert_operation!(operations::PossibleOperations, nr::NodeReduction, locks::Dict{Node, SpinLock}) + push!(operations.nodeReductions, nr) + for n in nr.input + lock(locks[n]) do; push!(n.operations, nr); end + end +end + +function insert_operation!(operations::PossibleOperations, ns::NodeSplit, locks::Dict{Node, SpinLock}) + push!(operations.nodeSplits, ns) + lock(locks[ns.input]) do; push!(ns.input.operations, ns); end +end + +function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}}, locks::Dict{Node, SpinLock}) + for vec in nodeReductions + for op in vec + insert_operation!(operations, op, locks) + end + end +end + +function nf_insertion!(operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}}, locks::Dict{Node, SpinLock}) + for vec in nodeFusions + for op in vec + insert_operation!(operations, op, locks) + end + end +end + +function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}}, locks::Dict{Node, SpinLock}) + for vec in nodeSplits + for op in vec + insert_operation!(operations, op, locks) + end + end +end + +# function to generate all possible operations on the graph +function generate_options(graph::DAG) + locks = Dict{Node, SpinLock}() + for n in graph.nodes + locks[n] = SpinLock() + end + + generatedFusions = [Vector{NodeFusion}() for _ in 1:nthreads()] + generatedReductions = [Vector{NodeReduction}() for _ in 1:nthreads()] + generatedSplits = [Vector{NodeSplit}() for _ in 1:nthreads()] + + # make sure the graph is fully generated through + apply_all!(graph) + + # --- find possible node reductions --- + + # find some useful partition of nodes without generating duplicate node reductions + nodePartitions = [Vector{Set{Node}}() for _ in 1:nthreads()] + avgNodes = 0. # the average number of nodes across all the node partitions + nodeSet = copy(graph.nodes) + + partitionPointer = 1 + rotatePointer(i) = (i % nthreads()) + 1 + + while !isempty(nodeSet) + # cycle partition pointer to a set with fewer than average nodes + nodes = partners(first(nodeSet)) + setdiff!(nodeSet, nodes) + + if length(nodes) == 1 + # nothing to reduce here anyways + continue + end + + partitionPointer = rotatePointer(partitionPointer) + + push!(nodePartitions[partitionPointer], nodes) + avgNodes = avgNodes + length(nodes) / nthreads() + end + + @threads for partition in nodePartitions + for partners_ in partition + reductionVector = nothing + + node = pop!(partners_) + + t = typeof(node) + + # possible reductions are with nodes that are partners, i.e. parents of children + for partner in partners_ + if (t != typeof(partner)) + continue + end + + if !can_reduce(node, partner) + continue + end + + if reductionVector === nothing + # only when there's at least one reduction partner, insert the vector + reductionVector = Vector{Node}() + push!(reductionVector, node) + end + + push!(reductionVector, partner) + end + + if reductionVector !== nothing + push!(generatedReductions[threadid()], NodeReduction(reductionVector)) + end + end + end + + # launch thread for node reduction insertion + nr_task = @task nr_insertion!(graph.possibleOperations, generatedReductions, locks) + schedule(nr_task) + + # --- find possible node fusions --- + nodeArray = collect(graph.nodes) + + @threads for node in nodeArray + if (typeof(node) <: DataTaskNode) + if length(node.parents) != 1 + # data node can only have a single parent + continue + end + parent_node = first(node.parents) + + if length(node.children) != 1 + # this node is an entry node or has multiple children which should not be possible + continue + end + child_node = first(node.children) + if (length(child_node.parents) != 1) + continue + end + + push!(generatedFusions[threadid()], NodeFusion((child_node, node, parent_node))) + end + end + + # launch thread for node fusion insertion + nf_task = @task nf_insertion!(graph.possibleOperations, generatedFusions, locks) + schedule(nf_task) + + # find possible node splits + @threads for node in nodeArray + if (can_split(node)) + push!(generatedSplits[threadid()], NodeSplit(node)) + end + end + + # launch thread for node split insertion + ns_task = @task ns_insertion!(graph.possibleOperations, generatedSplits, locks) + schedule(ns_task) + + empty!(graph.dirtyNodes) + + wait(nr_task) + wait(nf_task) + wait(ns_task) + + return nothing +end diff --git a/src/operations/get.jl b/src/operations/get.jl new file mode 100644 index 0000000..81ef1c0 --- /dev/null +++ b/src/operations/get.jl @@ -0,0 +1,18 @@ +# function to return the possible operations of a graph + +using Base.Threads + +function get_operations(graph::DAG) + apply_all!(graph) + + if isempty(graph.possibleOperations) + generate_options(graph) + end + + for node in graph.dirtyNodes + clean_node!(graph, node) + end + empty!(graph.dirtyNodes) + + return graph.possibleOperations +end diff --git a/src/operations/utility.jl b/src/operations/utility.jl new file mode 100644 index 0000000..4ce7b3a --- /dev/null +++ b/src/operations/utility.jl @@ -0,0 +1,109 @@ + +function isempty(operations::PossibleOperations) + return isempty(operations.nodeFusions) && + isempty(operations.nodeReductions) && + isempty(operations.nodeSplits) +end + +function length(operations::PossibleOperations) + return (nodeFusions = length(operations.nodeFusions), + nodeReductions = length(operations.nodeReductions), + nodeSplits = length(operations.nodeSplits)) +end + +function delete!(operations::PossibleOperations, op::NodeFusion) + delete!(operations.nodeFusions, op) + return operations +end + +function delete!(operations::PossibleOperations, op::NodeReduction) + delete!(operations.nodeReductions, op) + return operations +end + +function delete!(operations::PossibleOperations, op::NodeSplit) + delete!(operations.nodeSplits, op) + return operations +end + + +function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode) + if !is_child(n1, n2) || !is_child(n2, n3) + # the checks are redundant but maybe a good sanity check + return false + end + + if length(n2.parents) != 1 || length(n2.children) != 1 || length(n1.parents) != 1 + return false + end + + return true +end + +function can_reduce(n1::Node, n2::Node) + if (n1.task != n2.task) + return false + end + + n1_length = length(n1.children) + n2_length = length(n2.children) + + if (n1_length != n2_length) + return false + end + + # this seems to be the most common case so do this first + # doing it manually is a lot faster than using the sets for a general solution + if (n1_length == 2) + if (n1.children[1] != n2.children[1]) + if (n1.children[1] != n2.children[2]) + return false + end + # 1_1 == 2_2 + if (n1.children[2] != n2.children[1]) + return false + end + return true + end + + # 1_1 == 2_1 + if (n1.children[2] != n2.children[2]) + return false + end + return true + end + + # this is simple + if (n1_length == 1) + return n1.children[1] == n2.children[1] + end + + # this takes a long time + return Set(n1.children) == Set(n2.children) +end + +function can_split(n::Node) + return length(parents(n)) > 1 +end + +function ==(op1::Operation, op2::Operation) + return false +end + +function ==(op1::NodeFusion, op2::NodeFusion) + # there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same + return op1.input[2] == op2.input[2] +end + +function ==(op1::NodeReduction, op2::NodeReduction) + # only test the ids against each other + return op1.id == op2.id +end + +function ==(op1::NodeSplit, op2::NodeSplit) + return op1.input == op2.input +end + +NodeReduction(input::Vector{Node}) = NodeReduction(input, UUIDs.uuid1(rng[threadid()])) + +copy(id::UUID) = UUID(id.value) diff --git a/test/unit_tests_graph.jl b/test/unit_tests_graph.jl index 212c0f8..85c9d0e 100644 --- a/test/unit_tests_graph.jl +++ b/test/unit_tests_graph.jl @@ -127,8 +127,8 @@ import MetagraphOptimization.partners @test MetagraphOptimization.get_exit_node(graph) == d_exit - @test length(partners(s0)) == 0 - @test length(siblings(s0)) == 0 + @test length(partners(s0)) == 1 + @test length(siblings(s0)) == 1 operations = get_operations(graph) @test length(operations) == (nodeFusions = 10, nodeReductions = 0, nodeSplits = 0)