Merge pull request 'Performance Improvements' (#4) from performance into main
Reviewed-on: Rubydragon/MetagraphOptimization.jl#4
This commit is contained in:
commit
383c92ec47
@ -24,10 +24,10 @@ jobs:
|
||||
version: '1.9.1'
|
||||
|
||||
- name: Install dependencies
|
||||
run: julia --project -e 'import Pkg; Pkg.instantiate()'
|
||||
run: julia --project=./ -e 'import Pkg; Pkg.instantiate()'
|
||||
|
||||
- name: Run tests
|
||||
run: julia --project -t 4 -e 'import Pkg; Pkg.test()'
|
||||
run: julia --project=./ -t 4 -e 'import Pkg; Pkg.test()' -O0
|
||||
|
||||
- name: Run examples
|
||||
run: julia --project=examples/ -t 4 -e 'import Pkg; Pkg.develop(Pkg.PackageSpec(path=pwd())); Pkg.instantiate(); include("examples/import_bench.jl")'
|
||||
run: julia --project=examples/ -t 4 -e 'import Pkg; Pkg.develop(Pkg.PackageSpec(path=pwd())); Pkg.instantiate(); include("examples/import_bench.jl")' -O3
|
||||
|
6
.gitignore
vendored
6
.gitignore
vendored
@ -1,10 +1,10 @@
|
||||
# ---> Julia
|
||||
# Files generated by invoking Julia with --code-coverage
|
||||
*.jl.cov
|
||||
*.jl.*.cov
|
||||
*.cov
|
||||
*.cov
|
||||
|
||||
# Files generated by invoking Julia with --track-allocation
|
||||
*.jl.mem
|
||||
*.mem
|
||||
|
||||
# System-specific files and directories generated by the BinaryProvider and BinDeps packages
|
||||
# They contain absolute paths specific to the host computer, and so should not be committed
|
||||
|
@ -8,13 +8,13 @@ For all the julia calls, use `-t n` to give julia `n` threads.
|
||||
|
||||
Instantiate the project first:
|
||||
|
||||
`julia --project -e 'import Pkg; Pkg.instantiate()'`
|
||||
`julia --project=./ -e 'import Pkg; Pkg.instantiate()'`
|
||||
|
||||
### Run Tests
|
||||
|
||||
To run all tests, run
|
||||
|
||||
`julia --project=. -e 'import Pkg; Pkg.test()'`
|
||||
`julia --project=./ -e 'import Pkg; Pkg.test()' -O0`
|
||||
|
||||
### Run Examples
|
||||
|
||||
@ -24,7 +24,7 @@ Get the correct environment for the examples folder:
|
||||
|
||||
Then execute a specific example:
|
||||
|
||||
`julia --project=examples examples/<file>.jl`
|
||||
`julia --project=examples examples/<file>.jl -O3`
|
||||
|
||||
## Concepts
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
[deps]
|
||||
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
|
||||
MetagraphOptimization = "3e869610-d48d-4942-ba70-c1b702a33ca4"
|
||||
PProf = "e4faabce-9ead-11e9-39d9-4379958e3056"
|
||||
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
|
||||
ProfileView = "c46f51b8-102a-5cf2-8d2c-8597cb0e0da7"
|
||||
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
|
||||
|
@ -15,3 +15,16 @@
|
||||
(AB->ABBBBBBB, 6) 887.160 ms (5596691 allocations: 763.42 MiB)
|
||||
(AB->ABBBBBBB, 7) 898.757 ms (5596762 allocations: 789.91 MiB)
|
||||
(AB->ABBBBBBB, 8) 497.545 ms (5596820 allocations: 759.66 MiB)
|
||||
|
||||
|
||||
Initial:
|
||||
|
||||
$ julia --project=examples/ -e 'using BenchmarkTools; using MetagraphOptimization; parse_abc("input/AB->AB.txt"); @time g = parse_abc("input/AB->ABBBBBBBBB.txt")'
|
||||
65.370947 seconds (626.10 M allocations: 37.381 GiB, 53.59% gc time, 0.01% compilation time)
|
||||
|
||||
Removing make_edge from calls in parse:
|
||||
50.053920 seconds (593.41 M allocations: 32.921 GiB, 49.70% gc time, 0.09% compilation time)
|
||||
|
||||
Nodes operation storage rework (and O3):
|
||||
31.997128 seconds (450.66 M allocations: 25.294 GiB, 31.56% gc time, 0.14% compilation time)
|
||||
|
@ -6,20 +6,20 @@ julia --project=./examples -t 4 -e 'import Pkg; Pkg.instantiate()'
|
||||
|
||||
#for i in $(seq $minthreads $maxthreads)
|
||||
# printf "(AB->AB, $i) "
|
||||
# julia --project=./examples -t $i -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("examples/AB->AB.txt"))'
|
||||
# julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("input/AB->AB.txt"))'
|
||||
#end
|
||||
|
||||
#for i in $(seq $minthreads $maxthreads)
|
||||
# printf "(AB->ABBB, $i) "
|
||||
# julia --project=./examples -t $i -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("examples/AB->ABBB.txt"))'
|
||||
# julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("input/AB->ABBB.txt"))'
|
||||
#end
|
||||
|
||||
#for i in $(seq $minthreads $maxthreads)
|
||||
# printf "(AB->ABBBBB, $i) "
|
||||
# julia --project=./examples -t $i -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("examples/AB->ABBBBB.txt"))'
|
||||
# julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("input/AB->ABBBBB.txt"))'
|
||||
#end
|
||||
|
||||
for i in $(seq $minthreads $maxthreads)
|
||||
printf "(AB->ABBBBBBB, $i) "
|
||||
julia --project=./examples -t $i -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("examples/AB->ABBBBBBB.txt"))'
|
||||
julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("input/AB->ABBBBBBB.txt"))'
|
||||
end
|
||||
|
@ -38,6 +38,7 @@ include("operations/clean.jl")
|
||||
include("operations/find.jl")
|
||||
include("operations/get.jl")
|
||||
include("operations/print.jl")
|
||||
include("operations/validate.jl")
|
||||
|
||||
include("graph_interface.jl")
|
||||
|
||||
|
@ -42,7 +42,7 @@ function parse_abc(filename::String, verbose::Bool = false)
|
||||
|
||||
sum_node = insert_node!(graph, make_node(ComputeTaskSum()), false, false)
|
||||
global_data_out = insert_node!(graph, make_node(DataTask(10)), false, false)
|
||||
insert_edge!(graph, make_edge(sum_node, global_data_out), false, false)
|
||||
insert_edge!(graph, sum_node, global_data_out, false, false)
|
||||
|
||||
# remember the data out nodes for connection
|
||||
dataOutNodes = Dict()
|
||||
@ -64,10 +64,10 @@ function parse_abc(filename::String, verbose::Bool = false)
|
||||
compute_u = insert_node!(graph, make_node(ComputeTaskU()), false, false) # compute U node
|
||||
data_out = insert_node!(graph, make_node(DataTask(3)), false, false) # transfer data out from u
|
||||
|
||||
insert_edge!(graph, make_edge(data_in, compute_P), false, false)
|
||||
insert_edge!(graph, make_edge(compute_P, data_Pu), false, false)
|
||||
insert_edge!(graph, make_edge(data_Pu, compute_u), false, false)
|
||||
insert_edge!(graph, make_edge(compute_u, data_out), false, false)
|
||||
insert_edge!(graph, data_in, compute_P, false, false)
|
||||
insert_edge!(graph, compute_P, data_Pu, false, false)
|
||||
insert_edge!(graph, data_Pu, compute_u, false, false)
|
||||
insert_edge!(graph, compute_u, data_out, false, false)
|
||||
|
||||
# remember the data_out node for future edges
|
||||
dataOutNodes[node] = data_out
|
||||
@ -80,34 +80,34 @@ function parse_abc(filename::String, verbose::Bool = false)
|
||||
compute_v = insert_node!(graph, make_node(ComputeTaskV()), false, false)
|
||||
data_out = insert_node!(graph, make_node(DataTask(5)), false, false)
|
||||
|
||||
if (occursin(regex_c, capt.captures[1]))
|
||||
if (occursin(regex_c, in1))
|
||||
# put an S node after this input
|
||||
compute_S = insert_node!(graph, make_node(ComputeTaskS1()), false, false)
|
||||
data_S_v = insert_node!(graph, make_node(DataTask(5)), false, false)
|
||||
|
||||
insert_edge!(graph, make_edge(dataOutNodes[capt.captures[1]], compute_S), false, false)
|
||||
insert_edge!(graph, make_edge(compute_S, data_S_v), false, false)
|
||||
insert_edge!(graph, dataOutNodes[in1], compute_S, false, false)
|
||||
insert_edge!(graph, compute_S, data_S_v, false, false)
|
||||
|
||||
insert_edge!(graph, make_edge(data_S_v, compute_v), false, false)
|
||||
insert_edge!(graph, data_S_v, compute_v, false, false)
|
||||
else
|
||||
insert_edge!(graph, make_edge(dataOutNodes[capt.captures[1]], compute_v), false, false)
|
||||
insert_edge!(graph, dataOutNodes[in1], compute_v, false, false)
|
||||
end
|
||||
|
||||
if (occursin(regex_c, capt.captures[2]))
|
||||
if (occursin(regex_c, in2))
|
||||
# i think the current generator only puts the combined particles in the first space, so this case might never be entered
|
||||
# put an S node after this input
|
||||
compute_S = insert_node!(graph, make_node(ComputeTaskS1()), false, false)
|
||||
data_S_v = insert_node!(graph, make_node(DataTask(5)), false, false)
|
||||
|
||||
insert_edge!(graph, make_edge(dataOutNodes[capt.captures[2]], compute_S), false, false)
|
||||
insert_edge!(graph, make_edge(compute_S, data_S_v), false, false)
|
||||
insert_edge!(graph, dataOutNodes[in2], compute_S, false, false)
|
||||
insert_edge!(graph, compute_S, data_S_v, false, false)
|
||||
|
||||
insert_edge!(graph, make_edge(data_S_v, compute_v), false, false)
|
||||
insert_edge!(graph, data_S_v, compute_v, false, false)
|
||||
else
|
||||
insert_edge!(graph, make_edge(dataOutNodes[capt.captures[2]], compute_v), false, false)
|
||||
insert_edge!(graph, dataOutNodes[in2], compute_v, false, false)
|
||||
end
|
||||
|
||||
insert_edge!(graph, make_edge(compute_v, data_out), false, false)
|
||||
insert_edge!(graph, compute_v, data_out, false, false)
|
||||
dataOutNodes[node] = data_out
|
||||
|
||||
elseif occursin(regex_m, node)
|
||||
@ -121,26 +121,26 @@ function parse_abc(filename::String, verbose::Bool = false)
|
||||
compute_v = insert_node!(graph, make_node(ComputeTaskV()), false, false)
|
||||
data_v = insert_node!(graph, make_node(DataTask(5)), false, false)
|
||||
|
||||
insert_edge!(graph, make_edge(dataOutNodes[in2], compute_v), false, false)
|
||||
insert_edge!(graph, make_edge(dataOutNodes[in3], compute_v), false, false)
|
||||
insert_edge!(graph, make_edge(compute_v, data_v), false, false)
|
||||
insert_edge!(graph, dataOutNodes[in2], compute_v, false, false)
|
||||
insert_edge!(graph, dataOutNodes[in3], compute_v, false, false)
|
||||
insert_edge!(graph, compute_v, data_v, false, false)
|
||||
|
||||
# combine with the v of the combined other input
|
||||
compute_S2 = insert_node!(graph, make_node(ComputeTaskS2()), false, false)
|
||||
data_out = insert_node!(graph, make_node(DataTask(10)), false, false)
|
||||
|
||||
insert_edge!(graph, make_edge(data_v, compute_S2), false, false)
|
||||
insert_edge!(graph, make_edge(dataOutNodes[in1], compute_S2), false, false)
|
||||
insert_edge!(graph, make_edge(compute_S2, data_out), false, false)
|
||||
insert_edge!(graph, data_v, compute_S2, false, false)
|
||||
insert_edge!(graph, dataOutNodes[in1], compute_S2, false, false)
|
||||
insert_edge!(graph, compute_S2, data_out, false, false)
|
||||
|
||||
insert_edge!(graph, make_edge(data_out, sum_node), false, false)
|
||||
insert_edge!(graph, data_out, sum_node, false, false)
|
||||
elseif occursin(regex_plus, node)
|
||||
if (verbose)
|
||||
println("\rReading Nodes Complete ")
|
||||
println("Added ", length(graph.nodes), " nodes")
|
||||
end
|
||||
else
|
||||
error("Unknown node '", node, "' while reading from file ", filename)
|
||||
@assert false ("Unknown node '$node' while reading from file $filename")
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -67,27 +67,66 @@ end
|
||||
is_entry_node(node::Node) = length(node.children) == 0
|
||||
is_exit_node(node::Node) = length(node.parents) == 0
|
||||
|
||||
# function to invalidate the operation caches for a given operation
|
||||
function invalidate_caches!(graph::DAG, operation::Operation)
|
||||
# function to invalidate the operation caches for a given NodeFusion
|
||||
function invalidate_caches!(graph::DAG, operation::NodeFusion)
|
||||
delete!(graph.possibleOperations, operation)
|
||||
|
||||
# delete the operation from all caches of nodes involved in the operation
|
||||
# (we can iterate over tuples and vectors just fine)
|
||||
filter!(!=(operation), operation.input[1].nodeFusions)
|
||||
filter!(!=(operation), operation.input[3].nodeFusions)
|
||||
|
||||
operation.input[2].nodeFusion = missing
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# function to invalidate the operation caches for a given NodeReduction
|
||||
function invalidate_caches!(graph::DAG, operation::NodeReduction)
|
||||
delete!(graph.possibleOperations, operation)
|
||||
|
||||
for node in operation.input
|
||||
filter!(!=(operation), node.operations)
|
||||
node.nodeReduction = missing
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# function to invalidate the operation caches for a given Node Split specifically
|
||||
# function to invalidate the operation caches for a given NodeSplit
|
||||
function invalidate_caches!(graph::DAG, operation::NodeSplit)
|
||||
delete!(graph.possibleOperations, operation)
|
||||
|
||||
# delete the operation from all caches of nodes involved in the operation
|
||||
# for node split there is only one node
|
||||
filter!(x -> x != operation, operation.input.operations)
|
||||
operation.input.nodeSplit = missing
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# function to invalidate the operation caches of a ComputeTaskNode
|
||||
function invalidate_operation_caches!(graph::DAG, node::ComputeTaskNode)
|
||||
if !ismissing(node.nodeReduction)
|
||||
invalidate_caches!(graph, node.nodeReduction)
|
||||
end
|
||||
if !ismissing(node.nodeSplit)
|
||||
invalidate_caches!(graph, node.nodeSplit)
|
||||
end
|
||||
while !isempty(node.nodeFusions)
|
||||
invalidate_caches!(graph, pop!(node.nodeFusions))
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
# function to invalidate the operation caches of a DataTaskNode
|
||||
function invalidate_operation_caches!(graph::DAG, node::DataTaskNode)
|
||||
if !ismissing(node.nodeReduction)
|
||||
invalidate_caches!(graph, node.nodeReduction)
|
||||
end
|
||||
if !ismissing(node.nodeSplit)
|
||||
invalidate_caches!(graph, node.nodeSplit)
|
||||
end
|
||||
if !ismissing(node.nodeFusion)
|
||||
invalidate_caches!(graph, node.nodeFusion)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
@ -110,93 +149,72 @@ function insert_node!(graph::DAG, node::Node, track=true, invalidate_cache=true)
|
||||
return node
|
||||
end
|
||||
|
||||
function insert_edge!(graph::DAG, edge::Edge, track=true, invalidate_cache=true)
|
||||
node1 = edge.edge[1]
|
||||
node2 = edge.edge[2]
|
||||
function insert_edge!(graph::DAG, node1::Node, node2::Node, track=true, invalidate_cache=true)
|
||||
# @assert (node2 ∉ node1.parents) && (node1 ∉ node2.children) "Edge to insert already exists"
|
||||
|
||||
# 1: mute
|
||||
#=if (node2 in node1.parents) || (node1 in node2.children)
|
||||
if !(node2 in node1.parents && node1 in node2.children)
|
||||
error("One-sided edge")
|
||||
end
|
||||
error("Edge to insert already exists")
|
||||
end=#
|
||||
|
||||
# edge points from child to parent
|
||||
push!(node1.parents, node2)
|
||||
push!(node2.children, node1)
|
||||
|
||||
# 2: keep track
|
||||
if (track) push!(graph.diff.addedEdges, edge) end
|
||||
if (track) push!(graph.diff.addedEdges, make_edge(node1, node2)) end
|
||||
|
||||
# 3: invalidate caches
|
||||
if (!invalidate_cache) return edge end
|
||||
if (!invalidate_cache) return nothing end
|
||||
|
||||
invalidate_operation_caches!(graph, node1)
|
||||
invalidate_operation_caches!(graph, node2)
|
||||
|
||||
while !isempty(node1.operations)
|
||||
invalidate_caches!(graph, first(node1.operations))
|
||||
end
|
||||
while !isempty(node2.operations)
|
||||
invalidate_caches!(graph, first(node2.operations))
|
||||
end
|
||||
push!(graph.dirtyNodes, node1)
|
||||
push!(graph.dirtyNodes, node2)
|
||||
|
||||
return edge
|
||||
return nothing
|
||||
end
|
||||
|
||||
function remove_node!(graph::DAG, node::Node, track=true, invalidate_cache=true)
|
||||
# @assert node in graph.nodes "Trying to remove a node that's not in the graph"
|
||||
|
||||
# 1: mute
|
||||
#=if !(node in graph.nodes)
|
||||
error("Trying to remove a node that's not in the graph")
|
||||
end=#
|
||||
delete!(graph.nodes, node)
|
||||
|
||||
# 2: keep track
|
||||
if (track) push!(graph.diff.removedNodes, node) end
|
||||
|
||||
# 3: invalidate caches
|
||||
if (!invalidate_cache) return node end
|
||||
if (!invalidate_cache) return nothing end
|
||||
|
||||
while !isempty(node.operations)
|
||||
invalidate_caches!(graph, first(node.operations))
|
||||
end
|
||||
invalidate_operation_caches!(graph, node)
|
||||
delete!(graph.dirtyNodes, node)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function remove_edge!(graph::DAG, edge::Edge, track=true, invalidate_cache=true)
|
||||
node1 = edge.edge[1]
|
||||
node2 = edge.edge[2]
|
||||
|
||||
function remove_edge!(graph::DAG, node1::Node, node2::Node, track=true, invalidate_cache=true)
|
||||
# 1: mute
|
||||
pre_length1 = length(node1.parents)
|
||||
pre_length2 = length(node2.children)
|
||||
filter!(x -> x != node2, node1.parents)
|
||||
filter!(x -> x != node1, node2.children)
|
||||
|
||||
#=removed = pre_length1 - length(node1.parents)
|
||||
if (removed > 1)
|
||||
error("removed $removed from node1's parents")
|
||||
end
|
||||
#=@assert begin
|
||||
removed = pre_length1 - length(node1.parents)
|
||||
removed <= 1
|
||||
end "removed more than one node from node1's parents"=#
|
||||
|
||||
removed = pre_length2 - length(node2.children)
|
||||
if (removed > 1)
|
||||
error("removed $removed from node2's children")
|
||||
end=#
|
||||
#=@assert begin
|
||||
removed = pre_length2 - length(node2.children)
|
||||
removed <= 1
|
||||
end "removed more than one node from node2's children"=#
|
||||
|
||||
# 2: keep track
|
||||
if (track) push!(graph.diff.removedEdges, edge) end
|
||||
if (track) push!(graph.diff.removedEdges, make_edge(node1, node2)) end
|
||||
|
||||
# 3: invalidate caches
|
||||
if (!invalidate_cache) return nothing end
|
||||
|
||||
while !isempty(node1.operations)
|
||||
invalidate_caches!(graph, first(node1.operations))
|
||||
end
|
||||
while !isempty(node2.operations)
|
||||
invalidate_caches!(graph, first(node2.operations))
|
||||
end
|
||||
invalidate_operation_caches!(graph, node1)
|
||||
invalidate_operation_caches!(graph, node2)
|
||||
if (node1 in graph)
|
||||
push!(graph.dirtyNodes, node1)
|
||||
end
|
||||
@ -241,7 +259,7 @@ function get_exit_node(graph::DAG)
|
||||
return node
|
||||
end
|
||||
end
|
||||
error("The given graph has no exit node! It is either empty or not acyclic!")
|
||||
@assert false "The given graph has no exit node! It is either empty or not acyclic!"
|
||||
end
|
||||
|
||||
# check whether the given graph is connected
|
||||
|
@ -46,5 +46,6 @@ function ==(n1::DataTaskNode, n2::DataTaskNode)
|
||||
return n1.id == n2.id
|
||||
end
|
||||
|
||||
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng[threadid()]), copy(n.operations))
|
||||
copy(n::DataTaskNode) = DataTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng[threadid()]), copy(n.operations))
|
||||
copy(m::Missing) = missing
|
||||
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng[threadid()]), copy(n.nodeReduction), copy(n.nodeSplit), copy(n.nodeFusions))
|
||||
copy(n::DataTaskNode) = DataTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng[threadid()]), copy(n.nodeReduction), copy(n.nodeSplit), copy(n.nodeFusion))
|
||||
|
27
src/nodes.jl
27
src/nodes.jl
@ -2,6 +2,7 @@ using Random
|
||||
using UUIDs
|
||||
using Base.Threads
|
||||
|
||||
# TODO: reliably find out how many threads we're running with (nthreads() returns 1 when precompiling :/)
|
||||
rng = [Random.MersenneTwister(0) for _ in 1:32]
|
||||
|
||||
abstract type Node end
|
||||
@ -10,7 +11,7 @@ abstract type Node end
|
||||
# the specific operations are declared in graph.jl
|
||||
abstract type Operation end
|
||||
|
||||
struct DataTaskNode <: Node
|
||||
mutable struct DataTaskNode <: Node
|
||||
task::AbstractDataTask
|
||||
|
||||
# use vectors as sets have way too much memory overhead
|
||||
@ -21,21 +22,33 @@ struct DataTaskNode <: Node
|
||||
# however, it can be copied when splitting a node
|
||||
id::Base.UUID
|
||||
|
||||
# a vector holding references to the graph operations involving this node
|
||||
operations::Vector{Operation}
|
||||
# the NodeReduction involving this node, if it exists
|
||||
# Can't use the NodeReduction type here because it's not yet defined
|
||||
nodeReduction::Union{Operation, Missing}
|
||||
|
||||
# the NodeSplit involving this node, if it exists
|
||||
nodeSplit::Union{Operation, Missing}
|
||||
|
||||
# the node fusion involving this node, if it exists
|
||||
nodeFusion::Union{Operation, Missing}
|
||||
end
|
||||
|
||||
# same as DataTaskNode
|
||||
struct ComputeTaskNode <: Node
|
||||
mutable struct ComputeTaskNode <: Node
|
||||
task::AbstractComputeTask
|
||||
parents::Vector{Node}
|
||||
children::Vector{Node}
|
||||
id::Base.UUID
|
||||
operations::Vector{Operation}
|
||||
|
||||
nodeReduction::Union{Operation, Missing}
|
||||
nodeSplit::Union{Operation, Missing}
|
||||
|
||||
# for ComputeTasks there can be multiple fusions, unlike the DataTasks
|
||||
nodeFusions::Vector{Operation}
|
||||
end
|
||||
|
||||
DataTaskNode(t::AbstractDataTask) = DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), Vector{Operation}())
|
||||
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), Vector{Operation}())
|
||||
DataTaskNode(t::AbstractDataTask) = DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, missing)
|
||||
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, Vector{NodeFusion}())
|
||||
|
||||
struct Edge
|
||||
# edge points from child to parent
|
||||
|
@ -2,237 +2,197 @@
|
||||
|
||||
# applies all unapplied operations in the DAG
|
||||
function apply_all!(graph::DAG)
|
||||
while !isempty(graph.operationsToApply)
|
||||
# get next operation to apply from front of the deque
|
||||
op = popfirst!(graph.operationsToApply)
|
||||
while !isempty(graph.operationsToApply)
|
||||
# get next operation to apply from front of the deque
|
||||
op = popfirst!(graph.operationsToApply)
|
||||
|
||||
# apply it
|
||||
appliedOp = apply_operation!(graph, op)
|
||||
# apply it
|
||||
appliedOp = apply_operation!(graph, op)
|
||||
|
||||
# push to the end of the appliedOperations deque
|
||||
push!(graph.appliedOperations, appliedOp)
|
||||
end
|
||||
return nothing
|
||||
# push to the end of the appliedOperations deque
|
||||
push!(graph.appliedOperations, appliedOp)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function apply_operation!(graph::DAG, operation::Operation)
|
||||
error("Unknown operation type!")
|
||||
error("Unknown operation type!")
|
||||
end
|
||||
|
||||
function apply_operation!(graph::DAG, operation::NodeFusion)
|
||||
diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3])
|
||||
return AppliedNodeFusion(operation, diff)
|
||||
diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3])
|
||||
return AppliedNodeFusion(operation, diff)
|
||||
end
|
||||
|
||||
function apply_operation!(graph::DAG, operation::NodeReduction)
|
||||
diff = node_reduction!(graph, operation.input)
|
||||
return AppliedNodeReduction(operation, diff)
|
||||
diff = node_reduction!(graph, operation.input)
|
||||
return AppliedNodeReduction(operation, diff)
|
||||
end
|
||||
|
||||
function apply_operation!(graph::DAG, operation::NodeSplit)
|
||||
diff = node_split!(graph, operation.input)
|
||||
return AppliedNodeSplit(operation, diff)
|
||||
diff = node_split!(graph, operation.input)
|
||||
return AppliedNodeSplit(operation, diff)
|
||||
end
|
||||
|
||||
|
||||
function revert_operation!(graph::DAG, operation::AppliedOperation)
|
||||
error("Unknown operation type!")
|
||||
error("Unknown operation type!")
|
||||
end
|
||||
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeFusion)
|
||||
revert_diff!(graph, operation.diff)
|
||||
return operation.operation
|
||||
revert_diff!(graph, operation.diff)
|
||||
return operation.operation
|
||||
end
|
||||
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeReduction)
|
||||
revert_diff!(graph, operation.diff)
|
||||
return operation.operation
|
||||
revert_diff!(graph, operation.diff)
|
||||
return operation.operation
|
||||
end
|
||||
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeSplit)
|
||||
revert_diff!(graph, operation.diff)
|
||||
return operation.operation
|
||||
revert_diff!(graph, operation.diff)
|
||||
return operation.operation
|
||||
end
|
||||
|
||||
|
||||
function revert_diff!(graph::DAG, diff)
|
||||
# add removed nodes, remove added nodes, same for edges
|
||||
# note the order
|
||||
for edge in diff.addedEdges
|
||||
remove_edge!(graph, edge, false)
|
||||
end
|
||||
for node in diff.addedNodes
|
||||
remove_node!(graph, node, false)
|
||||
end
|
||||
function revert_diff!(graph::DAG, diff::Diff)
|
||||
# add removed nodes, remove added nodes, same for edges
|
||||
# note the order
|
||||
for edge in diff.addedEdges
|
||||
remove_edge!(graph, edge.edge[1], edge.edge[2], false)
|
||||
end
|
||||
for node in diff.addedNodes
|
||||
remove_node!(graph, node, false)
|
||||
end
|
||||
|
||||
for node in diff.removedNodes
|
||||
insert_node!(graph, node, false)
|
||||
end
|
||||
for edge in diff.removedEdges
|
||||
insert_edge!(graph, edge, false)
|
||||
end
|
||||
for node in diff.removedNodes
|
||||
insert_node!(graph, node, false)
|
||||
end
|
||||
for edge in diff.removedEdges
|
||||
insert_edge!(graph, edge.edge[1], edge.edge[2], false)
|
||||
end
|
||||
end
|
||||
|
||||
# Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph
|
||||
function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
# @assert is_valid_node_fusion_input(graph, n1, n2, n3)
|
||||
|
||||
if !(n1 in graph) || !(n2 in graph) || !(n3 in graph)
|
||||
error("[Node Fusion] The given nodes are not part of the given graph")
|
||||
end
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
|
||||
# the checks are redundant but maybe a good sanity check
|
||||
error("[Node Fusion] The given nodes are not connected by edges which is required for node fusion")
|
||||
end
|
||||
|
||||
# save children and parents
|
||||
n1_children = children(n1)
|
||||
n3_parents = parents(n3)
|
||||
n3_children = children(n3)
|
||||
# save children and parents
|
||||
n1_children = children(n1)
|
||||
n3_parents = parents(n3)
|
||||
n3_children = children(n3)
|
||||
|
||||
if length(n2.parents) > 1
|
||||
error("[Node Fusion] The given data node has more than one parent")
|
||||
end
|
||||
if length(n2.children) > 1
|
||||
error("[Node Fusion] The given data node has more than one child")
|
||||
end
|
||||
if length(n1.parents) > 1
|
||||
error("[Node Fusion] The given n1 has more than one parent")
|
||||
end
|
||||
# remove the edges and nodes that will be replaced by the fused node
|
||||
remove_edge!(graph, n1, n2)
|
||||
remove_edge!(graph, n2, n3)
|
||||
remove_node!(graph, n1)
|
||||
remove_node!(graph, n2)
|
||||
|
||||
required_edge1 = make_edge(n1, n2)
|
||||
required_edge2 = make_edge(n2, n3)
|
||||
# get n3's children now so it automatically excludes n2
|
||||
n3_children = children(n3)
|
||||
remove_node!(graph, n3)
|
||||
|
||||
# remove the edges and nodes that will be replaced by the fused node
|
||||
remove_edge!(graph, required_edge1)
|
||||
remove_edge!(graph, required_edge2)
|
||||
remove_node!(graph, n1)
|
||||
remove_node!(graph, n2)
|
||||
# create new node with the fused compute task
|
||||
new_node = ComputeTaskNode(FusedComputeTask{typeof(n1.task),typeof(n3.task)}())
|
||||
insert_node!(graph, new_node)
|
||||
|
||||
# get n3's children now so it automatically excludes n2
|
||||
n3_children = children(n3)
|
||||
remove_node!(graph, n3)
|
||||
# use a set for combined children of n1 and n3 to not get duplicates
|
||||
n1and3_children = Set{Node}()
|
||||
|
||||
# create new node with the fused compute task
|
||||
new_node = ComputeTaskNode(FusedComputeTask{typeof(n1.task),typeof(n3.task)}())
|
||||
insert_node!(graph, new_node)
|
||||
# remove edges from n1 children to n1
|
||||
for child in n1_children
|
||||
remove_edge!(graph, child, n1)
|
||||
push!(n1and3_children, child)
|
||||
end
|
||||
|
||||
# use a set for combined children of n1 and n3 to not get duplicates
|
||||
n1and3_children = Set{Node}()
|
||||
# remove edges from n3 children to n3
|
||||
for child in n3_children
|
||||
remove_edge!(graph, child, n3)
|
||||
push!(n1and3_children, child)
|
||||
end
|
||||
|
||||
# remove edges from n1 children to n1
|
||||
for child in n1_children
|
||||
remove_edge!(graph, make_edge(child, n1))
|
||||
push!(n1and3_children, child)
|
||||
end
|
||||
for child in n1and3_children
|
||||
insert_edge!(graph, child, new_node)
|
||||
end
|
||||
|
||||
# remove edges from n3 children to n3
|
||||
for child in n3_children
|
||||
remove_edge!(graph, make_edge(child, n3))
|
||||
push!(n1and3_children, child)
|
||||
end
|
||||
# "repoint" parents of n3 from new node
|
||||
for parent in n3_parents
|
||||
remove_edge!(graph, n3, parent)
|
||||
insert_edge!(graph, new_node, parent)
|
||||
end
|
||||
|
||||
for child in n1and3_children
|
||||
insert_edge!(graph, make_edge(child, new_node))
|
||||
end
|
||||
|
||||
# "repoint" parents of n3 from new node
|
||||
for parent in n3_parents
|
||||
remove_edge!(graph, make_edge(n3, parent))
|
||||
insert_edge!(graph, make_edge(new_node, parent))
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
return get_snapshot_diff(graph)
|
||||
end
|
||||
|
||||
function node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
# @assert is_valid_node_reduction_input(graph, nodes)
|
||||
|
||||
t = typeof(nodes[1].task)
|
||||
for n in nodes
|
||||
if n ∉ graph
|
||||
error("[Node Reduction] The given nodes are not part of the given graph")
|
||||
end
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
if typeof(n.task) != t
|
||||
error("[Node Reduction] The given nodes are not of the same type")
|
||||
end
|
||||
end
|
||||
n1 = nodes[1]
|
||||
n1_children = children(n1)
|
||||
|
||||
n1_parents = Set(n1.parents)
|
||||
new_parents = Set{Node}()
|
||||
|
||||
# remove all of the nodes' parents and children and the nodes themselves (except for first node)
|
||||
for i in 2:length(nodes)
|
||||
n = nodes[i]
|
||||
for child in n1_children
|
||||
remove_edge!(graph, child, n)
|
||||
end
|
||||
|
||||
n1 = nodes[1]
|
||||
n1_children = children(n1)
|
||||
for n in nodes
|
||||
if Set(n1_children) != Set(n.children)
|
||||
error("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction")
|
||||
end
|
||||
end
|
||||
for parent in parents(n)
|
||||
remove_edge!(graph, n, parent)
|
||||
|
||||
n1_parents = Set(n1.parents)
|
||||
new_parents = Set{Node}()
|
||||
# collect all parents
|
||||
push!(new_parents, parent)
|
||||
end
|
||||
|
||||
# remove all of the nodes' parents and children and the nodes themselves (except for first node)
|
||||
for i in 2:length(nodes)
|
||||
n = nodes[i]
|
||||
for child in n1_children
|
||||
remove_edge!(graph, make_edge(child, n))
|
||||
end
|
||||
remove_node!(graph, n)
|
||||
end
|
||||
|
||||
for parent in parents(n)
|
||||
remove_edge!(graph, make_edge(n, parent))
|
||||
setdiff!(new_parents, n1_parents)
|
||||
|
||||
# collect all parents
|
||||
push!(new_parents, parent)
|
||||
end
|
||||
for parent in new_parents
|
||||
# now add parents of all input nodes to n1 without duplicates
|
||||
insert_edge!(graph, n1, parent)
|
||||
end
|
||||
|
||||
remove_node!(graph, n)
|
||||
end
|
||||
|
||||
setdiff!(new_parents, n1_parents)
|
||||
|
||||
for parent in new_parents
|
||||
# now add parents of n2 to n1 without duplicates
|
||||
insert_edge!(graph, make_edge(n1, parent))
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
return get_snapshot_diff(graph)
|
||||
end
|
||||
|
||||
function node_split!(graph::DAG, n1::Node)
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
# @assert is_valid_node_split_input(graph, n1)
|
||||
|
||||
#=if !(n1 in graph)
|
||||
error("[Node Split] The given node is not part of the given graph")
|
||||
end=#
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
n1_parents = parents(n1)
|
||||
n1_children = children(n1)
|
||||
n1_parents = parents(n1)
|
||||
n1_children = children(n1)
|
||||
|
||||
#=if length(n1_parents) <= 1
|
||||
error("[Node Split] The given node does not have multiple parents which is required for node split")
|
||||
end=#
|
||||
for parent in n1_parents
|
||||
remove_edge!(graph, n1, parent)
|
||||
end
|
||||
for child in n1_children
|
||||
remove_edge!(graph, child, n1)
|
||||
end
|
||||
remove_node!(graph, n1)
|
||||
|
||||
for parent in n1_parents
|
||||
remove_edge!(graph, make_edge(n1, parent))
|
||||
end
|
||||
for child in n1_children
|
||||
remove_edge!(graph, make_edge(child, n1))
|
||||
end
|
||||
remove_node!(graph, n1)
|
||||
for parent in n1_parents
|
||||
n_copy = copy(n1)
|
||||
insert_node!(graph, n_copy)
|
||||
insert_edge!(graph, n_copy, parent)
|
||||
|
||||
for parent in n1_parents
|
||||
n_copy = copy(n1)
|
||||
insert_node!(graph, n_copy)
|
||||
insert_edge!(graph, make_edge(n_copy, parent))
|
||||
for child in n1_children
|
||||
insert_edge!(graph, child, n_copy)
|
||||
end
|
||||
end
|
||||
|
||||
for child in n1_children
|
||||
insert_edge!(graph, make_edge(child, n_copy))
|
||||
end
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
return get_snapshot_diff(graph)
|
||||
end
|
||||
|
@ -4,10 +4,8 @@
|
||||
# pushes the found fusion everywhere it needs to be and returns nothing
|
||||
function find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
# if there is already a fusion here, skip
|
||||
for op in node.operations
|
||||
if typeof(op) <: NodeFusion
|
||||
return nothing
|
||||
end
|
||||
if !ismissing(node.nodeFusion)
|
||||
return nothing
|
||||
end
|
||||
|
||||
if length(node.parents) != 1 || length(node.children) != 1
|
||||
@ -17,9 +15,9 @@ function find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
child_node = first(node.children)
|
||||
parent_node = first(node.parents)
|
||||
|
||||
#=if !(child_node in graph) || !(parent_node in graph)
|
||||
if !(child_node in graph) || !(parent_node in graph)
|
||||
error("Parents/Children that are not in the graph!!!")
|
||||
end=#
|
||||
end
|
||||
|
||||
if length(child_node.parents) != 1
|
||||
return nothing
|
||||
@ -27,9 +25,9 @@ function find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
|
||||
nf = NodeFusion((child_node, node, parent_node))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(child_node.operations, nf)
|
||||
push!(node.operations, nf)
|
||||
push!(parent_node.operations, nf)
|
||||
push!(child_node.nodeFusions, nf)
|
||||
node.nodeFusion = nf
|
||||
push!(parent_node.nodeFusions, nf)
|
||||
|
||||
return nothing
|
||||
end
|
||||
@ -37,7 +35,6 @@ end
|
||||
|
||||
function find_fusions!(graph::DAG, node::ComputeTaskNode)
|
||||
# just find fusions in neighbouring DataTaskNodes
|
||||
|
||||
for child in node.children
|
||||
find_fusions!(graph, child)
|
||||
end
|
||||
@ -51,10 +48,8 @@ end
|
||||
|
||||
function find_reductions!(graph::DAG, node::Node)
|
||||
# there can only be one reduction per node, avoid adding duplicates
|
||||
for op in node.operations
|
||||
if typeof(op) <: NodeReduction
|
||||
return nothing
|
||||
end
|
||||
if !ismissing(node.nodeReduction)
|
||||
return nothing
|
||||
end
|
||||
|
||||
reductionVector = nothing
|
||||
@ -62,7 +57,14 @@ function find_reductions!(graph::DAG, node::Node)
|
||||
partners_ = partners(node)
|
||||
delete!(partners_, node)
|
||||
for partner in partners_
|
||||
if partner ∉ graph.nodes
|
||||
error("Partner is not part of the graph")
|
||||
end
|
||||
|
||||
if can_reduce(node, partner)
|
||||
if Set(node.children) != Set(partner.children)
|
||||
error("Not equal children")
|
||||
end
|
||||
if reductionVector === nothing
|
||||
# only when there's at least one reduction partner, insert the vector
|
||||
reductionVector = Vector{Node}()
|
||||
@ -77,7 +79,12 @@ function find_reductions!(graph::DAG, node::Node)
|
||||
nr = NodeReduction(reductionVector)
|
||||
push!(graph.possibleOperations.nodeReductions, nr)
|
||||
for node in reductionVector
|
||||
push!(node.operations, nr)
|
||||
if !ismissing(node.nodeReduction)
|
||||
# it can happen that the dirty node becomes part of an existing NodeReduction and overrides those ones now
|
||||
# this is only a problem insofar the existing NodeReduction has to be deleted and replaced also in the possibleOperations
|
||||
invalidate_caches!(graph, node.nodeReduction)
|
||||
end
|
||||
node.nodeReduction = nr
|
||||
end
|
||||
end
|
||||
|
||||
@ -85,10 +92,14 @@ function find_reductions!(graph::DAG, node::Node)
|
||||
end
|
||||
|
||||
function find_splits!(graph::DAG, node::Node)
|
||||
if !ismissing(node.nodeSplit)
|
||||
return nothing
|
||||
end
|
||||
|
||||
if (can_split(node))
|
||||
ns = NodeSplit(node)
|
||||
push!(graph.possibleOperations.nodeSplits, ns)
|
||||
push!(node.operations, ns)
|
||||
node.nodeSplit = ns
|
||||
end
|
||||
|
||||
return nothing
|
||||
|
@ -2,49 +2,28 @@
|
||||
|
||||
using Base.Threads
|
||||
|
||||
function insert_operation!(operations::PossibleOperations, nf::NodeFusion, locks::Dict{Node, SpinLock})
|
||||
function insert_operation!(nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock})
|
||||
n1 = nf.input[1]; n2 = nf.input[2]; n3 = nf.input[3]
|
||||
|
||||
lock(locks[n1]) do; push!(nf.input[1].operations, nf); end
|
||||
lock(locks[n2]) do; push!(nf.input[2].operations, nf); end
|
||||
lock(locks[n3]) do; push!(nf.input[3].operations, nf); end
|
||||
lock(locks[n1]) do; push!(nf.input[1].nodeFusions, nf); end
|
||||
nf.input[2].nodeFusion = nf
|
||||
lock(locks[n3]) do; push!(nf.input[3].nodeFusions, nf); end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function insert_operation!(operations::PossibleOperations, nr::NodeReduction, locks::Dict{Node, SpinLock})
|
||||
# since node parents were sorted before, the NodeReductions contain elements in a known order
|
||||
# this, together with the locking, means that we can safely do the following without inserting duplicates
|
||||
first = true
|
||||
function insert_operation!(nr::NodeReduction)
|
||||
for n in nr.input
|
||||
skip_duplicate = false
|
||||
# careful here, this is a manual lock (because of the break)
|
||||
lock(locks[n])
|
||||
if first
|
||||
first = false
|
||||
for op in n.operations
|
||||
if typeof(op) <: NodeReduction
|
||||
skip_duplicate = true
|
||||
break
|
||||
end
|
||||
end
|
||||
if skip_duplicate
|
||||
unlock(locks[n])
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
push!(n.operations, nr)
|
||||
unlock(locks[n])
|
||||
n.nodeReduction = nr
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function insert_operation!(operations::PossibleOperations, ns::NodeSplit, locks::Dict{Node, SpinLock})
|
||||
lock(locks[ns.input]) do; push!(ns.input.operations, ns); end
|
||||
function insert_operation!(ns::NodeSplit)
|
||||
ns.input.nodeSplit = ns
|
||||
return nothing
|
||||
end
|
||||
|
||||
function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}}, locks::Dict{Node, SpinLock})
|
||||
function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}})
|
||||
total_len = 0
|
||||
for vec in nodeReductions
|
||||
total_len += length(vec)
|
||||
@ -58,7 +37,7 @@ function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Ve
|
||||
|
||||
@threads for vec in nodeReductions
|
||||
for op in vec
|
||||
insert_operation!(operations, op, locks)
|
||||
insert_operation!(op)
|
||||
end
|
||||
end
|
||||
|
||||
@ -67,7 +46,7 @@ function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Ve
|
||||
return nothing
|
||||
end
|
||||
|
||||
function nf_insertion!(operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}}, locks::Dict{Node, SpinLock})
|
||||
function nf_insertion!(graph::DAG, operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}})
|
||||
total_len = 0
|
||||
for vec in nodeFusions
|
||||
total_len += length(vec)
|
||||
@ -79,9 +58,16 @@ function nf_insertion!(operations::PossibleOperations, nodeFusions::Vector{Vecto
|
||||
end
|
||||
schedule(t)
|
||||
|
||||
locks = Dict{ComputeTaskNode, SpinLock}()
|
||||
for n in graph.nodes
|
||||
if (typeof(n) <: ComputeTaskNode)
|
||||
locks[n] = SpinLock()
|
||||
end
|
||||
end
|
||||
|
||||
@threads for vec in nodeFusions
|
||||
for op in vec
|
||||
insert_operation!(operations, op, locks)
|
||||
insert_operation!(op, locks)
|
||||
end
|
||||
end
|
||||
|
||||
@ -90,7 +76,7 @@ function nf_insertion!(operations::PossibleOperations, nodeFusions::Vector{Vecto
|
||||
return nothing
|
||||
end
|
||||
|
||||
function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}}, locks::Dict{Node, SpinLock})
|
||||
function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}})
|
||||
total_len = 0
|
||||
for vec in nodeSplits
|
||||
total_len += length(vec)
|
||||
@ -104,7 +90,7 @@ function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector
|
||||
|
||||
@threads for vec in nodeSplits
|
||||
for op in vec
|
||||
insert_operation!(operations, op, locks)
|
||||
insert_operation!(op)
|
||||
end
|
||||
end
|
||||
|
||||
@ -115,11 +101,6 @@ end
|
||||
|
||||
# function to generate all possible operations on the graph
|
||||
function generate_options(graph::DAG)
|
||||
locks = Dict{Node, SpinLock}()
|
||||
for n in graph.nodes
|
||||
locks[n] = SpinLock()
|
||||
end
|
||||
|
||||
generatedFusions = [Vector{NodeFusion}() for _ in 1:nthreads()]
|
||||
generatedReductions = [Vector{NodeReduction}() for _ in 1:nthreads()]
|
||||
generatedSplits = [Vector{NodeSplit}() for _ in 1:nthreads()]
|
||||
@ -174,7 +155,7 @@ function generate_options(graph::DAG)
|
||||
|
||||
# launch thread for node reduction insertion
|
||||
# remove duplicates
|
||||
nr_task = @task nr_insertion!(graph.possibleOperations, generatedReductions, locks)
|
||||
nr_task = @task nr_insertion!(graph.possibleOperations, generatedReductions)
|
||||
schedule(nr_task)
|
||||
|
||||
# --- find possible node fusions ---
|
||||
@ -200,7 +181,7 @@ function generate_options(graph::DAG)
|
||||
end
|
||||
|
||||
# launch thread for node fusion insertion
|
||||
nf_task = @task nf_insertion!(graph.possibleOperations, generatedFusions, locks)
|
||||
nf_task = @task nf_insertion!(graph, graph.possibleOperations, generatedFusions)
|
||||
schedule(nf_task)
|
||||
|
||||
# find possible node splits
|
||||
@ -211,7 +192,7 @@ function generate_options(graph::DAG)
|
||||
end
|
||||
|
||||
# launch thread for node split insertion
|
||||
ns_task = @task ns_insertion!(graph.possibleOperations, generatedSplits, locks)
|
||||
ns_task = @task ns_insertion!(graph.possibleOperations, generatedSplits)
|
||||
schedule(ns_task)
|
||||
|
||||
empty!(graph.dirtyNodes)
|
||||
|
61
src/operations/validate.jl
Normal file
61
src/operations/validate.jl
Normal file
@ -0,0 +1,61 @@
|
||||
# functions to throw assertion errors for inconsistent or wrong node operations
|
||||
# should be called with @assert
|
||||
# the functions throw their own errors though, to still have helpful error messages
|
||||
|
||||
function is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
if !(n1 in graph) || !(n2 in graph) || !(n3 in graph)
|
||||
throw(AssertionError("[Node Fusion] The given nodes are not part of the given graph"))
|
||||
end
|
||||
|
||||
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
|
||||
throw(AssertionError("[Node Fusion] The given nodes are not connected by edges which is required for node fusion"))
|
||||
end
|
||||
|
||||
if length(n2.parents) > 1
|
||||
throw(AssertionError("[Node Fusion] The given data node has more than one parent"))
|
||||
end
|
||||
if length(n2.children) > 1
|
||||
throw(AssertionError("[Node Fusion] The given data node has more than one child"))
|
||||
end
|
||||
if length(n1.parents) > 1
|
||||
throw(AssertionError("[Node Fusion] The given n1 has more than one parent"))
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node})
|
||||
for n in nodes
|
||||
if n ∉ graph
|
||||
throw(AssertionError("[Node Reduction] The given nodes are not part of the given graph"))
|
||||
end
|
||||
end
|
||||
|
||||
t = typeof(nodes[1].task)
|
||||
for n in nodes
|
||||
if typeof(n.task) != t
|
||||
throw(AssertionError("[Node Reduction] The given nodes are not of the same type"))
|
||||
end
|
||||
end
|
||||
|
||||
n1_children = nodes[1].children
|
||||
for n in nodes
|
||||
if Set(n1_children) != Set(n.children)
|
||||
throw(AssertionError("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction"))
|
||||
end
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function is_valid_node_split_input(graph::DAG, n1::Node)
|
||||
if n1 ∉ graph
|
||||
throw(AssertionError("[Node Split] The given node is not part of the given graph"))
|
||||
end
|
||||
|
||||
if length(n1.parents) <= 1
|
||||
throw(AssertionError("[Node Split] The given node does not have multiple parents which is required for node split"))
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
@ -1,7 +1,6 @@
|
||||
import MetagraphOptimization.insert_node!
|
||||
import MetagraphOptimization.insert_edge!
|
||||
import MetagraphOptimization.make_node
|
||||
import MetagraphOptimization.make_edge
|
||||
|
||||
@testset "Unit Tests Node Reduction" begin
|
||||
graph = MetagraphOptimization.DAG()
|
||||
@ -30,27 +29,27 @@ import MetagraphOptimization.make_edge
|
||||
BD = insert_node!(graph, make_node(DataTask(5)), false)
|
||||
CD = insert_node!(graph, make_node(DataTask(5)), false)
|
||||
|
||||
insert_edge!(graph, make_edge(s0, d_exit), false)
|
||||
insert_edge!(graph, make_edge(ED, s0), false)
|
||||
insert_edge!(graph, make_edge(FD, s0), false)
|
||||
insert_edge!(graph, make_edge(EC, ED), false)
|
||||
insert_edge!(graph, make_edge(FC, FD), false)
|
||||
insert_edge!(graph, s0, d_exit, false)
|
||||
insert_edge!(graph, ED, s0, false)
|
||||
insert_edge!(graph, FD, s0, false)
|
||||
insert_edge!(graph, EC, ED, false)
|
||||
insert_edge!(graph, FC, FD, false)
|
||||
|
||||
insert_edge!(graph, make_edge(A1D, EC), false)
|
||||
insert_edge!(graph, make_edge(B1D_1, EC), false)
|
||||
insert_edge!(graph, A1D, EC, false)
|
||||
insert_edge!(graph, B1D_1, EC, false)
|
||||
|
||||
insert_edge!(graph, make_edge(B1D_2, FC), false)
|
||||
insert_edge!(graph, make_edge(C1D, FC), false)
|
||||
insert_edge!(graph, B1D_2, FC, false)
|
||||
insert_edge!(graph, C1D, FC, false)
|
||||
|
||||
insert_edge!(graph, make_edge(A1C, A1D), false)
|
||||
insert_edge!(graph, make_edge(B1C_1, B1D_1), false)
|
||||
insert_edge!(graph, make_edge(B1C_2, B1D_2), false)
|
||||
insert_edge!(graph, make_edge(C1C, C1D), false)
|
||||
insert_edge!(graph, A1C, A1D, false)
|
||||
insert_edge!(graph, B1C_1, B1D_1, false)
|
||||
insert_edge!(graph, B1C_2, B1D_2, false)
|
||||
insert_edge!(graph, C1C, C1D, false)
|
||||
|
||||
insert_edge!(graph, make_edge(AD, A1C), false)
|
||||
insert_edge!(graph, make_edge(BD, B1C_1), false)
|
||||
insert_edge!(graph, make_edge(BD, B1C_2), false)
|
||||
insert_edge!(graph, make_edge(CD, C1C), false)
|
||||
insert_edge!(graph, AD, A1C, false)
|
||||
insert_edge!(graph, BD, B1C_1, false)
|
||||
insert_edge!(graph, BD, B1C_2, false)
|
||||
insert_edge!(graph, CD, C1C, false)
|
||||
|
||||
@test is_exit_node(d_exit)
|
||||
@test is_entry_node(AD)
|
||||
|
@ -1,7 +1,6 @@
|
||||
import MetagraphOptimization.insert_node!
|
||||
import MetagraphOptimization.insert_edge!
|
||||
import MetagraphOptimization.make_node
|
||||
import MetagraphOptimization.make_edge
|
||||
import MetagraphOptimization.siblings
|
||||
import MetagraphOptimization.partners
|
||||
|
||||
@ -69,38 +68,38 @@ import MetagraphOptimization.partners
|
||||
@test length(graph.dirtyNodes) == 26
|
||||
|
||||
# now for all the edgese
|
||||
insert_edge!(graph, make_edge(d_PB, PB), false)
|
||||
insert_edge!(graph, make_edge(d_PA, PA), false)
|
||||
insert_edge!(graph, make_edge(d_PBp, PBp), false)
|
||||
insert_edge!(graph, make_edge(d_PAp, PAp), false)
|
||||
insert_edge!(graph, d_PB, PB, false)
|
||||
insert_edge!(graph, d_PA, PA, false)
|
||||
insert_edge!(graph, d_PBp, PBp, false)
|
||||
insert_edge!(graph, d_PAp, PAp, false)
|
||||
|
||||
insert_edge!(graph, make_edge(PB, d_PB_uB), false)
|
||||
insert_edge!(graph, make_edge(PA, d_PA_uA), false)
|
||||
insert_edge!(graph, make_edge(PBp, d_PBp_uBp), false)
|
||||
insert_edge!(graph, make_edge(PAp, d_PAp_uAp), false)
|
||||
insert_edge!(graph, PB, d_PB_uB, false)
|
||||
insert_edge!(graph, PA, d_PA_uA, false)
|
||||
insert_edge!(graph, PBp, d_PBp_uBp, false)
|
||||
insert_edge!(graph, PAp, d_PAp_uAp, false)
|
||||
|
||||
insert_edge!(graph, make_edge(d_PB_uB, uB), false)
|
||||
insert_edge!(graph, make_edge(d_PA_uA, uA), false)
|
||||
insert_edge!(graph, make_edge(d_PBp_uBp, uBp), false)
|
||||
insert_edge!(graph, make_edge(d_PAp_uAp, uAp), false)
|
||||
insert_edge!(graph, d_PB_uB, uB, false)
|
||||
insert_edge!(graph, d_PA_uA, uA, false)
|
||||
insert_edge!(graph, d_PBp_uBp, uBp, false)
|
||||
insert_edge!(graph, d_PAp_uAp, uAp, false)
|
||||
|
||||
insert_edge!(graph, make_edge(uB, d_uB_v0), false)
|
||||
insert_edge!(graph, make_edge(uA, d_uA_v0), false)
|
||||
insert_edge!(graph, make_edge(uBp, d_uBp_v1), false)
|
||||
insert_edge!(graph, make_edge(uAp, d_uAp_v1), false)
|
||||
insert_edge!(graph, uB, d_uB_v0, false)
|
||||
insert_edge!(graph, uA, d_uA_v0, false)
|
||||
insert_edge!(graph, uBp, d_uBp_v1, false)
|
||||
insert_edge!(graph, uAp, d_uAp_v1, false)
|
||||
|
||||
insert_edge!(graph, make_edge(d_uB_v0, v0), false)
|
||||
insert_edge!(graph, make_edge(d_uA_v0, v0), false)
|
||||
insert_edge!(graph, make_edge(d_uBp_v1, v1), false)
|
||||
insert_edge!(graph, make_edge(d_uAp_v1, v1), false)
|
||||
insert_edge!(graph, d_uB_v0, v0, false)
|
||||
insert_edge!(graph, d_uA_v0, v0, false)
|
||||
insert_edge!(graph, d_uBp_v1, v1, false)
|
||||
insert_edge!(graph, d_uAp_v1, v1, false)
|
||||
|
||||
insert_edge!(graph, make_edge(v0, d_v0_s0), false)
|
||||
insert_edge!(graph, make_edge(v1, d_v1_s0), false)
|
||||
insert_edge!(graph, v0, d_v0_s0, false)
|
||||
insert_edge!(graph, v1, d_v1_s0, false)
|
||||
|
||||
insert_edge!(graph, make_edge(d_v0_s0, s0), false)
|
||||
insert_edge!(graph, make_edge(d_v1_s0, s0), false)
|
||||
insert_edge!(graph, d_v0_s0, s0, false)
|
||||
insert_edge!(graph, d_v1_s0, s0, false)
|
||||
|
||||
insert_edge!(graph, make_edge(s0, d_exit), false)
|
||||
insert_edge!(graph, s0, d_exit, false)
|
||||
|
||||
@test length(graph.nodes) == 26
|
||||
@test length(graph.appliedOperations) == 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user