Merge pull request 'Performance Improvements and Multi-Threading' (#2) from performance into main
Reviewed-on: Rubydragon/MetagraphOptimization.jl#2
This commit is contained in:
commit
569949d5c7
@ -27,7 +27,7 @@ jobs:
|
||||
run: julia --project -e 'import Pkg; Pkg.instantiate()'
|
||||
|
||||
- name: Run tests
|
||||
run: julia --project -e 'import Pkg; Pkg.test()'
|
||||
run: julia --project -t 4 -e 'import Pkg; Pkg.test()'
|
||||
|
||||
- name: Run examples
|
||||
run: julia --project=examples/ -e 'import Pkg; Pkg.develop(Pkg.PackageSpec(path=pwd())); Pkg.instantiate(); include("examples/import_bench.jl")'
|
||||
run: julia --project=examples/ -t 4 -e 'import Pkg; Pkg.develop(Pkg.PackageSpec(path=pwd())); Pkg.instantiate(); include("examples/import_bench.jl")'
|
||||
|
@ -4,6 +4,8 @@ Directed Acyclic Graph optimization for QED
|
||||
|
||||
## Usage
|
||||
|
||||
For all the julia calls, use `-t n` to give julia `n` threads.
|
||||
|
||||
Instantiate the project first:
|
||||
|
||||
`julia --project -e 'import Pkg; Pkg.instantiate()'`
|
||||
|
@ -3,3 +3,4 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
|
||||
MetagraphOptimization = "3e869610-d48d-4942-ba70-c1b702a33ca4"
|
||||
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
|
||||
ProfileView = "c46f51b8-102a-5cf2-8d2c-8597cb0e0da7"
|
||||
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
|
||||
|
@ -16,12 +16,15 @@ function bench_txt(filepath::String, bench::Bool = true)
|
||||
println(name, ":")
|
||||
g = parse_abc(filepath)
|
||||
print(g)
|
||||
println(" Graph size in memory: ", bytes_to_human_readable(Base.summarysize(g)))
|
||||
#println(" Graph size in memory: ", bytes_to_human_readable(Base.summarysize(g)))
|
||||
|
||||
if (bench)
|
||||
@btime parse_abc($filepath)
|
||||
println()
|
||||
end
|
||||
|
||||
println(" Get Operations: ")
|
||||
@time get_operations(g)
|
||||
println()
|
||||
end
|
||||
|
||||
function import_bench()
|
||||
@ -29,7 +32,7 @@ function import_bench()
|
||||
bench_txt("AB->ABBB.txt")
|
||||
bench_txt("AB->ABBBBB.txt")
|
||||
bench_txt("AB->ABBBBBBB.txt")
|
||||
#bench_txt("AB->ABBBBBBBBB.txt", false)
|
||||
#bench_txt("AB->ABBBBBBBBB.txt")
|
||||
bench_txt("ABAB->ABAB.txt")
|
||||
bench_txt("ABAB->ABC.txt")
|
||||
end
|
||||
|
164
results/FWKHIP8999
Normal file
164
results/FWKHIP8999
Normal file
@ -0,0 +1,164 @@
|
||||
Commit Hash: a7fb15c95b63eee40eb7b9324d83b748053c5e13
|
||||
|
||||
Run with 32 Threads
|
||||
|
||||
AB->AB:
|
||||
Graph:
|
||||
Nodes: Total: 34, ComputeTaskS2: 2, ComputeTaskU: 4,
|
||||
ComputeTaskSum: 1, ComputeTaskV: 4, ComputeTaskP: 4,
|
||||
DataTask: 19
|
||||
Edges: 37
|
||||
Total Compute Effort: 185
|
||||
Total Data Transfer: 104
|
||||
Total Compute Intensity: 1.7788461538461537
|
||||
28.171 μs (515 allocations: 52.06 KiB)
|
||||
Get Operations:
|
||||
Sorting...
|
||||
0.218136 seconds (155.59 k allocations: 10.433 MiB, 3.34% gc time, 3175.93% compilation time)
|
||||
Node Reductions...
|
||||
0.299127 seconds (257.04 k allocations: 16.853 MiB, 2827.94% compilation time)
|
||||
Node Fusions...
|
||||
0.046983 seconds (16.70 k allocations: 1.120 MiB, 3048.15% compilation time)
|
||||
Node Splits...
|
||||
0.033681 seconds (14.09 k allocations: 958.144 KiB, 3166.45% compilation time)
|
||||
Waiting...
|
||||
0.000001 seconds
|
||||
1.096006 seconds (581.46 k allocations: 38.180 MiB, 0.66% gc time, 1677.26% compilation time)
|
||||
rvim umount
|
||||
AB->ABBB:
|
||||
Graph:
|
||||
Nodes: Total: 280, ComputeTaskS2: 24, ComputeTaskU: 6,
|
||||
ComputeTaskV: 64, ComputeTaskSum: 1, ComputeTaskP: 6,
|
||||
ComputeTaskS1: 36, DataTask: 143
|
||||
Edges: 385
|
||||
Total Compute Effort: 2007
|
||||
Total Data Transfer: 1176
|
||||
Total Compute Intensity: 1.7066326530612246
|
||||
207.236 μs (4324 allocations: 296.87 KiB)
|
||||
Get Operations:
|
||||
Sorting...
|
||||
0.000120 seconds (167 allocations: 16.750 KiB)
|
||||
Node Reductions...
|
||||
0.000550 seconds (1.98 k allocations: 351.234 KiB)
|
||||
Node Fusions...
|
||||
0.000168 seconds (417 allocations: 83.797 KiB)
|
||||
Node Splits...
|
||||
0.000150 seconds (478 allocations: 36.406 KiB)
|
||||
Waiting...
|
||||
0.000000 seconds
|
||||
0.039897 seconds (16.19 k allocations: 1.440 MiB, 95.31% compilation time)
|
||||
|
||||
AB->ABBBBB:
|
||||
Graph:
|
||||
Nodes: Total: 7854, ComputeTaskS2: 720, ComputeTaskU: 8,
|
||||
ComputeTaskV: 1956, ComputeTaskSum: 1, ComputeTaskP: 8,
|
||||
ComputeTaskS1: 1230, DataTask: 3931
|
||||
Edges: 11241
|
||||
Total Compute Effort: 58789
|
||||
Total Data Transfer: 34826
|
||||
Total Compute Intensity: 1.6880778728536152
|
||||
5.787 ms (121839 allocations: 7.72 MiB)
|
||||
Get Operations:
|
||||
Sorting...
|
||||
0.000499 seconds (175 allocations: 17.000 KiB)
|
||||
Node Reductions...
|
||||
0.002126 seconds (45.76 k allocations: 4.477 MiB)
|
||||
Node Fusions...
|
||||
0.000949 seconds (7.09 k allocations: 1.730 MiB)
|
||||
Node Splits...
|
||||
0.000423 seconds (8.06 k allocations: 544.031 KiB)
|
||||
Waiting...
|
||||
0.000000 seconds
|
||||
0.015005 seconds (100.12 k allocations: 13.161 MiB)
|
||||
|
||||
AB->ABBBBBBB:
|
||||
Graph:
|
||||
Nodes: Total: 438436, ComputeTaskS2: 40320, ComputeTaskU: 10,
|
||||
ComputeTaskV: 109600, ComputeTaskSum: 1, ComputeTaskP: 10,
|
||||
ComputeTaskS1: 69272, DataTask: 219223
|
||||
Edges: 628665
|
||||
Total Compute Effort: 3288131
|
||||
Total Data Transfer: 1949004
|
||||
Total Compute Intensity: 1.687082735592128
|
||||
1.309 s (6826397 allocations: 430.63 MiB)
|
||||
Get Operations:
|
||||
Sorting...
|
||||
0.011898 seconds (197 allocations: 17.688 KiB)
|
||||
Node Reductions...
|
||||
0.110569 seconds (2.78 M allocations: 225.675 MiB)
|
||||
Node Fusions...
|
||||
0.022475 seconds (380.91 k allocations: 108.982 MiB)
|
||||
Node Splits...
|
||||
0.011369 seconds (438.80 k allocations: 28.743 MiB)
|
||||
Waiting...
|
||||
0.000001 seconds
|
||||
2.503065 seconds (5.77 M allocations: 683.968 MiB, 48.27% gc time)
|
||||
|
||||
AB->ABBBBBBBBB:
|
||||
Graph:
|
||||
Nodes: Total: 39456442, ComputeTaskS2: 3628800, ComputeTaskU: 12,
|
||||
ComputeTaskV: 9864100, ComputeTaskSum: 1, ComputeTaskP: 12,
|
||||
ComputeTaskS1: 6235290, DataTask: 19728227
|
||||
Edges: 56578129
|
||||
Total Compute Effort: 295923153
|
||||
Total Data Transfer: 175407750
|
||||
Total Compute Intensity: 1.6870585991782006
|
||||
389.495 s (626095682 allocations: 37.80 GiB)
|
||||
Get Operations:
|
||||
Sorting...
|
||||
1.181713 seconds (197 allocations: 17.688 KiB)
|
||||
Node Reductions...
|
||||
10.057358 seconds (251.09 M allocations: 19.927 GiB)
|
||||
Node Fusions...
|
||||
1.288635 seconds (34.24 M allocations: 6.095 GiB)
|
||||
Node Splits...
|
||||
0.719345 seconds (39.46 M allocations: 2.522 GiB)
|
||||
Waiting...
|
||||
0.000001 seconds
|
||||
904.138951 seconds (519.47 M allocations: 54.494 GiB, 25.03% gc time)
|
||||
|
||||
ABAB->ABAB:
|
||||
Graph:
|
||||
Nodes: Total: 3218, ComputeTaskS2: 288, ComputeTaskU: 8,
|
||||
ComputeTaskV: 796, ComputeTaskSum: 1, ComputeTaskP: 8,
|
||||
ComputeTaskS1: 504, DataTask: 1613
|
||||
Edges: 4581
|
||||
Total Compute Effort: 24009
|
||||
Total Data Transfer: 14144
|
||||
Total Compute Intensity: 1.697468891402715
|
||||
2.691 ms (49557 allocations: 3.17 MiB)
|
||||
Get Operations:
|
||||
Sorting...
|
||||
0.000246 seconds (171 allocations: 16.875 KiB)
|
||||
Node Reductions...
|
||||
0.001037 seconds (19.42 k allocations: 1.751 MiB)
|
||||
Node Fusions...
|
||||
0.001512 seconds (3.04 k allocations: 1.027 MiB)
|
||||
Node Splits...
|
||||
0.000197 seconds (3.41 k allocations: 231.078 KiB)
|
||||
Waiting...
|
||||
0.000000 seconds
|
||||
0.007492 seconds (42.20 k allocations: 5.399 MiB)
|
||||
|
||||
ABAB->ABC:
|
||||
Graph:
|
||||
Nodes: Total: 817, ComputeTaskS2: 72, ComputeTaskU: 7,
|
||||
ComputeTaskV: 198, ComputeTaskSum: 1, ComputeTaskP: 7,
|
||||
ComputeTaskS1: 120, DataTask: 412
|
||||
Edges: 1151
|
||||
Total Compute Effort: 6028
|
||||
Total Data Transfer: 3538
|
||||
Total Compute Intensity: 1.7037874505370265
|
||||
602.767 μs (12544 allocations: 843.16 KiB)
|
||||
Get Operations:
|
||||
Sorting...
|
||||
0.000127 seconds (171 allocations: 16.875 KiB)
|
||||
Node Reductions...
|
||||
0.000440 seconds (5.33 k allocations: 494.047 KiB)
|
||||
Node Fusions...
|
||||
0.001761 seconds (939 allocations: 280.797 KiB)
|
||||
Node Splits...
|
||||
0.000123 seconds (1.00 k allocations: 72.109 KiB)
|
||||
Waiting...
|
||||
0.000000 seconds
|
||||
0.003831 seconds (11.74 k allocations: 1.451 MiB)
|
25
scripts/bench_threads.fish
Executable file
25
scripts/bench_threads.fish
Executable file
@ -0,0 +1,25 @@
|
||||
#!/bin/fish
|
||||
set minthreads 1
|
||||
set maxthreads 8
|
||||
|
||||
julia --project=./examples -t 4 -e 'import Pkg; Pkg.instantiate()'
|
||||
|
||||
#for i in $(seq $minthreads $maxthreads)
|
||||
# printf "(AB->AB, $i) "
|
||||
# julia --project=./examples -t $i -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("examples/AB->AB.txt"))'
|
||||
#end
|
||||
|
||||
#for i in $(seq $minthreads $maxthreads)
|
||||
# printf "(AB->ABBB, $i) "
|
||||
# julia --project=./examples -t $i -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("examples/AB->ABBB.txt"))'
|
||||
#end
|
||||
|
||||
#for i in $(seq $minthreads $maxthreads)
|
||||
# printf "(AB->ABBBBB, $i) "
|
||||
# julia --project=./examples -t $i -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("examples/AB->ABBBBB.txt"))'
|
||||
#end
|
||||
|
||||
for i in $(seq $minthreads $maxthreads)
|
||||
printf "(AB->ABBBBBBB, $i) "
|
||||
julia --project=./examples -t $i -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("examples/AB->ABBBBBBB.txt"))'
|
||||
end
|
@ -17,17 +17,28 @@ import Base.in
|
||||
import Base.copy
|
||||
import Base.isempty
|
||||
import Base.delete!
|
||||
import Base.insert!
|
||||
import Base.collect
|
||||
|
||||
|
||||
include("tasks.jl")
|
||||
include("nodes.jl")
|
||||
include("graph.jl")
|
||||
|
||||
include("trie.jl")
|
||||
include("utility.jl")
|
||||
|
||||
include("task_functions.jl")
|
||||
include("node_functions.jl")
|
||||
include("graph_functions.jl")
|
||||
include("graph_operations.jl")
|
||||
include("utility.jl")
|
||||
|
||||
include("operations/utility.jl")
|
||||
include("operations/apply.jl")
|
||||
include("operations/clean.jl")
|
||||
include("operations/find.jl")
|
||||
include("operations/get.jl")
|
||||
|
||||
include("graph_interface.jl")
|
||||
|
||||
include("abc_model/tasks.jl")
|
||||
include("abc_model/task_functions.jl")
|
||||
|
@ -40,9 +40,9 @@ function parse_abc(filename::String, verbose::Bool = false)
|
||||
if (verbose) println("Estimating ", estimate_no_nodes, " Nodes") end
|
||||
sizehint!(graph.nodes, estimate_no_nodes)
|
||||
|
||||
sum_node = insert_node!(graph, make_node(ComputeTaskSum()), false)
|
||||
global_data_out = insert_node!(graph, make_node(DataTask(10)), false)
|
||||
insert_edge!(graph, make_edge(sum_node, global_data_out), false)
|
||||
sum_node = insert_node!(graph, make_node(ComputeTaskSum()), false, false)
|
||||
global_data_out = insert_node!(graph, make_node(DataTask(10)), false, false)
|
||||
insert_edge!(graph, make_edge(sum_node, global_data_out), false, false)
|
||||
|
||||
# remember the data out nodes for connection
|
||||
dataOutNodes = Dict()
|
||||
@ -58,16 +58,16 @@ function parse_abc(filename::String, verbose::Bool = false)
|
||||
end
|
||||
if occursin(regex_a, node)
|
||||
# add nodes and edges for the state reading to u(P(Particle))
|
||||
data_in = insert_node!(graph, make_node(DataTask(4)), false) # read particle data node
|
||||
compute_P = insert_node!(graph, make_node(ComputeTaskP()), false) # compute P node
|
||||
data_Pu = insert_node!(graph, make_node(DataTask(6)), false) # transfer data from P to u
|
||||
compute_u = insert_node!(graph, make_node(ComputeTaskU()), false) # compute U node
|
||||
data_out = insert_node!(graph, make_node(DataTask(3)), false) # transfer data out from u
|
||||
data_in = insert_node!(graph, make_node(DataTask(4)), false, false) # read particle data node
|
||||
compute_P = insert_node!(graph, make_node(ComputeTaskP()), false, false) # compute P node
|
||||
data_Pu = insert_node!(graph, make_node(DataTask(6)), false, false) # transfer data from P to u
|
||||
compute_u = insert_node!(graph, make_node(ComputeTaskU()), false, false) # compute U node
|
||||
data_out = insert_node!(graph, make_node(DataTask(3)), false, false) # transfer data out from u
|
||||
|
||||
insert_edge!(graph, make_edge(data_in, compute_P), false)
|
||||
insert_edge!(graph, make_edge(compute_P, data_Pu), false)
|
||||
insert_edge!(graph, make_edge(data_Pu, compute_u), false)
|
||||
insert_edge!(graph, make_edge(compute_u, data_out), false)
|
||||
insert_edge!(graph, make_edge(data_in, compute_P), false, false)
|
||||
insert_edge!(graph, make_edge(compute_P, data_Pu), false, false)
|
||||
insert_edge!(graph, make_edge(data_Pu, compute_u), false, false)
|
||||
insert_edge!(graph, make_edge(compute_u, data_out), false, false)
|
||||
|
||||
# remember the data_out node for future edges
|
||||
dataOutNodes[node] = data_out
|
||||
@ -77,37 +77,37 @@ function parse_abc(filename::String, verbose::Bool = false)
|
||||
in1 = capt.captures[1]
|
||||
in2 = capt.captures[2]
|
||||
|
||||
compute_v = insert_node!(graph, make_node(ComputeTaskV()), false)
|
||||
data_out = insert_node!(graph, make_node(DataTask(5)), false)
|
||||
compute_v = insert_node!(graph, make_node(ComputeTaskV()), false, false)
|
||||
data_out = insert_node!(graph, make_node(DataTask(5)), false, false)
|
||||
|
||||
if (occursin(regex_c, capt.captures[1]))
|
||||
# put an S node after this input
|
||||
compute_S = insert_node!(graph, make_node(ComputeTaskS1()), false)
|
||||
data_S_v = insert_node!(graph, make_node(DataTask(5)), false)
|
||||
compute_S = insert_node!(graph, make_node(ComputeTaskS1()), false, false)
|
||||
data_S_v = insert_node!(graph, make_node(DataTask(5)), false, false)
|
||||
|
||||
insert_edge!(graph, make_edge(dataOutNodes[capt.captures[1]], compute_S), false)
|
||||
insert_edge!(graph, make_edge(compute_S, data_S_v), false)
|
||||
insert_edge!(graph, make_edge(dataOutNodes[capt.captures[1]], compute_S), false, false)
|
||||
insert_edge!(graph, make_edge(compute_S, data_S_v), false, false)
|
||||
|
||||
insert_edge!(graph, make_edge(data_S_v, compute_v), false)
|
||||
insert_edge!(graph, make_edge(data_S_v, compute_v), false, false)
|
||||
else
|
||||
insert_edge!(graph, make_edge(dataOutNodes[capt.captures[1]], compute_v), false)
|
||||
insert_edge!(graph, make_edge(dataOutNodes[capt.captures[1]], compute_v), false, false)
|
||||
end
|
||||
|
||||
if (occursin(regex_c, capt.captures[2]))
|
||||
# i think the current generator only puts the combined particles in the first space, so this case might never be entered
|
||||
# put an S node after this input
|
||||
compute_S = insert_node!(graph, make_node(ComputeTaskS1()), false)
|
||||
data_S_v = insert_node!(graph, make_node(DataTask(5)), false)
|
||||
compute_S = insert_node!(graph, make_node(ComputeTaskS1()), false, false)
|
||||
data_S_v = insert_node!(graph, make_node(DataTask(5)), false, false)
|
||||
|
||||
insert_edge!(graph, make_edge(dataOutNodes[capt.captures[2]], compute_S), false)
|
||||
insert_edge!(graph, make_edge(compute_S, data_S_v), false)
|
||||
insert_edge!(graph, make_edge(dataOutNodes[capt.captures[2]], compute_S), false, false)
|
||||
insert_edge!(graph, make_edge(compute_S, data_S_v), false, false)
|
||||
|
||||
insert_edge!(graph, make_edge(data_S_v, compute_v), false)
|
||||
insert_edge!(graph, make_edge(data_S_v, compute_v), false, false)
|
||||
else
|
||||
insert_edge!(graph, make_edge(dataOutNodes[capt.captures[2]], compute_v), false)
|
||||
insert_edge!(graph, make_edge(dataOutNodes[capt.captures[2]], compute_v), false, false)
|
||||
end
|
||||
|
||||
insert_edge!(graph, make_edge(compute_v, data_out), false)
|
||||
insert_edge!(graph, make_edge(compute_v, data_out), false, false)
|
||||
dataOutNodes[node] = data_out
|
||||
|
||||
elseif occursin(regex_m, node)
|
||||
@ -118,22 +118,22 @@ function parse_abc(filename::String, verbose::Bool = false)
|
||||
in3 = capt.captures[3]
|
||||
|
||||
# in2 + in3 with a v
|
||||
compute_v = insert_node!(graph, make_node(ComputeTaskV()), false)
|
||||
data_v = insert_node!(graph, make_node(DataTask(5)), false)
|
||||
compute_v = insert_node!(graph, make_node(ComputeTaskV()), false, false)
|
||||
data_v = insert_node!(graph, make_node(DataTask(5)), false, false)
|
||||
|
||||
insert_edge!(graph, make_edge(dataOutNodes[in2], compute_v), false)
|
||||
insert_edge!(graph, make_edge(dataOutNodes[in3], compute_v), false)
|
||||
insert_edge!(graph, make_edge(compute_v, data_v), false)
|
||||
insert_edge!(graph, make_edge(dataOutNodes[in2], compute_v), false, false)
|
||||
insert_edge!(graph, make_edge(dataOutNodes[in3], compute_v), false, false)
|
||||
insert_edge!(graph, make_edge(compute_v, data_v), false, false)
|
||||
|
||||
# combine with the v of the combined other input
|
||||
compute_S2 = insert_node!(graph, make_node(ComputeTaskS2()), false)
|
||||
data_out = insert_node!(graph, make_node(DataTask(10)), false)
|
||||
compute_S2 = insert_node!(graph, make_node(ComputeTaskS2()), false, false)
|
||||
data_out = insert_node!(graph, make_node(DataTask(10)), false, false)
|
||||
|
||||
insert_edge!(graph, make_edge(data_v, compute_S2), false)
|
||||
insert_edge!(graph, make_edge(dataOutNodes[in1], compute_S2), false)
|
||||
insert_edge!(graph, make_edge(compute_S2, data_out), false)
|
||||
insert_edge!(graph, make_edge(data_v, compute_S2), false, false)
|
||||
insert_edge!(graph, make_edge(dataOutNodes[in1], compute_S2), false, false)
|
||||
insert_edge!(graph, make_edge(compute_S2, data_out), false, false)
|
||||
|
||||
insert_edge!(graph, make_edge(data_out, sum_node), false)
|
||||
insert_edge!(graph, make_edge(data_out, sum_node), false, false)
|
||||
elseif occursin(regex_plus, node)
|
||||
if (verbose)
|
||||
println("\rReading Nodes Complete ")
|
||||
@ -144,6 +144,9 @@ function parse_abc(filename::String, verbose::Bool = false)
|
||||
end
|
||||
end
|
||||
|
||||
#put all nodes into dirty nodes set
|
||||
graph.dirtyNodes = copy(graph.nodes)
|
||||
|
||||
# don't actually need to read the edges
|
||||
return graph
|
||||
end
|
||||
|
@ -1,27 +1,29 @@
|
||||
struct DataTask <: AbstractDataTask
|
||||
data::UInt64
|
||||
end
|
||||
|
||||
# S task with 1 child
|
||||
struct ComputeTaskS1 <: AbstractComputeTask
|
||||
end
|
||||
|
||||
# S task with 2 children
|
||||
struct ComputeTaskS2 <: AbstractComputeTask
|
||||
end
|
||||
|
||||
# P task with 0 children
|
||||
struct ComputeTaskP <: AbstractComputeTask
|
||||
end
|
||||
|
||||
# v task with 2 children
|
||||
struct ComputeTaskV <: AbstractComputeTask
|
||||
end
|
||||
|
||||
# u task with 1 child
|
||||
struct ComputeTaskU <: AbstractComputeTask
|
||||
end
|
||||
|
||||
# task that sums all its inputs, n children
|
||||
struct ComputeTaskSum <: AbstractComputeTask
|
||||
end
|
||||
end
|
||||
|
||||
# S task with 1 child
|
||||
struct ComputeTaskS1 <: AbstractComputeTask
|
||||
end
|
||||
|
||||
# S task with 2 children
|
||||
struct ComputeTaskS2 <: AbstractComputeTask
|
||||
end
|
||||
|
||||
# P task with 0 children
|
||||
struct ComputeTaskP <: AbstractComputeTask
|
||||
end
|
||||
|
||||
# v task with 2 children
|
||||
struct ComputeTaskV <: AbstractComputeTask
|
||||
end
|
||||
|
||||
# u task with 1 child
|
||||
struct ComputeTaskU <: AbstractComputeTask
|
||||
end
|
||||
|
||||
# task that sums all its inputs, n children
|
||||
struct ComputeTaskSum <: AbstractComputeTask
|
||||
end
|
||||
|
||||
ABC_TASKS = [DataTask, ComputeTaskS1, ComputeTaskS2, ComputeTaskP, ComputeTaskV, ComputeTaskU, ComputeTaskSum]
|
||||
|
@ -33,6 +33,10 @@ end
|
||||
|
||||
struct NodeReduction <: Operation
|
||||
input::Vector{Node}
|
||||
|
||||
# these inputs can (and do) get very large in large graphs, so we need a better way to compare equality between them
|
||||
# only node reductions with the same id will be considered equal (the id can be copied)
|
||||
id::UUID
|
||||
end
|
||||
|
||||
struct AppliedNodeReduction <: AppliedOperation
|
||||
@ -49,7 +53,6 @@ struct AppliedNodeSplit <: AppliedOperation
|
||||
diff::Diff
|
||||
end
|
||||
|
||||
|
||||
mutable struct PossibleOperations
|
||||
nodeFusions::Set{NodeFusion}
|
||||
nodeReductions::Set{NodeReduction}
|
||||
@ -64,7 +67,6 @@ function PossibleOperations()
|
||||
)
|
||||
end
|
||||
|
||||
|
||||
# The actual state of the DAG is the initial state given by the set of nodes
|
||||
# but with all the operations in appliedChain applied in order
|
||||
mutable struct DAG
|
||||
|
@ -3,31 +3,6 @@ using DataStructures
|
||||
in(node::Node, graph::DAG) = node in graph.nodes
|
||||
in(edge::Edge, graph::DAG) = edge in graph.edges
|
||||
|
||||
function isempty(operations::PossibleOperations)
|
||||
return isempty(operations.nodeFusions) &&
|
||||
isempty(operations.nodeReductions) &&
|
||||
isempty(operations.nodeSplits)
|
||||
end
|
||||
|
||||
function length(operations::PossibleOperations)
|
||||
return (nodeFusions = length(operations.nodeFusions),
|
||||
nodeReductions = length(operations.nodeReductions),
|
||||
nodeSplits = length(operations.nodeSplits))
|
||||
end
|
||||
|
||||
function delete!(operations::PossibleOperations, op::NodeFusion)
|
||||
delete!(operations.nodeFusions, op)
|
||||
return operations
|
||||
end
|
||||
function delete!(operations::PossibleOperations, op::NodeReduction)
|
||||
delete!(operations.nodeReductions, op)
|
||||
return operations
|
||||
end
|
||||
function delete!(operations::PossibleOperations, op::NodeSplit)
|
||||
delete!(operations.nodeSplits, op)
|
||||
return operations
|
||||
end
|
||||
|
||||
function is_parent(potential_parent, node)
|
||||
return potential_parent in node.parents
|
||||
end
|
||||
@ -57,34 +32,38 @@ function parents(node::Node)
|
||||
return copy(node.parents)
|
||||
end
|
||||
|
||||
# siblings = all children of any parents, no duplicates, does not include the node itself
|
||||
# siblings = all children of any parents, no duplicates, includes the node itself
|
||||
function siblings(node::Node)
|
||||
result = Set{Node}()
|
||||
push!(result, node)
|
||||
for parent in node.parents
|
||||
for sibling in parent.children
|
||||
if (sibling != node)
|
||||
push!(result, sibling)
|
||||
end
|
||||
end
|
||||
union!(result, parent.children)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
# partners = all parents of any children, no duplicates, does not include the node itself
|
||||
# partners = all parents of any children, no duplicates, includes the node itself
|
||||
function partners(node::Node)
|
||||
result = Set{Node}()
|
||||
push!(result, node)
|
||||
for child in node.children
|
||||
for partner in child.parents
|
||||
if (partner != node)
|
||||
push!(result, partner)
|
||||
end
|
||||
end
|
||||
union!(result, child.parents)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
# alternative version to partners(Node), avoiding allocation of a new set
|
||||
# works on the given set and returns nothing
|
||||
function partners(node::Node, set::Set{Node})
|
||||
push!(set, node)
|
||||
for child in node.children
|
||||
union!(set, child.parents)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
is_entry_node(node::Node) = length(node.children) == 0
|
||||
is_exit_node(node::Node) = length(node.parents) == 0
|
||||
|
||||
@ -117,7 +96,7 @@ end
|
||||
# 2: keep track of what was changed for the diff (if track == true)
|
||||
# 3: invalidate operation caches
|
||||
|
||||
function insert_node!(graph::DAG, node::Node, track=true)
|
||||
function insert_node!(graph::DAG, node::Node, track=true, invalidate_cache=true)
|
||||
# 1: mute
|
||||
push!(graph.nodes, node)
|
||||
|
||||
@ -125,12 +104,13 @@ function insert_node!(graph::DAG, node::Node, track=true)
|
||||
if (track) push!(graph.diff.addedNodes, node) end
|
||||
|
||||
# 3: invalidate caches
|
||||
if (!invalidate_cache) return node end
|
||||
push!(graph.dirtyNodes, node)
|
||||
|
||||
return node
|
||||
end
|
||||
|
||||
function insert_edge!(graph::DAG, edge::Edge, track=true)
|
||||
function insert_edge!(graph::DAG, edge::Edge, track=true, invalidate_cache=true)
|
||||
node1 = edge.edge[1]
|
||||
node2 = edge.edge[2]
|
||||
|
||||
@ -150,6 +130,8 @@ function insert_edge!(graph::DAG, edge::Edge, track=true)
|
||||
if (track) push!(graph.diff.addedEdges, edge) end
|
||||
|
||||
# 3: invalidate caches
|
||||
if (!invalidate_cache) return edge end
|
||||
|
||||
while !isempty(node1.operations)
|
||||
invalidate_caches!(graph, first(node1.operations))
|
||||
end
|
||||
@ -162,7 +144,7 @@ function insert_edge!(graph::DAG, edge::Edge, track=true)
|
||||
return edge
|
||||
end
|
||||
|
||||
function remove_node!(graph::DAG, node::Node, track=true)
|
||||
function remove_node!(graph::DAG, node::Node, track=true, invalidate_cache=true)
|
||||
# 1: mute
|
||||
#=if !(node in graph.nodes)
|
||||
error("Trying to remove a node that's not in the graph")
|
||||
@ -173,6 +155,8 @@ function remove_node!(graph::DAG, node::Node, track=true)
|
||||
if (track) push!(graph.diff.removedNodes, node) end
|
||||
|
||||
# 3: invalidate caches
|
||||
if (!invalidate_cache) return node end
|
||||
|
||||
while !isempty(node.operations)
|
||||
invalidate_caches!(graph, first(node.operations))
|
||||
end
|
||||
@ -181,7 +165,7 @@ function remove_node!(graph::DAG, node::Node, track=true)
|
||||
return nothing
|
||||
end
|
||||
|
||||
function remove_edge!(graph::DAG, edge::Edge, track=true)
|
||||
function remove_edge!(graph::DAG, edge::Edge, track=true, invalidate_cache=true)
|
||||
node1 = edge.edge[1]
|
||||
node2 = edge.edge[2]
|
||||
|
||||
@ -205,6 +189,8 @@ function remove_edge!(graph::DAG, edge::Edge, track=true)
|
||||
if (track) push!(graph.diff.removedEdges, edge) end
|
||||
|
||||
# 3: invalidate caches
|
||||
if (!invalidate_cache) return nothing end
|
||||
|
||||
while !isempty(node1.operations)
|
||||
invalidate_caches!(graph, first(node1.operations))
|
||||
end
|
||||
@ -258,30 +244,6 @@ function get_exit_node(graph::DAG)
|
||||
error("The given graph has no exit node! It is either empty or not acyclic!")
|
||||
end
|
||||
|
||||
function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
if !is_child(n1, n2) || !is_child(n2, n3)
|
||||
# the checks are redundant but maybe a good sanity check
|
||||
return false
|
||||
end
|
||||
|
||||
if length(n2.parents) != 1 || length(n2.children) != 1 || length(n1.parents) != 1
|
||||
return false
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function can_reduce(n1::Node, n2::Node)
|
||||
if (n1.task != n2.task)
|
||||
return false
|
||||
end
|
||||
return Set(n1.children) == Set(n2.children)
|
||||
end
|
||||
|
||||
function can_split(n::Node)
|
||||
return length(parents(n)) > 1
|
||||
end
|
||||
|
||||
# check whether the given graph is connected
|
||||
function is_valid(graph::DAG)
|
||||
nodeQueue = Deque{Node}()
|
||||
|
34
src/graph_interface.jl
Normal file
34
src/graph_interface.jl
Normal file
@ -0,0 +1,34 @@
|
||||
# user interface on the DAG
|
||||
|
||||
# applies a new operation to the end of the graph
|
||||
function push_operation!(graph::DAG, operation::Operation)
|
||||
# 1.: Add the operation to the DAG
|
||||
push!(graph.operationsToApply, operation)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# reverts the latest applied operation, essentially like a ctrl+z for
|
||||
function pop_operation!(graph::DAG)
|
||||
# 1.: Remove the operation from the appliedChain of the DAG
|
||||
if !isempty(graph.operationsToApply)
|
||||
pop!(graph.operationsToApply)
|
||||
elseif !isempty(graph.appliedOperations)
|
||||
appliedOp = pop!(graph.appliedOperations)
|
||||
revert_operation!(graph, appliedOp)
|
||||
else
|
||||
error("No more operations to pop!")
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
can_pop(graph::DAG) = !isempty(graph.operationsToApply) || !isempty(graph.appliedOperations)
|
||||
|
||||
# reset the graph to its initial state with no operations applied
|
||||
function reset_graph!(graph::DAG)
|
||||
while (can_pop(graph))
|
||||
pop_operation!(graph)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
@ -1,485 +0,0 @@
|
||||
# outside interface
|
||||
|
||||
# applies a new operation to the end of the graph
|
||||
function push_operation!(graph::DAG, operation::Operation)
|
||||
# 1.: Add the operation to the DAG
|
||||
push!(graph.operationsToApply, operation)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# reverts the latest applied operation, essentially like a ctrl+z for
|
||||
function pop_operation!(graph::DAG)
|
||||
# 1.: Remove the operation from the appliedChain of the DAG
|
||||
if !isempty(graph.operationsToApply)
|
||||
pop!(graph.operationsToApply)
|
||||
elseif !isempty(graph.appliedOperations)
|
||||
appliedOp = pop!(graph.appliedOperations)
|
||||
revert_operation!(graph, appliedOp)
|
||||
else
|
||||
error("No more operations to pop!")
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
can_pop(graph::DAG) = !isempty(graph.operationsToApply) || !isempty(graph.appliedOperations)
|
||||
|
||||
# reset the graph to its initial state with no operations applied
|
||||
function reset_graph!(graph::DAG)
|
||||
while (can_pop(graph))
|
||||
pop_operation!(graph)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# implementation detail functions, don't export
|
||||
|
||||
# applies all unapplied operations in the DAG
|
||||
function apply_all!(graph::DAG)
|
||||
while !isempty(graph.operationsToApply)
|
||||
# get next operation to apply from front of the deque
|
||||
op = popfirst!(graph.operationsToApply)
|
||||
|
||||
# apply it
|
||||
appliedOp = apply_operation!(graph, op)
|
||||
|
||||
# push to the end of the appliedOperations deque
|
||||
push!(graph.appliedOperations, appliedOp)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
|
||||
function apply_operation!(graph::DAG, operation::Operation)
|
||||
error("Unknown operation type!")
|
||||
end
|
||||
|
||||
function apply_operation!(graph::DAG, operation::NodeFusion)
|
||||
diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3])
|
||||
return AppliedNodeFusion(operation, diff)
|
||||
end
|
||||
|
||||
function apply_operation!(graph::DAG, operation::NodeReduction)
|
||||
diff = node_reduction!(graph, operation.input[1], operation.input[2])
|
||||
return AppliedNodeReduction(operation, diff)
|
||||
end
|
||||
|
||||
function apply_operation!(graph::DAG, operation::NodeSplit)
|
||||
diff = node_split!(graph, operation.input)
|
||||
return AppliedNodeSplit(operation, diff)
|
||||
end
|
||||
|
||||
|
||||
function revert_operation!(graph::DAG, operation::AppliedOperation)
|
||||
error("Unknown operation type!")
|
||||
end
|
||||
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeFusion)
|
||||
revert_diff!(graph, operation.diff)
|
||||
return operation.operation
|
||||
end
|
||||
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeReduction)
|
||||
revert_diff!(graph, operation.diff)
|
||||
return operation.operation
|
||||
end
|
||||
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeSplit)
|
||||
revert_diff!(graph, operation.diff)
|
||||
return operation.operation
|
||||
end
|
||||
|
||||
|
||||
function revert_diff!(graph::DAG, diff)
|
||||
# add removed nodes, remove added nodes, same for edges
|
||||
# note the order
|
||||
for edge in diff.addedEdges
|
||||
remove_edge!(graph, edge, false)
|
||||
end
|
||||
for node in diff.addedNodes
|
||||
remove_node!(graph, node, false)
|
||||
end
|
||||
|
||||
for node in diff.removedNodes
|
||||
insert_node!(graph, node, false)
|
||||
end
|
||||
for edge in diff.removedEdges
|
||||
insert_edge!(graph, edge, false)
|
||||
end
|
||||
end
|
||||
|
||||
# Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph
|
||||
function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
if !(n1 in graph) || !(n2 in graph) || !(n3 in graph)
|
||||
error("[Node Fusion] The given nodes are not part of the given graph")
|
||||
end
|
||||
|
||||
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
|
||||
# the checks are redundant but maybe a good sanity check
|
||||
error("[Node Fusion] The given nodes are not connected by edges which is required for node fusion")
|
||||
end
|
||||
|
||||
# save children and parents
|
||||
n1_children = children(n1)
|
||||
n3_parents = parents(n3)
|
||||
n3_children = children(n3)
|
||||
|
||||
if length(n2.parents) > 1
|
||||
error("[Node Fusion] The given data node has more than one parent")
|
||||
end
|
||||
if length(n2.children) > 1
|
||||
error("[Node Fusion] The given data node has more than one child")
|
||||
end
|
||||
if length(n1.parents) > 1
|
||||
error("[Node Fusion] The given n1 has more than one parent")
|
||||
end
|
||||
|
||||
required_edge1 = make_edge(n1, n2)
|
||||
required_edge2 = make_edge(n2, n3)
|
||||
|
||||
# remove the edges and nodes that will be replaced by the fused node
|
||||
remove_edge!(graph, required_edge1)
|
||||
remove_edge!(graph, required_edge2)
|
||||
remove_node!(graph, n1)
|
||||
remove_node!(graph, n2)
|
||||
|
||||
# get n3's children now so it automatically excludes n2
|
||||
n3_children = children(n3)
|
||||
remove_node!(graph, n3)
|
||||
|
||||
# create new node with the fused compute task
|
||||
new_node = ComputeTaskNode(FusedComputeTask{typeof(n1.task), typeof(n3.task)}())
|
||||
insert_node!(graph, new_node)
|
||||
|
||||
# use a set for combined children of n1 and n3 to not get duplicates
|
||||
n1and3_children = Set{Node}()
|
||||
|
||||
# remove edges from n1 children to n1
|
||||
for child in n1_children
|
||||
remove_edge!(graph, make_edge(child, n1))
|
||||
push!(n1and3_children, child)
|
||||
end
|
||||
|
||||
# remove edges from n3 children to n3
|
||||
for child in n3_children
|
||||
remove_edge!(graph, make_edge(child, n3))
|
||||
push!(n1and3_children, child)
|
||||
end
|
||||
|
||||
for child in n1and3_children
|
||||
insert_edge!(graph, make_edge(child, new_node))
|
||||
end
|
||||
|
||||
# "repoint" parents of n3 from new node
|
||||
for parent in n3_parents
|
||||
remove_edge!(graph, make_edge(n3, parent))
|
||||
insert_edge!(graph, make_edge(new_node, parent))
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
end
|
||||
|
||||
function node_reduction!(graph::DAG, n1::Node, n2::Node)
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
#=if !(n1 in graph) || !(n2 in graph)
|
||||
error("[Node Reduction] The given nodes are not part of the given graph")
|
||||
end=#
|
||||
|
||||
#=if typeof(n1) != typeof(n2)
|
||||
error("[Node Reduction] The given nodes are not of the same type")
|
||||
end=#
|
||||
|
||||
# save n2 parents and children
|
||||
n2_children = children(n2)
|
||||
n2_parents = Set(n2.parents)
|
||||
|
||||
#=if Set(n2_children) != Set(n1.children)
|
||||
error("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction")
|
||||
end=#
|
||||
|
||||
# remove n2 and all its parents and children
|
||||
for child in n2_children
|
||||
remove_edge!(graph, make_edge(child, n2))
|
||||
end
|
||||
|
||||
|
||||
for parent in n2_parents
|
||||
remove_edge!(graph, make_edge(n2, parent))
|
||||
end
|
||||
|
||||
for parent in n1.parents
|
||||
# delete parents in n1 that already exist in n2
|
||||
delete!(n2_parents, parent)
|
||||
end
|
||||
|
||||
for parent in n2_parents
|
||||
# now add parents of n2 to n1 without duplicates
|
||||
insert_edge!(graph, make_edge(n1, parent))
|
||||
end
|
||||
|
||||
remove_node!(graph, n2)
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
end
|
||||
|
||||
function node_split!(graph::DAG, n1::Node)
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
#=if !(n1 in graph)
|
||||
error("[Node Split] The given node is not part of the given graph")
|
||||
end=#
|
||||
|
||||
n1_parents = parents(n1)
|
||||
n1_children = children(n1)
|
||||
|
||||
#=if length(n1_parents) <= 1
|
||||
error("[Node Split] The given node does not have multiple parents which is required for node split")
|
||||
end=#
|
||||
|
||||
for parent in n1_parents
|
||||
remove_edge!(graph, make_edge(n1, parent))
|
||||
end
|
||||
for child in n1_children
|
||||
remove_edge!(graph, make_edge(child, n1))
|
||||
end
|
||||
remove_node!(graph, n1)
|
||||
|
||||
for parent in n1_parents
|
||||
n_copy = copy(n1)
|
||||
insert_node!(graph, n_copy)
|
||||
insert_edge!(graph, make_edge(n_copy, parent))
|
||||
|
||||
for child in n1_children
|
||||
insert_edge!(graph, make_edge(child, n_copy))
|
||||
end
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
end
|
||||
|
||||
# function to find node fusions involving the given node if it's a data node
|
||||
# pushes the found fusion everywhere it needs to be and returns nothing
|
||||
function find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
if length(node.parents) != 1 || length(node.children) != 1
|
||||
return nothing
|
||||
end
|
||||
|
||||
child_node = first(node.children)
|
||||
parent_node = first(node.parents)
|
||||
|
||||
#=if !(child_node in graph) || !(parent_node in graph)
|
||||
error("Parents/Children that are not in the graph!!!")
|
||||
end=#
|
||||
|
||||
if length(child_node.parents) != 1
|
||||
return nothing
|
||||
end
|
||||
|
||||
nf = NodeFusion((child_node, node, parent_node))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(child_node.operations, nf)
|
||||
push!(node.operations, nf)
|
||||
push!(parent_node.operations, nf)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# function to find node fusions involving the given node if it's a compute node
|
||||
# pushes the found fusion(s) everywhere it needs to be and returns nothing
|
||||
function find_fusions!(graph::DAG, node::ComputeTaskNode)
|
||||
# for loop that always runs once for a scoped block we can break out of
|
||||
for _ in 1:1
|
||||
# assume this node as child of the chain
|
||||
if length(node.parents) != 1
|
||||
break
|
||||
end
|
||||
node2 = first(node.parents)
|
||||
if length(node2.parents) != 1 || length(node2.children) != 1
|
||||
break
|
||||
end
|
||||
node3 = first(node2.parents)
|
||||
|
||||
#=if !(node2 in graph) || !(node3 in graph)
|
||||
error("Parents/Children that are not in the graph!!!")
|
||||
end=#
|
||||
|
||||
nf = NodeFusion((node, node2, node3))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(node.operations, nf)
|
||||
push!(node2.operations, nf)
|
||||
push!(node3.operations, nf)
|
||||
end
|
||||
|
||||
for _ in 1:1
|
||||
# assume this node as parent of the chain
|
||||
if length(node.children) < 1
|
||||
break
|
||||
end
|
||||
node2 = first(node.children)
|
||||
if length(node2.parents) != 1 || length(node2.children) != 1
|
||||
break
|
||||
end
|
||||
node1 = first(node2.children)
|
||||
if (length(node1.parents) > 1)
|
||||
break
|
||||
end
|
||||
|
||||
#=if !(node2 in graph) || !(node1 in graph)
|
||||
error("Parents/Children that are not in the graph!!!")
|
||||
end=#
|
||||
|
||||
nf = NodeFusion((node1, node2, node))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(node1.operations, nf)
|
||||
push!(node2.operations, nf)
|
||||
push!(node.operations, nf)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function find_reductions!(graph::DAG, node::Node)
|
||||
reductionVector = nothing
|
||||
# possible reductions are with nodes that are partners, i.e. parents of children
|
||||
for partner in partners(node)
|
||||
if can_reduce(node, partner)
|
||||
if reductionVector === nothing
|
||||
# only when there's at least one reduction partner, insert the vector
|
||||
reductionVector = Vector{Node}()
|
||||
push!(reductionVector, node)
|
||||
end
|
||||
|
||||
push!(reductionVector, partner)
|
||||
end
|
||||
end
|
||||
|
||||
if reductionVector !== nothing
|
||||
nr = NodeReduction(reductionVector)
|
||||
push!(graph.possibleOperations.nodeReductions, nr)
|
||||
for node in reductionVector
|
||||
push!(node.operations, nr)
|
||||
end
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function find_splits!(graph::DAG, node::Node)
|
||||
if (can_split(node))
|
||||
ns = NodeSplit(node)
|
||||
push!(graph.possibleOperations.nodeSplits, ns)
|
||||
push!(node.operations, ns)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# "clean" the operations on a dirty node
|
||||
function clean_node!(graph::DAG, node::Node)
|
||||
find_fusions!(graph, node)
|
||||
find_reductions!(graph, node)
|
||||
find_splits!(graph, node)
|
||||
|
||||
delete!(graph.dirtyNodes, node)
|
||||
end
|
||||
|
||||
# function to generate all possible optmizations on the graph
|
||||
function generate_options(graph::DAG)
|
||||
options = PossibleOperations()
|
||||
|
||||
# make sure the graph is fully generated through
|
||||
apply_all!(graph)
|
||||
|
||||
# find possible node fusions
|
||||
for node in graph.nodes
|
||||
if (typeof(node) <: DataTaskNode)
|
||||
if length(node.parents) != 1
|
||||
# data node can only have a single parent
|
||||
continue
|
||||
end
|
||||
parent_node = first(node.parents)
|
||||
|
||||
if length(node.children) != 1
|
||||
# this node is an entry node or has multiple children which should not be possible
|
||||
continue
|
||||
end
|
||||
child_node = first(node.children)
|
||||
if (length(child_node.parents) != 1)
|
||||
continue
|
||||
end
|
||||
|
||||
nf = NodeFusion((child_node, node, parent_node))
|
||||
push!(options.nodeFusions, nf)
|
||||
push!(child_node.operations, nf)
|
||||
push!(node.operations, nf)
|
||||
push!(parent_node.operations, nf)
|
||||
end
|
||||
end
|
||||
|
||||
# find possible node reductions
|
||||
visitedNodes = Set{Node}()
|
||||
|
||||
for node in graph.nodes
|
||||
if (node in visitedNodes)
|
||||
continue
|
||||
end
|
||||
|
||||
push!(visitedNodes, node)
|
||||
|
||||
reductionVector = nothing
|
||||
# possible reductions are with nodes that are partners, i.e. parents of children
|
||||
for partner in partners(node)
|
||||
if can_reduce(node, partner)
|
||||
if reductionVector === nothing
|
||||
# only when there's at least one reduction partner, insert the vector
|
||||
reductionVector = Vector{Node}()
|
||||
push!(reductionVector, node)
|
||||
end
|
||||
|
||||
push!(reductionVector, partner)
|
||||
push!(visitedNodes, partner)
|
||||
end
|
||||
end
|
||||
|
||||
if reductionVector !== nothing
|
||||
nr = NodeReduction(reductionVector)
|
||||
push!(options.nodeReductions, nr)
|
||||
for node in reductionVector
|
||||
push!(node.operations, nr)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# find possible node splits
|
||||
for node in graph.nodes
|
||||
if (can_split(node))
|
||||
ns = NodeSplit(node)
|
||||
push!(options.nodeSplits, ns)
|
||||
push!(node.operations, ns)
|
||||
end
|
||||
end
|
||||
|
||||
graph.possibleOperations = options
|
||||
empty!(graph.dirtyNodes)
|
||||
end
|
||||
|
||||
function get_operations(graph::DAG)
|
||||
apply_all!(graph)
|
||||
|
||||
if isempty(graph.possibleOperations)
|
||||
generate_options(graph)
|
||||
end
|
||||
|
||||
while !isempty(graph.dirtyNodes)
|
||||
clean_node!(graph, first(graph.dirtyNodes))
|
||||
end
|
||||
|
||||
return graph.possibleOperations
|
||||
end
|
@ -46,5 +46,5 @@ function ==(n1::DataTaskNode, n2::DataTaskNode)
|
||||
return n1.id == n2.id
|
||||
end
|
||||
|
||||
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng), copy(n.operations))
|
||||
copy(n::DataTaskNode) = DataTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng), copy(n.operations))
|
||||
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng[threadid()]), copy(n.operations))
|
||||
copy(n::DataTaskNode) = DataTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng[threadid()]), copy(n.operations))
|
||||
|
@ -1,7 +1,8 @@
|
||||
using Random
|
||||
using UUIDs
|
||||
using Base.Threads
|
||||
|
||||
rng = Random.MersenneTwister(0)
|
||||
rng = [Random.MersenneTwister(0) for _ in 1:32]
|
||||
|
||||
abstract type Node end
|
||||
|
||||
@ -33,8 +34,8 @@ struct ComputeTaskNode <: Node
|
||||
operations::Vector{Operation}
|
||||
end
|
||||
|
||||
DataTaskNode(t::AbstractDataTask) = DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng), Vector{Operation}())
|
||||
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng), Vector{Operation}())
|
||||
DataTaskNode(t::AbstractDataTask) = DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), Vector{Operation}())
|
||||
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), Vector{Operation}())
|
||||
|
||||
struct Edge
|
||||
# edge points from child to parent
|
||||
|
229
src/operations/apply.jl
Normal file
229
src/operations/apply.jl
Normal file
@ -0,0 +1,229 @@
|
||||
# functions that apply graph operations
|
||||
|
||||
# applies all unapplied operations in the DAG
|
||||
function apply_all!(graph::DAG)
|
||||
while !isempty(graph.operationsToApply)
|
||||
# get next operation to apply from front of the deque
|
||||
op = popfirst!(graph.operationsToApply)
|
||||
|
||||
# apply it
|
||||
appliedOp = apply_operation!(graph, op)
|
||||
|
||||
# push to the end of the appliedOperations deque
|
||||
push!(graph.appliedOperations, appliedOp)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function apply_operation!(graph::DAG, operation::Operation)
|
||||
error("Unknown operation type!")
|
||||
end
|
||||
|
||||
function apply_operation!(graph::DAG, operation::NodeFusion)
|
||||
diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3])
|
||||
return AppliedNodeFusion(operation, diff)
|
||||
end
|
||||
|
||||
function apply_operation!(graph::DAG, operation::NodeReduction)
|
||||
diff = node_reduction!(graph, operation.input[1], operation.input[2])
|
||||
return AppliedNodeReduction(operation, diff)
|
||||
end
|
||||
|
||||
function apply_operation!(graph::DAG, operation::NodeSplit)
|
||||
diff = node_split!(graph, operation.input)
|
||||
return AppliedNodeSplit(operation, diff)
|
||||
end
|
||||
|
||||
|
||||
function revert_operation!(graph::DAG, operation::AppliedOperation)
|
||||
error("Unknown operation type!")
|
||||
end
|
||||
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeFusion)
|
||||
revert_diff!(graph, operation.diff)
|
||||
return operation.operation
|
||||
end
|
||||
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeReduction)
|
||||
revert_diff!(graph, operation.diff)
|
||||
return operation.operation
|
||||
end
|
||||
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeSplit)
|
||||
revert_diff!(graph, operation.diff)
|
||||
return operation.operation
|
||||
end
|
||||
|
||||
|
||||
function revert_diff!(graph::DAG, diff)
|
||||
# add removed nodes, remove added nodes, same for edges
|
||||
# note the order
|
||||
for edge in diff.addedEdges
|
||||
remove_edge!(graph, edge, false)
|
||||
end
|
||||
for node in diff.addedNodes
|
||||
remove_node!(graph, node, false)
|
||||
end
|
||||
|
||||
for node in diff.removedNodes
|
||||
insert_node!(graph, node, false)
|
||||
end
|
||||
for edge in diff.removedEdges
|
||||
insert_edge!(graph, edge, false)
|
||||
end
|
||||
end
|
||||
|
||||
# Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph
|
||||
function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
if !(n1 in graph) || !(n2 in graph) || !(n3 in graph)
|
||||
error("[Node Fusion] The given nodes are not part of the given graph")
|
||||
end
|
||||
|
||||
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
|
||||
# the checks are redundant but maybe a good sanity check
|
||||
error("[Node Fusion] The given nodes are not connected by edges which is required for node fusion")
|
||||
end
|
||||
|
||||
# save children and parents
|
||||
n1_children = children(n1)
|
||||
n3_parents = parents(n3)
|
||||
n3_children = children(n3)
|
||||
|
||||
if length(n2.parents) > 1
|
||||
error("[Node Fusion] The given data node has more than one parent")
|
||||
end
|
||||
if length(n2.children) > 1
|
||||
error("[Node Fusion] The given data node has more than one child")
|
||||
end
|
||||
if length(n1.parents) > 1
|
||||
error("[Node Fusion] The given n1 has more than one parent")
|
||||
end
|
||||
|
||||
required_edge1 = make_edge(n1, n2)
|
||||
required_edge2 = make_edge(n2, n3)
|
||||
|
||||
# remove the edges and nodes that will be replaced by the fused node
|
||||
remove_edge!(graph, required_edge1)
|
||||
remove_edge!(graph, required_edge2)
|
||||
remove_node!(graph, n1)
|
||||
remove_node!(graph, n2)
|
||||
|
||||
# get n3's children now so it automatically excludes n2
|
||||
n3_children = children(n3)
|
||||
remove_node!(graph, n3)
|
||||
|
||||
# create new node with the fused compute task
|
||||
new_node = ComputeTaskNode(FusedComputeTask{typeof(n1.task),typeof(n3.task)}())
|
||||
insert_node!(graph, new_node)
|
||||
|
||||
# use a set for combined children of n1 and n3 to not get duplicates
|
||||
n1and3_children = Set{Node}()
|
||||
|
||||
# remove edges from n1 children to n1
|
||||
for child in n1_children
|
||||
remove_edge!(graph, make_edge(child, n1))
|
||||
push!(n1and3_children, child)
|
||||
end
|
||||
|
||||
# remove edges from n3 children to n3
|
||||
for child in n3_children
|
||||
remove_edge!(graph, make_edge(child, n3))
|
||||
push!(n1and3_children, child)
|
||||
end
|
||||
|
||||
for child in n1and3_children
|
||||
insert_edge!(graph, make_edge(child, new_node))
|
||||
end
|
||||
|
||||
# "repoint" parents of n3 from new node
|
||||
for parent in n3_parents
|
||||
remove_edge!(graph, make_edge(n3, parent))
|
||||
insert_edge!(graph, make_edge(new_node, parent))
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
end
|
||||
|
||||
function node_reduction!(graph::DAG, n1::Node, n2::Node)
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
#=if !(n1 in graph) || !(n2 in graph)
|
||||
error("[Node Reduction] The given nodes are not part of the given graph")
|
||||
end=#
|
||||
|
||||
#=if typeof(n1) != typeof(n2)
|
||||
error("[Node Reduction] The given nodes are not of the same type")
|
||||
end=#
|
||||
|
||||
# save n2 parents and children
|
||||
n2_children = children(n2)
|
||||
n2_parents = Set(n2.parents)
|
||||
|
||||
#=if Set(n2_children) != Set(n1.children)
|
||||
error("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction")
|
||||
end=#
|
||||
|
||||
# remove n2 and all its parents and children
|
||||
for child in n2_children
|
||||
remove_edge!(graph, make_edge(child, n2))
|
||||
end
|
||||
|
||||
|
||||
for parent in n2_parents
|
||||
remove_edge!(graph, make_edge(n2, parent))
|
||||
end
|
||||
|
||||
for parent in n1.parents
|
||||
# delete parents in n1 that already exist in n2
|
||||
delete!(n2_parents, parent)
|
||||
end
|
||||
|
||||
for parent in n2_parents
|
||||
# now add parents of n2 to n1 without duplicates
|
||||
insert_edge!(graph, make_edge(n1, parent))
|
||||
end
|
||||
|
||||
remove_node!(graph, n2)
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
end
|
||||
|
||||
function node_split!(graph::DAG, n1::Node)
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
#=if !(n1 in graph)
|
||||
error("[Node Split] The given node is not part of the given graph")
|
||||
end=#
|
||||
|
||||
n1_parents = parents(n1)
|
||||
n1_children = children(n1)
|
||||
|
||||
#=if length(n1_parents) <= 1
|
||||
error("[Node Split] The given node does not have multiple parents which is required for node split")
|
||||
end=#
|
||||
|
||||
for parent in n1_parents
|
||||
remove_edge!(graph, make_edge(n1, parent))
|
||||
end
|
||||
for child in n1_children
|
||||
remove_edge!(graph, make_edge(child, n1))
|
||||
end
|
||||
remove_node!(graph, n1)
|
||||
|
||||
for parent in n1_parents
|
||||
n_copy = copy(n1)
|
||||
insert_node!(graph, n_copy)
|
||||
insert_edge!(graph, make_edge(n_copy, parent))
|
||||
|
||||
for child in n1_children
|
||||
insert_edge!(graph, make_edge(child, n_copy))
|
||||
end
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
end
|
127
src/operations/clean.jl
Normal file
127
src/operations/clean.jl
Normal file
@ -0,0 +1,127 @@
|
||||
# functions for "cleaning" nodes, i.e. regenerating the possible operations for a node
|
||||
|
||||
# function to find node fusions involving the given node if it's a data node
|
||||
# pushes the found fusion everywhere it needs to be and returns nothing
|
||||
function find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
if length(node.parents) != 1 || length(node.children) != 1
|
||||
return nothing
|
||||
end
|
||||
|
||||
child_node = first(node.children)
|
||||
parent_node = first(node.parents)
|
||||
|
||||
#=if !(child_node in graph) || !(parent_node in graph)
|
||||
error("Parents/Children that are not in the graph!!!")
|
||||
end=#
|
||||
|
||||
if length(child_node.parents) != 1
|
||||
return nothing
|
||||
end
|
||||
|
||||
nf = NodeFusion((child_node, node, parent_node))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(child_node.operations, nf)
|
||||
push!(node.operations, nf)
|
||||
push!(parent_node.operations, nf)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# function to find node fusions involving the given node if it's a compute node
|
||||
# pushes the found fusion(s) everywhere it needs to be and returns nothing
|
||||
function find_fusions!(graph::DAG, node::ComputeTaskNode)
|
||||
# for loop that always runs once for a scoped block we can break out of
|
||||
for _ in 1:1
|
||||
# assume this node as child of the chain
|
||||
if length(node.parents) != 1
|
||||
break
|
||||
end
|
||||
node2 = first(node.parents)
|
||||
if length(node2.parents) != 1 || length(node2.children) != 1
|
||||
break
|
||||
end
|
||||
node3 = first(node2.parents)
|
||||
|
||||
#=if !(node2 in graph) || !(node3 in graph)
|
||||
error("Parents/Children that are not in the graph!!!")
|
||||
end=#
|
||||
|
||||
nf = NodeFusion((node, node2, node3))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(node.operations, nf)
|
||||
push!(node2.operations, nf)
|
||||
push!(node3.operations, nf)
|
||||
end
|
||||
|
||||
for _ in 1:1
|
||||
# assume this node as parent of the chain
|
||||
if length(node.children) < 1
|
||||
break
|
||||
end
|
||||
node2 = first(node.children)
|
||||
if length(node2.parents) != 1 || length(node2.children) != 1
|
||||
break
|
||||
end
|
||||
node1 = first(node2.children)
|
||||
if (length(node1.parents) > 1)
|
||||
break
|
||||
end
|
||||
|
||||
#=if !(node2 in graph) || !(node1 in graph)
|
||||
error("Parents/Children that are not in the graph!!!")
|
||||
end=#
|
||||
|
||||
nf = NodeFusion((node1, node2, node))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(node1.operations, nf)
|
||||
push!(node2.operations, nf)
|
||||
push!(node.operations, nf)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function find_reductions!(graph::DAG, node::Node)
|
||||
reductionVector = nothing
|
||||
# possible reductions are with nodes that are partners, i.e. parents of children
|
||||
partners_ = partners(node)
|
||||
delete!(partners_, node)
|
||||
for partner in partners_
|
||||
if can_reduce(node, partner)
|
||||
if reductionVector === nothing
|
||||
# only when there's at least one reduction partner, insert the vector
|
||||
reductionVector = Vector{Node}()
|
||||
push!(reductionVector, node)
|
||||
end
|
||||
|
||||
push!(reductionVector, partner)
|
||||
end
|
||||
end
|
||||
|
||||
if reductionVector !== nothing
|
||||
nr = NodeReduction(reductionVector)
|
||||
push!(graph.possibleOperations.nodeReductions, nr)
|
||||
for node in reductionVector
|
||||
push!(node.operations, nr)
|
||||
end
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function find_splits!(graph::DAG, node::Node)
|
||||
if (can_split(node))
|
||||
ns = NodeSplit(node)
|
||||
push!(graph.possibleOperations.nodeSplits, ns)
|
||||
push!(node.operations, ns)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# "clean" the operations on a dirty node
|
||||
function clean_node!(graph::DAG, node::Node)
|
||||
find_fusions!(graph, node)
|
||||
find_reductions!(graph, node)
|
||||
find_splits!(graph, node)
|
||||
end
|
224
src/operations/find.jl
Normal file
224
src/operations/find.jl
Normal file
@ -0,0 +1,224 @@
|
||||
# functions that find operations on the inital graph
|
||||
|
||||
using Base.Threads
|
||||
|
||||
function insert_operation!(operations::PossibleOperations, nf::NodeFusion, locks::Dict{Node, SpinLock})
|
||||
n1 = nf.input[1]; n2 = nf.input[2]; n3 = nf.input[3]
|
||||
|
||||
lock(locks[n1]) do; push!(nf.input[1].operations, nf); end
|
||||
lock(locks[n2]) do; push!(nf.input[2].operations, nf); end
|
||||
lock(locks[n3]) do; push!(nf.input[3].operations, nf); end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function insert_operation!(operations::PossibleOperations, nr::NodeReduction, locks::Dict{Node, SpinLock})
|
||||
# since node parents were sorted before, the NodeReductions contain elements in a known order
|
||||
# this, together with the locking, means that we can safely do the following without inserting duplicates
|
||||
first = true
|
||||
for n in nr.input
|
||||
skip_duplicate = false
|
||||
# careful here, this is a manual lock (because of the break)
|
||||
lock(locks[n])
|
||||
if first
|
||||
first = false
|
||||
for op in n.operations
|
||||
if typeof(op) <: NodeReduction
|
||||
skip_duplicate = true
|
||||
break
|
||||
end
|
||||
end
|
||||
if skip_duplicate
|
||||
unlock(locks[n])
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
push!(n.operations, nr)
|
||||
unlock(locks[n])
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function insert_operation!(operations::PossibleOperations, ns::NodeSplit, locks::Dict{Node, SpinLock})
|
||||
lock(locks[ns.input]) do; push!(ns.input.operations, ns); end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}}, locks::Dict{Node, SpinLock})
|
||||
total_len = 0
|
||||
for vec in nodeReductions
|
||||
total_len += length(vec)
|
||||
end
|
||||
sizehint!(operations.nodeReductions, total_len)
|
||||
|
||||
t = @task for vec in nodeReductions
|
||||
union!(operations.nodeReductions, Set(vec))
|
||||
end
|
||||
schedule(t)
|
||||
|
||||
@threads for vec in nodeReductions
|
||||
for op in vec
|
||||
insert_operation!(operations, op, locks)
|
||||
end
|
||||
end
|
||||
|
||||
wait(t)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function nf_insertion!(operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}}, locks::Dict{Node, SpinLock})
|
||||
total_len = 0
|
||||
for vec in nodeFusions
|
||||
total_len += length(vec)
|
||||
end
|
||||
sizehint!(operations.nodeFusions, total_len)
|
||||
|
||||
t = @task for vec in nodeFusions
|
||||
union!(operations.nodeFusions, Set(vec))
|
||||
end
|
||||
schedule(t)
|
||||
|
||||
@threads for vec in nodeFusions
|
||||
for op in vec
|
||||
insert_operation!(operations, op, locks)
|
||||
end
|
||||
end
|
||||
|
||||
wait(t)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}}, locks::Dict{Node, SpinLock})
|
||||
total_len = 0
|
||||
for vec in nodeSplits
|
||||
total_len += length(vec)
|
||||
end
|
||||
sizehint!(operations.nodeSplits, total_len)
|
||||
|
||||
t = @task for vec in nodeSplits
|
||||
union!(operations.nodeSplits, Set(vec))
|
||||
end
|
||||
schedule(t)
|
||||
|
||||
@threads for vec in nodeSplits
|
||||
for op in vec
|
||||
insert_operation!(operations, op, locks)
|
||||
end
|
||||
end
|
||||
|
||||
wait(t)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# function to generate all possible operations on the graph
|
||||
function generate_options(graph::DAG)
|
||||
locks = Dict{Node, SpinLock}()
|
||||
for n in graph.nodes
|
||||
locks[n] = SpinLock()
|
||||
end
|
||||
|
||||
generatedFusions = [Vector{NodeFusion}() for _ in 1:nthreads()]
|
||||
generatedReductions = [Vector{NodeReduction}() for _ in 1:nthreads()]
|
||||
generatedSplits = [Vector{NodeSplit}() for _ in 1:nthreads()]
|
||||
|
||||
# make sure the graph is fully generated through
|
||||
apply_all!(graph)
|
||||
|
||||
nodeArray = collect(graph.nodes)
|
||||
|
||||
# sort all nodes
|
||||
@threads for node in nodeArray
|
||||
sort_node!(node)
|
||||
end
|
||||
|
||||
checkedNodes = Set{Node}()
|
||||
checkedNodesLock = SpinLock()
|
||||
# --- find possible node reductions ---
|
||||
@threads for node in nodeArray
|
||||
# we're looking for nodes with multiple parents, those parents can then potentially reduce with one another
|
||||
if (length(node.parents) <= 1)
|
||||
continue
|
||||
end
|
||||
|
||||
candidates = node.parents
|
||||
|
||||
# sort into equivalence classes
|
||||
trie = NodeTrie()
|
||||
|
||||
for candidate in candidates
|
||||
# insert into trie
|
||||
insert!(trie, candidate)
|
||||
end
|
||||
|
||||
nodeReductions = collect(trie)
|
||||
|
||||
for nrVec in nodeReductions
|
||||
# parent sets are ordered and any node can only be part of one nodeReduction, so a NodeReduction is uniquely identifiable by its first element
|
||||
# this prevents duplicate nodeReductions being generated
|
||||
lock(checkedNodesLock)
|
||||
if (nrVec[1] in checkedNodes)
|
||||
unlock(checkedNodesLock)
|
||||
continue
|
||||
else
|
||||
push!(checkedNodes, nrVec[1])
|
||||
end
|
||||
unlock(checkedNodesLock)
|
||||
|
||||
push!(generatedReductions[threadid()], NodeReduction(nrVec))
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
# launch thread for node reduction insertion
|
||||
# remove duplicates
|
||||
nr_task = @task nr_insertion!(graph.possibleOperations, generatedReductions, locks)
|
||||
schedule(nr_task)
|
||||
|
||||
# --- find possible node fusions ---
|
||||
@threads for node in nodeArray
|
||||
if (typeof(node) <: DataTaskNode)
|
||||
if length(node.parents) != 1
|
||||
# data node can only have a single parent
|
||||
continue
|
||||
end
|
||||
parent_node = first(node.parents)
|
||||
|
||||
if length(node.children) != 1
|
||||
# this node is an entry node or has multiple children which should not be possible
|
||||
continue
|
||||
end
|
||||
child_node = first(node.children)
|
||||
if (length(child_node.parents) != 1)
|
||||
continue
|
||||
end
|
||||
|
||||
push!(generatedFusions[threadid()], NodeFusion((child_node, node, parent_node)))
|
||||
end
|
||||
end
|
||||
|
||||
# launch thread for node fusion insertion
|
||||
nf_task = @task nf_insertion!(graph.possibleOperations, generatedFusions, locks)
|
||||
schedule(nf_task)
|
||||
|
||||
# find possible node splits
|
||||
@threads for node in nodeArray
|
||||
if (can_split(node))
|
||||
push!(generatedSplits[threadid()], NodeSplit(node))
|
||||
end
|
||||
end
|
||||
|
||||
# launch thread for node split insertion
|
||||
ns_task = @task ns_insertion!(graph.possibleOperations, generatedSplits, locks)
|
||||
schedule(ns_task)
|
||||
|
||||
empty!(graph.dirtyNodes)
|
||||
|
||||
wait(nr_task)
|
||||
wait(nf_task)
|
||||
wait(ns_task)
|
||||
|
||||
return nothing
|
||||
end
|
18
src/operations/get.jl
Normal file
18
src/operations/get.jl
Normal file
@ -0,0 +1,18 @@
|
||||
# function to return the possible operations of a graph
|
||||
|
||||
using Base.Threads
|
||||
|
||||
function get_operations(graph::DAG)
|
||||
apply_all!(graph)
|
||||
|
||||
if isempty(graph.possibleOperations)
|
||||
generate_options(graph)
|
||||
end
|
||||
|
||||
for node in graph.dirtyNodes
|
||||
clean_node!(graph, node)
|
||||
end
|
||||
empty!(graph.dirtyNodes)
|
||||
|
||||
return graph.possibleOperations
|
||||
end
|
109
src/operations/utility.jl
Normal file
109
src/operations/utility.jl
Normal file
@ -0,0 +1,109 @@
|
||||
|
||||
function isempty(operations::PossibleOperations)
|
||||
return isempty(operations.nodeFusions) &&
|
||||
isempty(operations.nodeReductions) &&
|
||||
isempty(operations.nodeSplits)
|
||||
end
|
||||
|
||||
function length(operations::PossibleOperations)
|
||||
return (nodeFusions = length(operations.nodeFusions),
|
||||
nodeReductions = length(operations.nodeReductions),
|
||||
nodeSplits = length(operations.nodeSplits))
|
||||
end
|
||||
|
||||
function delete!(operations::PossibleOperations, op::NodeFusion)
|
||||
delete!(operations.nodeFusions, op)
|
||||
return operations
|
||||
end
|
||||
|
||||
function delete!(operations::PossibleOperations, op::NodeReduction)
|
||||
delete!(operations.nodeReductions, op)
|
||||
return operations
|
||||
end
|
||||
|
||||
function delete!(operations::PossibleOperations, op::NodeSplit)
|
||||
delete!(operations.nodeSplits, op)
|
||||
return operations
|
||||
end
|
||||
|
||||
|
||||
function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
if !is_child(n1, n2) || !is_child(n2, n3)
|
||||
# the checks are redundant but maybe a good sanity check
|
||||
return false
|
||||
end
|
||||
|
||||
if length(n2.parents) != 1 || length(n2.children) != 1 || length(n1.parents) != 1
|
||||
return false
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function can_reduce(n1::Node, n2::Node)
|
||||
if (n1.task != n2.task)
|
||||
return false
|
||||
end
|
||||
|
||||
n1_length = length(n1.children)
|
||||
n2_length = length(n2.children)
|
||||
|
||||
if (n1_length != n2_length)
|
||||
return false
|
||||
end
|
||||
|
||||
# this seems to be the most common case so do this first
|
||||
# doing it manually is a lot faster than using the sets for a general solution
|
||||
if (n1_length == 2)
|
||||
if (n1.children[1] != n2.children[1])
|
||||
if (n1.children[1] != n2.children[2])
|
||||
return false
|
||||
end
|
||||
# 1_1 == 2_2
|
||||
if (n1.children[2] != n2.children[1])
|
||||
return false
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
# 1_1 == 2_1
|
||||
if (n1.children[2] != n2.children[2])
|
||||
return false
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
# this is simple
|
||||
if (n1_length == 1)
|
||||
return n1.children[1] == n2.children[1]
|
||||
end
|
||||
|
||||
# this takes a long time
|
||||
return Set(n1.children) == Set(n2.children)
|
||||
end
|
||||
|
||||
function can_split(n::Node)
|
||||
return length(parents(n)) > 1
|
||||
end
|
||||
|
||||
function ==(op1::Operation, op2::Operation)
|
||||
return false
|
||||
end
|
||||
|
||||
function ==(op1::NodeFusion, op2::NodeFusion)
|
||||
# there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same
|
||||
return op1.input[2] == op2.input[2]
|
||||
end
|
||||
|
||||
function ==(op1::NodeReduction, op2::NodeReduction)
|
||||
# only test the ids against each other
|
||||
return op1.id == op2.id
|
||||
end
|
||||
|
||||
function ==(op1::NodeSplit, op2::NodeSplit)
|
||||
return op1.input == op2.input
|
||||
end
|
||||
|
||||
NodeReduction(input::Vector{Node}) = NodeReduction(input, UUIDs.uuid1(rng[threadid()]))
|
||||
|
||||
copy(id::UUID) = UUID(id.value)
|
65
src/trie.jl
Normal file
65
src/trie.jl
Normal file
@ -0,0 +1,65 @@
|
||||
|
||||
# helper struct for NodeTrie
|
||||
mutable struct NodeIdTrie
|
||||
value::Vector{Node}
|
||||
children::Dict{UUID, NodeIdTrie}
|
||||
end
|
||||
|
||||
# Trie data structure for node reduction, inserts nodes by children
|
||||
# Assumes that given nodes have ordered vectors of children (see sort_node)
|
||||
# First level is the task type and thus does not have a value
|
||||
# Should be constructed with all Types that will be used
|
||||
mutable struct NodeTrie
|
||||
children::Dict{DataType, NodeIdTrie}
|
||||
end
|
||||
|
||||
function NodeTrie()
|
||||
return NodeTrie(Dict{DataType, NodeIdTrie}())
|
||||
end
|
||||
|
||||
function NodeIdTrie()
|
||||
return NodeIdTrie(Vector{Node}(), Dict{UUID, NodeIdTrie}())
|
||||
end
|
||||
|
||||
function insert_helper!(trie::NodeIdTrie, node::Node, depth::Int)
|
||||
if (length(node.children) == depth)
|
||||
push!(trie.value, node)
|
||||
return nothing
|
||||
end
|
||||
|
||||
depth = depth + 1
|
||||
id = node.children[depth].id
|
||||
|
||||
if (!haskey(trie.children, id))
|
||||
trie.children[id] = NodeIdTrie()
|
||||
end
|
||||
insert_helper!(trie.children[id], node, depth)
|
||||
end
|
||||
|
||||
function insert!(trie::NodeTrie, node::Node)
|
||||
t = typeof(node.task)
|
||||
if (!haskey(trie.children, t))
|
||||
trie.children[t] = NodeIdTrie()
|
||||
end
|
||||
insert_helper!(trie.children[typeof(node.task)], node, 0)
|
||||
end
|
||||
|
||||
function collect_helper(trie::NodeIdTrie, acc::Set{Vector{Node}})
|
||||
if (length(trie.value) >= 2)
|
||||
push!(acc, trie.value)
|
||||
end
|
||||
|
||||
for (id,child) in trie.children
|
||||
collect_helper(child, acc)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
# returns all sets of multiple nodes that have accumulated in leaves
|
||||
function collect(trie::NodeTrie)
|
||||
acc = Set{Vector{Node}}()
|
||||
for (t,child) in trie.children
|
||||
collect_helper(child, acc)
|
||||
end
|
||||
return acc
|
||||
end
|
@ -7,3 +7,12 @@ function bytes_to_human_readable(bytes::Int64)
|
||||
end
|
||||
return string(round(bytes, sigdigits=4), " ", units[unit_index])
|
||||
end
|
||||
|
||||
function lt_nodes(n1::Node, n2::Node)
|
||||
return n1.id < n2.id
|
||||
end
|
||||
|
||||
function sort_node!(node::Node)
|
||||
sort!(node.children, lt=lt_nodes)
|
||||
sort!(node.parents, lt=lt_nodes)
|
||||
end
|
||||
|
@ -127,8 +127,8 @@ import MetagraphOptimization.partners
|
||||
|
||||
@test MetagraphOptimization.get_exit_node(graph) == d_exit
|
||||
|
||||
@test length(partners(s0)) == 0
|
||||
@test length(siblings(s0)) == 0
|
||||
@test length(partners(s0)) == 1
|
||||
@test length(siblings(s0)) == 1
|
||||
|
||||
operations = get_operations(graph)
|
||||
@test length(operations) == (nodeFusions = 10, nodeReductions = 0, nodeSplits = 0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user