Actually fix the rare execution error this time
This commit is contained in:
parent
140a954d01
commit
3267daadfd
@ -5,7 +5,7 @@ A named tuple representing a difference of added and removed nodes and edges on
|
||||
"""
|
||||
const Diff = NamedTuple{
|
||||
(:addedNodes, :removedNodes, :addedEdges, :removedEdges, :updatedChildren),
|
||||
Tuple{Vector{Node}, Vector{Node}, Vector{Edge}, Vector{Edge}, Vector{Tuple{Node, Symbol, Symbol}}},
|
||||
Tuple{Vector{Node}, Vector{Node}, Vector{Edge}, Vector{Edge}, Vector{Tuple{Node, AbstractTask}}},
|
||||
}
|
||||
|
||||
function Diff()
|
||||
@ -15,7 +15,7 @@ function Diff()
|
||||
addedEdges = Vector{Edge}(),
|
||||
removedEdges = Vector{Edge}(),
|
||||
|
||||
# children were updated from updatedChildren[2] to updatedChildren[3] in node updatedChildren[1]
|
||||
updatedChildren = Vector{Tuple{Node, Symbol, Symbol}}(),
|
||||
# children were updated in the task, updatedChildren[x][2] is the task before the update
|
||||
updatedChildren = Vector{Tuple{Node, AbstractTask}}(),
|
||||
)::Diff
|
||||
end
|
||||
|
@ -189,6 +189,8 @@ function update_child!(graph::DAG, n::Node, child_before::Symbol, child_after::S
|
||||
return nothing
|
||||
end
|
||||
|
||||
taskBefore = copy(n.task)
|
||||
|
||||
if !((child_before in n.task.t1_inputs) || (child_before in n.task.t2_inputs))
|
||||
println("------------------ Nothing to replace!! ------------------")
|
||||
child_ids = Vector{String}()
|
||||
@ -213,7 +215,7 @@ function update_child!(graph::DAG, n::Node, child_before::Symbol, child_after::S
|
||||
|
||||
# keep track
|
||||
if (track)
|
||||
push!(graph.diff.updatedChildren, (n, child_before, child_after))
|
||||
push!(graph.diff.updatedChildren, (n, taskBefore))
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -132,11 +132,11 @@ function revert_diff!(graph::DAG, diff::Diff)
|
||||
insert_edge!(graph, edge.edge[1], edge.edge[2], track = false)
|
||||
end
|
||||
|
||||
for (node, before, after) in diff.updatedChildren
|
||||
for (node, task) in diff.updatedChildren
|
||||
# node must be fused compute task at this point
|
||||
@assert typeof(node.task) <: FusedComputeTask
|
||||
|
||||
update_child!(graph, node, after, before, track = false)
|
||||
node.task = task
|
||||
end
|
||||
|
||||
graph.properties -= GraphProperties(diff)
|
||||
@ -234,6 +234,7 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
|
||||
# set of the new parents of n1
|
||||
newParents = Set{Node}()
|
||||
|
||||
# names of the previous children that n1 now replaces per parent
|
||||
newParentsChildNames = Dict{Node, Symbol}()
|
||||
|
||||
@ -255,12 +256,14 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
remove_node!(graph, n)
|
||||
end
|
||||
|
||||
setdiff!(newParents, n1Parents)
|
||||
|
||||
for parent in newParents
|
||||
# now add parents of all input nodes to n1 without duplicates
|
||||
insert_edge!(graph, n1, parent)
|
||||
if !(parent in n1Parents)
|
||||
# don't double insert edges
|
||||
insert_edge!(graph, n1, parent)
|
||||
end
|
||||
|
||||
# this has to be done for all parents, even the ones of n1 because they can be duplicate
|
||||
prevChild = newParentsChildNames[parent]
|
||||
update_child!(graph, parent, prevChild, Symbol(to_var_name(n1.id)))
|
||||
end
|
||||
|
@ -73,5 +73,98 @@ include("../examples/profiling_utilities.jl")
|
||||
end
|
||||
end
|
||||
|
||||
@testset "AB->AB large sum fusion" for _ in 1:20
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
|
||||
|
||||
# push a fusion with the sum node
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[3].task, ComputeTaskSum)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
# push two more fusions with the fused node
|
||||
for _ in 1:15
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[3].task, FusedComputeTask)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# try execute
|
||||
@test is_valid(graph)
|
||||
expected_result = 0.00013916495566048735
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001)
|
||||
end
|
||||
|
||||
|
||||
@testset "AB->AB large sum fusion" for _ in 1:20
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
|
||||
|
||||
# push a fusion with the sum node
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[3].task, ComputeTaskSum)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
# push two more fusions with the fused node
|
||||
for _ in 1:15
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[3].task, FusedComputeTask)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# try execute
|
||||
@test is_valid(graph)
|
||||
expected_result = 0.00013916495566048735
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001)
|
||||
end
|
||||
|
||||
@testset "AB->AB fusion edge case" for _ in 1:20
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
|
||||
|
||||
# push two fusions with ComputeTaskV
|
||||
for _ in 1:2
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[1].task, ComputeTaskV)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# push fusions until the end
|
||||
cont = true
|
||||
while cont
|
||||
cont = false
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[1].task, FusedComputeTask)
|
||||
push_operation!(graph, fusion)
|
||||
cont = true
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# try execute
|
||||
@test is_valid(graph)
|
||||
expected_result = 0.00013916495566048735
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001)
|
||||
end
|
||||
|
||||
end
|
||||
println("Execution Unit Tests Complete!")
|
||||
|
Loading…
x
Reference in New Issue
Block a user