Add momentum conservation tests; Debug result against groundtruth

This commit is contained in:
2023-11-28 15:50:18 +01:00
parent aa18430d29
commit ba0c75c8dc
8 changed files with 149 additions and 17 deletions

View File

@@ -1,6 +1,7 @@
[deps]
AccurateArithmetic = "22286c92-06ac-501d-9306-4abd417d9753"
QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
QEDprocesses = "46de9c38-1bb3-4547-a1ec-da24d767fdad"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

View File

@@ -1,5 +1,5 @@
using SafeTestsets
#=
@safetestset "Utility Unit Tests" begin
include("unit_tests_utility.jl")
end
@@ -17,10 +17,10 @@ end
end
@safetestset "ABC-Model Unit Tests" begin
include("unit_tests_abcmodel.jl")
end
end=#
@safetestset "QED-Model Unit Tests" begin
include("unit_tests_qedmodel.jl")
end
end#=
@safetestset "Node Reduction Unit Tests" begin
include("node_reduction.jl")
end
@@ -36,3 +36,4 @@ end
@safetestset "Known Graph Tests" begin
include("known_graphs.jl")
end
=#

View File

@@ -1,15 +1,21 @@
using MetagraphOptimization
using QEDbase
using QEDprocesses
using StatsBase # for countmap
using Random
import MetagraphOptimization.caninteract
import MetagraphOptimization.issame
import MetagraphOptimization.interaction_result
import MetagraphOptimization.propagation_result
import MetagraphOptimization.direction
import MetagraphOptimization.spin_or_pol
import MetagraphOptimization.QED_vertex
def_momentum = SFourMomentum(1.0, 0.0, 0.0, 0.0)
RNG = Random.default_rng()
testparticleTypes = [
PhotonStateful{Incoming},
PhotonStateful{Outgoing},
@@ -28,14 +34,69 @@ testparticleTypesPropagated = [
AntiFermionStateful{Incoming},
]
function compton_groundtruth(input::QEDProcessInput)
# p1k1 -> p2k2
# formula: (ie)^2 (u(p2) slashed(ε1) S(p2 k1) slashed(ε2) u(p1) + u(p2) slashed(ε2) S(p1 + k1) slashed(ε1) u(p1))
p1 = input.inParticles[findfirst(x -> typeof(x) <: FermionStateful, input.inParticles)]
p2 = input.outParticles[findfirst(x -> typeof(x) <: FermionStateful, input.outParticles)]
k1 = input.inParticles[findfirst(x -> typeof(x) <: PhotonStateful, input.inParticles)]
k2 = input.outParticles[findfirst(x -> typeof(x) <: PhotonStateful, input.outParticles)]
u_p1 = base_state(Electron(), Incoming(), p1.momentum, spin_or_pol(p1))
u_p2 = base_state(Electron(), Outgoing(), p2.momentum, spin_or_pol(p2))
eps_slashed_1 = base_state(Photon(), Incoming(), k1.momentum, spin_or_pol(k1))
eps_slashed_2 = base_state(Photon(), Outgoing(), k2.momentum, spin_or_pol(k2))
virt1_mom = p2.momentum - k1.momentum
virt2_mom = p1.momentum + k1.momentum
println("Groundtruth virtual particle (p2 - k1): $(virt1_mom)")
println("Groundtruth virtual particle (p1 + k1): $(virt2_mom)")
s_p2_k1 = propagator(Electron(), virt1_mom)
s_p1_k1 = propagator(Electron(), virt2_mom)
diagram1 = u_p2 * eps_slashed_1 * QED_vertex() * s_p2_k1 * QED_vertex() * eps_slashed_2 * u_p1
diagram2 = u_p2 * eps_slashed_2 * QED_vertex() * s_p1_k1 * QED_vertex() * eps_slashed_1 * u_p1
return (diagram1 + diagram2)
end
@testset "Interaction Result" begin
import MetagraphOptimization.QED_conserve_momentum
for p1 in testparticleTypes, p2 in testparticleTypes
if !caninteract(p1, p2)
@test_throws AssertionError interaction_result(p1, p2)
else
@test interaction_result(p1, p2) in setdiff(testparticleTypes, [p1, p2])
@test issame(interaction_result(p1, p2), interaction_result(p2, p1))
continue
end
@test interaction_result(p1, p2) in setdiff(testparticleTypes, [p1, p2])
@test issame(interaction_result(p1, p2), interaction_result(p2, p1))
testParticle1 = p1(rand(RNG, SFourMomentum))
testParticle2 = p2(rand(RNG, SFourMomentum))
p3 = interaction_result(p1, p2)
resultParticle = QED_conserve_momentum(testParticle1, testParticle2)
@test issame(typeof(resultParticle), interaction_result(p1, p2))
totalMom = zero(SFourMomentum)
for (p, mom) in [(p1, testParticle1.momentum), (p2, testParticle2.momentum), (p3, resultParticle.momentum)]
if (typeof(direction(p)) <: Incoming)
totalMom += mom
else
totalMom -= mom
end
end
@test isapprox(totalMom, zero(SFourMomentum); atol = sqrt(eps()))
end
end
@@ -129,23 +190,39 @@ end
# s to output (exit node)
d_exit = insert_node!(graph, make_node(DataTask(16)), track = false)
sum_node = insert_node!(graph, make_node(ComputeTaskQED_Sum(2)), track = false)
d_s0_sum = insert_node!(graph, make_node(DataTask(16)), track = false)
d_s1_sum = insert_node!(graph, make_node(DataTask(16)), track = false)
# final s compute
s0 = insert_node!(graph, make_node(ComputeTaskQED_S2()), track = false)
s1 = insert_node!(graph, make_node(ComputeTaskQED_S2()), track = false)
# data from v0 and v1 to s0
d_v0_s0 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_v1_s0 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_v2_s1 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_v3_s1 = insert_node!(graph, make_node(DataTask(96)), track = false)
# v0 and v1 compute
v0 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
v1 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
v2 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
v3 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
# data from uPhIn, uPhOut, uElIn, uElOut
# data from uPhIn, uPhOut, uElIn, uElOut to v0 and v1
d_uPhIn_v0 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uElIn_v0 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uPhOut_v1 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uElOut_v1 = insert_node!(graph, make_node(DataTask(96)), track = false)
# data from uPhIn, uPhOut, uElIn, uElOut to v2 and v3
d_uPhOut_v2 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uElIn_v2 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uPhIn_v3 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uElOut_v3 = insert_node!(graph, make_node(DataTask(96)), track = false)
# uPhIn, uPhOut, uElIn and uElOut computes
uPhIn = insert_node!(graph, make_node(ComputeTaskQED_U()), track = false)
uPhOut = insert_node!(graph, make_node(ComputeTaskQED_U()), track = false)
@@ -169,18 +246,45 @@ end
insert_edge!(graph, uElIn, d_uElIn_v0, track = false)
insert_edge!(graph, uElOut, d_uElOut_v1, track = false)
insert_edge!(graph, uPhIn, d_uPhIn_v3, track = false)
insert_edge!(graph, uPhOut, d_uPhOut_v2, track = false)
insert_edge!(graph, uElIn, d_uElIn_v2, track = false)
insert_edge!(graph, uElOut, d_uElOut_v3, track = false)
insert_edge!(graph, d_uPhIn_v0, v0, track = false)
insert_edge!(graph, d_uPhOut_v1, v1, track = false)
insert_edge!(graph, d_uElIn_v0, v0, track = false)
insert_edge!(graph, d_uElOut_v1, v1, track = false)
insert_edge!(graph, d_uPhIn_v3, v3, track = false)
insert_edge!(graph, d_uPhOut_v2, v2, track = false)
insert_edge!(graph, d_uElIn_v2, v2, track = false)
insert_edge!(graph, d_uElOut_v3, v3, track = false)
insert_edge!(graph, v0, d_v0_s0, track = false)
insert_edge!(graph, v1, d_v1_s0, track = false)
insert_edge!(graph, v2, d_v2_s1, track = false)
insert_edge!(graph, v3, d_v3_s1, track = false)
insert_edge!(graph, d_v0_s0, s0, track = false)
insert_edge!(graph, d_v1_s0, s0, track = false)
insert_edge!(graph, s0, d_exit, track = false)
insert_edge!(graph, d_v2_s1, s1, track = false)
insert_edge!(graph, d_v3_s1, s1, track = false)
insert_edge!(graph, s0, d_s0_sum, track = false)
insert_edge!(graph, s1, d_s1_sum, track = false)
insert_edge!(graph, d_s0_sum, sum_node, track = false)
insert_edge!(graph, d_s1_sum, sum_node, track = false)
insert_edge!(graph, sum_node, d_exit, track = false)
compton_function = get_compute_function(graph, process, machine)
input = gen_process_input(process)
println("Function result: $(compton_function(input))")
println("Groundtruth: $(compton_groundtruth(input))")
end