Rework node operations storage, remove make_edge from insert_edge calls

This commit is contained in:
2023-08-23 19:28:45 +02:00
parent a81aafbf20
commit c365233ea4
16 changed files with 421 additions and 363 deletions

View File

@ -38,6 +38,7 @@ include("operations/clean.jl")
include("operations/find.jl")
include("operations/get.jl")
include("operations/print.jl")
include("operations/validate.jl")
include("graph_interface.jl")

View File

@ -42,7 +42,7 @@ function parse_abc(filename::String, verbose::Bool = false)
sum_node = insert_node!(graph, make_node(ComputeTaskSum()), false, false)
global_data_out = insert_node!(graph, make_node(DataTask(10)), false, false)
insert_edge!(graph, make_edge(sum_node, global_data_out), false, false)
insert_edge!(graph, sum_node, global_data_out, false, false)
# remember the data out nodes for connection
dataOutNodes = Dict()
@ -64,10 +64,10 @@ function parse_abc(filename::String, verbose::Bool = false)
compute_u = insert_node!(graph, make_node(ComputeTaskU()), false, false) # compute U node
data_out = insert_node!(graph, make_node(DataTask(3)), false, false) # transfer data out from u
insert_edge!(graph, make_edge(data_in, compute_P), false, false)
insert_edge!(graph, make_edge(compute_P, data_Pu), false, false)
insert_edge!(graph, make_edge(data_Pu, compute_u), false, false)
insert_edge!(graph, make_edge(compute_u, data_out), false, false)
insert_edge!(graph, data_in, compute_P, false, false)
insert_edge!(graph, compute_P, data_Pu, false, false)
insert_edge!(graph, data_Pu, compute_u, false, false)
insert_edge!(graph, compute_u, data_out, false, false)
# remember the data_out node for future edges
dataOutNodes[node] = data_out
@ -80,34 +80,34 @@ function parse_abc(filename::String, verbose::Bool = false)
compute_v = insert_node!(graph, make_node(ComputeTaskV()), false, false)
data_out = insert_node!(graph, make_node(DataTask(5)), false, false)
if (occursin(regex_c, capt.captures[1]))
if (occursin(regex_c, in1))
# put an S node after this input
compute_S = insert_node!(graph, make_node(ComputeTaskS1()), false, false)
data_S_v = insert_node!(graph, make_node(DataTask(5)), false, false)
insert_edge!(graph, make_edge(dataOutNodes[capt.captures[1]], compute_S), false, false)
insert_edge!(graph, make_edge(compute_S, data_S_v), false, false)
insert_edge!(graph, dataOutNodes[in1], compute_S, false, false)
insert_edge!(graph, compute_S, data_S_v, false, false)
insert_edge!(graph, make_edge(data_S_v, compute_v), false, false)
insert_edge!(graph, data_S_v, compute_v, false, false)
else
insert_edge!(graph, make_edge(dataOutNodes[capt.captures[1]], compute_v), false, false)
insert_edge!(graph, dataOutNodes[in1], compute_v, false, false)
end
if (occursin(regex_c, capt.captures[2]))
if (occursin(regex_c, in2))
# i think the current generator only puts the combined particles in the first space, so this case might never be entered
# put an S node after this input
compute_S = insert_node!(graph, make_node(ComputeTaskS1()), false, false)
data_S_v = insert_node!(graph, make_node(DataTask(5)), false, false)
insert_edge!(graph, make_edge(dataOutNodes[capt.captures[2]], compute_S), false, false)
insert_edge!(graph, make_edge(compute_S, data_S_v), false, false)
insert_edge!(graph, dataOutNodes[in2], compute_S, false, false)
insert_edge!(graph, compute_S, data_S_v, false, false)
insert_edge!(graph, make_edge(data_S_v, compute_v), false, false)
insert_edge!(graph, data_S_v, compute_v, false, false)
else
insert_edge!(graph, make_edge(dataOutNodes[capt.captures[2]], compute_v), false, false)
insert_edge!(graph, dataOutNodes[in2], compute_v, false, false)
end
insert_edge!(graph, make_edge(compute_v, data_out), false, false)
insert_edge!(graph, compute_v, data_out, false, false)
dataOutNodes[node] = data_out
elseif occursin(regex_m, node)
@ -121,26 +121,26 @@ function parse_abc(filename::String, verbose::Bool = false)
compute_v = insert_node!(graph, make_node(ComputeTaskV()), false, false)
data_v = insert_node!(graph, make_node(DataTask(5)), false, false)
insert_edge!(graph, make_edge(dataOutNodes[in2], compute_v), false, false)
insert_edge!(graph, make_edge(dataOutNodes[in3], compute_v), false, false)
insert_edge!(graph, make_edge(compute_v, data_v), false, false)
insert_edge!(graph, dataOutNodes[in2], compute_v, false, false)
insert_edge!(graph, dataOutNodes[in3], compute_v, false, false)
insert_edge!(graph, compute_v, data_v, false, false)
# combine with the v of the combined other input
compute_S2 = insert_node!(graph, make_node(ComputeTaskS2()), false, false)
data_out = insert_node!(graph, make_node(DataTask(10)), false, false)
insert_edge!(graph, make_edge(data_v, compute_S2), false, false)
insert_edge!(graph, make_edge(dataOutNodes[in1], compute_S2), false, false)
insert_edge!(graph, make_edge(compute_S2, data_out), false, false)
insert_edge!(graph, data_v, compute_S2, false, false)
insert_edge!(graph, dataOutNodes[in1], compute_S2, false, false)
insert_edge!(graph, compute_S2, data_out, false, false)
insert_edge!(graph, make_edge(data_out, sum_node), false, false)
insert_edge!(graph, data_out, sum_node, false, false)
elseif occursin(regex_plus, node)
if (verbose)
println("\rReading Nodes Complete ")
println("Added ", length(graph.nodes), " nodes")
end
else
error("Unknown node '", node, "' while reading from file ", filename)
@assert false ("Unknown node '$node' while reading from file $filename")
end
end

View File

@ -67,27 +67,66 @@ end
is_entry_node(node::Node) = length(node.children) == 0
is_exit_node(node::Node) = length(node.parents) == 0
# function to invalidate the operation caches for a given operation
function invalidate_caches!(graph::DAG, operation::Operation)
# function to invalidate the operation caches for a given NodeFusion
function invalidate_caches!(graph::DAG, operation::NodeFusion)
delete!(graph.possibleOperations, operation)
# delete the operation from all caches of nodes involved in the operation
# (we can iterate over tuples and vectors just fine)
filter!(!=(operation), operation.input[1].nodeFusions)
filter!(!=(operation), operation.input[3].nodeFusions)
operation.input[2].nodeFusion = missing
return nothing
end
# function to invalidate the operation caches for a given NodeReduction
function invalidate_caches!(graph::DAG, operation::NodeReduction)
delete!(graph.possibleOperations, operation)
for node in operation.input
filter!(!=(operation), node.operations)
node.nodeReduction = missing
end
return nothing
end
# function to invalidate the operation caches for a given Node Split specifically
# function to invalidate the operation caches for a given NodeSplit
function invalidate_caches!(graph::DAG, operation::NodeSplit)
delete!(graph.possibleOperations, operation)
# delete the operation from all caches of nodes involved in the operation
# for node split there is only one node
filter!(x -> x != operation, operation.input.operations)
operation.input.nodeSplit = missing
return nothing
end
# function to invalidate the operation caches of a ComputeTaskNode
function invalidate_operation_caches!(graph::DAG, node::ComputeTaskNode)
if !ismissing(node.nodeReduction)
invalidate_caches!(graph, node.nodeReduction)
end
if !ismissing(node.nodeSplit)
invalidate_caches!(graph, node.nodeSplit)
end
while !isempty(node.nodeFusions)
invalidate_caches!(graph, pop!(node.nodeFusions))
end
return nothing
end
# function to invalidate the operation caches of a DataTaskNode
function invalidate_operation_caches!(graph::DAG, node::DataTaskNode)
if !ismissing(node.nodeReduction)
invalidate_caches!(graph, node.nodeReduction)
end
if !ismissing(node.nodeSplit)
invalidate_caches!(graph, node.nodeSplit)
end
if !ismissing(node.nodeFusion)
invalidate_caches!(graph, node.nodeFusion)
end
return nothing
end
@ -110,93 +149,72 @@ function insert_node!(graph::DAG, node::Node, track=true, invalidate_cache=true)
return node
end
function insert_edge!(graph::DAG, edge::Edge, track=true, invalidate_cache=true)
node1 = edge.edge[1]
node2 = edge.edge[2]
function insert_edge!(graph::DAG, node1::Node, node2::Node, track=true, invalidate_cache=true)
# @assert (node2 ∉ node1.parents) && (node1 ∉ node2.children) "Edge to insert already exists"
# 1: mute
#=if (node2 in node1.parents) || (node1 in node2.children)
if !(node2 in node1.parents && node1 in node2.children)
error("One-sided edge")
end
error("Edge to insert already exists")
end=#
# edge points from child to parent
push!(node1.parents, node2)
push!(node2.children, node1)
# 2: keep track
if (track) push!(graph.diff.addedEdges, edge) end
if (track) push!(graph.diff.addedEdges, make_edge(node1, node2)) end
# 3: invalidate caches
if (!invalidate_cache) return edge end
if (!invalidate_cache) return nothing end
invalidate_operation_caches!(graph, node1)
invalidate_operation_caches!(graph, node2)
while !isempty(node1.operations)
invalidate_caches!(graph, first(node1.operations))
end
while !isempty(node2.operations)
invalidate_caches!(graph, first(node2.operations))
end
push!(graph.dirtyNodes, node1)
push!(graph.dirtyNodes, node2)
return edge
return nothing
end
function remove_node!(graph::DAG, node::Node, track=true, invalidate_cache=true)
# @assert node in graph.nodes "Trying to remove a node that's not in the graph"
# 1: mute
#=if !(node in graph.nodes)
error("Trying to remove a node that's not in the graph")
end=#
delete!(graph.nodes, node)
# 2: keep track
if (track) push!(graph.diff.removedNodes, node) end
# 3: invalidate caches
if (!invalidate_cache) return node end
if (!invalidate_cache) return nothing end
while !isempty(node.operations)
invalidate_caches!(graph, first(node.operations))
end
invalidate_operation_caches!(graph, node)
delete!(graph.dirtyNodes, node)
return nothing
end
function remove_edge!(graph::DAG, edge::Edge, track=true, invalidate_cache=true)
node1 = edge.edge[1]
node2 = edge.edge[2]
function remove_edge!(graph::DAG, node1::Node, node2::Node, track=true, invalidate_cache=true)
# 1: mute
pre_length1 = length(node1.parents)
pre_length2 = length(node2.children)
filter!(x -> x != node2, node1.parents)
filter!(x -> x != node1, node2.children)
#=removed = pre_length1 - length(node1.parents)
if (removed > 1)
error("removed $removed from node1's parents")
end
#=@assert begin
removed = pre_length1 - length(node1.parents)
removed <= 1
end "removed more than one node from node1's parents"=#
removed = pre_length2 - length(node2.children)
if (removed > 1)
error("removed $removed from node2's children")
end=#
#=@assert begin
removed = pre_length2 - length(node2.children)
removed <= 1
end "removed more than one node from node2's children"=#
# 2: keep track
if (track) push!(graph.diff.removedEdges, edge) end
if (track) push!(graph.diff.removedEdges, make_edge(node1, node2)) end
# 3: invalidate caches
if (!invalidate_cache) return nothing end
while !isempty(node1.operations)
invalidate_caches!(graph, first(node1.operations))
end
while !isempty(node2.operations)
invalidate_caches!(graph, first(node2.operations))
end
invalidate_operation_caches!(graph, node1)
invalidate_operation_caches!(graph, node2)
if (node1 in graph)
push!(graph.dirtyNodes, node1)
end
@ -241,7 +259,7 @@ function get_exit_node(graph::DAG)
return node
end
end
error("The given graph has no exit node! It is either empty or not acyclic!")
@assert false "The given graph has no exit node! It is either empty or not acyclic!"
end
# check whether the given graph is connected

View File

@ -46,5 +46,6 @@ function ==(n1::DataTaskNode, n2::DataTaskNode)
return n1.id == n2.id
end
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng[threadid()]), copy(n.operations))
copy(n::DataTaskNode) = DataTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng[threadid()]), copy(n.operations))
copy(m::Missing) = missing
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng[threadid()]), copy(n.nodeReduction), copy(n.nodeSplit), copy(n.nodeFusions))
copy(n::DataTaskNode) = DataTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng[threadid()]), copy(n.nodeReduction), copy(n.nodeSplit), copy(n.nodeFusion))

View File

@ -2,6 +2,7 @@ using Random
using UUIDs
using Base.Threads
# TODO: reliably find out how many threads we're running with (nthreads() returns 1 when precompiling :/)
rng = [Random.MersenneTwister(0) for _ in 1:32]
abstract type Node end
@ -10,7 +11,7 @@ abstract type Node end
# the specific operations are declared in graph.jl
abstract type Operation end
struct DataTaskNode <: Node
mutable struct DataTaskNode <: Node
task::AbstractDataTask
# use vectors as sets have way too much memory overhead
@ -21,21 +22,33 @@ struct DataTaskNode <: Node
# however, it can be copied when splitting a node
id::Base.UUID
# a vector holding references to the graph operations involving this node
operations::Vector{Operation}
# the NodeReduction involving this node, if it exists
# Can't use the NodeReduction type here because it's not yet defined
nodeReduction::Union{Operation, Missing}
# the NodeSplit involving this node, if it exists
nodeSplit::Union{Operation, Missing}
# the node fusion involving this node, if it exists
nodeFusion::Union{Operation, Missing}
end
# same as DataTaskNode
struct ComputeTaskNode <: Node
mutable struct ComputeTaskNode <: Node
task::AbstractComputeTask
parents::Vector{Node}
children::Vector{Node}
id::Base.UUID
operations::Vector{Operation}
nodeReduction::Union{Operation, Missing}
nodeSplit::Union{Operation, Missing}
# for ComputeTasks there can be multiple fusions, unlike the DataTasks
nodeFusions::Vector{Operation}
end
DataTaskNode(t::AbstractDataTask) = DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), Vector{Operation}())
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), Vector{Operation}())
DataTaskNode(t::AbstractDataTask) = DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, missing)
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, Vector{NodeFusion}())
struct Edge
# edge points from child to parent

View File

@ -2,237 +2,197 @@
# applies all unapplied operations in the DAG
function apply_all!(graph::DAG)
while !isempty(graph.operationsToApply)
# get next operation to apply from front of the deque
op = popfirst!(graph.operationsToApply)
while !isempty(graph.operationsToApply)
# get next operation to apply from front of the deque
op = popfirst!(graph.operationsToApply)
# apply it
appliedOp = apply_operation!(graph, op)
# apply it
appliedOp = apply_operation!(graph, op)
# push to the end of the appliedOperations deque
push!(graph.appliedOperations, appliedOp)
end
return nothing
# push to the end of the appliedOperations deque
push!(graph.appliedOperations, appliedOp)
end
return nothing
end
function apply_operation!(graph::DAG, operation::Operation)
error("Unknown operation type!")
error("Unknown operation type!")
end
function apply_operation!(graph::DAG, operation::NodeFusion)
diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3])
return AppliedNodeFusion(operation, diff)
diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3])
return AppliedNodeFusion(operation, diff)
end
function apply_operation!(graph::DAG, operation::NodeReduction)
diff = node_reduction!(graph, operation.input)
return AppliedNodeReduction(operation, diff)
diff = node_reduction!(graph, operation.input)
return AppliedNodeReduction(operation, diff)
end
function apply_operation!(graph::DAG, operation::NodeSplit)
diff = node_split!(graph, operation.input)
return AppliedNodeSplit(operation, diff)
diff = node_split!(graph, operation.input)
return AppliedNodeSplit(operation, diff)
end
function revert_operation!(graph::DAG, operation::AppliedOperation)
error("Unknown operation type!")
error("Unknown operation type!")
end
function revert_operation!(graph::DAG, operation::AppliedNodeFusion)
revert_diff!(graph, operation.diff)
return operation.operation
revert_diff!(graph, operation.diff)
return operation.operation
end
function revert_operation!(graph::DAG, operation::AppliedNodeReduction)
revert_diff!(graph, operation.diff)
return operation.operation
revert_diff!(graph, operation.diff)
return operation.operation
end
function revert_operation!(graph::DAG, operation::AppliedNodeSplit)
revert_diff!(graph, operation.diff)
return operation.operation
revert_diff!(graph, operation.diff)
return operation.operation
end
function revert_diff!(graph::DAG, diff)
# add removed nodes, remove added nodes, same for edges
# note the order
for edge in diff.addedEdges
remove_edge!(graph, edge, false)
end
for node in diff.addedNodes
remove_node!(graph, node, false)
end
function revert_diff!(graph::DAG, diff::Diff)
# add removed nodes, remove added nodes, same for edges
# note the order
for edge in diff.addedEdges
remove_edge!(graph, edge.edge[1], edge.edge[2], false)
end
for node in diff.addedNodes
remove_node!(graph, node, false)
end
for node in diff.removedNodes
insert_node!(graph, node, false)
end
for edge in diff.removedEdges
insert_edge!(graph, edge, false)
end
for node in diff.removedNodes
insert_node!(graph, node, false)
end
for edge in diff.removedEdges
insert_edge!(graph, edge.edge[1], edge.edge[2], false)
end
end
# Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph
function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
# clear snapshot
get_snapshot_diff(graph)
# @assert is_valid_node_fusion_input(graph, n1, n2, n3)
if !(n1 in graph) || !(n2 in graph) || !(n3 in graph)
error("[Node Fusion] The given nodes are not part of the given graph")
end
# clear snapshot
get_snapshot_diff(graph)
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
# the checks are redundant but maybe a good sanity check
error("[Node Fusion] The given nodes are not connected by edges which is required for node fusion")
end
# save children and parents
n1_children = children(n1)
n3_parents = parents(n3)
n3_children = children(n3)
# save children and parents
n1_children = children(n1)
n3_parents = parents(n3)
n3_children = children(n3)
if length(n2.parents) > 1
error("[Node Fusion] The given data node has more than one parent")
end
if length(n2.children) > 1
error("[Node Fusion] The given data node has more than one child")
end
if length(n1.parents) > 1
error("[Node Fusion] The given n1 has more than one parent")
end
# remove the edges and nodes that will be replaced by the fused node
remove_edge!(graph, n1, n2)
remove_edge!(graph, n2, n3)
remove_node!(graph, n1)
remove_node!(graph, n2)
required_edge1 = make_edge(n1, n2)
required_edge2 = make_edge(n2, n3)
# get n3's children now so it automatically excludes n2
n3_children = children(n3)
remove_node!(graph, n3)
# remove the edges and nodes that will be replaced by the fused node
remove_edge!(graph, required_edge1)
remove_edge!(graph, required_edge2)
remove_node!(graph, n1)
remove_node!(graph, n2)
# create new node with the fused compute task
new_node = ComputeTaskNode(FusedComputeTask{typeof(n1.task),typeof(n3.task)}())
insert_node!(graph, new_node)
# get n3's children now so it automatically excludes n2
n3_children = children(n3)
remove_node!(graph, n3)
# use a set for combined children of n1 and n3 to not get duplicates
n1and3_children = Set{Node}()
# create new node with the fused compute task
new_node = ComputeTaskNode(FusedComputeTask{typeof(n1.task),typeof(n3.task)}())
insert_node!(graph, new_node)
# remove edges from n1 children to n1
for child in n1_children
remove_edge!(graph, child, n1)
push!(n1and3_children, child)
end
# use a set for combined children of n1 and n3 to not get duplicates
n1and3_children = Set{Node}()
# remove edges from n3 children to n3
for child in n3_children
remove_edge!(graph, child, n3)
push!(n1and3_children, child)
end
# remove edges from n1 children to n1
for child in n1_children
remove_edge!(graph, make_edge(child, n1))
push!(n1and3_children, child)
end
for child in n1and3_children
insert_edge!(graph, child, new_node)
end
# remove edges from n3 children to n3
for child in n3_children
remove_edge!(graph, make_edge(child, n3))
push!(n1and3_children, child)
end
# "repoint" parents of n3 from new node
for parent in n3_parents
remove_edge!(graph, n3, parent)
insert_edge!(graph, new_node, parent)
end
for child in n1and3_children
insert_edge!(graph, make_edge(child, new_node))
end
# "repoint" parents of n3 from new node
for parent in n3_parents
remove_edge!(graph, make_edge(n3, parent))
insert_edge!(graph, make_edge(new_node, parent))
end
return get_snapshot_diff(graph)
return get_snapshot_diff(graph)
end
function node_reduction!(graph::DAG, nodes::Vector{Node})
# clear snapshot
get_snapshot_diff(graph)
# @assert is_valid_node_reduction_input(graph, nodes)
t = typeof(nodes[1].task)
for n in nodes
if n graph
error("[Node Reduction] The given nodes are not part of the given graph")
end
# clear snapshot
get_snapshot_diff(graph)
if typeof(n.task) != t
error("[Node Reduction] The given nodes are not of the same type")
end
end
n1 = nodes[1]
n1_children = children(n1)
n1_parents = Set(n1.parents)
new_parents = Set{Node}()
# remove all of the nodes' parents and children and the nodes themselves (except for first node)
for i in 2:length(nodes)
n = nodes[i]
for child in n1_children
remove_edge!(graph, child, n)
end
n1 = nodes[1]
n1_children = children(n1)
for n in nodes
if Set(n1_children) != Set(n.children)
error("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction")
end
end
for parent in parents(n)
remove_edge!(graph, n, parent)
n1_parents = Set(n1.parents)
new_parents = Set{Node}()
# collect all parents
push!(new_parents, parent)
end
# remove all of the nodes' parents and children and the nodes themselves (except for first node)
for i in 2:length(nodes)
n = nodes[i]
for child in n1_children
remove_edge!(graph, make_edge(child, n))
end
remove_node!(graph, n)
end
for parent in parents(n)
remove_edge!(graph, make_edge(n, parent))
setdiff!(new_parents, n1_parents)
# collect all parents
push!(new_parents, parent)
end
for parent in new_parents
# now add parents of all input nodes to n1 without duplicates
insert_edge!(graph, n1, parent)
end
remove_node!(graph, n)
end
setdiff!(new_parents, n1_parents)
for parent in new_parents
# now add parents of n2 to n1 without duplicates
insert_edge!(graph, make_edge(n1, parent))
end
return get_snapshot_diff(graph)
return get_snapshot_diff(graph)
end
function node_split!(graph::DAG, n1::Node)
# clear snapshot
get_snapshot_diff(graph)
# @assert is_valid_node_split_input(graph, n1)
#=if !(n1 in graph)
error("[Node Split] The given node is not part of the given graph")
end=#
# clear snapshot
get_snapshot_diff(graph)
n1_parents = parents(n1)
n1_children = children(n1)
n1_parents = parents(n1)
n1_children = children(n1)
#=if length(n1_parents) <= 1
error("[Node Split] The given node does not have multiple parents which is required for node split")
end=#
for parent in n1_parents
remove_edge!(graph, n1, parent)
end
for child in n1_children
remove_edge!(graph, child, n1)
end
remove_node!(graph, n1)
for parent in n1_parents
remove_edge!(graph, make_edge(n1, parent))
end
for child in n1_children
remove_edge!(graph, make_edge(child, n1))
end
remove_node!(graph, n1)
for parent in n1_parents
n_copy = copy(n1)
insert_node!(graph, n_copy)
insert_edge!(graph, n_copy, parent)
for parent in n1_parents
n_copy = copy(n1)
insert_node!(graph, n_copy)
insert_edge!(graph, make_edge(n_copy, parent))
for child in n1_children
insert_edge!(graph, child, n_copy)
end
end
for child in n1_children
insert_edge!(graph, make_edge(child, n_copy))
end
end
return get_snapshot_diff(graph)
return get_snapshot_diff(graph)
end

View File

@ -4,10 +4,8 @@
# pushes the found fusion everywhere it needs to be and returns nothing
function find_fusions!(graph::DAG, node::DataTaskNode)
# if there is already a fusion here, skip
for op in node.operations
if typeof(op) <: NodeFusion
return nothing
end
if !ismissing(node.nodeFusion)
return nothing
end
if length(node.parents) != 1 || length(node.children) != 1
@ -17,9 +15,9 @@ function find_fusions!(graph::DAG, node::DataTaskNode)
child_node = first(node.children)
parent_node = first(node.parents)
#=if !(child_node in graph) || !(parent_node in graph)
if !(child_node in graph) || !(parent_node in graph)
error("Parents/Children that are not in the graph!!!")
end=#
end
if length(child_node.parents) != 1
return nothing
@ -27,9 +25,9 @@ function find_fusions!(graph::DAG, node::DataTaskNode)
nf = NodeFusion((child_node, node, parent_node))
push!(graph.possibleOperations.nodeFusions, nf)
push!(child_node.operations, nf)
push!(node.operations, nf)
push!(parent_node.operations, nf)
push!(child_node.nodeFusions, nf)
node.nodeFusion = nf
push!(parent_node.nodeFusions, nf)
return nothing
end
@ -37,7 +35,6 @@ end
function find_fusions!(graph::DAG, node::ComputeTaskNode)
# just find fusions in neighbouring DataTaskNodes
for child in node.children
find_fusions!(graph, child)
end
@ -51,10 +48,8 @@ end
function find_reductions!(graph::DAG, node::Node)
# there can only be one reduction per node, avoid adding duplicates
for op in node.operations
if typeof(op) <: NodeReduction
return nothing
end
if !ismissing(node.nodeReduction)
return nothing
end
reductionVector = nothing
@ -62,7 +57,14 @@ function find_reductions!(graph::DAG, node::Node)
partners_ = partners(node)
delete!(partners_, node)
for partner in partners_
if partner graph.nodes
error("Partner is not part of the graph")
end
if can_reduce(node, partner)
if Set(node.children) != Set(partner.children)
error("Not equal children")
end
if reductionVector === nothing
# only when there's at least one reduction partner, insert the vector
reductionVector = Vector{Node}()
@ -77,7 +79,12 @@ function find_reductions!(graph::DAG, node::Node)
nr = NodeReduction(reductionVector)
push!(graph.possibleOperations.nodeReductions, nr)
for node in reductionVector
push!(node.operations, nr)
if !ismissing(node.nodeReduction)
# it can happen that the dirty node becomes part of an existing NodeReduction and overrides those ones now
# this is only a problem insofar the existing NodeReduction has to be deleted and replaced also in the possibleOperations
invalidate_caches!(graph, node.nodeReduction)
end
node.nodeReduction = nr
end
end
@ -85,10 +92,14 @@ function find_reductions!(graph::DAG, node::Node)
end
function find_splits!(graph::DAG, node::Node)
if !ismissing(node.nodeSplit)
return nothing
end
if (can_split(node))
ns = NodeSplit(node)
push!(graph.possibleOperations.nodeSplits, ns)
push!(node.operations, ns)
node.nodeSplit = ns
end
return nothing

View File

@ -2,49 +2,28 @@
using Base.Threads
function insert_operation!(operations::PossibleOperations, nf::NodeFusion, locks::Dict{Node, SpinLock})
function insert_operation!(nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock})
n1 = nf.input[1]; n2 = nf.input[2]; n3 = nf.input[3]
lock(locks[n1]) do; push!(nf.input[1].operations, nf); end
lock(locks[n2]) do; push!(nf.input[2].operations, nf); end
lock(locks[n3]) do; push!(nf.input[3].operations, nf); end
lock(locks[n1]) do; push!(nf.input[1].nodeFusions, nf); end
nf.input[2].nodeFusion = nf
lock(locks[n3]) do; push!(nf.input[3].nodeFusions, nf); end
return nothing
end
function insert_operation!(operations::PossibleOperations, nr::NodeReduction, locks::Dict{Node, SpinLock})
# since node parents were sorted before, the NodeReductions contain elements in a known order
# this, together with the locking, means that we can safely do the following without inserting duplicates
first = true
function insert_operation!(nr::NodeReduction)
for n in nr.input
skip_duplicate = false
# careful here, this is a manual lock (because of the break)
lock(locks[n])
if first
first = false
for op in n.operations
if typeof(op) <: NodeReduction
skip_duplicate = true
break
end
end
if skip_duplicate
unlock(locks[n])
break
end
end
push!(n.operations, nr)
unlock(locks[n])
n.nodeReduction = nr
end
return nothing
end
function insert_operation!(operations::PossibleOperations, ns::NodeSplit, locks::Dict{Node, SpinLock})
lock(locks[ns.input]) do; push!(ns.input.operations, ns); end
function insert_operation!(ns::NodeSplit)
ns.input.nodeSplit = ns
return nothing
end
function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}}, locks::Dict{Node, SpinLock})
function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}})
total_len = 0
for vec in nodeReductions
total_len += length(vec)
@ -58,7 +37,7 @@ function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Ve
@threads for vec in nodeReductions
for op in vec
insert_operation!(operations, op, locks)
insert_operation!(op)
end
end
@ -67,7 +46,7 @@ function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Ve
return nothing
end
function nf_insertion!(operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}}, locks::Dict{Node, SpinLock})
function nf_insertion!(graph::DAG, operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}})
total_len = 0
for vec in nodeFusions
total_len += length(vec)
@ -79,9 +58,16 @@ function nf_insertion!(operations::PossibleOperations, nodeFusions::Vector{Vecto
end
schedule(t)
locks = Dict{ComputeTaskNode, SpinLock}()
for n in graph.nodes
if (typeof(n) <: ComputeTaskNode)
locks[n] = SpinLock()
end
end
@threads for vec in nodeFusions
for op in vec
insert_operation!(operations, op, locks)
insert_operation!(op, locks)
end
end
@ -90,7 +76,7 @@ function nf_insertion!(operations::PossibleOperations, nodeFusions::Vector{Vecto
return nothing
end
function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}}, locks::Dict{Node, SpinLock})
function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}})
total_len = 0
for vec in nodeSplits
total_len += length(vec)
@ -104,7 +90,7 @@ function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector
@threads for vec in nodeSplits
for op in vec
insert_operation!(operations, op, locks)
insert_operation!(op)
end
end
@ -115,11 +101,6 @@ end
# function to generate all possible operations on the graph
function generate_options(graph::DAG)
locks = Dict{Node, SpinLock}()
for n in graph.nodes
locks[n] = SpinLock()
end
generatedFusions = [Vector{NodeFusion}() for _ in 1:nthreads()]
generatedReductions = [Vector{NodeReduction}() for _ in 1:nthreads()]
generatedSplits = [Vector{NodeSplit}() for _ in 1:nthreads()]
@ -174,7 +155,7 @@ function generate_options(graph::DAG)
# launch thread for node reduction insertion
# remove duplicates
nr_task = @task nr_insertion!(graph.possibleOperations, generatedReductions, locks)
nr_task = @task nr_insertion!(graph.possibleOperations, generatedReductions)
schedule(nr_task)
# --- find possible node fusions ---
@ -200,7 +181,7 @@ function generate_options(graph::DAG)
end
# launch thread for node fusion insertion
nf_task = @task nf_insertion!(graph.possibleOperations, generatedFusions, locks)
nf_task = @task nf_insertion!(graph, graph.possibleOperations, generatedFusions)
schedule(nf_task)
# find possible node splits
@ -211,7 +192,7 @@ function generate_options(graph::DAG)
end
# launch thread for node split insertion
ns_task = @task ns_insertion!(graph.possibleOperations, generatedSplits, locks)
ns_task = @task ns_insertion!(graph.possibleOperations, generatedSplits)
schedule(ns_task)
empty!(graph.dirtyNodes)

View File

@ -0,0 +1,61 @@
# functions to throw assertion errors for inconsistent or wrong node operations
# should be called with @assert
# the functions throw their own errors though, to still have helpful error messages
function is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
if !(n1 in graph) || !(n2 in graph) || !(n3 in graph)
throw(AssertionError("[Node Fusion] The given nodes are not part of the given graph"))
end
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
throw(AssertionError("[Node Fusion] The given nodes are not connected by edges which is required for node fusion"))
end
if length(n2.parents) > 1
throw(AssertionError("[Node Fusion] The given data node has more than one parent"))
end
if length(n2.children) > 1
throw(AssertionError("[Node Fusion] The given data node has more than one child"))
end
if length(n1.parents) > 1
throw(AssertionError("[Node Fusion] The given n1 has more than one parent"))
end
return true
end
function is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node})
for n in nodes
if n graph
throw(AssertionError("[Node Reduction] The given nodes are not part of the given graph"))
end
end
t = typeof(nodes[1].task)
for n in nodes
if typeof(n.task) != t
throw(AssertionError("[Node Reduction] The given nodes are not of the same type"))
end
end
n1_children = nodes[1].children
for n in nodes
if Set(n1_children) != Set(n.children)
throw(AssertionError("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction"))
end
end
return true
end
function is_valid_node_split_input(graph::DAG, n1::Node)
if n1 graph
throw(AssertionError("[Node Split] The given node is not part of the given graph"))
end
if length(n1.parents) <= 1
throw(AssertionError("[Node Split] The given node does not have multiple parents which is required for node split"))
end
return true
end