Add propagation results and tests

This commit is contained in:
2023-11-24 16:55:22 +01:00
parent c2687cdc01
commit 62d572adbf
4 changed files with 128 additions and 76 deletions

View File

@@ -122,95 +122,101 @@ end
end
end
@testset "AB->AB large sum fusion" for _ in 1:20
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
@testset "AB->AB large sum fusion" begin
for _ in 1:20
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
# push a fusion with the sum node
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, ComputeTaskABC_Sum)
push_operation!(graph, fusion)
break
end
end
# push two more fusions with the fused node
for _ in 1:15
# push a fusion with the sum node
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, FusedComputeTask)
if isa(fusion.input[3].task, ComputeTaskABC_Sum)
push_operation!(graph, fusion)
break
end
end
end
# try execute
@test is_valid(graph)
expected_result = ground_truth_graph_result(particles_2_2)
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
# push two more fusions with the fused node
for _ in 1:15
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, FusedComputeTask)
push_operation!(graph, fusion)
break
end
end
end
# try execute
@test is_valid(graph)
expected_result = ground_truth_graph_result(particles_2_2)
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
end
end
@testset "AB->AB large sum fusion" for _ in 1:20
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
@testset "AB->AB large sum fusion" begin
for _ in 1:20
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
# push a fusion with the sum node
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, ComputeTaskABC_Sum)
push_operation!(graph, fusion)
break
end
end
# push two more fusions with the fused node
for _ in 1:15
# push a fusion with the sum node
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, FusedComputeTask)
if isa(fusion.input[3].task, ComputeTaskABC_Sum)
push_operation!(graph, fusion)
break
end
end
end
# try execute
@test is_valid(graph)
expected_result = ground_truth_graph_result(particles_2_2)
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
# push two more fusions with the fused node
for _ in 1:15
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, FusedComputeTask)
push_operation!(graph, fusion)
break
end
end
end
# try execute
@test is_valid(graph)
expected_result = ground_truth_graph_result(particles_2_2)
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
end
end
@testset "AB->AB fusion edge case" for _ in 1:20
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
@testset "AB->AB fusion edge case" begin
for _ in 1:20
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
# push two fusions with ComputeTaskABC_V
for _ in 1:2
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[1].task, ComputeTaskABC_V)
push_operation!(graph, fusion)
break
# push two fusions with ComputeTaskABC_V
for _ in 1:2
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[1].task, ComputeTaskABC_V)
push_operation!(graph, fusion)
break
end
end
end
end
# push fusions until the end
cont = true
while cont
cont = false
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[1].task, FusedComputeTask)
push_operation!(graph, fusion)
cont = true
break
# push fusions until the end
cont = true
while cont
cont = false
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[1].task, FusedComputeTask)
push_operation!(graph, fusion)
cont = true
break
end
end
end
end
# try execute
@test is_valid(graph)
expected_result = ground_truth_graph_result(particles_2_2)
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
# try execute
@test is_valid(graph)
expected_result = ground_truth_graph_result(particles_2_2)
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
end
end

View File

@@ -2,6 +2,8 @@ using MetagraphOptimization
using QEDbase
import MetagraphOptimization.interaction_result
import MetagraphOptimization.propagation_result
import MetagraphOptimization.direction
def_momentum = SFourMomentum(1.0, 0.0, 0.0, 0.0)
@@ -13,7 +15,15 @@ testparticleTypes = [
AntiFermionStateful{Incoming},
AntiFermionStateful{Outgoing},
]
#testparticles = [ParticleA(def_momentum), ParticleB(def_momentum), ParticleC(def_momentum)]
testparticleTypesPropagated = [
PhotonStateful{Outgoing},
PhotonStateful{Incoming},
FermionStateful{Outgoing},
FermionStateful{Incoming},
AntiFermionStateful{Outgoing},
AntiFermionStateful{Incoming},
]
function caninteract(t1::Type{<:QEDParticle}, t2::Type{<:QEDParticle})
if (t1 == t2)
@@ -49,3 +59,10 @@ end
end
end
end
@testset "Propagation Result" begin
for (p, propResult) in zip(testparticleTypes, testparticleTypesPropagated)
@test issame(propagation_result(p), propResult)
@test direction(propagation_result(p)(def_momentum)) != direction(p(def_momentum))
end
end