Some file reordering and parallelization work

This commit is contained in:
Anton Reinhard 2023-08-21 12:54:45 +02:00
parent 895e4b2a12
commit 2e96e6520e
11 changed files with 705 additions and 628 deletions

View File

@ -26,7 +26,15 @@ include("graph.jl")
include("task_functions.jl")
include("node_functions.jl")
include("graph_functions.jl")
include("graph_operations.jl")
include("operations/utility.jl")
include("operations/apply.jl")
include("operations/clean.jl")
include("operations/find.jl")
include("operations/get.jl")
include("graph_interface.jl")
include("utility.jl")
include("abc_model/tasks.jl")

View File

@ -3,31 +3,6 @@ using DataStructures
in(node::Node, graph::DAG) = node in graph.nodes
in(edge::Edge, graph::DAG) = edge in graph.edges
function isempty(operations::PossibleOperations)
return isempty(operations.nodeFusions) &&
isempty(operations.nodeReductions) &&
isempty(operations.nodeSplits)
end
function length(operations::PossibleOperations)
return (nodeFusions = length(operations.nodeFusions),
nodeReductions = length(operations.nodeReductions),
nodeSplits = length(operations.nodeSplits))
end
function delete!(operations::PossibleOperations, op::NodeFusion)
delete!(operations.nodeFusions, op)
return operations
end
function delete!(operations::PossibleOperations, op::NodeReduction)
delete!(operations.nodeReductions, op)
return operations
end
function delete!(operations::PossibleOperations, op::NodeSplit)
delete!(operations.nodeSplits, op)
return operations
end
function is_parent(potential_parent, node)
return potential_parent in node.parents
end
@ -57,13 +32,13 @@ function parents(node::Node)
return copy(node.parents)
end
# siblings = all children of any parents, no duplicates, does not include the node itself
# siblings = all children of any parents, no duplicates, includes the node itself
function siblings(node::Node)
result = Set{Node}()
push!(result, node)
for parent in node.parents
union!(result, parent.children)
end
delete!(result, node)
return result
end
@ -71,10 +46,10 @@ end
# partners = all parents of any children, no duplicates, includes the node itself
function partners(node::Node)
result = Set{Node}()
push!(result, node)
for child in node.children
union!(result, child.parents)
end
delete!(result, node)
return result
end
@ -259,65 +234,6 @@ function get_exit_node(graph::DAG)
error("The given graph has no exit node! It is either empty or not acyclic!")
end
function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
if !is_child(n1, n2) || !is_child(n2, n3)
# the checks are redundant but maybe a good sanity check
return false
end
if length(n2.parents) != 1 || length(n2.children) != 1 || length(n1.parents) != 1
return false
end
return true
end
function can_reduce(n1::Node, n2::Node)
if (n1.task != n2.task)
return false
end
n1_length = length(n1.children)
n2_length = length(n2.children)
if (n1_length != n2_length)
return false
end
# this seems to be the most common case so do this first
# doing it manually is a lot faster than using the sets for a general solution
if (n1_length == 2)
if (n1.children[1] != n2.children[1])
if (n1.children[1] != n2.children[2])
return false
end
# 1_1 == 2_2
if (n1.children[2] != n2.children[1])
return false
end
return true
end
# 1_1 == 2_1
if (n1.children[2] != n2.children[2])
return false
end
return true
end
# this is simple
if (n1_length == 1)
return n1.children[1] == n2.children[1]
end
# this takes a long time
return Set(n1.children) == Set(n2.children)
end
function can_split(n::Node)
return length(parents(n)) > 1
end
# check whether the given graph is connected
function is_valid(graph::DAG)
nodeQueue = Deque{Node}()
@ -408,25 +324,3 @@ function length(diff::Diff)
removedEdges = length(diff.removedEdges)
)
end
function ==(op1::Operation, op2::Operation)
return false
end
function ==(op1::NodeFusion, op2::NodeFusion)
# there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same
return op1.input[2] == op2.input[2]
end
function ==(op1::NodeReduction, op2::NodeReduction)
# only test the ids against each other
return op1.id == op2.id
end
function ==(op1::NodeSplit, op2::NodeSplit)
return op1.input == op2.input
end
NodeReduction(input::Vector{Node}) = NodeReduction(input, UUIDs.uuid1(rng[threadid()]))
copy(id::UUID) = UUID(id.value)

34
src/graph_interface.jl Normal file
View File

@ -0,0 +1,34 @@
# user interface on the DAG
# applies a new operation to the end of the graph
function push_operation!(graph::DAG, operation::Operation)
# 1.: Add the operation to the DAG
push!(graph.operationsToApply, operation)
return nothing
end
# reverts the latest applied operation, essentially like a ctrl+z for
function pop_operation!(graph::DAG)
# 1.: Remove the operation from the appliedChain of the DAG
if !isempty(graph.operationsToApply)
pop!(graph.operationsToApply)
elseif !isempty(graph.appliedOperations)
appliedOp = pop!(graph.appliedOperations)
revert_operation!(graph, appliedOp)
else
error("No more operations to pop!")
end
return nothing
end
can_pop(graph::DAG) = !isempty(graph.operationsToApply) || !isempty(graph.appliedOperations)
# reset the graph to its initial state with no operations applied
function reset_graph!(graph::DAG)
while (can_pop(graph))
pop_operation!(graph)
end
return nothing
end

View File

@ -1,515 +0,0 @@
using Base.Threads
# outside interface
# applies a new operation to the end of the graph
function push_operation!(graph::DAG, operation::Operation)
# 1.: Add the operation to the DAG
push!(graph.operationsToApply, operation)
return nothing
end
# reverts the latest applied operation, essentially like a ctrl+z for
function pop_operation!(graph::DAG)
# 1.: Remove the operation from the appliedChain of the DAG
if !isempty(graph.operationsToApply)
pop!(graph.operationsToApply)
elseif !isempty(graph.appliedOperations)
appliedOp = pop!(graph.appliedOperations)
revert_operation!(graph, appliedOp)
else
error("No more operations to pop!")
end
return nothing
end
can_pop(graph::DAG) = !isempty(graph.operationsToApply) || !isempty(graph.appliedOperations)
# reset the graph to its initial state with no operations applied
function reset_graph!(graph::DAG)
while (can_pop(graph))
pop_operation!(graph)
end
return nothing
end
# implementation detail functions, don't export
# applies all unapplied operations in the DAG
function apply_all!(graph::DAG)
while !isempty(graph.operationsToApply)
# get next operation to apply from front of the deque
op = popfirst!(graph.operationsToApply)
# apply it
appliedOp = apply_operation!(graph, op)
# push to the end of the appliedOperations deque
push!(graph.appliedOperations, appliedOp)
end
return nothing
end
function apply_operation!(graph::DAG, operation::Operation)
error("Unknown operation type!")
end
function apply_operation!(graph::DAG, operation::NodeFusion)
diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3])
return AppliedNodeFusion(operation, diff)
end
function apply_operation!(graph::DAG, operation::NodeReduction)
diff = node_reduction!(graph, operation.input[1], operation.input[2])
return AppliedNodeReduction(operation, diff)
end
function apply_operation!(graph::DAG, operation::NodeSplit)
diff = node_split!(graph, operation.input)
return AppliedNodeSplit(operation, diff)
end
function revert_operation!(graph::DAG, operation::AppliedOperation)
error("Unknown operation type!")
end
function revert_operation!(graph::DAG, operation::AppliedNodeFusion)
revert_diff!(graph, operation.diff)
return operation.operation
end
function revert_operation!(graph::DAG, operation::AppliedNodeReduction)
revert_diff!(graph, operation.diff)
return operation.operation
end
function revert_operation!(graph::DAG, operation::AppliedNodeSplit)
revert_diff!(graph, operation.diff)
return operation.operation
end
function revert_diff!(graph::DAG, diff)
# add removed nodes, remove added nodes, same for edges
# note the order
for edge in diff.addedEdges
remove_edge!(graph, edge, false)
end
for node in diff.addedNodes
remove_node!(graph, node, false)
end
for node in diff.removedNodes
insert_node!(graph, node, false)
end
for edge in diff.removedEdges
insert_edge!(graph, edge, false)
end
end
# Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph
function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
# clear snapshot
get_snapshot_diff(graph)
if !(n1 in graph) || !(n2 in graph) || !(n3 in graph)
error("[Node Fusion] The given nodes are not part of the given graph")
end
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
# the checks are redundant but maybe a good sanity check
error("[Node Fusion] The given nodes are not connected by edges which is required for node fusion")
end
# save children and parents
n1_children = children(n1)
n3_parents = parents(n3)
n3_children = children(n3)
if length(n2.parents) > 1
error("[Node Fusion] The given data node has more than one parent")
end
if length(n2.children) > 1
error("[Node Fusion] The given data node has more than one child")
end
if length(n1.parents) > 1
error("[Node Fusion] The given n1 has more than one parent")
end
required_edge1 = make_edge(n1, n2)
required_edge2 = make_edge(n2, n3)
# remove the edges and nodes that will be replaced by the fused node
remove_edge!(graph, required_edge1)
remove_edge!(graph, required_edge2)
remove_node!(graph, n1)
remove_node!(graph, n2)
# get n3's children now so it automatically excludes n2
n3_children = children(n3)
remove_node!(graph, n3)
# create new node with the fused compute task
new_node = ComputeTaskNode(FusedComputeTask{typeof(n1.task),typeof(n3.task)}())
insert_node!(graph, new_node)
# use a set for combined children of n1 and n3 to not get duplicates
n1and3_children = Set{Node}()
# remove edges from n1 children to n1
for child in n1_children
remove_edge!(graph, make_edge(child, n1))
push!(n1and3_children, child)
end
# remove edges from n3 children to n3
for child in n3_children
remove_edge!(graph, make_edge(child, n3))
push!(n1and3_children, child)
end
for child in n1and3_children
insert_edge!(graph, make_edge(child, new_node))
end
# "repoint" parents of n3 from new node
for parent in n3_parents
remove_edge!(graph, make_edge(n3, parent))
insert_edge!(graph, make_edge(new_node, parent))
end
return get_snapshot_diff(graph)
end
function node_reduction!(graph::DAG, n1::Node, n2::Node)
# clear snapshot
get_snapshot_diff(graph)
#=if !(n1 in graph) || !(n2 in graph)
error("[Node Reduction] The given nodes are not part of the given graph")
end=#
#=if typeof(n1) != typeof(n2)
error("[Node Reduction] The given nodes are not of the same type")
end=#
# save n2 parents and children
n2_children = children(n2)
n2_parents = Set(n2.parents)
#=if Set(n2_children) != Set(n1.children)
error("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction")
end=#
# remove n2 and all its parents and children
for child in n2_children
remove_edge!(graph, make_edge(child, n2))
end
for parent in n2_parents
remove_edge!(graph, make_edge(n2, parent))
end
for parent in n1.parents
# delete parents in n1 that already exist in n2
delete!(n2_parents, parent)
end
for parent in n2_parents
# now add parents of n2 to n1 without duplicates
insert_edge!(graph, make_edge(n1, parent))
end
remove_node!(graph, n2)
return get_snapshot_diff(graph)
end
function node_split!(graph::DAG, n1::Node)
# clear snapshot
get_snapshot_diff(graph)
#=if !(n1 in graph)
error("[Node Split] The given node is not part of the given graph")
end=#
n1_parents = parents(n1)
n1_children = children(n1)
#=if length(n1_parents) <= 1
error("[Node Split] The given node does not have multiple parents which is required for node split")
end=#
for parent in n1_parents
remove_edge!(graph, make_edge(n1, parent))
end
for child in n1_children
remove_edge!(graph, make_edge(child, n1))
end
remove_node!(graph, n1)
for parent in n1_parents
n_copy = copy(n1)
insert_node!(graph, n_copy)
insert_edge!(graph, make_edge(n_copy, parent))
for child in n1_children
insert_edge!(graph, make_edge(child, n_copy))
end
end
return get_snapshot_diff(graph)
end
# function to find node fusions involving the given node if it's a data node
# pushes the found fusion everywhere it needs to be and returns nothing
function find_fusions!(graph::DAG, node::DataTaskNode)
if length(node.parents) != 1 || length(node.children) != 1
return nothing
end
child_node = first(node.children)
parent_node = first(node.parents)
#=if !(child_node in graph) || !(parent_node in graph)
error("Parents/Children that are not in the graph!!!")
end=#
if length(child_node.parents) != 1
return nothing
end
nf = NodeFusion((child_node, node, parent_node))
push!(graph.possibleOperations.nodeFusions, nf)
push!(child_node.operations, nf)
push!(node.operations, nf)
push!(parent_node.operations, nf)
return nothing
end
# function to find node fusions involving the given node if it's a compute node
# pushes the found fusion(s) everywhere it needs to be and returns nothing
function find_fusions!(graph::DAG, node::ComputeTaskNode)
# for loop that always runs once for a scoped block we can break out of
for _ in 1:1
# assume this node as child of the chain
if length(node.parents) != 1
break
end
node2 = first(node.parents)
if length(node2.parents) != 1 || length(node2.children) != 1
break
end
node3 = first(node2.parents)
#=if !(node2 in graph) || !(node3 in graph)
error("Parents/Children that are not in the graph!!!")
end=#
nf = NodeFusion((node, node2, node3))
push!(graph.possibleOperations.nodeFusions, nf)
push!(node.operations, nf)
push!(node2.operations, nf)
push!(node3.operations, nf)
end
for _ in 1:1
# assume this node as parent of the chain
if length(node.children) < 1
break
end
node2 = first(node.children)
if length(node2.parents) != 1 || length(node2.children) != 1
break
end
node1 = first(node2.children)
if (length(node1.parents) > 1)
break
end
#=if !(node2 in graph) || !(node1 in graph)
error("Parents/Children that are not in the graph!!!")
end=#
nf = NodeFusion((node1, node2, node))
push!(graph.possibleOperations.nodeFusions, nf)
push!(node1.operations, nf)
push!(node2.operations, nf)
push!(node.operations, nf)
end
return nothing
end
function find_reductions!(graph::DAG, node::Node)
reductionVector = nothing
# possible reductions are with nodes that are partners, i.e. parents of children
for partner in partners(node)
if can_reduce(node, partner)
if reductionVector === nothing
# only when there's at least one reduction partner, insert the vector
reductionVector = Vector{Node}()
push!(reductionVector, node)
end
push!(reductionVector, partner)
end
end
if reductionVector !== nothing
nr = NodeReduction(reductionVector)
push!(graph.possibleOperations.nodeReductions, nr)
for node in reductionVector
push!(node.operations, nr)
end
end
return nothing
end
function find_splits!(graph::DAG, node::Node)
if (can_split(node))
ns = NodeSplit(node)
push!(graph.possibleOperations.nodeSplits, ns)
push!(node.operations, ns)
end
return nothing
end
# "clean" the operations on a dirty node
function clean_node!(graph::DAG, node::Node)
find_fusions!(graph, node)
find_reductions!(graph, node)
find_splits!(graph, node)
end
# function to generate all possible optmizations on the graph
function generate_options(graph::DAG)
generatedOperations = [Vector{Operation}() for _ in 1:nthreads()]
# make sure the graph is fully generated through
apply_all!(graph)
nodeArray = collect(graph.nodes)
# find possible node fusions
@threads for node in nodeArray
if (typeof(node) <: DataTaskNode)
if length(node.parents) != 1
# data node can only have a single parent
continue
end
parent_node = first(node.parents)
if length(node.children) != 1
# this node is an entry node or has multiple children which should not be possible
continue
end
child_node = first(node.children)
if (length(child_node.parents) != 1)
continue
end
push!(generatedOperations[threadid()], NodeFusion((child_node, node, parent_node)))
end
end
# TODO figure out how to parallelize this
# find possible node reductions
visitedNodes = Set{Node}()
for node in graph.nodes
if (node in visitedNodes)
continue
end
push!(visitedNodes, node)
reductionVector = nothing
partners_ = partners(node)
t = typeof(node)
# possible reductions are with nodes that are partners, i.e. parents of children
for partner in partners_
# see proof Node Reduction 1
if (t != typeof(partner))
continue
end
push!(visitedNodes, partner)
if can_reduce(node, partner)
if reductionVector === nothing
# only when there's at least one reduction partner, insert the vector
reductionVector = Vector{Node}()
push!(reductionVector, node)
end
push!(reductionVector, partner)
end
end
if reductionVector !== nothing
push!(generatedOperations[threadid()], NodeReduction(reductionVector))
end
end
# find possible node splits
@threads for node in nodeArray
if (can_split(node))
push!(generatedOperations[threadid()], NodeSplit(node))
end
end
# TODO figure out how to parallelize this
# insert generated operations from every thread into the final result
for genOps in generatedOperations
for op in genOps
insert_operation!(graph.possibleOperations, op)
end
end
empty!(graph.dirtyNodes)
end
function get_operations(graph::DAG)
apply_all!(graph)
if isempty(graph.possibleOperations)
generate_options(graph)
end
for node in graph.dirtyNodes
clean_node!(graph, node)
end
empty!(graph.dirtyNodes)
return graph.possibleOperations
end
function insert_operation!(operations::PossibleOperations, nf::NodeFusion)
push!(operations.nodeFusions, nf)
push!(nf.input[1].operations, nf)
push!(nf.input[2].operations, nf)
push!(nf.input[3].operations, nf)
end
function insert_operation!(operations::PossibleOperations, nr::NodeReduction)
push!(operations.nodeReductions, nr)
for n in nr.input
push!(n.operations, nr)
end
end
function insert_operation!(operations::PossibleOperations, ns::NodeSplit)
push!(operations.nodeSplits, ns)
push!(ns.input.operations, ns)
end

View File

@ -2,7 +2,7 @@ using Random
using UUIDs
using Base.Threads
rng = [Random.MersenneTwister(0) for _ in 1:nthreads()]
rng = [Random.MersenneTwister(0) for _ in 1:32]
abstract type Node end

229
src/operations/apply.jl Normal file
View File

@ -0,0 +1,229 @@
# functions that apply graph operations
# applies all unapplied operations in the DAG
function apply_all!(graph::DAG)
while !isempty(graph.operationsToApply)
# get next operation to apply from front of the deque
op = popfirst!(graph.operationsToApply)
# apply it
appliedOp = apply_operation!(graph, op)
# push to the end of the appliedOperations deque
push!(graph.appliedOperations, appliedOp)
end
return nothing
end
function apply_operation!(graph::DAG, operation::Operation)
error("Unknown operation type!")
end
function apply_operation!(graph::DAG, operation::NodeFusion)
diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3])
return AppliedNodeFusion(operation, diff)
end
function apply_operation!(graph::DAG, operation::NodeReduction)
diff = node_reduction!(graph, operation.input[1], operation.input[2])
return AppliedNodeReduction(operation, diff)
end
function apply_operation!(graph::DAG, operation::NodeSplit)
diff = node_split!(graph, operation.input)
return AppliedNodeSplit(operation, diff)
end
function revert_operation!(graph::DAG, operation::AppliedOperation)
error("Unknown operation type!")
end
function revert_operation!(graph::DAG, operation::AppliedNodeFusion)
revert_diff!(graph, operation.diff)
return operation.operation
end
function revert_operation!(graph::DAG, operation::AppliedNodeReduction)
revert_diff!(graph, operation.diff)
return operation.operation
end
function revert_operation!(graph::DAG, operation::AppliedNodeSplit)
revert_diff!(graph, operation.diff)
return operation.operation
end
function revert_diff!(graph::DAG, diff)
# add removed nodes, remove added nodes, same for edges
# note the order
for edge in diff.addedEdges
remove_edge!(graph, edge, false)
end
for node in diff.addedNodes
remove_node!(graph, node, false)
end
for node in diff.removedNodes
insert_node!(graph, node, false)
end
for edge in diff.removedEdges
insert_edge!(graph, edge, false)
end
end
# Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph
function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
# clear snapshot
get_snapshot_diff(graph)
if !(n1 in graph) || !(n2 in graph) || !(n3 in graph)
error("[Node Fusion] The given nodes are not part of the given graph")
end
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
# the checks are redundant but maybe a good sanity check
error("[Node Fusion] The given nodes are not connected by edges which is required for node fusion")
end
# save children and parents
n1_children = children(n1)
n3_parents = parents(n3)
n3_children = children(n3)
if length(n2.parents) > 1
error("[Node Fusion] The given data node has more than one parent")
end
if length(n2.children) > 1
error("[Node Fusion] The given data node has more than one child")
end
if length(n1.parents) > 1
error("[Node Fusion] The given n1 has more than one parent")
end
required_edge1 = make_edge(n1, n2)
required_edge2 = make_edge(n2, n3)
# remove the edges and nodes that will be replaced by the fused node
remove_edge!(graph, required_edge1)
remove_edge!(graph, required_edge2)
remove_node!(graph, n1)
remove_node!(graph, n2)
# get n3's children now so it automatically excludes n2
n3_children = children(n3)
remove_node!(graph, n3)
# create new node with the fused compute task
new_node = ComputeTaskNode(FusedComputeTask{typeof(n1.task),typeof(n3.task)}())
insert_node!(graph, new_node)
# use a set for combined children of n1 and n3 to not get duplicates
n1and3_children = Set{Node}()
# remove edges from n1 children to n1
for child in n1_children
remove_edge!(graph, make_edge(child, n1))
push!(n1and3_children, child)
end
# remove edges from n3 children to n3
for child in n3_children
remove_edge!(graph, make_edge(child, n3))
push!(n1and3_children, child)
end
for child in n1and3_children
insert_edge!(graph, make_edge(child, new_node))
end
# "repoint" parents of n3 from new node
for parent in n3_parents
remove_edge!(graph, make_edge(n3, parent))
insert_edge!(graph, make_edge(new_node, parent))
end
return get_snapshot_diff(graph)
end
function node_reduction!(graph::DAG, n1::Node, n2::Node)
# clear snapshot
get_snapshot_diff(graph)
#=if !(n1 in graph) || !(n2 in graph)
error("[Node Reduction] The given nodes are not part of the given graph")
end=#
#=if typeof(n1) != typeof(n2)
error("[Node Reduction] The given nodes are not of the same type")
end=#
# save n2 parents and children
n2_children = children(n2)
n2_parents = Set(n2.parents)
#=if Set(n2_children) != Set(n1.children)
error("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction")
end=#
# remove n2 and all its parents and children
for child in n2_children
remove_edge!(graph, make_edge(child, n2))
end
for parent in n2_parents
remove_edge!(graph, make_edge(n2, parent))
end
for parent in n1.parents
# delete parents in n1 that already exist in n2
delete!(n2_parents, parent)
end
for parent in n2_parents
# now add parents of n2 to n1 without duplicates
insert_edge!(graph, make_edge(n1, parent))
end
remove_node!(graph, n2)
return get_snapshot_diff(graph)
end
function node_split!(graph::DAG, n1::Node)
# clear snapshot
get_snapshot_diff(graph)
#=if !(n1 in graph)
error("[Node Split] The given node is not part of the given graph")
end=#
n1_parents = parents(n1)
n1_children = children(n1)
#=if length(n1_parents) <= 1
error("[Node Split] The given node does not have multiple parents which is required for node split")
end=#
for parent in n1_parents
remove_edge!(graph, make_edge(n1, parent))
end
for child in n1_children
remove_edge!(graph, make_edge(child, n1))
end
remove_node!(graph, n1)
for parent in n1_parents
n_copy = copy(n1)
insert_node!(graph, n_copy)
insert_edge!(graph, make_edge(n_copy, parent))
for child in n1_children
insert_edge!(graph, make_edge(child, n_copy))
end
end
return get_snapshot_diff(graph)
end

127
src/operations/clean.jl Normal file
View File

@ -0,0 +1,127 @@
# functions for "cleaning" nodes, i.e. regenerating the possible operations for a node
# function to find node fusions involving the given node if it's a data node
# pushes the found fusion everywhere it needs to be and returns nothing
function find_fusions!(graph::DAG, node::DataTaskNode)
if length(node.parents) != 1 || length(node.children) != 1
return nothing
end
child_node = first(node.children)
parent_node = first(node.parents)
#=if !(child_node in graph) || !(parent_node in graph)
error("Parents/Children that are not in the graph!!!")
end=#
if length(child_node.parents) != 1
return nothing
end
nf = NodeFusion((child_node, node, parent_node))
push!(graph.possibleOperations.nodeFusions, nf)
push!(child_node.operations, nf)
push!(node.operations, nf)
push!(parent_node.operations, nf)
return nothing
end
# function to find node fusions involving the given node if it's a compute node
# pushes the found fusion(s) everywhere it needs to be and returns nothing
function find_fusions!(graph::DAG, node::ComputeTaskNode)
# for loop that always runs once for a scoped block we can break out of
for _ in 1:1
# assume this node as child of the chain
if length(node.parents) != 1
break
end
node2 = first(node.parents)
if length(node2.parents) != 1 || length(node2.children) != 1
break
end
node3 = first(node2.parents)
#=if !(node2 in graph) || !(node3 in graph)
error("Parents/Children that are not in the graph!!!")
end=#
nf = NodeFusion((node, node2, node3))
push!(graph.possibleOperations.nodeFusions, nf)
push!(node.operations, nf)
push!(node2.operations, nf)
push!(node3.operations, nf)
end
for _ in 1:1
# assume this node as parent of the chain
if length(node.children) < 1
break
end
node2 = first(node.children)
if length(node2.parents) != 1 || length(node2.children) != 1
break
end
node1 = first(node2.children)
if (length(node1.parents) > 1)
break
end
#=if !(node2 in graph) || !(node1 in graph)
error("Parents/Children that are not in the graph!!!")
end=#
nf = NodeFusion((node1, node2, node))
push!(graph.possibleOperations.nodeFusions, nf)
push!(node1.operations, nf)
push!(node2.operations, nf)
push!(node.operations, nf)
end
return nothing
end
function find_reductions!(graph::DAG, node::Node)
reductionVector = nothing
# possible reductions are with nodes that are partners, i.e. parents of children
partners_ = partners(node)
delete!(partners_, node)
for partner in partners_
if can_reduce(node, partner)
if reductionVector === nothing
# only when there's at least one reduction partner, insert the vector
reductionVector = Vector{Node}()
push!(reductionVector, node)
end
push!(reductionVector, partner)
end
end
if reductionVector !== nothing
nr = NodeReduction(reductionVector)
push!(graph.possibleOperations.nodeReductions, nr)
for node in reductionVector
push!(node.operations, nr)
end
end
return nothing
end
function find_splits!(graph::DAG, node::Node)
if (can_split(node))
ns = NodeSplit(node)
push!(graph.possibleOperations.nodeSplits, ns)
push!(node.operations, ns)
end
return nothing
end
# "clean" the operations on a dirty node
function clean_node!(graph::DAG, node::Node)
find_fusions!(graph, node)
find_reductions!(graph, node)
find_splits!(graph, node)
end

173
src/operations/find.jl Normal file
View File

@ -0,0 +1,173 @@
# functions that find operations on the inital graph
using Base.Threads
function insert_operation!(operations::PossibleOperations, nf::NodeFusion, locks::Dict{Node, SpinLock})
push!(operations.nodeFusions, nf)
n1 = nf.input[1]; n2 = nf.input[2]; n3 = nf.input[3]
lock(locks[n1]) do; push!(nf.input[1].operations, nf); end
lock(locks[n2]) do; push!(nf.input[2].operations, nf); end
lock(locks[n3]) do; push!(nf.input[3].operations, nf); end
end
function insert_operation!(operations::PossibleOperations, nr::NodeReduction, locks::Dict{Node, SpinLock})
push!(operations.nodeReductions, nr)
for n in nr.input
lock(locks[n]) do; push!(n.operations, nr); end
end
end
function insert_operation!(operations::PossibleOperations, ns::NodeSplit, locks::Dict{Node, SpinLock})
push!(operations.nodeSplits, ns)
lock(locks[ns.input]) do; push!(ns.input.operations, ns); end
end
function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}}, locks::Dict{Node, SpinLock})
for vec in nodeReductions
for op in vec
insert_operation!(operations, op, locks)
end
end
end
function nf_insertion!(operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}}, locks::Dict{Node, SpinLock})
for vec in nodeFusions
for op in vec
insert_operation!(operations, op, locks)
end
end
end
function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}}, locks::Dict{Node, SpinLock})
for vec in nodeSplits
for op in vec
insert_operation!(operations, op, locks)
end
end
end
# function to generate all possible operations on the graph
function generate_options(graph::DAG)
locks = Dict{Node, SpinLock}()
for n in graph.nodes
locks[n] = SpinLock()
end
generatedFusions = [Vector{NodeFusion}() for _ in 1:nthreads()]
generatedReductions = [Vector{NodeReduction}() for _ in 1:nthreads()]
generatedSplits = [Vector{NodeSplit}() for _ in 1:nthreads()]
# make sure the graph is fully generated through
apply_all!(graph)
# --- find possible node reductions ---
# find some useful partition of nodes without generating duplicate node reductions
nodePartitions = [Vector{Set{Node}}() for _ in 1:nthreads()]
avgNodes = 0. # the average number of nodes across all the node partitions
nodeSet = copy(graph.nodes)
partitionPointer = 1
rotatePointer(i) = (i % nthreads()) + 1
while !isempty(nodeSet)
# cycle partition pointer to a set with fewer than average nodes
nodes = partners(first(nodeSet))
setdiff!(nodeSet, nodes)
if length(nodes) == 1
# nothing to reduce here anyways
continue
end
partitionPointer = rotatePointer(partitionPointer)
push!(nodePartitions[partitionPointer], nodes)
avgNodes = avgNodes + length(nodes) / nthreads()
end
@threads for partition in nodePartitions
for partners_ in partition
reductionVector = nothing
node = pop!(partners_)
t = typeof(node)
# possible reductions are with nodes that are partners, i.e. parents of children
for partner in partners_
if (t != typeof(partner))
continue
end
if !can_reduce(node, partner)
continue
end
if reductionVector === nothing
# only when there's at least one reduction partner, insert the vector
reductionVector = Vector{Node}()
push!(reductionVector, node)
end
push!(reductionVector, partner)
end
if reductionVector !== nothing
push!(generatedReductions[threadid()], NodeReduction(reductionVector))
end
end
end
# launch thread for node reduction insertion
nr_task = @task nr_insertion!(graph.possibleOperations, generatedReductions, locks)
schedule(nr_task)
# --- find possible node fusions ---
nodeArray = collect(graph.nodes)
@threads for node in nodeArray
if (typeof(node) <: DataTaskNode)
if length(node.parents) != 1
# data node can only have a single parent
continue
end
parent_node = first(node.parents)
if length(node.children) != 1
# this node is an entry node or has multiple children which should not be possible
continue
end
child_node = first(node.children)
if (length(child_node.parents) != 1)
continue
end
push!(generatedFusions[threadid()], NodeFusion((child_node, node, parent_node)))
end
end
# launch thread for node fusion insertion
nf_task = @task nf_insertion!(graph.possibleOperations, generatedFusions, locks)
schedule(nf_task)
# find possible node splits
@threads for node in nodeArray
if (can_split(node))
push!(generatedSplits[threadid()], NodeSplit(node))
end
end
# launch thread for node split insertion
ns_task = @task ns_insertion!(graph.possibleOperations, generatedSplits, locks)
schedule(ns_task)
empty!(graph.dirtyNodes)
wait(nr_task)
wait(nf_task)
wait(ns_task)
return nothing
end

18
src/operations/get.jl Normal file
View File

@ -0,0 +1,18 @@
# function to return the possible operations of a graph
using Base.Threads
function get_operations(graph::DAG)
apply_all!(graph)
if isempty(graph.possibleOperations)
generate_options(graph)
end
for node in graph.dirtyNodes
clean_node!(graph, node)
end
empty!(graph.dirtyNodes)
return graph.possibleOperations
end

109
src/operations/utility.jl Normal file
View File

@ -0,0 +1,109 @@
function isempty(operations::PossibleOperations)
return isempty(operations.nodeFusions) &&
isempty(operations.nodeReductions) &&
isempty(operations.nodeSplits)
end
function length(operations::PossibleOperations)
return (nodeFusions = length(operations.nodeFusions),
nodeReductions = length(operations.nodeReductions),
nodeSplits = length(operations.nodeSplits))
end
function delete!(operations::PossibleOperations, op::NodeFusion)
delete!(operations.nodeFusions, op)
return operations
end
function delete!(operations::PossibleOperations, op::NodeReduction)
delete!(operations.nodeReductions, op)
return operations
end
function delete!(operations::PossibleOperations, op::NodeSplit)
delete!(operations.nodeSplits, op)
return operations
end
function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
if !is_child(n1, n2) || !is_child(n2, n3)
# the checks are redundant but maybe a good sanity check
return false
end
if length(n2.parents) != 1 || length(n2.children) != 1 || length(n1.parents) != 1
return false
end
return true
end
function can_reduce(n1::Node, n2::Node)
if (n1.task != n2.task)
return false
end
n1_length = length(n1.children)
n2_length = length(n2.children)
if (n1_length != n2_length)
return false
end
# this seems to be the most common case so do this first
# doing it manually is a lot faster than using the sets for a general solution
if (n1_length == 2)
if (n1.children[1] != n2.children[1])
if (n1.children[1] != n2.children[2])
return false
end
# 1_1 == 2_2
if (n1.children[2] != n2.children[1])
return false
end
return true
end
# 1_1 == 2_1
if (n1.children[2] != n2.children[2])
return false
end
return true
end
# this is simple
if (n1_length == 1)
return n1.children[1] == n2.children[1]
end
# this takes a long time
return Set(n1.children) == Set(n2.children)
end
function can_split(n::Node)
return length(parents(n)) > 1
end
function ==(op1::Operation, op2::Operation)
return false
end
function ==(op1::NodeFusion, op2::NodeFusion)
# there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same
return op1.input[2] == op2.input[2]
end
function ==(op1::NodeReduction, op2::NodeReduction)
# only test the ids against each other
return op1.id == op2.id
end
function ==(op1::NodeSplit, op2::NodeSplit)
return op1.input == op2.input
end
NodeReduction(input::Vector{Node}) = NodeReduction(input, UUIDs.uuid1(rng[threadid()]))
copy(id::UUID) = UUID(id.value)

View File

@ -127,8 +127,8 @@ import MetagraphOptimization.partners
@test MetagraphOptimization.get_exit_node(graph) == d_exit
@test length(partners(s0)) == 0
@test length(siblings(s0)) == 0
@test length(partners(s0)) == 1
@test length(siblings(s0)) == 1
operations = get_operations(graph)
@test length(operations) == (nodeFusions = 10, nodeReductions = 0, nodeSplits = 0)