From 3267daadfd1ae9ff9f68fbce7c51dee6f2cb1ddd Mon Sep 17 00:00:00 2001 From: Anton Reinhard Date: Tue, 10 Oct 2023 21:49:31 +0200 Subject: [PATCH] Actually fix the rare execution error this time --- src/diff/type.jl | 6 +-- src/graph/mute.jl | 4 +- src/operation/apply.jl | 13 +++-- test/unit_tests_execution.jl | 93 ++++++++++++++++++++++++++++++++++++ 4 files changed, 107 insertions(+), 9 deletions(-) diff --git a/src/diff/type.jl b/src/diff/type.jl index 86f9ec8..be6d8b9 100644 --- a/src/diff/type.jl +++ b/src/diff/type.jl @@ -5,7 +5,7 @@ A named tuple representing a difference of added and removed nodes and edges on """ const Diff = NamedTuple{ (:addedNodes, :removedNodes, :addedEdges, :removedEdges, :updatedChildren), - Tuple{Vector{Node}, Vector{Node}, Vector{Edge}, Vector{Edge}, Vector{Tuple{Node, Symbol, Symbol}}}, + Tuple{Vector{Node}, Vector{Node}, Vector{Edge}, Vector{Edge}, Vector{Tuple{Node, AbstractTask}}}, } function Diff() @@ -15,7 +15,7 @@ function Diff() addedEdges = Vector{Edge}(), removedEdges = Vector{Edge}(), - # children were updated from updatedChildren[2] to updatedChildren[3] in node updatedChildren[1] - updatedChildren = Vector{Tuple{Node, Symbol, Symbol}}(), + # children were updated in the task, updatedChildren[x][2] is the task before the update + updatedChildren = Vector{Tuple{Node, AbstractTask}}(), )::Diff end diff --git a/src/graph/mute.jl b/src/graph/mute.jl index 9a01866..d23611f 100644 --- a/src/graph/mute.jl +++ b/src/graph/mute.jl @@ -189,6 +189,8 @@ function update_child!(graph::DAG, n::Node, child_before::Symbol, child_after::S return nothing end + taskBefore = copy(n.task) + if !((child_before in n.task.t1_inputs) || (child_before in n.task.t2_inputs)) println("------------------ Nothing to replace!! ------------------") child_ids = Vector{String}() @@ -213,7 +215,7 @@ function update_child!(graph::DAG, n::Node, child_before::Symbol, child_after::S # keep track if (track) - push!(graph.diff.updatedChildren, (n, child_before, child_after)) + push!(graph.diff.updatedChildren, (n, taskBefore)) end end diff --git a/src/operation/apply.jl b/src/operation/apply.jl index a2b6db0..dfe9a0b 100644 --- a/src/operation/apply.jl +++ b/src/operation/apply.jl @@ -132,11 +132,11 @@ function revert_diff!(graph::DAG, diff::Diff) insert_edge!(graph, edge.edge[1], edge.edge[2], track = false) end - for (node, before, after) in diff.updatedChildren + for (node, task) in diff.updatedChildren # node must be fused compute task at this point @assert typeof(node.task) <: FusedComputeTask - update_child!(graph, node, after, before, track = false) + node.task = task end graph.properties -= GraphProperties(diff) @@ -234,6 +234,7 @@ function node_reduction!(graph::DAG, nodes::Vector{Node}) # set of the new parents of n1 newParents = Set{Node}() + # names of the previous children that n1 now replaces per parent newParentsChildNames = Dict{Node, Symbol}() @@ -255,12 +256,14 @@ function node_reduction!(graph::DAG, nodes::Vector{Node}) remove_node!(graph, n) end - setdiff!(newParents, n1Parents) - for parent in newParents # now add parents of all input nodes to n1 without duplicates - insert_edge!(graph, n1, parent) + if !(parent in n1Parents) + # don't double insert edges + insert_edge!(graph, n1, parent) + end + # this has to be done for all parents, even the ones of n1 because they can be duplicate prevChild = newParentsChildNames[parent] update_child!(graph, parent, prevChild, Symbol(to_var_name(n1.id))) end diff --git a/test/unit_tests_execution.jl b/test/unit_tests_execution.jl index 3f5603d..37897e2 100644 --- a/test/unit_tests_execution.jl +++ b/test/unit_tests_execution.jl @@ -73,5 +73,98 @@ include("../examples/profiling_utilities.jl") end end + @testset "AB->AB large sum fusion" for _ in 1:20 + graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel()) + + # push a fusion with the sum node + ops = get_operations(graph) + for fusion in ops.nodeFusions + if isa(fusion.input[3].task, ComputeTaskSum) + push_operation!(graph, fusion) + break + end + end + + # push two more fusions with the fused node + for _ in 1:15 + ops = get_operations(graph) + for fusion in ops.nodeFusions + if isa(fusion.input[3].task, FusedComputeTask) + push_operation!(graph, fusion) + break + end + end + end + + # try execute + @test is_valid(graph) + expected_result = 0.00013916495566048735 + @test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001) + end + + + @testset "AB->AB large sum fusion" for _ in 1:20 + graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel()) + + # push a fusion with the sum node + ops = get_operations(graph) + for fusion in ops.nodeFusions + if isa(fusion.input[3].task, ComputeTaskSum) + push_operation!(graph, fusion) + break + end + end + + # push two more fusions with the fused node + for _ in 1:15 + ops = get_operations(graph) + for fusion in ops.nodeFusions + if isa(fusion.input[3].task, FusedComputeTask) + push_operation!(graph, fusion) + break + end + end + end + + # try execute + @test is_valid(graph) + expected_result = 0.00013916495566048735 + @test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001) + end + + @testset "AB->AB fusion edge case" for _ in 1:20 + graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel()) + + # push two fusions with ComputeTaskV + for _ in 1:2 + ops = get_operations(graph) + for fusion in ops.nodeFusions + if isa(fusion.input[1].task, ComputeTaskV) + push_operation!(graph, fusion) + break + end + end + end + + # push fusions until the end + cont = true + while cont + cont = false + ops = get_operations(graph) + for fusion in ops.nodeFusions + if isa(fusion.input[1].task, FusedComputeTask) + push_operation!(graph, fusion) + cont = true + break + end + end + end + + # try execute + @test is_valid(graph) + expected_result = 0.00013916495566048735 + @test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001) + end + end println("Execution Unit Tests Complete!")