Add propagation results and tests

This commit is contained in:
Anton Reinhard 2023-11-24 16:55:22 +01:00
parent c2687cdc01
commit 62d572adbf
4 changed files with 128 additions and 76 deletions

View File

@ -45,13 +45,11 @@ For valid inputs, both input particles should have the same momenta at this poin
12 FLOP.
"""
function compute(::ComputeTaskQED_S2, data1::ParticleValue{P}, data2::ParticleValue{P})::ComplexF64
#=
@assert isapprox(abs(data1.p.momentum.E), abs(data2.p.momentum.E), rtol = 0.001, atol = sqrt(eps())) "E: $(data1.p.momentum.E) vs. $(data2.p.momentum.E)"
@assert isapprox(data1.p.momentum.px, -data2.p.momentum.px, rtol = 0.001, atol = sqrt(eps())) "px: $(data1.p.momentum.px) vs. $(data2.p.momentum.px)"
@assert isapprox(data1.p.momentum.py, -data2.p.momentum.py, rtol = 0.001, atol = sqrt(eps())) "py: $(data1.p.momentum.py) vs. $(data2.p.momentum.py)"
@assert isapprox(data1.p.momentum.pz, -data2.p.momentum.pz, rtol = 0.001, atol = sqrt(eps())) "pz: $(data1.p.momentum.pz) vs. $(data2.p.momentum.pz)"
=#
function compute(
::ComputeTaskQED_S2,
data1::ParticleValue{P},
data2::ParticleValue{P},
)::ComplexF64 where {P <: QEDParticle}
inner = QED_inner_edge(data1.p)
return data1.v * inner * data2.v
end
@ -65,7 +63,8 @@ Compute inner edge (1 input particle, 1 output particle).
"""
function compute(::ComputeTaskQED_S1, data::QEDParticleValue{P})::QEDParticleValue where {P <: QEDParticle}
# TODO invert P for result (incoming becomes outgoing, outgoing becomes incoming)
return QEDParticleValue{P}(data.p, data.v * QED_inner_edge(data.p))
newP = propagation_result(P)
return QEDParticleValue{newP}(newP(data.p), data.v * QED_inner_edge(data.p))
end
"""

View File

@ -57,6 +57,12 @@ struct PhotonStateful{Direction <: ParticleDirection} <: QEDParticle{Direction}
polarization::AbstractPolarization
end
PhotonStateful{Direction}(mom::SFourMomentum) where {Direction <: ParticleDirection} =
PhotonStateful{Direction}(mom, AllPolarization())
PhotonStateful{Dir1}(ph::PhotonStateful{Dir2}) where {Dir1 <: ParticleDirection, Dir2 <: ParticleDirection} =
PhotonStateful{Dir1}(ph.momentum, ph.polarization)
"""
FermionStateful <: QEDParticle
@ -67,6 +73,12 @@ struct FermionStateful{Direction <: ParticleDirection} <: QEDParticle{Direction}
spin::AbstractSpin
end
FermionStateful{Direction}(mom::SFourMomentum) where {Direction <: ParticleDirection} =
FermionStateful{Direction}(mom, AllSpin())
FermionStateful{Dir1}(f::FermionStateful{Dir2}) where {Dir1 <: ParticleDirection, Dir2 <: ParticleDirection} =
FermionStateful{Dir1}(f.momentum, f.spin)
"""
AntiFermionStateful <: QEDParticle
@ -77,10 +89,16 @@ struct AntiFermionStateful{Direction <: ParticleDirection} <: QEDParticle{Direct
spin::AbstractSpin
end
AntiFermionStateful{Direction}(mom::SFourMomentum) where {Direction <: ParticleDirection} =
AntiFermionStateful{Direction}(mom, AllSpin())
AntiFermionStateful{Dir1}(f::AntiFermionStateful{Dir2}) where {Dir1 <: ParticleDirection, Dir2 <: ParticleDirection} =
AntiFermionStateful{Dir1}(f.momentum, f.spin)
"""
interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: QEDParticle, T2 <: QEDParticle}
For 2 given (non-equal) particle types, return the third.
For two given particle types that can interact, return the third.
"""
function interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: QEDParticle, T2 <: QEDParticle}
@assert false "Invalid interaction between particles of types $t1 and $t2"
@ -108,6 +126,18 @@ function interaction_result(t1::Type{<:PhotonStateful}, t2::Type{<:PhotonStatefu
@assert false "Invalid interaction between particles of types $t1 and $t2"
end
"""
propagation_result(t1::Type{T}) where {T <: QEDParticle}
Return the type of the inverted direction. E.g.
"""
propagation_result(::Type{FermionStateful{Incoming}}) = FermionStateful{Outgoing}
propagation_result(::Type{FermionStateful{Outgoing}}) = FermionStateful{Incoming}
propagation_result(::Type{AntiFermionStateful{Incoming}}) = AntiFermionStateful{Outgoing}
propagation_result(::Type{AntiFermionStateful{Outgoing}}) = AntiFermionStateful{Incoming}
propagation_result(::Type{PhotonStateful{Incoming}}) = PhotonStateful{Outgoing}
propagation_result(::Type{PhotonStateful{Outgoing}}) = PhotonStateful{Incoming}
"""
types(::QEDModel)
@ -143,9 +173,9 @@ end
@inline spin_or_pol(p::FermionStateful)::AbstractSpin = p.spin
@inline spin_or_pol(p::AntiFermionStateful)::AbstractSpin = p.spin
@inline direction(::PhotonStateful{Dir})::ParticleDirection = Dir()
@inline direction(::FermionStateful{Dir})::ParticleDirection = Dir()
@inline direction(::AntiFermionStateful{Dir})::ParticleDirection = Dir()
@inline direction(::PhotonStateful{Dir}) where {Dir <: ParticleDirection} = Dir()
@inline direction(::FermionStateful{Dir}) where {Dir <: ParticleDirection} = Dir()
@inline direction(::AntiFermionStateful{Dir}) where {Dir <: ParticleDirection} = Dir()
"""
QED_vertex()

View File

@ -122,95 +122,101 @@ end
end
end
@testset "AB->AB large sum fusion" for _ in 1:20
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
@testset "AB->AB large sum fusion" begin
for _ in 1:20
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
# push a fusion with the sum node
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, ComputeTaskABC_Sum)
push_operation!(graph, fusion)
break
end
end
# push two more fusions with the fused node
for _ in 1:15
# push a fusion with the sum node
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, FusedComputeTask)
if isa(fusion.input[3].task, ComputeTaskABC_Sum)
push_operation!(graph, fusion)
break
end
end
end
# try execute
@test is_valid(graph)
expected_result = ground_truth_graph_result(particles_2_2)
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
# push two more fusions with the fused node
for _ in 1:15
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, FusedComputeTask)
push_operation!(graph, fusion)
break
end
end
end
# try execute
@test is_valid(graph)
expected_result = ground_truth_graph_result(particles_2_2)
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
end
end
@testset "AB->AB large sum fusion" for _ in 1:20
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
@testset "AB->AB large sum fusion" begin
for _ in 1:20
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
# push a fusion with the sum node
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, ComputeTaskABC_Sum)
push_operation!(graph, fusion)
break
end
end
# push two more fusions with the fused node
for _ in 1:15
# push a fusion with the sum node
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, FusedComputeTask)
if isa(fusion.input[3].task, ComputeTaskABC_Sum)
push_operation!(graph, fusion)
break
end
end
end
# try execute
@test is_valid(graph)
expected_result = ground_truth_graph_result(particles_2_2)
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
# push two more fusions with the fused node
for _ in 1:15
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, FusedComputeTask)
push_operation!(graph, fusion)
break
end
end
end
# try execute
@test is_valid(graph)
expected_result = ground_truth_graph_result(particles_2_2)
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
end
end
@testset "AB->AB fusion edge case" for _ in 1:20
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
@testset "AB->AB fusion edge case" begin
for _ in 1:20
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
# push two fusions with ComputeTaskABC_V
for _ in 1:2
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[1].task, ComputeTaskABC_V)
push_operation!(graph, fusion)
break
# push two fusions with ComputeTaskABC_V
for _ in 1:2
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[1].task, ComputeTaskABC_V)
push_operation!(graph, fusion)
break
end
end
end
end
# push fusions until the end
cont = true
while cont
cont = false
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[1].task, FusedComputeTask)
push_operation!(graph, fusion)
cont = true
break
# push fusions until the end
cont = true
while cont
cont = false
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[1].task, FusedComputeTask)
push_operation!(graph, fusion)
cont = true
break
end
end
end
end
# try execute
@test is_valid(graph)
expected_result = ground_truth_graph_result(particles_2_2)
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
# try execute
@test is_valid(graph)
expected_result = ground_truth_graph_result(particles_2_2)
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
end
end

View File

@ -2,6 +2,8 @@ using MetagraphOptimization
using QEDbase
import MetagraphOptimization.interaction_result
import MetagraphOptimization.propagation_result
import MetagraphOptimization.direction
def_momentum = SFourMomentum(1.0, 0.0, 0.0, 0.0)
@ -13,7 +15,15 @@ testparticleTypes = [
AntiFermionStateful{Incoming},
AntiFermionStateful{Outgoing},
]
#testparticles = [ParticleA(def_momentum), ParticleB(def_momentum), ParticleC(def_momentum)]
testparticleTypesPropagated = [
PhotonStateful{Outgoing},
PhotonStateful{Incoming},
FermionStateful{Outgoing},
FermionStateful{Incoming},
AntiFermionStateful{Outgoing},
AntiFermionStateful{Incoming},
]
function caninteract(t1::Type{<:QEDParticle}, t2::Type{<:QEDParticle})
if (t1 == t2)
@ -49,3 +59,10 @@ end
end
end
end
@testset "Propagation Result" begin
for (p, propResult) in zip(testparticleTypes, testparticleTypesPropagated)
@test issame(propagation_result(p), propResult)
@test direction(propagation_result(p)(def_momentum)) != direction(p(def_momentum))
end
end