Add propagation results and tests
This commit is contained in:
parent
c2687cdc01
commit
62d572adbf
@ -45,13 +45,11 @@ For valid inputs, both input particles should have the same momenta at this poin
|
||||
|
||||
12 FLOP.
|
||||
"""
|
||||
function compute(::ComputeTaskQED_S2, data1::ParticleValue{P}, data2::ParticleValue{P})::ComplexF64
|
||||
#=
|
||||
@assert isapprox(abs(data1.p.momentum.E), abs(data2.p.momentum.E), rtol = 0.001, atol = sqrt(eps())) "E: $(data1.p.momentum.E) vs. $(data2.p.momentum.E)"
|
||||
@assert isapprox(data1.p.momentum.px, -data2.p.momentum.px, rtol = 0.001, atol = sqrt(eps())) "px: $(data1.p.momentum.px) vs. $(data2.p.momentum.px)"
|
||||
@assert isapprox(data1.p.momentum.py, -data2.p.momentum.py, rtol = 0.001, atol = sqrt(eps())) "py: $(data1.p.momentum.py) vs. $(data2.p.momentum.py)"
|
||||
@assert isapprox(data1.p.momentum.pz, -data2.p.momentum.pz, rtol = 0.001, atol = sqrt(eps())) "pz: $(data1.p.momentum.pz) vs. $(data2.p.momentum.pz)"
|
||||
=#
|
||||
function compute(
|
||||
::ComputeTaskQED_S2,
|
||||
data1::ParticleValue{P},
|
||||
data2::ParticleValue{P},
|
||||
)::ComplexF64 where {P <: QEDParticle}
|
||||
inner = QED_inner_edge(data1.p)
|
||||
return data1.v * inner * data2.v
|
||||
end
|
||||
@ -65,7 +63,8 @@ Compute inner edge (1 input particle, 1 output particle).
|
||||
"""
|
||||
function compute(::ComputeTaskQED_S1, data::QEDParticleValue{P})::QEDParticleValue where {P <: QEDParticle}
|
||||
# TODO invert P for result (incoming becomes outgoing, outgoing becomes incoming)
|
||||
return QEDParticleValue{P}(data.p, data.v * QED_inner_edge(data.p))
|
||||
newP = propagation_result(P)
|
||||
return QEDParticleValue{newP}(newP(data.p), data.v * QED_inner_edge(data.p))
|
||||
end
|
||||
|
||||
"""
|
||||
|
@ -57,6 +57,12 @@ struct PhotonStateful{Direction <: ParticleDirection} <: QEDParticle{Direction}
|
||||
polarization::AbstractPolarization
|
||||
end
|
||||
|
||||
PhotonStateful{Direction}(mom::SFourMomentum) where {Direction <: ParticleDirection} =
|
||||
PhotonStateful{Direction}(mom, AllPolarization())
|
||||
|
||||
PhotonStateful{Dir1}(ph::PhotonStateful{Dir2}) where {Dir1 <: ParticleDirection, Dir2 <: ParticleDirection} =
|
||||
PhotonStateful{Dir1}(ph.momentum, ph.polarization)
|
||||
|
||||
"""
|
||||
FermionStateful <: QEDParticle
|
||||
|
||||
@ -67,6 +73,12 @@ struct FermionStateful{Direction <: ParticleDirection} <: QEDParticle{Direction}
|
||||
spin::AbstractSpin
|
||||
end
|
||||
|
||||
FermionStateful{Direction}(mom::SFourMomentum) where {Direction <: ParticleDirection} =
|
||||
FermionStateful{Direction}(mom, AllSpin())
|
||||
|
||||
FermionStateful{Dir1}(f::FermionStateful{Dir2}) where {Dir1 <: ParticleDirection, Dir2 <: ParticleDirection} =
|
||||
FermionStateful{Dir1}(f.momentum, f.spin)
|
||||
|
||||
"""
|
||||
AntiFermionStateful <: QEDParticle
|
||||
|
||||
@ -77,10 +89,16 @@ struct AntiFermionStateful{Direction <: ParticleDirection} <: QEDParticle{Direct
|
||||
spin::AbstractSpin
|
||||
end
|
||||
|
||||
AntiFermionStateful{Direction}(mom::SFourMomentum) where {Direction <: ParticleDirection} =
|
||||
AntiFermionStateful{Direction}(mom, AllSpin())
|
||||
|
||||
AntiFermionStateful{Dir1}(f::AntiFermionStateful{Dir2}) where {Dir1 <: ParticleDirection, Dir2 <: ParticleDirection} =
|
||||
AntiFermionStateful{Dir1}(f.momentum, f.spin)
|
||||
|
||||
"""
|
||||
interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: QEDParticle, T2 <: QEDParticle}
|
||||
|
||||
For 2 given (non-equal) particle types, return the third.
|
||||
For two given particle types that can interact, return the third.
|
||||
"""
|
||||
function interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: QEDParticle, T2 <: QEDParticle}
|
||||
@assert false "Invalid interaction between particles of types $t1 and $t2"
|
||||
@ -108,6 +126,18 @@ function interaction_result(t1::Type{<:PhotonStateful}, t2::Type{<:PhotonStatefu
|
||||
@assert false "Invalid interaction between particles of types $t1 and $t2"
|
||||
end
|
||||
|
||||
"""
|
||||
propagation_result(t1::Type{T}) where {T <: QEDParticle}
|
||||
|
||||
Return the type of the inverted direction. E.g.
|
||||
"""
|
||||
propagation_result(::Type{FermionStateful{Incoming}}) = FermionStateful{Outgoing}
|
||||
propagation_result(::Type{FermionStateful{Outgoing}}) = FermionStateful{Incoming}
|
||||
propagation_result(::Type{AntiFermionStateful{Incoming}}) = AntiFermionStateful{Outgoing}
|
||||
propagation_result(::Type{AntiFermionStateful{Outgoing}}) = AntiFermionStateful{Incoming}
|
||||
propagation_result(::Type{PhotonStateful{Incoming}}) = PhotonStateful{Outgoing}
|
||||
propagation_result(::Type{PhotonStateful{Outgoing}}) = PhotonStateful{Incoming}
|
||||
|
||||
"""
|
||||
types(::QEDModel)
|
||||
|
||||
@ -143,9 +173,9 @@ end
|
||||
@inline spin_or_pol(p::FermionStateful)::AbstractSpin = p.spin
|
||||
@inline spin_or_pol(p::AntiFermionStateful)::AbstractSpin = p.spin
|
||||
|
||||
@inline direction(::PhotonStateful{Dir})::ParticleDirection = Dir()
|
||||
@inline direction(::FermionStateful{Dir})::ParticleDirection = Dir()
|
||||
@inline direction(::AntiFermionStateful{Dir})::ParticleDirection = Dir()
|
||||
@inline direction(::PhotonStateful{Dir}) where {Dir <: ParticleDirection} = Dir()
|
||||
@inline direction(::FermionStateful{Dir}) where {Dir <: ParticleDirection} = Dir()
|
||||
@inline direction(::AntiFermionStateful{Dir}) where {Dir <: ParticleDirection} = Dir()
|
||||
|
||||
"""
|
||||
QED_vertex()
|
||||
|
@ -122,95 +122,101 @@ end
|
||||
end
|
||||
end
|
||||
|
||||
@testset "AB->AB large sum fusion" for _ in 1:20
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
|
||||
@testset "AB->AB large sum fusion" begin
|
||||
for _ in 1:20
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
|
||||
|
||||
# push a fusion with the sum node
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[3].task, ComputeTaskABC_Sum)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
# push two more fusions with the fused node
|
||||
for _ in 1:15
|
||||
# push a fusion with the sum node
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[3].task, FusedComputeTask)
|
||||
if isa(fusion.input[3].task, ComputeTaskABC_Sum)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# try execute
|
||||
@test is_valid(graph)
|
||||
expected_result = ground_truth_graph_result(particles_2_2)
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
|
||||
# push two more fusions with the fused node
|
||||
for _ in 1:15
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[3].task, FusedComputeTask)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# try execute
|
||||
@test is_valid(graph)
|
||||
expected_result = ground_truth_graph_result(particles_2_2)
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
@testset "AB->AB large sum fusion" for _ in 1:20
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
|
||||
@testset "AB->AB large sum fusion" begin
|
||||
for _ in 1:20
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
|
||||
|
||||
# push a fusion with the sum node
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[3].task, ComputeTaskABC_Sum)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
# push two more fusions with the fused node
|
||||
for _ in 1:15
|
||||
# push a fusion with the sum node
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[3].task, FusedComputeTask)
|
||||
if isa(fusion.input[3].task, ComputeTaskABC_Sum)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# try execute
|
||||
@test is_valid(graph)
|
||||
expected_result = ground_truth_graph_result(particles_2_2)
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
|
||||
# push two more fusions with the fused node
|
||||
for _ in 1:15
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[3].task, FusedComputeTask)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# try execute
|
||||
@test is_valid(graph)
|
||||
expected_result = ground_truth_graph_result(particles_2_2)
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
|
||||
end
|
||||
end
|
||||
|
||||
@testset "AB->AB fusion edge case" for _ in 1:20
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
|
||||
@testset "AB->AB fusion edge case" begin
|
||||
for _ in 1:20
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
|
||||
|
||||
# push two fusions with ComputeTaskABC_V
|
||||
for _ in 1:2
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[1].task, ComputeTaskABC_V)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
# push two fusions with ComputeTaskABC_V
|
||||
for _ in 1:2
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[1].task, ComputeTaskABC_V)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# push fusions until the end
|
||||
cont = true
|
||||
while cont
|
||||
cont = false
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[1].task, FusedComputeTask)
|
||||
push_operation!(graph, fusion)
|
||||
cont = true
|
||||
break
|
||||
# push fusions until the end
|
||||
cont = true
|
||||
while cont
|
||||
cont = false
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[1].task, FusedComputeTask)
|
||||
push_operation!(graph, fusion)
|
||||
cont = true
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# try execute
|
||||
@test is_valid(graph)
|
||||
expected_result = ground_truth_graph_result(particles_2_2)
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
|
||||
# try execute
|
||||
@test is_valid(graph)
|
||||
expected_result = ground_truth_graph_result(particles_2_2)
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
|
||||
end
|
||||
end
|
||||
|
@ -2,6 +2,8 @@ using MetagraphOptimization
|
||||
using QEDbase
|
||||
|
||||
import MetagraphOptimization.interaction_result
|
||||
import MetagraphOptimization.propagation_result
|
||||
import MetagraphOptimization.direction
|
||||
|
||||
def_momentum = SFourMomentum(1.0, 0.0, 0.0, 0.0)
|
||||
|
||||
@ -13,7 +15,15 @@ testparticleTypes = [
|
||||
AntiFermionStateful{Incoming},
|
||||
AntiFermionStateful{Outgoing},
|
||||
]
|
||||
#testparticles = [ParticleA(def_momentum), ParticleB(def_momentum), ParticleC(def_momentum)]
|
||||
|
||||
testparticleTypesPropagated = [
|
||||
PhotonStateful{Outgoing},
|
||||
PhotonStateful{Incoming},
|
||||
FermionStateful{Outgoing},
|
||||
FermionStateful{Incoming},
|
||||
AntiFermionStateful{Outgoing},
|
||||
AntiFermionStateful{Incoming},
|
||||
]
|
||||
|
||||
function caninteract(t1::Type{<:QEDParticle}, t2::Type{<:QEDParticle})
|
||||
if (t1 == t2)
|
||||
@ -49,3 +59,10 @@ end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Propagation Result" begin
|
||||
for (p, propResult) in zip(testparticleTypes, testparticleTypesPropagated)
|
||||
@test issame(propagation_result(p), propResult)
|
||||
@test direction(propagation_result(p)(def_momentum)) != direction(p(def_momentum))
|
||||
end
|
||||
end
|
||||
|
Loading…
x
Reference in New Issue
Block a user