diff --git a/.gitea/workflows/julia-package-ci.yml b/.gitea/workflows/julia-package-ci.yml index 203d0cf..0f6d916 100644 --- a/.gitea/workflows/julia-package-ci.yml +++ b/.gitea/workflows/julia-package-ci.yml @@ -24,10 +24,10 @@ jobs: version: '1.9.1' - name: Install dependencies - run: julia --project -e 'import Pkg; Pkg.instantiate()' + run: julia --project=./ -e 'import Pkg; Pkg.instantiate()' - name: Run tests - run: julia --project -t 4 -e 'import Pkg; Pkg.test()' + run: julia --project=./ -t 4 -e 'import Pkg; Pkg.test()' -O0 - name: Run examples - run: julia --project=examples/ -t 4 -e 'import Pkg; Pkg.develop(Pkg.PackageSpec(path=pwd())); Pkg.instantiate(); include("examples/import_bench.jl")' + run: julia --project=examples/ -t 4 -e 'import Pkg; Pkg.develop(Pkg.PackageSpec(path=pwd())); Pkg.instantiate(); include("examples/import_bench.jl")' -O3 diff --git a/README.md b/README.md index 364bf68..1cb9a8b 100644 --- a/README.md +++ b/README.md @@ -8,13 +8,13 @@ For all the julia calls, use `-t n` to give julia `n` threads. Instantiate the project first: -`julia --project -e 'import Pkg; Pkg.instantiate()'` +`julia --project=./ -e 'import Pkg; Pkg.instantiate()'` ### Run Tests To run all tests, run -`julia --project=. -e 'import Pkg; Pkg.test()'` +`julia --project=./ -e 'import Pkg; Pkg.test()' -O0` ### Run Examples @@ -24,7 +24,7 @@ Get the correct environment for the examples folder: Then execute a specific example: -`julia --project=examples examples/.jl` +`julia --project=examples examples/.jl -O3` ## Concepts diff --git a/examples/Project.toml b/examples/Project.toml index 295195a..68e6047 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -1,6 +1,7 @@ [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" MetagraphOptimization = "3e869610-d48d-4942-ba70-c1b702a33ca4" +PProf = "e4faabce-9ead-11e9-39d9-4379958e3056" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" ProfileView = "c46f51b8-102a-5cf2-8d2c-8597cb0e0da7" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" diff --git a/results/temp.md b/results/temp.md index 06c38f0..42bc3ce 100644 --- a/results/temp.md +++ b/results/temp.md @@ -15,3 +15,16 @@ (AB->ABBBBBBB, 6) 887.160 ms (5596691 allocations: 763.42 MiB) (AB->ABBBBBBB, 7) 898.757 ms (5596762 allocations: 789.91 MiB) (AB->ABBBBBBB, 8) 497.545 ms (5596820 allocations: 759.66 MiB) + + +Initial: + +$ julia --project=examples/ -e 'using BenchmarkTools; using MetagraphOptimization; parse_abc("input/AB->AB.txt"); @time g = parse_abc("input/AB->ABBBBBBBBB.txt")' + 65.370947 seconds (626.10 M allocations: 37.381 GiB, 53.59% gc time, 0.01% compilation time) + +Removing make_edge from calls in parse: + 50.053920 seconds (593.41 M allocations: 32.921 GiB, 49.70% gc time, 0.09% compilation time) + +Nodes operation storage rework (and O3): + 31.997128 seconds (450.66 M allocations: 25.294 GiB, 31.56% gc time, 0.14% compilation time) + \ No newline at end of file diff --git a/scripts/bench_threads.fish b/scripts/bench_threads.fish index 13c5582..28df8c3 100755 --- a/scripts/bench_threads.fish +++ b/scripts/bench_threads.fish @@ -6,20 +6,20 @@ julia --project=./examples -t 4 -e 'import Pkg; Pkg.instantiate()' #for i in $(seq $minthreads $maxthreads) # printf "(AB->AB, $i) " -# julia --project=./examples -t $i -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("examples/AB->AB.txt"))' +# julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("input/AB->AB.txt"))' #end #for i in $(seq $minthreads $maxthreads) # printf "(AB->ABBB, $i) " -# julia --project=./examples -t $i -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("examples/AB->ABBB.txt"))' +# julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("input/AB->ABBB.txt"))' #end #for i in $(seq $minthreads $maxthreads) # printf "(AB->ABBBBB, $i) " -# julia --project=./examples -t $i -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("examples/AB->ABBBBB.txt"))' +# julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("input/AB->ABBBBB.txt"))' #end for i in $(seq $minthreads $maxthreads) printf "(AB->ABBBBBBB, $i) " - julia --project=./examples -t $i -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("examples/AB->ABBBBBBB.txt"))' + julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("input/AB->ABBBBBBB.txt"))' end diff --git a/src/MetagraphOptimization.jl b/src/MetagraphOptimization.jl index b51f43c..5909227 100644 --- a/src/MetagraphOptimization.jl +++ b/src/MetagraphOptimization.jl @@ -38,6 +38,7 @@ include("operations/clean.jl") include("operations/find.jl") include("operations/get.jl") include("operations/print.jl") +include("operations/validate.jl") include("graph_interface.jl") diff --git a/src/abc_model/parse.jl b/src/abc_model/parse.jl index 9334c9b..1d50df9 100644 --- a/src/abc_model/parse.jl +++ b/src/abc_model/parse.jl @@ -42,7 +42,7 @@ function parse_abc(filename::String, verbose::Bool = false) sum_node = insert_node!(graph, make_node(ComputeTaskSum()), false, false) global_data_out = insert_node!(graph, make_node(DataTask(10)), false, false) - insert_edge!(graph, make_edge(sum_node, global_data_out), false, false) + insert_edge!(graph, sum_node, global_data_out, false, false) # remember the data out nodes for connection dataOutNodes = Dict() @@ -64,10 +64,10 @@ function parse_abc(filename::String, verbose::Bool = false) compute_u = insert_node!(graph, make_node(ComputeTaskU()), false, false) # compute U node data_out = insert_node!(graph, make_node(DataTask(3)), false, false) # transfer data out from u - insert_edge!(graph, make_edge(data_in, compute_P), false, false) - insert_edge!(graph, make_edge(compute_P, data_Pu), false, false) - insert_edge!(graph, make_edge(data_Pu, compute_u), false, false) - insert_edge!(graph, make_edge(compute_u, data_out), false, false) + insert_edge!(graph, data_in, compute_P, false, false) + insert_edge!(graph, compute_P, data_Pu, false, false) + insert_edge!(graph, data_Pu, compute_u, false, false) + insert_edge!(graph, compute_u, data_out, false, false) # remember the data_out node for future edges dataOutNodes[node] = data_out @@ -80,34 +80,34 @@ function parse_abc(filename::String, verbose::Bool = false) compute_v = insert_node!(graph, make_node(ComputeTaskV()), false, false) data_out = insert_node!(graph, make_node(DataTask(5)), false, false) - if (occursin(regex_c, capt.captures[1])) + if (occursin(regex_c, in1)) # put an S node after this input compute_S = insert_node!(graph, make_node(ComputeTaskS1()), false, false) data_S_v = insert_node!(graph, make_node(DataTask(5)), false, false) - insert_edge!(graph, make_edge(dataOutNodes[capt.captures[1]], compute_S), false, false) - insert_edge!(graph, make_edge(compute_S, data_S_v), false, false) + insert_edge!(graph, dataOutNodes[in1], compute_S, false, false) + insert_edge!(graph, compute_S, data_S_v, false, false) - insert_edge!(graph, make_edge(data_S_v, compute_v), false, false) + insert_edge!(graph, data_S_v, compute_v, false, false) else - insert_edge!(graph, make_edge(dataOutNodes[capt.captures[1]], compute_v), false, false) + insert_edge!(graph, dataOutNodes[in1], compute_v, false, false) end - if (occursin(regex_c, capt.captures[2])) + if (occursin(regex_c, in2)) # i think the current generator only puts the combined particles in the first space, so this case might never be entered # put an S node after this input compute_S = insert_node!(graph, make_node(ComputeTaskS1()), false, false) data_S_v = insert_node!(graph, make_node(DataTask(5)), false, false) - insert_edge!(graph, make_edge(dataOutNodes[capt.captures[2]], compute_S), false, false) - insert_edge!(graph, make_edge(compute_S, data_S_v), false, false) + insert_edge!(graph, dataOutNodes[in2], compute_S, false, false) + insert_edge!(graph, compute_S, data_S_v, false, false) - insert_edge!(graph, make_edge(data_S_v, compute_v), false, false) + insert_edge!(graph, data_S_v, compute_v, false, false) else - insert_edge!(graph, make_edge(dataOutNodes[capt.captures[2]], compute_v), false, false) + insert_edge!(graph, dataOutNodes[in2], compute_v, false, false) end - insert_edge!(graph, make_edge(compute_v, data_out), false, false) + insert_edge!(graph, compute_v, data_out, false, false) dataOutNodes[node] = data_out elseif occursin(regex_m, node) @@ -121,26 +121,26 @@ function parse_abc(filename::String, verbose::Bool = false) compute_v = insert_node!(graph, make_node(ComputeTaskV()), false, false) data_v = insert_node!(graph, make_node(DataTask(5)), false, false) - insert_edge!(graph, make_edge(dataOutNodes[in2], compute_v), false, false) - insert_edge!(graph, make_edge(dataOutNodes[in3], compute_v), false, false) - insert_edge!(graph, make_edge(compute_v, data_v), false, false) + insert_edge!(graph, dataOutNodes[in2], compute_v, false, false) + insert_edge!(graph, dataOutNodes[in3], compute_v, false, false) + insert_edge!(graph, compute_v, data_v, false, false) # combine with the v of the combined other input compute_S2 = insert_node!(graph, make_node(ComputeTaskS2()), false, false) data_out = insert_node!(graph, make_node(DataTask(10)), false, false) - insert_edge!(graph, make_edge(data_v, compute_S2), false, false) - insert_edge!(graph, make_edge(dataOutNodes[in1], compute_S2), false, false) - insert_edge!(graph, make_edge(compute_S2, data_out), false, false) + insert_edge!(graph, data_v, compute_S2, false, false) + insert_edge!(graph, dataOutNodes[in1], compute_S2, false, false) + insert_edge!(graph, compute_S2, data_out, false, false) - insert_edge!(graph, make_edge(data_out, sum_node), false, false) + insert_edge!(graph, data_out, sum_node, false, false) elseif occursin(regex_plus, node) if (verbose) println("\rReading Nodes Complete ") println("Added ", length(graph.nodes), " nodes") end else - error("Unknown node '", node, "' while reading from file ", filename) + @assert false ("Unknown node '$node' while reading from file $filename") end end diff --git a/src/graph_functions.jl b/src/graph_functions.jl index 5b26c8a..58aac5b 100644 --- a/src/graph_functions.jl +++ b/src/graph_functions.jl @@ -67,27 +67,66 @@ end is_entry_node(node::Node) = length(node.children) == 0 is_exit_node(node::Node) = length(node.parents) == 0 -# function to invalidate the operation caches for a given operation -function invalidate_caches!(graph::DAG, operation::Operation) +# function to invalidate the operation caches for a given NodeFusion +function invalidate_caches!(graph::DAG, operation::NodeFusion) delete!(graph.possibleOperations, operation) # delete the operation from all caches of nodes involved in the operation - # (we can iterate over tuples and vectors just fine) + filter!(!=(operation), operation.input[1].nodeFusions) + filter!(!=(operation), operation.input[3].nodeFusions) + + operation.input[2].nodeFusion = missing + + return nothing +end + +# function to invalidate the operation caches for a given NodeReduction +function invalidate_caches!(graph::DAG, operation::NodeReduction) + delete!(graph.possibleOperations, operation) + for node in operation.input - filter!(!=(operation), node.operations) + node.nodeReduction = missing end return nothing end -# function to invalidate the operation caches for a given Node Split specifically +# function to invalidate the operation caches for a given NodeSplit function invalidate_caches!(graph::DAG, operation::NodeSplit) delete!(graph.possibleOperations, operation) # delete the operation from all caches of nodes involved in the operation # for node split there is only one node - filter!(x -> x != operation, operation.input.operations) + operation.input.nodeSplit = missing + + return nothing +end +# function to invalidate the operation caches of a ComputeTaskNode +function invalidate_operation_caches!(graph::DAG, node::ComputeTaskNode) + if !ismissing(node.nodeReduction) + invalidate_caches!(graph, node.nodeReduction) + end + if !ismissing(node.nodeSplit) + invalidate_caches!(graph, node.nodeSplit) + end + while !isempty(node.nodeFusions) + invalidate_caches!(graph, pop!(node.nodeFusions)) + end + return nothing +end + +# function to invalidate the operation caches of a DataTaskNode +function invalidate_operation_caches!(graph::DAG, node::DataTaskNode) + if !ismissing(node.nodeReduction) + invalidate_caches!(graph, node.nodeReduction) + end + if !ismissing(node.nodeSplit) + invalidate_caches!(graph, node.nodeSplit) + end + if !ismissing(node.nodeFusion) + invalidate_caches!(graph, node.nodeFusion) + end return nothing end @@ -110,93 +149,72 @@ function insert_node!(graph::DAG, node::Node, track=true, invalidate_cache=true) return node end -function insert_edge!(graph::DAG, edge::Edge, track=true, invalidate_cache=true) - node1 = edge.edge[1] - node2 = edge.edge[2] +function insert_edge!(graph::DAG, node1::Node, node2::Node, track=true, invalidate_cache=true) + # @assert (node2 ∉ node1.parents) && (node1 ∉ node2.children) "Edge to insert already exists" # 1: mute - #=if (node2 in node1.parents) || (node1 in node2.children) - if !(node2 in node1.parents && node1 in node2.children) - error("One-sided edge") - end - error("Edge to insert already exists") - end=# - # edge points from child to parent push!(node1.parents, node2) push!(node2.children, node1) # 2: keep track - if (track) push!(graph.diff.addedEdges, edge) end + if (track) push!(graph.diff.addedEdges, make_edge(node1, node2)) end # 3: invalidate caches - if (!invalidate_cache) return edge end + if (!invalidate_cache) return nothing end + + invalidate_operation_caches!(graph, node1) + invalidate_operation_caches!(graph, node2) - while !isempty(node1.operations) - invalidate_caches!(graph, first(node1.operations)) - end - while !isempty(node2.operations) - invalidate_caches!(graph, first(node2.operations)) - end push!(graph.dirtyNodes, node1) push!(graph.dirtyNodes, node2) - return edge + return nothing end function remove_node!(graph::DAG, node::Node, track=true, invalidate_cache=true) + # @assert node in graph.nodes "Trying to remove a node that's not in the graph" + # 1: mute - #=if !(node in graph.nodes) - error("Trying to remove a node that's not in the graph") - end=# delete!(graph.nodes, node) # 2: keep track if (track) push!(graph.diff.removedNodes, node) end # 3: invalidate caches - if (!invalidate_cache) return node end + if (!invalidate_cache) return nothing end - while !isempty(node.operations) - invalidate_caches!(graph, first(node.operations)) - end + invalidate_operation_caches!(graph, node) delete!(graph.dirtyNodes, node) return nothing end -function remove_edge!(graph::DAG, edge::Edge, track=true, invalidate_cache=true) - node1 = edge.edge[1] - node2 = edge.edge[2] - +function remove_edge!(graph::DAG, node1::Node, node2::Node, track=true, invalidate_cache=true) # 1: mute pre_length1 = length(node1.parents) pre_length2 = length(node2.children) filter!(x -> x != node2, node1.parents) filter!(x -> x != node1, node2.children) - #=removed = pre_length1 - length(node1.parents) - if (removed > 1) - error("removed $removed from node1's parents") - end + #=@assert begin + removed = pre_length1 - length(node1.parents) + removed <= 1 + end "removed more than one node from node1's parents"=# - removed = pre_length2 - length(node2.children) - if (removed > 1) - error("removed $removed from node2's children") - end=# + #=@assert begin + removed = pre_length2 - length(node2.children) + removed <= 1 + end "removed more than one node from node2's children"=# # 2: keep track - if (track) push!(graph.diff.removedEdges, edge) end + if (track) push!(graph.diff.removedEdges, make_edge(node1, node2)) end # 3: invalidate caches if (!invalidate_cache) return nothing end - while !isempty(node1.operations) - invalidate_caches!(graph, first(node1.operations)) - end - while !isempty(node2.operations) - invalidate_caches!(graph, first(node2.operations)) - end + invalidate_operation_caches!(graph, node1) + invalidate_operation_caches!(graph, node2) if (node1 in graph) push!(graph.dirtyNodes, node1) end @@ -241,7 +259,7 @@ function get_exit_node(graph::DAG) return node end end - error("The given graph has no exit node! It is either empty or not acyclic!") + @assert false "The given graph has no exit node! It is either empty or not acyclic!" end # check whether the given graph is connected diff --git a/src/node_functions.jl b/src/node_functions.jl index e005a54..1d53f54 100644 --- a/src/node_functions.jl +++ b/src/node_functions.jl @@ -46,5 +46,6 @@ function ==(n1::DataTaskNode, n2::DataTaskNode) return n1.id == n2.id end -copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng[threadid()]), copy(n.operations)) -copy(n::DataTaskNode) = DataTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng[threadid()]), copy(n.operations)) +copy(m::Missing) = missing +copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng[threadid()]), copy(n.nodeReduction), copy(n.nodeSplit), copy(n.nodeFusions)) +copy(n::DataTaskNode) = DataTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng[threadid()]), copy(n.nodeReduction), copy(n.nodeSplit), copy(n.nodeFusion)) diff --git a/src/nodes.jl b/src/nodes.jl index 6e92f26..c1e6b80 100644 --- a/src/nodes.jl +++ b/src/nodes.jl @@ -2,6 +2,7 @@ using Random using UUIDs using Base.Threads +# TODO: reliably find out how many threads we're running with (nthreads() returns 1 when precompiling :/) rng = [Random.MersenneTwister(0) for _ in 1:32] abstract type Node end @@ -10,7 +11,7 @@ abstract type Node end # the specific operations are declared in graph.jl abstract type Operation end -struct DataTaskNode <: Node +mutable struct DataTaskNode <: Node task::AbstractDataTask # use vectors as sets have way too much memory overhead @@ -21,21 +22,33 @@ struct DataTaskNode <: Node # however, it can be copied when splitting a node id::Base.UUID - # a vector holding references to the graph operations involving this node - operations::Vector{Operation} + # the NodeReduction involving this node, if it exists + # Can't use the NodeReduction type here because it's not yet defined + nodeReduction::Union{Operation, Missing} + + # the NodeSplit involving this node, if it exists + nodeSplit::Union{Operation, Missing} + + # the node fusion involving this node, if it exists + nodeFusion::Union{Operation, Missing} end # same as DataTaskNode -struct ComputeTaskNode <: Node +mutable struct ComputeTaskNode <: Node task::AbstractComputeTask parents::Vector{Node} children::Vector{Node} id::Base.UUID - operations::Vector{Operation} + + nodeReduction::Union{Operation, Missing} + nodeSplit::Union{Operation, Missing} + + # for ComputeTasks there can be multiple fusions, unlike the DataTasks + nodeFusions::Vector{Operation} end -DataTaskNode(t::AbstractDataTask) = DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), Vector{Operation}()) -ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), Vector{Operation}()) +DataTaskNode(t::AbstractDataTask) = DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, missing) +ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, Vector{NodeFusion}()) struct Edge # edge points from child to parent diff --git a/src/operations/apply.jl b/src/operations/apply.jl index 3e1e61c..9be20ad 100644 --- a/src/operations/apply.jl +++ b/src/operations/apply.jl @@ -2,237 +2,197 @@ # applies all unapplied operations in the DAG function apply_all!(graph::DAG) - while !isempty(graph.operationsToApply) - # get next operation to apply from front of the deque - op = popfirst!(graph.operationsToApply) + while !isempty(graph.operationsToApply) + # get next operation to apply from front of the deque + op = popfirst!(graph.operationsToApply) - # apply it - appliedOp = apply_operation!(graph, op) + # apply it + appliedOp = apply_operation!(graph, op) - # push to the end of the appliedOperations deque - push!(graph.appliedOperations, appliedOp) - end - return nothing + # push to the end of the appliedOperations deque + push!(graph.appliedOperations, appliedOp) + end + return nothing end function apply_operation!(graph::DAG, operation::Operation) - error("Unknown operation type!") + error("Unknown operation type!") end function apply_operation!(graph::DAG, operation::NodeFusion) - diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3]) - return AppliedNodeFusion(operation, diff) + diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3]) + return AppliedNodeFusion(operation, diff) end function apply_operation!(graph::DAG, operation::NodeReduction) - diff = node_reduction!(graph, operation.input) - return AppliedNodeReduction(operation, diff) + diff = node_reduction!(graph, operation.input) + return AppliedNodeReduction(operation, diff) end function apply_operation!(graph::DAG, operation::NodeSplit) - diff = node_split!(graph, operation.input) - return AppliedNodeSplit(operation, diff) + diff = node_split!(graph, operation.input) + return AppliedNodeSplit(operation, diff) end function revert_operation!(graph::DAG, operation::AppliedOperation) - error("Unknown operation type!") + error("Unknown operation type!") end function revert_operation!(graph::DAG, operation::AppliedNodeFusion) - revert_diff!(graph, operation.diff) - return operation.operation + revert_diff!(graph, operation.diff) + return operation.operation end function revert_operation!(graph::DAG, operation::AppliedNodeReduction) - revert_diff!(graph, operation.diff) - return operation.operation + revert_diff!(graph, operation.diff) + return operation.operation end function revert_operation!(graph::DAG, operation::AppliedNodeSplit) - revert_diff!(graph, operation.diff) - return operation.operation + revert_diff!(graph, operation.diff) + return operation.operation end -function revert_diff!(graph::DAG, diff) - # add removed nodes, remove added nodes, same for edges - # note the order - for edge in diff.addedEdges - remove_edge!(graph, edge, false) - end - for node in diff.addedNodes - remove_node!(graph, node, false) - end +function revert_diff!(graph::DAG, diff::Diff) + # add removed nodes, remove added nodes, same for edges + # note the order + for edge in diff.addedEdges + remove_edge!(graph, edge.edge[1], edge.edge[2], false) + end + for node in diff.addedNodes + remove_node!(graph, node, false) + end - for node in diff.removedNodes - insert_node!(graph, node, false) - end - for edge in diff.removedEdges - insert_edge!(graph, edge, false) - end + for node in diff.removedNodes + insert_node!(graph, node, false) + end + for edge in diff.removedEdges + insert_edge!(graph, edge.edge[1], edge.edge[2], false) + end end # Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode) - # clear snapshot - get_snapshot_diff(graph) + # @assert is_valid_node_fusion_input(graph, n1, n2, n3) - if !(n1 in graph) || !(n2 in graph) || !(n3 in graph) - error("[Node Fusion] The given nodes are not part of the given graph") - end + # clear snapshot + get_snapshot_diff(graph) - if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1) - # the checks are redundant but maybe a good sanity check - error("[Node Fusion] The given nodes are not connected by edges which is required for node fusion") - end - # save children and parents - n1_children = children(n1) - n3_parents = parents(n3) - n3_children = children(n3) + # save children and parents + n1_children = children(n1) + n3_parents = parents(n3) + n3_children = children(n3) - if length(n2.parents) > 1 - error("[Node Fusion] The given data node has more than one parent") - end - if length(n2.children) > 1 - error("[Node Fusion] The given data node has more than one child") - end - if length(n1.parents) > 1 - error("[Node Fusion] The given n1 has more than one parent") - end + # remove the edges and nodes that will be replaced by the fused node + remove_edge!(graph, n1, n2) + remove_edge!(graph, n2, n3) + remove_node!(graph, n1) + remove_node!(graph, n2) - required_edge1 = make_edge(n1, n2) - required_edge2 = make_edge(n2, n3) + # get n3's children now so it automatically excludes n2 + n3_children = children(n3) + remove_node!(graph, n3) - # remove the edges and nodes that will be replaced by the fused node - remove_edge!(graph, required_edge1) - remove_edge!(graph, required_edge2) - remove_node!(graph, n1) - remove_node!(graph, n2) + # create new node with the fused compute task + new_node = ComputeTaskNode(FusedComputeTask{typeof(n1.task),typeof(n3.task)}()) + insert_node!(graph, new_node) - # get n3's children now so it automatically excludes n2 - n3_children = children(n3) - remove_node!(graph, n3) + # use a set for combined children of n1 and n3 to not get duplicates + n1and3_children = Set{Node}() - # create new node with the fused compute task - new_node = ComputeTaskNode(FusedComputeTask{typeof(n1.task),typeof(n3.task)}()) - insert_node!(graph, new_node) + # remove edges from n1 children to n1 + for child in n1_children + remove_edge!(graph, child, n1) + push!(n1and3_children, child) + end - # use a set for combined children of n1 and n3 to not get duplicates - n1and3_children = Set{Node}() + # remove edges from n3 children to n3 + for child in n3_children + remove_edge!(graph, child, n3) + push!(n1and3_children, child) + end - # remove edges from n1 children to n1 - for child in n1_children - remove_edge!(graph, make_edge(child, n1)) - push!(n1and3_children, child) - end + for child in n1and3_children + insert_edge!(graph, child, new_node) + end - # remove edges from n3 children to n3 - for child in n3_children - remove_edge!(graph, make_edge(child, n3)) - push!(n1and3_children, child) - end + # "repoint" parents of n3 from new node + for parent in n3_parents + remove_edge!(graph, n3, parent) + insert_edge!(graph, new_node, parent) + end - for child in n1and3_children - insert_edge!(graph, make_edge(child, new_node)) - end - - # "repoint" parents of n3 from new node - for parent in n3_parents - remove_edge!(graph, make_edge(n3, parent)) - insert_edge!(graph, make_edge(new_node, parent)) - end - - return get_snapshot_diff(graph) + return get_snapshot_diff(graph) end function node_reduction!(graph::DAG, nodes::Vector{Node}) - # clear snapshot - get_snapshot_diff(graph) + # @assert is_valid_node_reduction_input(graph, nodes) - t = typeof(nodes[1].task) - for n in nodes - if n ∉ graph - error("[Node Reduction] The given nodes are not part of the given graph") - end + # clear snapshot + get_snapshot_diff(graph) - if typeof(n.task) != t - error("[Node Reduction] The given nodes are not of the same type") - end - end + n1 = nodes[1] + n1_children = children(n1) + + n1_parents = Set(n1.parents) + new_parents = Set{Node}() + # remove all of the nodes' parents and children and the nodes themselves (except for first node) + for i in 2:length(nodes) + n = nodes[i] + for child in n1_children + remove_edge!(graph, child, n) + end - n1 = nodes[1] - n1_children = children(n1) - for n in nodes - if Set(n1_children) != Set(n.children) - error("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction") - end - end + for parent in parents(n) + remove_edge!(graph, n, parent) - n1_parents = Set(n1.parents) - new_parents = Set{Node}() + # collect all parents + push!(new_parents, parent) + end - # remove all of the nodes' parents and children and the nodes themselves (except for first node) - for i in 2:length(nodes) - n = nodes[i] - for child in n1_children - remove_edge!(graph, make_edge(child, n)) - end + remove_node!(graph, n) + end - for parent in parents(n) - remove_edge!(graph, make_edge(n, parent)) + setdiff!(new_parents, n1_parents) - # collect all parents - push!(new_parents, parent) - end + for parent in new_parents + # now add parents of all input nodes to n1 without duplicates + insert_edge!(graph, n1, parent) + end - remove_node!(graph, n) - end - - setdiff!(new_parents, n1_parents) - - for parent in new_parents - # now add parents of n2 to n1 without duplicates - insert_edge!(graph, make_edge(n1, parent)) - end - - return get_snapshot_diff(graph) + return get_snapshot_diff(graph) end function node_split!(graph::DAG, n1::Node) - # clear snapshot - get_snapshot_diff(graph) + # @assert is_valid_node_split_input(graph, n1) - #=if !(n1 in graph) - error("[Node Split] The given node is not part of the given graph") - end=# + # clear snapshot + get_snapshot_diff(graph) - n1_parents = parents(n1) - n1_children = children(n1) + n1_parents = parents(n1) + n1_children = children(n1) - #=if length(n1_parents) <= 1 - error("[Node Split] The given node does not have multiple parents which is required for node split") - end=# + for parent in n1_parents + remove_edge!(graph, n1, parent) + end + for child in n1_children + remove_edge!(graph, child, n1) + end + remove_node!(graph, n1) - for parent in n1_parents - remove_edge!(graph, make_edge(n1, parent)) - end - for child in n1_children - remove_edge!(graph, make_edge(child, n1)) - end - remove_node!(graph, n1) + for parent in n1_parents + n_copy = copy(n1) + insert_node!(graph, n_copy) + insert_edge!(graph, n_copy, parent) - for parent in n1_parents - n_copy = copy(n1) - insert_node!(graph, n_copy) - insert_edge!(graph, make_edge(n_copy, parent)) + for child in n1_children + insert_edge!(graph, child, n_copy) + end + end - for child in n1_children - insert_edge!(graph, make_edge(child, n_copy)) - end - end - - return get_snapshot_diff(graph) + return get_snapshot_diff(graph) end diff --git a/src/operations/clean.jl b/src/operations/clean.jl index 56fb249..902012d 100644 --- a/src/operations/clean.jl +++ b/src/operations/clean.jl @@ -4,10 +4,8 @@ # pushes the found fusion everywhere it needs to be and returns nothing function find_fusions!(graph::DAG, node::DataTaskNode) # if there is already a fusion here, skip - for op in node.operations - if typeof(op) <: NodeFusion - return nothing - end + if !ismissing(node.nodeFusion) + return nothing end if length(node.parents) != 1 || length(node.children) != 1 @@ -17,9 +15,9 @@ function find_fusions!(graph::DAG, node::DataTaskNode) child_node = first(node.children) parent_node = first(node.parents) - #=if !(child_node in graph) || !(parent_node in graph) + if !(child_node in graph) || !(parent_node in graph) error("Parents/Children that are not in the graph!!!") - end=# + end if length(child_node.parents) != 1 return nothing @@ -27,9 +25,9 @@ function find_fusions!(graph::DAG, node::DataTaskNode) nf = NodeFusion((child_node, node, parent_node)) push!(graph.possibleOperations.nodeFusions, nf) - push!(child_node.operations, nf) - push!(node.operations, nf) - push!(parent_node.operations, nf) + push!(child_node.nodeFusions, nf) + node.nodeFusion = nf + push!(parent_node.nodeFusions, nf) return nothing end @@ -37,7 +35,6 @@ end function find_fusions!(graph::DAG, node::ComputeTaskNode) # just find fusions in neighbouring DataTaskNodes - for child in node.children find_fusions!(graph, child) end @@ -51,10 +48,8 @@ end function find_reductions!(graph::DAG, node::Node) # there can only be one reduction per node, avoid adding duplicates - for op in node.operations - if typeof(op) <: NodeReduction - return nothing - end + if !ismissing(node.nodeReduction) + return nothing end reductionVector = nothing @@ -62,7 +57,14 @@ function find_reductions!(graph::DAG, node::Node) partners_ = partners(node) delete!(partners_, node) for partner in partners_ + if partner ∉ graph.nodes + error("Partner is not part of the graph") + end + if can_reduce(node, partner) + if Set(node.children) != Set(partner.children) + error("Not equal children") + end if reductionVector === nothing # only when there's at least one reduction partner, insert the vector reductionVector = Vector{Node}() @@ -77,7 +79,12 @@ function find_reductions!(graph::DAG, node::Node) nr = NodeReduction(reductionVector) push!(graph.possibleOperations.nodeReductions, nr) for node in reductionVector - push!(node.operations, nr) + if !ismissing(node.nodeReduction) + # it can happen that the dirty node becomes part of an existing NodeReduction and overrides those ones now + # this is only a problem insofar the existing NodeReduction has to be deleted and replaced also in the possibleOperations + invalidate_caches!(graph, node.nodeReduction) + end + node.nodeReduction = nr end end @@ -85,10 +92,14 @@ function find_reductions!(graph::DAG, node::Node) end function find_splits!(graph::DAG, node::Node) + if !ismissing(node.nodeSplit) + return nothing + end + if (can_split(node)) ns = NodeSplit(node) push!(graph.possibleOperations.nodeSplits, ns) - push!(node.operations, ns) + node.nodeSplit = ns end return nothing diff --git a/src/operations/find.jl b/src/operations/find.jl index 948d73c..1fa0624 100644 --- a/src/operations/find.jl +++ b/src/operations/find.jl @@ -2,49 +2,28 @@ using Base.Threads -function insert_operation!(operations::PossibleOperations, nf::NodeFusion, locks::Dict{Node, SpinLock}) +function insert_operation!(nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock}) n1 = nf.input[1]; n2 = nf.input[2]; n3 = nf.input[3] - lock(locks[n1]) do; push!(nf.input[1].operations, nf); end - lock(locks[n2]) do; push!(nf.input[2].operations, nf); end - lock(locks[n3]) do; push!(nf.input[3].operations, nf); end + lock(locks[n1]) do; push!(nf.input[1].nodeFusions, nf); end + nf.input[2].nodeFusion = nf + lock(locks[n3]) do; push!(nf.input[3].nodeFusions, nf); end return nothing end -function insert_operation!(operations::PossibleOperations, nr::NodeReduction, locks::Dict{Node, SpinLock}) - # since node parents were sorted before, the NodeReductions contain elements in a known order - # this, together with the locking, means that we can safely do the following without inserting duplicates - first = true +function insert_operation!(nr::NodeReduction) for n in nr.input - skip_duplicate = false - # careful here, this is a manual lock (because of the break) - lock(locks[n]) - if first - first = false - for op in n.operations - if typeof(op) <: NodeReduction - skip_duplicate = true - break - end - end - if skip_duplicate - unlock(locks[n]) - break - end - end - - push!(n.operations, nr) - unlock(locks[n]) + n.nodeReduction = nr end return nothing end -function insert_operation!(operations::PossibleOperations, ns::NodeSplit, locks::Dict{Node, SpinLock}) - lock(locks[ns.input]) do; push!(ns.input.operations, ns); end +function insert_operation!(ns::NodeSplit) + ns.input.nodeSplit = ns return nothing end -function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}}, locks::Dict{Node, SpinLock}) +function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}}) total_len = 0 for vec in nodeReductions total_len += length(vec) @@ -58,7 +37,7 @@ function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Ve @threads for vec in nodeReductions for op in vec - insert_operation!(operations, op, locks) + insert_operation!(op) end end @@ -67,7 +46,7 @@ function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Ve return nothing end -function nf_insertion!(operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}}, locks::Dict{Node, SpinLock}) +function nf_insertion!(graph::DAG, operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}}) total_len = 0 for vec in nodeFusions total_len += length(vec) @@ -79,9 +58,16 @@ function nf_insertion!(operations::PossibleOperations, nodeFusions::Vector{Vecto end schedule(t) + locks = Dict{ComputeTaskNode, SpinLock}() + for n in graph.nodes + if (typeof(n) <: ComputeTaskNode) + locks[n] = SpinLock() + end + end + @threads for vec in nodeFusions for op in vec - insert_operation!(operations, op, locks) + insert_operation!(op, locks) end end @@ -90,7 +76,7 @@ function nf_insertion!(operations::PossibleOperations, nodeFusions::Vector{Vecto return nothing end -function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}}, locks::Dict{Node, SpinLock}) +function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}}) total_len = 0 for vec in nodeSplits total_len += length(vec) @@ -104,7 +90,7 @@ function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector @threads for vec in nodeSplits for op in vec - insert_operation!(operations, op, locks) + insert_operation!(op) end end @@ -115,11 +101,6 @@ end # function to generate all possible operations on the graph function generate_options(graph::DAG) - locks = Dict{Node, SpinLock}() - for n in graph.nodes - locks[n] = SpinLock() - end - generatedFusions = [Vector{NodeFusion}() for _ in 1:nthreads()] generatedReductions = [Vector{NodeReduction}() for _ in 1:nthreads()] generatedSplits = [Vector{NodeSplit}() for _ in 1:nthreads()] @@ -174,7 +155,7 @@ function generate_options(graph::DAG) # launch thread for node reduction insertion # remove duplicates - nr_task = @task nr_insertion!(graph.possibleOperations, generatedReductions, locks) + nr_task = @task nr_insertion!(graph.possibleOperations, generatedReductions) schedule(nr_task) # --- find possible node fusions --- @@ -200,7 +181,7 @@ function generate_options(graph::DAG) end # launch thread for node fusion insertion - nf_task = @task nf_insertion!(graph.possibleOperations, generatedFusions, locks) + nf_task = @task nf_insertion!(graph, graph.possibleOperations, generatedFusions) schedule(nf_task) # find possible node splits @@ -211,7 +192,7 @@ function generate_options(graph::DAG) end # launch thread for node split insertion - ns_task = @task ns_insertion!(graph.possibleOperations, generatedSplits, locks) + ns_task = @task ns_insertion!(graph.possibleOperations, generatedSplits) schedule(ns_task) empty!(graph.dirtyNodes) diff --git a/src/operations/validate.jl b/src/operations/validate.jl new file mode 100644 index 0000000..95f9929 --- /dev/null +++ b/src/operations/validate.jl @@ -0,0 +1,61 @@ +# functions to throw assertion errors for inconsistent or wrong node operations +# should be called with @assert +# the functions throw their own errors though, to still have helpful error messages + +function is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode) + if !(n1 in graph) || !(n2 in graph) || !(n3 in graph) + throw(AssertionError("[Node Fusion] The given nodes are not part of the given graph")) + end + + if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1) + throw(AssertionError("[Node Fusion] The given nodes are not connected by edges which is required for node fusion")) + end + + if length(n2.parents) > 1 + throw(AssertionError("[Node Fusion] The given data node has more than one parent")) + end + if length(n2.children) > 1 + throw(AssertionError("[Node Fusion] The given data node has more than one child")) + end + if length(n1.parents) > 1 + throw(AssertionError("[Node Fusion] The given n1 has more than one parent")) + end + + return true +end + +function is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node}) + for n in nodes + if n ∉ graph + throw(AssertionError("[Node Reduction] The given nodes are not part of the given graph")) + end + end + + t = typeof(nodes[1].task) + for n in nodes + if typeof(n.task) != t + throw(AssertionError("[Node Reduction] The given nodes are not of the same type")) + end + end + + n1_children = nodes[1].children + for n in nodes + if Set(n1_children) != Set(n.children) + throw(AssertionError("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction")) + end + end + + return true +end + +function is_valid_node_split_input(graph::DAG, n1::Node) + if n1 ∉ graph + throw(AssertionError("[Node Split] The given node is not part of the given graph")) + end + + if length(n1.parents) <= 1 + throw(AssertionError("[Node Split] The given node does not have multiple parents which is required for node split")) + end + + return true +end diff --git a/test/node_reduction.jl b/test/node_reduction.jl index decb724..646035f 100644 --- a/test/node_reduction.jl +++ b/test/node_reduction.jl @@ -1,7 +1,6 @@ import MetagraphOptimization.insert_node! import MetagraphOptimization.insert_edge! import MetagraphOptimization.make_node -import MetagraphOptimization.make_edge @testset "Unit Tests Node Reduction" begin graph = MetagraphOptimization.DAG() @@ -30,27 +29,27 @@ import MetagraphOptimization.make_edge BD = insert_node!(graph, make_node(DataTask(5)), false) CD = insert_node!(graph, make_node(DataTask(5)), false) - insert_edge!(graph, make_edge(s0, d_exit), false) - insert_edge!(graph, make_edge(ED, s0), false) - insert_edge!(graph, make_edge(FD, s0), false) - insert_edge!(graph, make_edge(EC, ED), false) - insert_edge!(graph, make_edge(FC, FD), false) + insert_edge!(graph, s0, d_exit, false) + insert_edge!(graph, ED, s0, false) + insert_edge!(graph, FD, s0, false) + insert_edge!(graph, EC, ED, false) + insert_edge!(graph, FC, FD, false) - insert_edge!(graph, make_edge(A1D, EC), false) - insert_edge!(graph, make_edge(B1D_1, EC), false) + insert_edge!(graph, A1D, EC, false) + insert_edge!(graph, B1D_1, EC, false) - insert_edge!(graph, make_edge(B1D_2, FC), false) - insert_edge!(graph, make_edge(C1D, FC), false) + insert_edge!(graph, B1D_2, FC, false) + insert_edge!(graph, C1D, FC, false) - insert_edge!(graph, make_edge(A1C, A1D), false) - insert_edge!(graph, make_edge(B1C_1, B1D_1), false) - insert_edge!(graph, make_edge(B1C_2, B1D_2), false) - insert_edge!(graph, make_edge(C1C, C1D), false) + insert_edge!(graph, A1C, A1D, false) + insert_edge!(graph, B1C_1, B1D_1, false) + insert_edge!(graph, B1C_2, B1D_2, false) + insert_edge!(graph, C1C, C1D, false) - insert_edge!(graph, make_edge(AD, A1C), false) - insert_edge!(graph, make_edge(BD, B1C_1), false) - insert_edge!(graph, make_edge(BD, B1C_2), false) - insert_edge!(graph, make_edge(CD, C1C), false) + insert_edge!(graph, AD, A1C, false) + insert_edge!(graph, BD, B1C_1, false) + insert_edge!(graph, BD, B1C_2, false) + insert_edge!(graph, CD, C1C, false) @test is_exit_node(d_exit) @test is_entry_node(AD) diff --git a/test/unit_tests_graph.jl b/test/unit_tests_graph.jl index db166c5..20dd5fa 100644 --- a/test/unit_tests_graph.jl +++ b/test/unit_tests_graph.jl @@ -1,7 +1,6 @@ import MetagraphOptimization.insert_node! import MetagraphOptimization.insert_edge! import MetagraphOptimization.make_node -import MetagraphOptimization.make_edge import MetagraphOptimization.siblings import MetagraphOptimization.partners @@ -69,38 +68,38 @@ import MetagraphOptimization.partners @test length(graph.dirtyNodes) == 26 # now for all the edgese - insert_edge!(graph, make_edge(d_PB, PB), false) - insert_edge!(graph, make_edge(d_PA, PA), false) - insert_edge!(graph, make_edge(d_PBp, PBp), false) - insert_edge!(graph, make_edge(d_PAp, PAp), false) + insert_edge!(graph, d_PB, PB, false) + insert_edge!(graph, d_PA, PA, false) + insert_edge!(graph, d_PBp, PBp, false) + insert_edge!(graph, d_PAp, PAp, false) - insert_edge!(graph, make_edge(PB, d_PB_uB), false) - insert_edge!(graph, make_edge(PA, d_PA_uA), false) - insert_edge!(graph, make_edge(PBp, d_PBp_uBp), false) - insert_edge!(graph, make_edge(PAp, d_PAp_uAp), false) + insert_edge!(graph, PB, d_PB_uB, false) + insert_edge!(graph, PA, d_PA_uA, false) + insert_edge!(graph, PBp, d_PBp_uBp, false) + insert_edge!(graph, PAp, d_PAp_uAp, false) - insert_edge!(graph, make_edge(d_PB_uB, uB), false) - insert_edge!(graph, make_edge(d_PA_uA, uA), false) - insert_edge!(graph, make_edge(d_PBp_uBp, uBp), false) - insert_edge!(graph, make_edge(d_PAp_uAp, uAp), false) + insert_edge!(graph, d_PB_uB, uB, false) + insert_edge!(graph, d_PA_uA, uA, false) + insert_edge!(graph, d_PBp_uBp, uBp, false) + insert_edge!(graph, d_PAp_uAp, uAp, false) - insert_edge!(graph, make_edge(uB, d_uB_v0), false) - insert_edge!(graph, make_edge(uA, d_uA_v0), false) - insert_edge!(graph, make_edge(uBp, d_uBp_v1), false) - insert_edge!(graph, make_edge(uAp, d_uAp_v1), false) + insert_edge!(graph, uB, d_uB_v0, false) + insert_edge!(graph, uA, d_uA_v0, false) + insert_edge!(graph, uBp, d_uBp_v1, false) + insert_edge!(graph, uAp, d_uAp_v1, false) - insert_edge!(graph, make_edge(d_uB_v0, v0), false) - insert_edge!(graph, make_edge(d_uA_v0, v0), false) - insert_edge!(graph, make_edge(d_uBp_v1, v1), false) - insert_edge!(graph, make_edge(d_uAp_v1, v1), false) + insert_edge!(graph, d_uB_v0, v0, false) + insert_edge!(graph, d_uA_v0, v0, false) + insert_edge!(graph, d_uBp_v1, v1, false) + insert_edge!(graph, d_uAp_v1, v1, false) - insert_edge!(graph, make_edge(v0, d_v0_s0), false) - insert_edge!(graph, make_edge(v1, d_v1_s0), false) + insert_edge!(graph, v0, d_v0_s0, false) + insert_edge!(graph, v1, d_v1_s0, false) - insert_edge!(graph, make_edge(d_v0_s0, s0), false) - insert_edge!(graph, make_edge(d_v1_s0, s0), false) + insert_edge!(graph, d_v0_s0, s0, false) + insert_edge!(graph, d_v1_s0, s0, false) - insert_edge!(graph, make_edge(s0, d_exit), false) + insert_edge!(graph, s0, d_exit, false) @test length(graph.nodes) == 26 @test length(graph.appliedOperations) == 0