Actually fix the rare execution error this time

This commit is contained in:
Anton Reinhard 2023-10-10 21:49:31 +02:00
parent 140a954d01
commit 3267daadfd
4 changed files with 107 additions and 9 deletions

@ -5,7 +5,7 @@ A named tuple representing a difference of added and removed nodes and edges on
"""
const Diff = NamedTuple{
(:addedNodes, :removedNodes, :addedEdges, :removedEdges, :updatedChildren),
Tuple{Vector{Node}, Vector{Node}, Vector{Edge}, Vector{Edge}, Vector{Tuple{Node, Symbol, Symbol}}},
Tuple{Vector{Node}, Vector{Node}, Vector{Edge}, Vector{Edge}, Vector{Tuple{Node, AbstractTask}}},
}
function Diff()
@ -15,7 +15,7 @@ function Diff()
addedEdges = Vector{Edge}(),
removedEdges = Vector{Edge}(),
# children were updated from updatedChildren[2] to updatedChildren[3] in node updatedChildren[1]
updatedChildren = Vector{Tuple{Node, Symbol, Symbol}}(),
# children were updated in the task, updatedChildren[x][2] is the task before the update
updatedChildren = Vector{Tuple{Node, AbstractTask}}(),
)::Diff
end

@ -189,6 +189,8 @@ function update_child!(graph::DAG, n::Node, child_before::Symbol, child_after::S
return nothing
end
taskBefore = copy(n.task)
if !((child_before in n.task.t1_inputs) || (child_before in n.task.t2_inputs))
println("------------------ Nothing to replace!! ------------------")
child_ids = Vector{String}()
@ -213,7 +215,7 @@ function update_child!(graph::DAG, n::Node, child_before::Symbol, child_after::S
# keep track
if (track)
push!(graph.diff.updatedChildren, (n, child_before, child_after))
push!(graph.diff.updatedChildren, (n, taskBefore))
end
end

@ -132,11 +132,11 @@ function revert_diff!(graph::DAG, diff::Diff)
insert_edge!(graph, edge.edge[1], edge.edge[2], track = false)
end
for (node, before, after) in diff.updatedChildren
for (node, task) in diff.updatedChildren
# node must be fused compute task at this point
@assert typeof(node.task) <: FusedComputeTask
update_child!(graph, node, after, before, track = false)
node.task = task
end
graph.properties -= GraphProperties(diff)
@ -234,6 +234,7 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
# set of the new parents of n1
newParents = Set{Node}()
# names of the previous children that n1 now replaces per parent
newParentsChildNames = Dict{Node, Symbol}()
@ -255,12 +256,14 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
remove_node!(graph, n)
end
setdiff!(newParents, n1Parents)
for parent in newParents
# now add parents of all input nodes to n1 without duplicates
insert_edge!(graph, n1, parent)
if !(parent in n1Parents)
# don't double insert edges
insert_edge!(graph, n1, parent)
end
# this has to be done for all parents, even the ones of n1 because they can be duplicate
prevChild = newParentsChildNames[parent]
update_child!(graph, parent, prevChild, Symbol(to_var_name(n1.id)))
end

@ -73,5 +73,98 @@ include("../examples/profiling_utilities.jl")
end
end
@testset "AB->AB large sum fusion" for _ in 1:20
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
# push a fusion with the sum node
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, ComputeTaskSum)
push_operation!(graph, fusion)
break
end
end
# push two more fusions with the fused node
for _ in 1:15
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, FusedComputeTask)
push_operation!(graph, fusion)
break
end
end
end
# try execute
@test is_valid(graph)
expected_result = 0.00013916495566048735
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001)
end
@testset "AB->AB large sum fusion" for _ in 1:20
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
# push a fusion with the sum node
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, ComputeTaskSum)
push_operation!(graph, fusion)
break
end
end
# push two more fusions with the fused node
for _ in 1:15
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, FusedComputeTask)
push_operation!(graph, fusion)
break
end
end
end
# try execute
@test is_valid(graph)
expected_result = 0.00013916495566048735
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001)
end
@testset "AB->AB fusion edge case" for _ in 1:20
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
# push two fusions with ComputeTaskV
for _ in 1:2
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[1].task, ComputeTaskV)
push_operation!(graph, fusion)
break
end
end
end
# push fusions until the end
cont = true
while cont
cont = false
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[1].task, FusedComputeTask)
push_operation!(graph, fusion)
cont = true
break
end
end
end
# try execute
@test is_valid(graph)
expected_result = 0.00013916495566048735
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001)
end
end
println("Execution Unit Tests Complete!")