Rework node operations storage, remove make_edge from insert_edge calls

This commit is contained in:
Anton Reinhard 2023-08-23 19:28:45 +02:00
parent a81aafbf20
commit c365233ea4
16 changed files with 421 additions and 363 deletions

View File

@ -24,10 +24,10 @@ jobs:
version: '1.9.1'
- name: Install dependencies
run: julia --project -e 'import Pkg; Pkg.instantiate()'
run: julia --project=./ -e 'import Pkg; Pkg.instantiate()'
- name: Run tests
run: julia --project -t 4 -e 'import Pkg; Pkg.test()'
run: julia --project=./ -t 4 -e 'import Pkg; Pkg.test()' -O0
- name: Run examples
run: julia --project=examples/ -t 4 -e 'import Pkg; Pkg.develop(Pkg.PackageSpec(path=pwd())); Pkg.instantiate(); include("examples/import_bench.jl")'
run: julia --project=examples/ -t 4 -e 'import Pkg; Pkg.develop(Pkg.PackageSpec(path=pwd())); Pkg.instantiate(); include("examples/import_bench.jl")' -O3

View File

@ -8,13 +8,13 @@ For all the julia calls, use `-t n` to give julia `n` threads.
Instantiate the project first:
`julia --project -e 'import Pkg; Pkg.instantiate()'`
`julia --project=./ -e 'import Pkg; Pkg.instantiate()'`
### Run Tests
To run all tests, run
`julia --project=. -e 'import Pkg; Pkg.test()'`
`julia --project=./ -e 'import Pkg; Pkg.test()' -O0`
### Run Examples
@ -24,7 +24,7 @@ Get the correct environment for the examples folder:
Then execute a specific example:
`julia --project=examples examples/<file>.jl`
`julia --project=examples examples/<file>.jl -O3`
## Concepts

View File

@ -1,6 +1,7 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
MetagraphOptimization = "3e869610-d48d-4942-ba70-c1b702a33ca4"
PProf = "e4faabce-9ead-11e9-39d9-4379958e3056"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProfileView = "c46f51b8-102a-5cf2-8d2c-8597cb0e0da7"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"

View File

@ -15,3 +15,16 @@
(AB->ABBBBBBB, 6) 887.160 ms (5596691 allocations: 763.42 MiB)
(AB->ABBBBBBB, 7) 898.757 ms (5596762 allocations: 789.91 MiB)
(AB->ABBBBBBB, 8) 497.545 ms (5596820 allocations: 759.66 MiB)
Initial:
$ julia --project=examples/ -e 'using BenchmarkTools; using MetagraphOptimization; parse_abc("input/AB->AB.txt"); @time g = parse_abc("input/AB->ABBBBBBBBB.txt")'
65.370947 seconds (626.10 M allocations: 37.381 GiB, 53.59% gc time, 0.01% compilation time)
Removing make_edge from calls in parse:
50.053920 seconds (593.41 M allocations: 32.921 GiB, 49.70% gc time, 0.09% compilation time)
Nodes operation storage rework (and O3):
31.997128 seconds (450.66 M allocations: 25.294 GiB, 31.56% gc time, 0.14% compilation time)

View File

@ -6,20 +6,20 @@ julia --project=./examples -t 4 -e 'import Pkg; Pkg.instantiate()'
#for i in $(seq $minthreads $maxthreads)
# printf "(AB->AB, $i) "
# julia --project=./examples -t $i -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("examples/AB->AB.txt"))'
# julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("input/AB->AB.txt"))'
#end
#for i in $(seq $minthreads $maxthreads)
# printf "(AB->ABBB, $i) "
# julia --project=./examples -t $i -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("examples/AB->ABBB.txt"))'
# julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("input/AB->ABBB.txt"))'
#end
#for i in $(seq $minthreads $maxthreads)
# printf "(AB->ABBBBB, $i) "
# julia --project=./examples -t $i -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("examples/AB->ABBBBB.txt"))'
# julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("input/AB->ABBBBB.txt"))'
#end
for i in $(seq $minthreads $maxthreads)
printf "(AB->ABBBBBBB, $i) "
julia --project=./examples -t $i -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("examples/AB->ABBBBBBB.txt"))'
julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("input/AB->ABBBBBBB.txt"))'
end

View File

@ -38,6 +38,7 @@ include("operations/clean.jl")
include("operations/find.jl")
include("operations/get.jl")
include("operations/print.jl")
include("operations/validate.jl")
include("graph_interface.jl")

View File

@ -42,7 +42,7 @@ function parse_abc(filename::String, verbose::Bool = false)
sum_node = insert_node!(graph, make_node(ComputeTaskSum()), false, false)
global_data_out = insert_node!(graph, make_node(DataTask(10)), false, false)
insert_edge!(graph, make_edge(sum_node, global_data_out), false, false)
insert_edge!(graph, sum_node, global_data_out, false, false)
# remember the data out nodes for connection
dataOutNodes = Dict()
@ -64,10 +64,10 @@ function parse_abc(filename::String, verbose::Bool = false)
compute_u = insert_node!(graph, make_node(ComputeTaskU()), false, false) # compute U node
data_out = insert_node!(graph, make_node(DataTask(3)), false, false) # transfer data out from u
insert_edge!(graph, make_edge(data_in, compute_P), false, false)
insert_edge!(graph, make_edge(compute_P, data_Pu), false, false)
insert_edge!(graph, make_edge(data_Pu, compute_u), false, false)
insert_edge!(graph, make_edge(compute_u, data_out), false, false)
insert_edge!(graph, data_in, compute_P, false, false)
insert_edge!(graph, compute_P, data_Pu, false, false)
insert_edge!(graph, data_Pu, compute_u, false, false)
insert_edge!(graph, compute_u, data_out, false, false)
# remember the data_out node for future edges
dataOutNodes[node] = data_out
@ -80,34 +80,34 @@ function parse_abc(filename::String, verbose::Bool = false)
compute_v = insert_node!(graph, make_node(ComputeTaskV()), false, false)
data_out = insert_node!(graph, make_node(DataTask(5)), false, false)
if (occursin(regex_c, capt.captures[1]))
if (occursin(regex_c, in1))
# put an S node after this input
compute_S = insert_node!(graph, make_node(ComputeTaskS1()), false, false)
data_S_v = insert_node!(graph, make_node(DataTask(5)), false, false)
insert_edge!(graph, make_edge(dataOutNodes[capt.captures[1]], compute_S), false, false)
insert_edge!(graph, make_edge(compute_S, data_S_v), false, false)
insert_edge!(graph, dataOutNodes[in1], compute_S, false, false)
insert_edge!(graph, compute_S, data_S_v, false, false)
insert_edge!(graph, make_edge(data_S_v, compute_v), false, false)
insert_edge!(graph, data_S_v, compute_v, false, false)
else
insert_edge!(graph, make_edge(dataOutNodes[capt.captures[1]], compute_v), false, false)
insert_edge!(graph, dataOutNodes[in1], compute_v, false, false)
end
if (occursin(regex_c, capt.captures[2]))
if (occursin(regex_c, in2))
# i think the current generator only puts the combined particles in the first space, so this case might never be entered
# put an S node after this input
compute_S = insert_node!(graph, make_node(ComputeTaskS1()), false, false)
data_S_v = insert_node!(graph, make_node(DataTask(5)), false, false)
insert_edge!(graph, make_edge(dataOutNodes[capt.captures[2]], compute_S), false, false)
insert_edge!(graph, make_edge(compute_S, data_S_v), false, false)
insert_edge!(graph, dataOutNodes[in2], compute_S, false, false)
insert_edge!(graph, compute_S, data_S_v, false, false)
insert_edge!(graph, make_edge(data_S_v, compute_v), false, false)
insert_edge!(graph, data_S_v, compute_v, false, false)
else
insert_edge!(graph, make_edge(dataOutNodes[capt.captures[2]], compute_v), false, false)
insert_edge!(graph, dataOutNodes[in2], compute_v, false, false)
end
insert_edge!(graph, make_edge(compute_v, data_out), false, false)
insert_edge!(graph, compute_v, data_out, false, false)
dataOutNodes[node] = data_out
elseif occursin(regex_m, node)
@ -121,26 +121,26 @@ function parse_abc(filename::String, verbose::Bool = false)
compute_v = insert_node!(graph, make_node(ComputeTaskV()), false, false)
data_v = insert_node!(graph, make_node(DataTask(5)), false, false)
insert_edge!(graph, make_edge(dataOutNodes[in2], compute_v), false, false)
insert_edge!(graph, make_edge(dataOutNodes[in3], compute_v), false, false)
insert_edge!(graph, make_edge(compute_v, data_v), false, false)
insert_edge!(graph, dataOutNodes[in2], compute_v, false, false)
insert_edge!(graph, dataOutNodes[in3], compute_v, false, false)
insert_edge!(graph, compute_v, data_v, false, false)
# combine with the v of the combined other input
compute_S2 = insert_node!(graph, make_node(ComputeTaskS2()), false, false)
data_out = insert_node!(graph, make_node(DataTask(10)), false, false)
insert_edge!(graph, make_edge(data_v, compute_S2), false, false)
insert_edge!(graph, make_edge(dataOutNodes[in1], compute_S2), false, false)
insert_edge!(graph, make_edge(compute_S2, data_out), false, false)
insert_edge!(graph, data_v, compute_S2, false, false)
insert_edge!(graph, dataOutNodes[in1], compute_S2, false, false)
insert_edge!(graph, compute_S2, data_out, false, false)
insert_edge!(graph, make_edge(data_out, sum_node), false, false)
insert_edge!(graph, data_out, sum_node, false, false)
elseif occursin(regex_plus, node)
if (verbose)
println("\rReading Nodes Complete ")
println("Added ", length(graph.nodes), " nodes")
end
else
error("Unknown node '", node, "' while reading from file ", filename)
@assert false ("Unknown node '$node' while reading from file $filename")
end
end

View File

@ -67,27 +67,66 @@ end
is_entry_node(node::Node) = length(node.children) == 0
is_exit_node(node::Node) = length(node.parents) == 0
# function to invalidate the operation caches for a given operation
function invalidate_caches!(graph::DAG, operation::Operation)
# function to invalidate the operation caches for a given NodeFusion
function invalidate_caches!(graph::DAG, operation::NodeFusion)
delete!(graph.possibleOperations, operation)
# delete the operation from all caches of nodes involved in the operation
# (we can iterate over tuples and vectors just fine)
filter!(!=(operation), operation.input[1].nodeFusions)
filter!(!=(operation), operation.input[3].nodeFusions)
operation.input[2].nodeFusion = missing
return nothing
end
# function to invalidate the operation caches for a given NodeReduction
function invalidate_caches!(graph::DAG, operation::NodeReduction)
delete!(graph.possibleOperations, operation)
for node in operation.input
filter!(!=(operation), node.operations)
node.nodeReduction = missing
end
return nothing
end
# function to invalidate the operation caches for a given Node Split specifically
# function to invalidate the operation caches for a given NodeSplit
function invalidate_caches!(graph::DAG, operation::NodeSplit)
delete!(graph.possibleOperations, operation)
# delete the operation from all caches of nodes involved in the operation
# for node split there is only one node
filter!(x -> x != operation, operation.input.operations)
operation.input.nodeSplit = missing
return nothing
end
# function to invalidate the operation caches of a ComputeTaskNode
function invalidate_operation_caches!(graph::DAG, node::ComputeTaskNode)
if !ismissing(node.nodeReduction)
invalidate_caches!(graph, node.nodeReduction)
end
if !ismissing(node.nodeSplit)
invalidate_caches!(graph, node.nodeSplit)
end
while !isempty(node.nodeFusions)
invalidate_caches!(graph, pop!(node.nodeFusions))
end
return nothing
end
# function to invalidate the operation caches of a DataTaskNode
function invalidate_operation_caches!(graph::DAG, node::DataTaskNode)
if !ismissing(node.nodeReduction)
invalidate_caches!(graph, node.nodeReduction)
end
if !ismissing(node.nodeSplit)
invalidate_caches!(graph, node.nodeSplit)
end
if !ismissing(node.nodeFusion)
invalidate_caches!(graph, node.nodeFusion)
end
return nothing
end
@ -110,93 +149,72 @@ function insert_node!(graph::DAG, node::Node, track=true, invalidate_cache=true)
return node
end
function insert_edge!(graph::DAG, edge::Edge, track=true, invalidate_cache=true)
node1 = edge.edge[1]
node2 = edge.edge[2]
function insert_edge!(graph::DAG, node1::Node, node2::Node, track=true, invalidate_cache=true)
# @assert (node2 ∉ node1.parents) && (node1 ∉ node2.children) "Edge to insert already exists"
# 1: mute
#=if (node2 in node1.parents) || (node1 in node2.children)
if !(node2 in node1.parents && node1 in node2.children)
error("One-sided edge")
end
error("Edge to insert already exists")
end=#
# edge points from child to parent
push!(node1.parents, node2)
push!(node2.children, node1)
# 2: keep track
if (track) push!(graph.diff.addedEdges, edge) end
if (track) push!(graph.diff.addedEdges, make_edge(node1, node2)) end
# 3: invalidate caches
if (!invalidate_cache) return edge end
if (!invalidate_cache) return nothing end
invalidate_operation_caches!(graph, node1)
invalidate_operation_caches!(graph, node2)
while !isempty(node1.operations)
invalidate_caches!(graph, first(node1.operations))
end
while !isempty(node2.operations)
invalidate_caches!(graph, first(node2.operations))
end
push!(graph.dirtyNodes, node1)
push!(graph.dirtyNodes, node2)
return edge
return nothing
end
function remove_node!(graph::DAG, node::Node, track=true, invalidate_cache=true)
# @assert node in graph.nodes "Trying to remove a node that's not in the graph"
# 1: mute
#=if !(node in graph.nodes)
error("Trying to remove a node that's not in the graph")
end=#
delete!(graph.nodes, node)
# 2: keep track
if (track) push!(graph.diff.removedNodes, node) end
# 3: invalidate caches
if (!invalidate_cache) return node end
if (!invalidate_cache) return nothing end
while !isempty(node.operations)
invalidate_caches!(graph, first(node.operations))
end
invalidate_operation_caches!(graph, node)
delete!(graph.dirtyNodes, node)
return nothing
end
function remove_edge!(graph::DAG, edge::Edge, track=true, invalidate_cache=true)
node1 = edge.edge[1]
node2 = edge.edge[2]
function remove_edge!(graph::DAG, node1::Node, node2::Node, track=true, invalidate_cache=true)
# 1: mute
pre_length1 = length(node1.parents)
pre_length2 = length(node2.children)
filter!(x -> x != node2, node1.parents)
filter!(x -> x != node1, node2.children)
#=removed = pre_length1 - length(node1.parents)
if (removed > 1)
error("removed $removed from node1's parents")
end
#=@assert begin
removed = pre_length1 - length(node1.parents)
removed <= 1
end "removed more than one node from node1's parents"=#
removed = pre_length2 - length(node2.children)
if (removed > 1)
error("removed $removed from node2's children")
end=#
#=@assert begin
removed = pre_length2 - length(node2.children)
removed <= 1
end "removed more than one node from node2's children"=#
# 2: keep track
if (track) push!(graph.diff.removedEdges, edge) end
if (track) push!(graph.diff.removedEdges, make_edge(node1, node2)) end
# 3: invalidate caches
if (!invalidate_cache) return nothing end
while !isempty(node1.operations)
invalidate_caches!(graph, first(node1.operations))
end
while !isempty(node2.operations)
invalidate_caches!(graph, first(node2.operations))
end
invalidate_operation_caches!(graph, node1)
invalidate_operation_caches!(graph, node2)
if (node1 in graph)
push!(graph.dirtyNodes, node1)
end
@ -241,7 +259,7 @@ function get_exit_node(graph::DAG)
return node
end
end
error("The given graph has no exit node! It is either empty or not acyclic!")
@assert false "The given graph has no exit node! It is either empty or not acyclic!"
end
# check whether the given graph is connected

View File

@ -46,5 +46,6 @@ function ==(n1::DataTaskNode, n2::DataTaskNode)
return n1.id == n2.id
end
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng[threadid()]), copy(n.operations))
copy(n::DataTaskNode) = DataTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng[threadid()]), copy(n.operations))
copy(m::Missing) = missing
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng[threadid()]), copy(n.nodeReduction), copy(n.nodeSplit), copy(n.nodeFusions))
copy(n::DataTaskNode) = DataTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng[threadid()]), copy(n.nodeReduction), copy(n.nodeSplit), copy(n.nodeFusion))

View File

@ -2,6 +2,7 @@ using Random
using UUIDs
using Base.Threads
# TODO: reliably find out how many threads we're running with (nthreads() returns 1 when precompiling :/)
rng = [Random.MersenneTwister(0) for _ in 1:32]
abstract type Node end
@ -10,7 +11,7 @@ abstract type Node end
# the specific operations are declared in graph.jl
abstract type Operation end
struct DataTaskNode <: Node
mutable struct DataTaskNode <: Node
task::AbstractDataTask
# use vectors as sets have way too much memory overhead
@ -21,21 +22,33 @@ struct DataTaskNode <: Node
# however, it can be copied when splitting a node
id::Base.UUID
# a vector holding references to the graph operations involving this node
operations::Vector{Operation}
# the NodeReduction involving this node, if it exists
# Can't use the NodeReduction type here because it's not yet defined
nodeReduction::Union{Operation, Missing}
# the NodeSplit involving this node, if it exists
nodeSplit::Union{Operation, Missing}
# the node fusion involving this node, if it exists
nodeFusion::Union{Operation, Missing}
end
# same as DataTaskNode
struct ComputeTaskNode <: Node
mutable struct ComputeTaskNode <: Node
task::AbstractComputeTask
parents::Vector{Node}
children::Vector{Node}
id::Base.UUID
operations::Vector{Operation}
nodeReduction::Union{Operation, Missing}
nodeSplit::Union{Operation, Missing}
# for ComputeTasks there can be multiple fusions, unlike the DataTasks
nodeFusions::Vector{Operation}
end
DataTaskNode(t::AbstractDataTask) = DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), Vector{Operation}())
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), Vector{Operation}())
DataTaskNode(t::AbstractDataTask) = DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, missing)
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, Vector{NodeFusion}())
struct Edge
# edge points from child to parent

View File

@ -2,237 +2,197 @@
# applies all unapplied operations in the DAG
function apply_all!(graph::DAG)
while !isempty(graph.operationsToApply)
# get next operation to apply from front of the deque
op = popfirst!(graph.operationsToApply)
while !isempty(graph.operationsToApply)
# get next operation to apply from front of the deque
op = popfirst!(graph.operationsToApply)
# apply it
appliedOp = apply_operation!(graph, op)
# apply it
appliedOp = apply_operation!(graph, op)
# push to the end of the appliedOperations deque
push!(graph.appliedOperations, appliedOp)
end
return nothing
# push to the end of the appliedOperations deque
push!(graph.appliedOperations, appliedOp)
end
return nothing
end
function apply_operation!(graph::DAG, operation::Operation)
error("Unknown operation type!")
error("Unknown operation type!")
end
function apply_operation!(graph::DAG, operation::NodeFusion)
diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3])
return AppliedNodeFusion(operation, diff)
diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3])
return AppliedNodeFusion(operation, diff)
end
function apply_operation!(graph::DAG, operation::NodeReduction)
diff = node_reduction!(graph, operation.input)
return AppliedNodeReduction(operation, diff)
diff = node_reduction!(graph, operation.input)
return AppliedNodeReduction(operation, diff)
end
function apply_operation!(graph::DAG, operation::NodeSplit)
diff = node_split!(graph, operation.input)
return AppliedNodeSplit(operation, diff)
diff = node_split!(graph, operation.input)
return AppliedNodeSplit(operation, diff)
end
function revert_operation!(graph::DAG, operation::AppliedOperation)
error("Unknown operation type!")
error("Unknown operation type!")
end
function revert_operation!(graph::DAG, operation::AppliedNodeFusion)
revert_diff!(graph, operation.diff)
return operation.operation
revert_diff!(graph, operation.diff)
return operation.operation
end
function revert_operation!(graph::DAG, operation::AppliedNodeReduction)
revert_diff!(graph, operation.diff)
return operation.operation
revert_diff!(graph, operation.diff)
return operation.operation
end
function revert_operation!(graph::DAG, operation::AppliedNodeSplit)
revert_diff!(graph, operation.diff)
return operation.operation
revert_diff!(graph, operation.diff)
return operation.operation
end
function revert_diff!(graph::DAG, diff)
# add removed nodes, remove added nodes, same for edges
# note the order
for edge in diff.addedEdges
remove_edge!(graph, edge, false)
end
for node in diff.addedNodes
remove_node!(graph, node, false)
end
function revert_diff!(graph::DAG, diff::Diff)
# add removed nodes, remove added nodes, same for edges
# note the order
for edge in diff.addedEdges
remove_edge!(graph, edge.edge[1], edge.edge[2], false)
end
for node in diff.addedNodes
remove_node!(graph, node, false)
end
for node in diff.removedNodes
insert_node!(graph, node, false)
end
for edge in diff.removedEdges
insert_edge!(graph, edge, false)
end
for node in diff.removedNodes
insert_node!(graph, node, false)
end
for edge in diff.removedEdges
insert_edge!(graph, edge.edge[1], edge.edge[2], false)
end
end
# Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph
function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
# clear snapshot
get_snapshot_diff(graph)
# @assert is_valid_node_fusion_input(graph, n1, n2, n3)
if !(n1 in graph) || !(n2 in graph) || !(n3 in graph)
error("[Node Fusion] The given nodes are not part of the given graph")
end
# clear snapshot
get_snapshot_diff(graph)
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
# the checks are redundant but maybe a good sanity check
error("[Node Fusion] The given nodes are not connected by edges which is required for node fusion")
end
# save children and parents
n1_children = children(n1)
n3_parents = parents(n3)
n3_children = children(n3)
# save children and parents
n1_children = children(n1)
n3_parents = parents(n3)
n3_children = children(n3)
if length(n2.parents) > 1
error("[Node Fusion] The given data node has more than one parent")
end
if length(n2.children) > 1
error("[Node Fusion] The given data node has more than one child")
end
if length(n1.parents) > 1
error("[Node Fusion] The given n1 has more than one parent")
end
# remove the edges and nodes that will be replaced by the fused node
remove_edge!(graph, n1, n2)
remove_edge!(graph, n2, n3)
remove_node!(graph, n1)
remove_node!(graph, n2)
required_edge1 = make_edge(n1, n2)
required_edge2 = make_edge(n2, n3)
# get n3's children now so it automatically excludes n2
n3_children = children(n3)
remove_node!(graph, n3)
# remove the edges and nodes that will be replaced by the fused node
remove_edge!(graph, required_edge1)
remove_edge!(graph, required_edge2)
remove_node!(graph, n1)
remove_node!(graph, n2)
# create new node with the fused compute task
new_node = ComputeTaskNode(FusedComputeTask{typeof(n1.task),typeof(n3.task)}())
insert_node!(graph, new_node)
# get n3's children now so it automatically excludes n2
n3_children = children(n3)
remove_node!(graph, n3)
# use a set for combined children of n1 and n3 to not get duplicates
n1and3_children = Set{Node}()
# create new node with the fused compute task
new_node = ComputeTaskNode(FusedComputeTask{typeof(n1.task),typeof(n3.task)}())
insert_node!(graph, new_node)
# remove edges from n1 children to n1
for child in n1_children
remove_edge!(graph, child, n1)
push!(n1and3_children, child)
end
# use a set for combined children of n1 and n3 to not get duplicates
n1and3_children = Set{Node}()
# remove edges from n3 children to n3
for child in n3_children
remove_edge!(graph, child, n3)
push!(n1and3_children, child)
end
# remove edges from n1 children to n1
for child in n1_children
remove_edge!(graph, make_edge(child, n1))
push!(n1and3_children, child)
end
for child in n1and3_children
insert_edge!(graph, child, new_node)
end
# remove edges from n3 children to n3
for child in n3_children
remove_edge!(graph, make_edge(child, n3))
push!(n1and3_children, child)
end
# "repoint" parents of n3 from new node
for parent in n3_parents
remove_edge!(graph, n3, parent)
insert_edge!(graph, new_node, parent)
end
for child in n1and3_children
insert_edge!(graph, make_edge(child, new_node))
end
# "repoint" parents of n3 from new node
for parent in n3_parents
remove_edge!(graph, make_edge(n3, parent))
insert_edge!(graph, make_edge(new_node, parent))
end
return get_snapshot_diff(graph)
return get_snapshot_diff(graph)
end
function node_reduction!(graph::DAG, nodes::Vector{Node})
# clear snapshot
get_snapshot_diff(graph)
# @assert is_valid_node_reduction_input(graph, nodes)
t = typeof(nodes[1].task)
for n in nodes
if n graph
error("[Node Reduction] The given nodes are not part of the given graph")
end
# clear snapshot
get_snapshot_diff(graph)
if typeof(n.task) != t
error("[Node Reduction] The given nodes are not of the same type")
end
end
n1 = nodes[1]
n1_children = children(n1)
n1_parents = Set(n1.parents)
new_parents = Set{Node}()
# remove all of the nodes' parents and children and the nodes themselves (except for first node)
for i in 2:length(nodes)
n = nodes[i]
for child in n1_children
remove_edge!(graph, child, n)
end
n1 = nodes[1]
n1_children = children(n1)
for n in nodes
if Set(n1_children) != Set(n.children)
error("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction")
end
end
for parent in parents(n)
remove_edge!(graph, n, parent)
n1_parents = Set(n1.parents)
new_parents = Set{Node}()
# collect all parents
push!(new_parents, parent)
end
# remove all of the nodes' parents and children and the nodes themselves (except for first node)
for i in 2:length(nodes)
n = nodes[i]
for child in n1_children
remove_edge!(graph, make_edge(child, n))
end
remove_node!(graph, n)
end
for parent in parents(n)
remove_edge!(graph, make_edge(n, parent))
setdiff!(new_parents, n1_parents)
# collect all parents
push!(new_parents, parent)
end
for parent in new_parents
# now add parents of all input nodes to n1 without duplicates
insert_edge!(graph, n1, parent)
end
remove_node!(graph, n)
end
setdiff!(new_parents, n1_parents)
for parent in new_parents
# now add parents of n2 to n1 without duplicates
insert_edge!(graph, make_edge(n1, parent))
end
return get_snapshot_diff(graph)
return get_snapshot_diff(graph)
end
function node_split!(graph::DAG, n1::Node)
# clear snapshot
get_snapshot_diff(graph)
# @assert is_valid_node_split_input(graph, n1)
#=if !(n1 in graph)
error("[Node Split] The given node is not part of the given graph")
end=#
# clear snapshot
get_snapshot_diff(graph)
n1_parents = parents(n1)
n1_children = children(n1)
n1_parents = parents(n1)
n1_children = children(n1)
#=if length(n1_parents) <= 1
error("[Node Split] The given node does not have multiple parents which is required for node split")
end=#
for parent in n1_parents
remove_edge!(graph, n1, parent)
end
for child in n1_children
remove_edge!(graph, child, n1)
end
remove_node!(graph, n1)
for parent in n1_parents
remove_edge!(graph, make_edge(n1, parent))
end
for child in n1_children
remove_edge!(graph, make_edge(child, n1))
end
remove_node!(graph, n1)
for parent in n1_parents
n_copy = copy(n1)
insert_node!(graph, n_copy)
insert_edge!(graph, n_copy, parent)
for parent in n1_parents
n_copy = copy(n1)
insert_node!(graph, n_copy)
insert_edge!(graph, make_edge(n_copy, parent))
for child in n1_children
insert_edge!(graph, child, n_copy)
end
end
for child in n1_children
insert_edge!(graph, make_edge(child, n_copy))
end
end
return get_snapshot_diff(graph)
return get_snapshot_diff(graph)
end

View File

@ -4,10 +4,8 @@
# pushes the found fusion everywhere it needs to be and returns nothing
function find_fusions!(graph::DAG, node::DataTaskNode)
# if there is already a fusion here, skip
for op in node.operations
if typeof(op) <: NodeFusion
return nothing
end
if !ismissing(node.nodeFusion)
return nothing
end
if length(node.parents) != 1 || length(node.children) != 1
@ -17,9 +15,9 @@ function find_fusions!(graph::DAG, node::DataTaskNode)
child_node = first(node.children)
parent_node = first(node.parents)
#=if !(child_node in graph) || !(parent_node in graph)
if !(child_node in graph) || !(parent_node in graph)
error("Parents/Children that are not in the graph!!!")
end=#
end
if length(child_node.parents) != 1
return nothing
@ -27,9 +25,9 @@ function find_fusions!(graph::DAG, node::DataTaskNode)
nf = NodeFusion((child_node, node, parent_node))
push!(graph.possibleOperations.nodeFusions, nf)
push!(child_node.operations, nf)
push!(node.operations, nf)
push!(parent_node.operations, nf)
push!(child_node.nodeFusions, nf)
node.nodeFusion = nf
push!(parent_node.nodeFusions, nf)
return nothing
end
@ -37,7 +35,6 @@ end
function find_fusions!(graph::DAG, node::ComputeTaskNode)
# just find fusions in neighbouring DataTaskNodes
for child in node.children
find_fusions!(graph, child)
end
@ -51,10 +48,8 @@ end
function find_reductions!(graph::DAG, node::Node)
# there can only be one reduction per node, avoid adding duplicates
for op in node.operations
if typeof(op) <: NodeReduction
return nothing
end
if !ismissing(node.nodeReduction)
return nothing
end
reductionVector = nothing
@ -62,7 +57,14 @@ function find_reductions!(graph::DAG, node::Node)
partners_ = partners(node)
delete!(partners_, node)
for partner in partners_
if partner graph.nodes
error("Partner is not part of the graph")
end
if can_reduce(node, partner)
if Set(node.children) != Set(partner.children)
error("Not equal children")
end
if reductionVector === nothing
# only when there's at least one reduction partner, insert the vector
reductionVector = Vector{Node}()
@ -77,7 +79,12 @@ function find_reductions!(graph::DAG, node::Node)
nr = NodeReduction(reductionVector)
push!(graph.possibleOperations.nodeReductions, nr)
for node in reductionVector
push!(node.operations, nr)
if !ismissing(node.nodeReduction)
# it can happen that the dirty node becomes part of an existing NodeReduction and overrides those ones now
# this is only a problem insofar the existing NodeReduction has to be deleted and replaced also in the possibleOperations
invalidate_caches!(graph, node.nodeReduction)
end
node.nodeReduction = nr
end
end
@ -85,10 +92,14 @@ function find_reductions!(graph::DAG, node::Node)
end
function find_splits!(graph::DAG, node::Node)
if !ismissing(node.nodeSplit)
return nothing
end
if (can_split(node))
ns = NodeSplit(node)
push!(graph.possibleOperations.nodeSplits, ns)
push!(node.operations, ns)
node.nodeSplit = ns
end
return nothing

View File

@ -2,49 +2,28 @@
using Base.Threads
function insert_operation!(operations::PossibleOperations, nf::NodeFusion, locks::Dict{Node, SpinLock})
function insert_operation!(nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock})
n1 = nf.input[1]; n2 = nf.input[2]; n3 = nf.input[3]
lock(locks[n1]) do; push!(nf.input[1].operations, nf); end
lock(locks[n2]) do; push!(nf.input[2].operations, nf); end
lock(locks[n3]) do; push!(nf.input[3].operations, nf); end
lock(locks[n1]) do; push!(nf.input[1].nodeFusions, nf); end
nf.input[2].nodeFusion = nf
lock(locks[n3]) do; push!(nf.input[3].nodeFusions, nf); end
return nothing
end
function insert_operation!(operations::PossibleOperations, nr::NodeReduction, locks::Dict{Node, SpinLock})
# since node parents were sorted before, the NodeReductions contain elements in a known order
# this, together with the locking, means that we can safely do the following without inserting duplicates
first = true
function insert_operation!(nr::NodeReduction)
for n in nr.input
skip_duplicate = false
# careful here, this is a manual lock (because of the break)
lock(locks[n])
if first
first = false
for op in n.operations
if typeof(op) <: NodeReduction
skip_duplicate = true
break
end
end
if skip_duplicate
unlock(locks[n])
break
end
end
push!(n.operations, nr)
unlock(locks[n])
n.nodeReduction = nr
end
return nothing
end
function insert_operation!(operations::PossibleOperations, ns::NodeSplit, locks::Dict{Node, SpinLock})
lock(locks[ns.input]) do; push!(ns.input.operations, ns); end
function insert_operation!(ns::NodeSplit)
ns.input.nodeSplit = ns
return nothing
end
function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}}, locks::Dict{Node, SpinLock})
function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}})
total_len = 0
for vec in nodeReductions
total_len += length(vec)
@ -58,7 +37,7 @@ function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Ve
@threads for vec in nodeReductions
for op in vec
insert_operation!(operations, op, locks)
insert_operation!(op)
end
end
@ -67,7 +46,7 @@ function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Ve
return nothing
end
function nf_insertion!(operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}}, locks::Dict{Node, SpinLock})
function nf_insertion!(graph::DAG, operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}})
total_len = 0
for vec in nodeFusions
total_len += length(vec)
@ -79,9 +58,16 @@ function nf_insertion!(operations::PossibleOperations, nodeFusions::Vector{Vecto
end
schedule(t)
locks = Dict{ComputeTaskNode, SpinLock}()
for n in graph.nodes
if (typeof(n) <: ComputeTaskNode)
locks[n] = SpinLock()
end
end
@threads for vec in nodeFusions
for op in vec
insert_operation!(operations, op, locks)
insert_operation!(op, locks)
end
end
@ -90,7 +76,7 @@ function nf_insertion!(operations::PossibleOperations, nodeFusions::Vector{Vecto
return nothing
end
function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}}, locks::Dict{Node, SpinLock})
function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}})
total_len = 0
for vec in nodeSplits
total_len += length(vec)
@ -104,7 +90,7 @@ function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector
@threads for vec in nodeSplits
for op in vec
insert_operation!(operations, op, locks)
insert_operation!(op)
end
end
@ -115,11 +101,6 @@ end
# function to generate all possible operations on the graph
function generate_options(graph::DAG)
locks = Dict{Node, SpinLock}()
for n in graph.nodes
locks[n] = SpinLock()
end
generatedFusions = [Vector{NodeFusion}() for _ in 1:nthreads()]
generatedReductions = [Vector{NodeReduction}() for _ in 1:nthreads()]
generatedSplits = [Vector{NodeSplit}() for _ in 1:nthreads()]
@ -174,7 +155,7 @@ function generate_options(graph::DAG)
# launch thread for node reduction insertion
# remove duplicates
nr_task = @task nr_insertion!(graph.possibleOperations, generatedReductions, locks)
nr_task = @task nr_insertion!(graph.possibleOperations, generatedReductions)
schedule(nr_task)
# --- find possible node fusions ---
@ -200,7 +181,7 @@ function generate_options(graph::DAG)
end
# launch thread for node fusion insertion
nf_task = @task nf_insertion!(graph.possibleOperations, generatedFusions, locks)
nf_task = @task nf_insertion!(graph, graph.possibleOperations, generatedFusions)
schedule(nf_task)
# find possible node splits
@ -211,7 +192,7 @@ function generate_options(graph::DAG)
end
# launch thread for node split insertion
ns_task = @task ns_insertion!(graph.possibleOperations, generatedSplits, locks)
ns_task = @task ns_insertion!(graph.possibleOperations, generatedSplits)
schedule(ns_task)
empty!(graph.dirtyNodes)

View File

@ -0,0 +1,61 @@
# functions to throw assertion errors for inconsistent or wrong node operations
# should be called with @assert
# the functions throw their own errors though, to still have helpful error messages
function is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
if !(n1 in graph) || !(n2 in graph) || !(n3 in graph)
throw(AssertionError("[Node Fusion] The given nodes are not part of the given graph"))
end
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
throw(AssertionError("[Node Fusion] The given nodes are not connected by edges which is required for node fusion"))
end
if length(n2.parents) > 1
throw(AssertionError("[Node Fusion] The given data node has more than one parent"))
end
if length(n2.children) > 1
throw(AssertionError("[Node Fusion] The given data node has more than one child"))
end
if length(n1.parents) > 1
throw(AssertionError("[Node Fusion] The given n1 has more than one parent"))
end
return true
end
function is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node})
for n in nodes
if n graph
throw(AssertionError("[Node Reduction] The given nodes are not part of the given graph"))
end
end
t = typeof(nodes[1].task)
for n in nodes
if typeof(n.task) != t
throw(AssertionError("[Node Reduction] The given nodes are not of the same type"))
end
end
n1_children = nodes[1].children
for n in nodes
if Set(n1_children) != Set(n.children)
throw(AssertionError("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction"))
end
end
return true
end
function is_valid_node_split_input(graph::DAG, n1::Node)
if n1 graph
throw(AssertionError("[Node Split] The given node is not part of the given graph"))
end
if length(n1.parents) <= 1
throw(AssertionError("[Node Split] The given node does not have multiple parents which is required for node split"))
end
return true
end

View File

@ -1,7 +1,6 @@
import MetagraphOptimization.insert_node!
import MetagraphOptimization.insert_edge!
import MetagraphOptimization.make_node
import MetagraphOptimization.make_edge
@testset "Unit Tests Node Reduction" begin
graph = MetagraphOptimization.DAG()
@ -30,27 +29,27 @@ import MetagraphOptimization.make_edge
BD = insert_node!(graph, make_node(DataTask(5)), false)
CD = insert_node!(graph, make_node(DataTask(5)), false)
insert_edge!(graph, make_edge(s0, d_exit), false)
insert_edge!(graph, make_edge(ED, s0), false)
insert_edge!(graph, make_edge(FD, s0), false)
insert_edge!(graph, make_edge(EC, ED), false)
insert_edge!(graph, make_edge(FC, FD), false)
insert_edge!(graph, s0, d_exit, false)
insert_edge!(graph, ED, s0, false)
insert_edge!(graph, FD, s0, false)
insert_edge!(graph, EC, ED, false)
insert_edge!(graph, FC, FD, false)
insert_edge!(graph, make_edge(A1D, EC), false)
insert_edge!(graph, make_edge(B1D_1, EC), false)
insert_edge!(graph, A1D, EC, false)
insert_edge!(graph, B1D_1, EC, false)
insert_edge!(graph, make_edge(B1D_2, FC), false)
insert_edge!(graph, make_edge(C1D, FC), false)
insert_edge!(graph, B1D_2, FC, false)
insert_edge!(graph, C1D, FC, false)
insert_edge!(graph, make_edge(A1C, A1D), false)
insert_edge!(graph, make_edge(B1C_1, B1D_1), false)
insert_edge!(graph, make_edge(B1C_2, B1D_2), false)
insert_edge!(graph, make_edge(C1C, C1D), false)
insert_edge!(graph, A1C, A1D, false)
insert_edge!(graph, B1C_1, B1D_1, false)
insert_edge!(graph, B1C_2, B1D_2, false)
insert_edge!(graph, C1C, C1D, false)
insert_edge!(graph, make_edge(AD, A1C), false)
insert_edge!(graph, make_edge(BD, B1C_1), false)
insert_edge!(graph, make_edge(BD, B1C_2), false)
insert_edge!(graph, make_edge(CD, C1C), false)
insert_edge!(graph, AD, A1C, false)
insert_edge!(graph, BD, B1C_1, false)
insert_edge!(graph, BD, B1C_2, false)
insert_edge!(graph, CD, C1C, false)
@test is_exit_node(d_exit)
@test is_entry_node(AD)

View File

@ -1,7 +1,6 @@
import MetagraphOptimization.insert_node!
import MetagraphOptimization.insert_edge!
import MetagraphOptimization.make_node
import MetagraphOptimization.make_edge
import MetagraphOptimization.siblings
import MetagraphOptimization.partners
@ -69,38 +68,38 @@ import MetagraphOptimization.partners
@test length(graph.dirtyNodes) == 26
# now for all the edgese
insert_edge!(graph, make_edge(d_PB, PB), false)
insert_edge!(graph, make_edge(d_PA, PA), false)
insert_edge!(graph, make_edge(d_PBp, PBp), false)
insert_edge!(graph, make_edge(d_PAp, PAp), false)
insert_edge!(graph, d_PB, PB, false)
insert_edge!(graph, d_PA, PA, false)
insert_edge!(graph, d_PBp, PBp, false)
insert_edge!(graph, d_PAp, PAp, false)
insert_edge!(graph, make_edge(PB, d_PB_uB), false)
insert_edge!(graph, make_edge(PA, d_PA_uA), false)
insert_edge!(graph, make_edge(PBp, d_PBp_uBp), false)
insert_edge!(graph, make_edge(PAp, d_PAp_uAp), false)
insert_edge!(graph, PB, d_PB_uB, false)
insert_edge!(graph, PA, d_PA_uA, false)
insert_edge!(graph, PBp, d_PBp_uBp, false)
insert_edge!(graph, PAp, d_PAp_uAp, false)
insert_edge!(graph, make_edge(d_PB_uB, uB), false)
insert_edge!(graph, make_edge(d_PA_uA, uA), false)
insert_edge!(graph, make_edge(d_PBp_uBp, uBp), false)
insert_edge!(graph, make_edge(d_PAp_uAp, uAp), false)
insert_edge!(graph, d_PB_uB, uB, false)
insert_edge!(graph, d_PA_uA, uA, false)
insert_edge!(graph, d_PBp_uBp, uBp, false)
insert_edge!(graph, d_PAp_uAp, uAp, false)
insert_edge!(graph, make_edge(uB, d_uB_v0), false)
insert_edge!(graph, make_edge(uA, d_uA_v0), false)
insert_edge!(graph, make_edge(uBp, d_uBp_v1), false)
insert_edge!(graph, make_edge(uAp, d_uAp_v1), false)
insert_edge!(graph, uB, d_uB_v0, false)
insert_edge!(graph, uA, d_uA_v0, false)
insert_edge!(graph, uBp, d_uBp_v1, false)
insert_edge!(graph, uAp, d_uAp_v1, false)
insert_edge!(graph, make_edge(d_uB_v0, v0), false)
insert_edge!(graph, make_edge(d_uA_v0, v0), false)
insert_edge!(graph, make_edge(d_uBp_v1, v1), false)
insert_edge!(graph, make_edge(d_uAp_v1, v1), false)
insert_edge!(graph, d_uB_v0, v0, false)
insert_edge!(graph, d_uA_v0, v0, false)
insert_edge!(graph, d_uBp_v1, v1, false)
insert_edge!(graph, d_uAp_v1, v1, false)
insert_edge!(graph, make_edge(v0, d_v0_s0), false)
insert_edge!(graph, make_edge(v1, d_v1_s0), false)
insert_edge!(graph, v0, d_v0_s0, false)
insert_edge!(graph, v1, d_v1_s0, false)
insert_edge!(graph, make_edge(d_v0_s0, s0), false)
insert_edge!(graph, make_edge(d_v1_s0, s0), false)
insert_edge!(graph, d_v0_s0, s0, false)
insert_edge!(graph, d_v1_s0, s0, false)
insert_edge!(graph, make_edge(s0, d_exit), false)
insert_edge!(graph, s0, d_exit, false)
@test length(graph.nodes) == 26
@test length(graph.appliedOperations) == 0