From 62d572adbfdf0d708d2835251a2d35b59c130e12 Mon Sep 17 00:00:00 2001 From: Anton Reinhard Date: Fri, 24 Nov 2023 16:55:22 +0100 Subject: [PATCH] Add propagation results and tests --- src/models/qed/compute.jl | 15 ++-- src/models/qed/particle.jl | 38 ++++++++-- test/unit_tests_execution.jl | 132 ++++++++++++++++++----------------- test/unit_tests_qedmodel.jl | 19 ++++- 4 files changed, 128 insertions(+), 76 deletions(-) diff --git a/src/models/qed/compute.jl b/src/models/qed/compute.jl index 844e3f1..48151f0 100644 --- a/src/models/qed/compute.jl +++ b/src/models/qed/compute.jl @@ -45,13 +45,11 @@ For valid inputs, both input particles should have the same momenta at this poin 12 FLOP. """ -function compute(::ComputeTaskQED_S2, data1::ParticleValue{P}, data2::ParticleValue{P})::ComplexF64 - #= - @assert isapprox(abs(data1.p.momentum.E), abs(data2.p.momentum.E), rtol = 0.001, atol = sqrt(eps())) "E: $(data1.p.momentum.E) vs. $(data2.p.momentum.E)" - @assert isapprox(data1.p.momentum.px, -data2.p.momentum.px, rtol = 0.001, atol = sqrt(eps())) "px: $(data1.p.momentum.px) vs. $(data2.p.momentum.px)" - @assert isapprox(data1.p.momentum.py, -data2.p.momentum.py, rtol = 0.001, atol = sqrt(eps())) "py: $(data1.p.momentum.py) vs. $(data2.p.momentum.py)" - @assert isapprox(data1.p.momentum.pz, -data2.p.momentum.pz, rtol = 0.001, atol = sqrt(eps())) "pz: $(data1.p.momentum.pz) vs. $(data2.p.momentum.pz)" - =# +function compute( + ::ComputeTaskQED_S2, + data1::ParticleValue{P}, + data2::ParticleValue{P}, +)::ComplexF64 where {P <: QEDParticle} inner = QED_inner_edge(data1.p) return data1.v * inner * data2.v end @@ -65,7 +63,8 @@ Compute inner edge (1 input particle, 1 output particle). """ function compute(::ComputeTaskQED_S1, data::QEDParticleValue{P})::QEDParticleValue where {P <: QEDParticle} # TODO invert P for result (incoming becomes outgoing, outgoing becomes incoming) - return QEDParticleValue{P}(data.p, data.v * QED_inner_edge(data.p)) + newP = propagation_result(P) + return QEDParticleValue{newP}(newP(data.p), data.v * QED_inner_edge(data.p)) end """ diff --git a/src/models/qed/particle.jl b/src/models/qed/particle.jl index 673e649..4e8a18d 100644 --- a/src/models/qed/particle.jl +++ b/src/models/qed/particle.jl @@ -57,6 +57,12 @@ struct PhotonStateful{Direction <: ParticleDirection} <: QEDParticle{Direction} polarization::AbstractPolarization end +PhotonStateful{Direction}(mom::SFourMomentum) where {Direction <: ParticleDirection} = + PhotonStateful{Direction}(mom, AllPolarization()) + +PhotonStateful{Dir1}(ph::PhotonStateful{Dir2}) where {Dir1 <: ParticleDirection, Dir2 <: ParticleDirection} = + PhotonStateful{Dir1}(ph.momentum, ph.polarization) + """ FermionStateful <: QEDParticle @@ -67,6 +73,12 @@ struct FermionStateful{Direction <: ParticleDirection} <: QEDParticle{Direction} spin::AbstractSpin end +FermionStateful{Direction}(mom::SFourMomentum) where {Direction <: ParticleDirection} = + FermionStateful{Direction}(mom, AllSpin()) + +FermionStateful{Dir1}(f::FermionStateful{Dir2}) where {Dir1 <: ParticleDirection, Dir2 <: ParticleDirection} = + FermionStateful{Dir1}(f.momentum, f.spin) + """ AntiFermionStateful <: QEDParticle @@ -77,10 +89,16 @@ struct AntiFermionStateful{Direction <: ParticleDirection} <: QEDParticle{Direct spin::AbstractSpin end +AntiFermionStateful{Direction}(mom::SFourMomentum) where {Direction <: ParticleDirection} = + AntiFermionStateful{Direction}(mom, AllSpin()) + +AntiFermionStateful{Dir1}(f::AntiFermionStateful{Dir2}) where {Dir1 <: ParticleDirection, Dir2 <: ParticleDirection} = + AntiFermionStateful{Dir1}(f.momentum, f.spin) + """ interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: QEDParticle, T2 <: QEDParticle} -For 2 given (non-equal) particle types, return the third. +For two given particle types that can interact, return the third. """ function interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: QEDParticle, T2 <: QEDParticle} @assert false "Invalid interaction between particles of types $t1 and $t2" @@ -108,6 +126,18 @@ function interaction_result(t1::Type{<:PhotonStateful}, t2::Type{<:PhotonStatefu @assert false "Invalid interaction between particles of types $t1 and $t2" end +""" + propagation_result(t1::Type{T}) where {T <: QEDParticle} + +Return the type of the inverted direction. E.g. +""" +propagation_result(::Type{FermionStateful{Incoming}}) = FermionStateful{Outgoing} +propagation_result(::Type{FermionStateful{Outgoing}}) = FermionStateful{Incoming} +propagation_result(::Type{AntiFermionStateful{Incoming}}) = AntiFermionStateful{Outgoing} +propagation_result(::Type{AntiFermionStateful{Outgoing}}) = AntiFermionStateful{Incoming} +propagation_result(::Type{PhotonStateful{Incoming}}) = PhotonStateful{Outgoing} +propagation_result(::Type{PhotonStateful{Outgoing}}) = PhotonStateful{Incoming} + """ types(::QEDModel) @@ -143,9 +173,9 @@ end @inline spin_or_pol(p::FermionStateful)::AbstractSpin = p.spin @inline spin_or_pol(p::AntiFermionStateful)::AbstractSpin = p.spin -@inline direction(::PhotonStateful{Dir})::ParticleDirection = Dir() -@inline direction(::FermionStateful{Dir})::ParticleDirection = Dir() -@inline direction(::AntiFermionStateful{Dir})::ParticleDirection = Dir() +@inline direction(::PhotonStateful{Dir}) where {Dir <: ParticleDirection} = Dir() +@inline direction(::FermionStateful{Dir}) where {Dir <: ParticleDirection} = Dir() +@inline direction(::AntiFermionStateful{Dir}) where {Dir <: ParticleDirection} = Dir() """ QED_vertex() diff --git a/test/unit_tests_execution.jl b/test/unit_tests_execution.jl index c2046d1..dea4072 100644 --- a/test/unit_tests_execution.jl +++ b/test/unit_tests_execution.jl @@ -122,95 +122,101 @@ end end end -@testset "AB->AB large sum fusion" for _ in 1:20 - graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel()) +@testset "AB->AB large sum fusion" begin + for _ in 1:20 + graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel()) - # push a fusion with the sum node - ops = get_operations(graph) - for fusion in ops.nodeFusions - if isa(fusion.input[3].task, ComputeTaskABC_Sum) - push_operation!(graph, fusion) - break - end - end - - # push two more fusions with the fused node - for _ in 1:15 + # push a fusion with the sum node ops = get_operations(graph) for fusion in ops.nodeFusions - if isa(fusion.input[3].task, FusedComputeTask) + if isa(fusion.input[3].task, ComputeTaskABC_Sum) push_operation!(graph, fusion) break end end - end - # try execute - @test is_valid(graph) - expected_result = ground_truth_graph_result(particles_2_2) - @test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL) + # push two more fusions with the fused node + for _ in 1:15 + ops = get_operations(graph) + for fusion in ops.nodeFusions + if isa(fusion.input[3].task, FusedComputeTask) + push_operation!(graph, fusion) + break + end + end + end + + # try execute + @test is_valid(graph) + expected_result = ground_truth_graph_result(particles_2_2) + @test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL) + end end -@testset "AB->AB large sum fusion" for _ in 1:20 - graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel()) +@testset "AB->AB large sum fusion" begin + for _ in 1:20 + graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel()) - # push a fusion with the sum node - ops = get_operations(graph) - for fusion in ops.nodeFusions - if isa(fusion.input[3].task, ComputeTaskABC_Sum) - push_operation!(graph, fusion) - break - end - end - - # push two more fusions with the fused node - for _ in 1:15 + # push a fusion with the sum node ops = get_operations(graph) for fusion in ops.nodeFusions - if isa(fusion.input[3].task, FusedComputeTask) + if isa(fusion.input[3].task, ComputeTaskABC_Sum) push_operation!(graph, fusion) break end end - end - # try execute - @test is_valid(graph) - expected_result = ground_truth_graph_result(particles_2_2) - @test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL) + # push two more fusions with the fused node + for _ in 1:15 + ops = get_operations(graph) + for fusion in ops.nodeFusions + if isa(fusion.input[3].task, FusedComputeTask) + push_operation!(graph, fusion) + break + end + end + end + + # try execute + @test is_valid(graph) + expected_result = ground_truth_graph_result(particles_2_2) + @test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL) + end end -@testset "AB->AB fusion edge case" for _ in 1:20 - graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel()) +@testset "AB->AB fusion edge case" begin + for _ in 1:20 + graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel()) - # push two fusions with ComputeTaskABC_V - for _ in 1:2 - ops = get_operations(graph) - for fusion in ops.nodeFusions - if isa(fusion.input[1].task, ComputeTaskABC_V) - push_operation!(graph, fusion) - break + # push two fusions with ComputeTaskABC_V + for _ in 1:2 + ops = get_operations(graph) + for fusion in ops.nodeFusions + if isa(fusion.input[1].task, ComputeTaskABC_V) + push_operation!(graph, fusion) + break + end end end - end - # push fusions until the end - cont = true - while cont - cont = false - ops = get_operations(graph) - for fusion in ops.nodeFusions - if isa(fusion.input[1].task, FusedComputeTask) - push_operation!(graph, fusion) - cont = true - break + # push fusions until the end + cont = true + while cont + cont = false + ops = get_operations(graph) + for fusion in ops.nodeFusions + if isa(fusion.input[1].task, FusedComputeTask) + push_operation!(graph, fusion) + cont = true + break + end end end - end - # try execute - @test is_valid(graph) - expected_result = ground_truth_graph_result(particles_2_2) - @test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL) + # try execute + @test is_valid(graph) + expected_result = ground_truth_graph_result(particles_2_2) + @test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL) + end end diff --git a/test/unit_tests_qedmodel.jl b/test/unit_tests_qedmodel.jl index ead6388..36e5f41 100644 --- a/test/unit_tests_qedmodel.jl +++ b/test/unit_tests_qedmodel.jl @@ -2,6 +2,8 @@ using MetagraphOptimization using QEDbase import MetagraphOptimization.interaction_result +import MetagraphOptimization.propagation_result +import MetagraphOptimization.direction def_momentum = SFourMomentum(1.0, 0.0, 0.0, 0.0) @@ -13,7 +15,15 @@ testparticleTypes = [ AntiFermionStateful{Incoming}, AntiFermionStateful{Outgoing}, ] -#testparticles = [ParticleA(def_momentum), ParticleB(def_momentum), ParticleC(def_momentum)] + +testparticleTypesPropagated = [ + PhotonStateful{Outgoing}, + PhotonStateful{Incoming}, + FermionStateful{Outgoing}, + FermionStateful{Incoming}, + AntiFermionStateful{Outgoing}, + AntiFermionStateful{Incoming}, +] function caninteract(t1::Type{<:QEDParticle}, t2::Type{<:QEDParticle}) if (t1 == t2) @@ -49,3 +59,10 @@ end end end end + +@testset "Propagation Result" begin + for (p, propResult) in zip(testparticleTypes, testparticleTypesPropagated) + @test issame(propagation_result(p), propResult) + @test direction(propagation_result(p)(def_momentum)) != direction(p(def_momentum)) + end +end