Add momentum conservation tests; Debug result against groundtruth

This commit is contained in:
Anton Reinhard 2023-11-28 15:50:18 +01:00
parent aa18430d29
commit ba0c75c8dc
8 changed files with 149 additions and 17 deletions

View File

@ -74,7 +74,6 @@ function gen_input_assignment_code(
end
for symbol in symbols
# TODO: how to get the "default" cpu device?
device = entry_device(machine)
evalExpr = eval(gen_access_expr(device, symbol))
push!(assignInputs, Meta.parse("$(evalExpr)::ParticleValue{$type} = ParticleValue($p, one(ComplexF64))"))

View File

@ -123,7 +123,6 @@ function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invali
pre_length1 = length(node1.parents)
pre_length2 = length(node2.children)
#TODO: filter is very slow
for i in eachindex(node1.parents)
if (node1.parents[i] == node2)
splice!(node1.parents, i)
@ -252,7 +251,6 @@ function invalidate_caches!(graph::DAG, operation::NodeFusion)
delete!(graph.possibleOperations, operation)
# delete the operation from all caches of nodes involved in the operation
# TODO: filter is very slow
for n in [1, 3]
for i in eachindex(operation.input[n].nodeFusions)
if operation == operation.input[n].nodeFusions[i]

View File

@ -35,6 +35,8 @@ function compute(
p3 = QED_conserve_momentum(data1.p, data2.p)
P3 = interaction_result(P1, P2)
println("Virtual particle: $(P3(p3))")
state = QED_vertex()
if (typeof(data1.v) <: AdjointBiSpinor)
state = data1.v * state
@ -68,8 +70,15 @@ function compute(
P1 <: Union{AntiFermionStateful, FermionStateful},
P2 <: Union{AntiFermionStateful, FermionStateful},
}
println("S2: $(direction(data1.p)) $(data1.p) - $(direction(data2.p)) $(data2.p)")
# TODO: assert that data1 and data2 are opposites
inner = QED_inner_edge(data1.p)
#=@assert isapprox(abs(data1.p.momentum.E), abs(data2.p.momentum.E), rtol = 0.001, atol = sqrt(eps())) "E: $(data1.p.momentum.E) vs. $(data2.p.momentum.E)"
@assert isapprox(data1.p.momentum.px, -data2.p.momentum.px, rtol = 0.001, atol = sqrt(eps())) "px: $(data1.p.momentum.px) vs. $(data2.p.momentum.px)"
@assert isapprox(data1.p.momentum.py, -data2.p.momentum.py, rtol = 0.001, atol = sqrt(eps())) "py: $(data1.p.momentum.py) vs. $(data2.p.momentum.py)"
@assert isapprox(data1.p.momentum.pz, -data2.p.momentum.pz, rtol = 0.001, atol = sqrt(eps())) "pz: $(data1.p.momentum.pz) vs. $(data2.p.momentum.pz)"
=#
inner = QED_inner_edge(data2.p)
# inner edge is just a scalar, data1 and data2 are bispinor/adjointbispinnor, need to keep correct order
if typeof(data1.v) <: BiSpinor
return data2.v * inner * data1.v
@ -111,7 +120,8 @@ Compute a sum over the vector. Use an algorithm that accounts for accumulated er
Linearly many FLOP with growing data.
"""
function compute(::ComputeTaskQED_Sum, data::Vector{ComplexF64})::ComplexF64
return sum_kbn(data)
# TODO: want to use sum_kbn here but it doesn't seem to support ComplexF64, do it element-wise?
return sum(data)
end
"""

View File

@ -1,4 +1,6 @@
ComputeTaskQED_Sum() = ComputeTaskQED_Sum(0)
"""
gen_process_input(processDescription::QEDProcessDescription)

View File

@ -1,3 +1,4 @@
using QEDprocesses
import QEDbase.mass
"""
@ -191,6 +192,10 @@ end
@inline direction(::FermionStateful{Dir}) where {Dir <: ParticleDirection} = Dir()
@inline direction(::AntiFermionStateful{Dir}) where {Dir <: ParticleDirection} = Dir()
@inline direction(::Type{PhotonStateful{Dir}}) where {Dir <: ParticleDirection} = Dir()
@inline direction(::Type{FermionStateful{Dir}}) where {Dir <: ParticleDirection} = Dir()
@inline direction(::Type{AntiFermionStateful{Dir}}) where {Dir <: ParticleDirection} = Dir()
@inline isincoming(::QEDParticle{Incoming}) = true
@inline isincoming(::QEDParticle{Outgoing}) = false
@inline isoutgoing(::QEDParticle{Incoming}) = false
@ -271,8 +276,7 @@ Return the factor of a vertex in a QED feynman diagram.
end
@inline function QED_inner_edge(p::QEDParticle)
# TODO: doesn't exist yet in QEDprocesses
return one(ComplexF64)
return propagator(particle(p), p.momentum)
end
"""
@ -281,10 +285,23 @@ end
Calculate and return a new particle from two given interacting ones at a vertex.
"""
function QED_conserve_momentum(p1::QEDParticle, p2::QEDParticle)
#println("Conserving momentum of \n$(direction(p1)) $(p1)\n and \n$(direction(p2)) $(p2)")
T3 = interaction_result(typeof(p1), typeof(p2))
# TODO: probably also need to do something about the spin/pol
p3 = T3(p1.momentum + p2.momentum)
return p3
p1_mom = p1.momentum
if (typeof(direction(p1)) <: Outgoing)
p1_mom *= -1
end
p2_mom = p2.momentum
if (typeof(direction(p2)) <: Outgoing)
p2_mom *= -1
end
p3_mom = p1_mom + p2_mom
if (typeof(direction(T3)) <: Incoming)
return T3(-p3_mom)
end
return T3(p3_mom)
end
"""

View File

@ -1,6 +1,7 @@
[deps]
AccurateArithmetic = "22286c92-06ac-501d-9306-4abd417d9753"
QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
QEDprocesses = "46de9c38-1bb3-4547-a1ec-da24d767fdad"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

View File

@ -1,5 +1,5 @@
using SafeTestsets
#=
@safetestset "Utility Unit Tests" begin
include("unit_tests_utility.jl")
end
@ -17,10 +17,10 @@ end
end
@safetestset "ABC-Model Unit Tests" begin
include("unit_tests_abcmodel.jl")
end
end=#
@safetestset "QED-Model Unit Tests" begin
include("unit_tests_qedmodel.jl")
end
end#=
@safetestset "Node Reduction Unit Tests" begin
include("node_reduction.jl")
end
@ -36,3 +36,4 @@ end
@safetestset "Known Graph Tests" begin
include("known_graphs.jl")
end
=#

View File

@ -1,15 +1,21 @@
using MetagraphOptimization
using QEDbase
using QEDprocesses
using StatsBase # for countmap
using Random
import MetagraphOptimization.caninteract
import MetagraphOptimization.issame
import MetagraphOptimization.interaction_result
import MetagraphOptimization.propagation_result
import MetagraphOptimization.direction
import MetagraphOptimization.spin_or_pol
import MetagraphOptimization.QED_vertex
def_momentum = SFourMomentum(1.0, 0.0, 0.0, 0.0)
RNG = Random.default_rng()
testparticleTypes = [
PhotonStateful{Incoming},
PhotonStateful{Outgoing},
@ -28,14 +34,69 @@ testparticleTypesPropagated = [
AntiFermionStateful{Incoming},
]
function compton_groundtruth(input::QEDProcessInput)
# p1k1 -> p2k2
# formula: (ie)^2 (u(p2) slashed(ε1) S(p2 k1) slashed(ε2) u(p1) + u(p2) slashed(ε2) S(p1 + k1) slashed(ε1) u(p1))
p1 = input.inParticles[findfirst(x -> typeof(x) <: FermionStateful, input.inParticles)]
p2 = input.outParticles[findfirst(x -> typeof(x) <: FermionStateful, input.outParticles)]
k1 = input.inParticles[findfirst(x -> typeof(x) <: PhotonStateful, input.inParticles)]
k2 = input.outParticles[findfirst(x -> typeof(x) <: PhotonStateful, input.outParticles)]
u_p1 = base_state(Electron(), Incoming(), p1.momentum, spin_or_pol(p1))
u_p2 = base_state(Electron(), Outgoing(), p2.momentum, spin_or_pol(p2))
eps_slashed_1 = base_state(Photon(), Incoming(), k1.momentum, spin_or_pol(k1))
eps_slashed_2 = base_state(Photon(), Outgoing(), k2.momentum, spin_or_pol(k2))
virt1_mom = p2.momentum - k1.momentum
virt2_mom = p1.momentum + k1.momentum
println("Groundtruth virtual particle (p2 - k1): $(virt1_mom)")
println("Groundtruth virtual particle (p1 + k1): $(virt2_mom)")
s_p2_k1 = propagator(Electron(), virt1_mom)
s_p1_k1 = propagator(Electron(), virt2_mom)
diagram1 = u_p2 * eps_slashed_1 * QED_vertex() * s_p2_k1 * QED_vertex() * eps_slashed_2 * u_p1
diagram2 = u_p2 * eps_slashed_2 * QED_vertex() * s_p1_k1 * QED_vertex() * eps_slashed_1 * u_p1
return (diagram1 + diagram2)
end
@testset "Interaction Result" begin
import MetagraphOptimization.QED_conserve_momentum
for p1 in testparticleTypes, p2 in testparticleTypes
if !caninteract(p1, p2)
@test_throws AssertionError interaction_result(p1, p2)
else
@test interaction_result(p1, p2) in setdiff(testparticleTypes, [p1, p2])
@test issame(interaction_result(p1, p2), interaction_result(p2, p1))
continue
end
@test interaction_result(p1, p2) in setdiff(testparticleTypes, [p1, p2])
@test issame(interaction_result(p1, p2), interaction_result(p2, p1))
testParticle1 = p1(rand(RNG, SFourMomentum))
testParticle2 = p2(rand(RNG, SFourMomentum))
p3 = interaction_result(p1, p2)
resultParticle = QED_conserve_momentum(testParticle1, testParticle2)
@test issame(typeof(resultParticle), interaction_result(p1, p2))
totalMom = zero(SFourMomentum)
for (p, mom) in [(p1, testParticle1.momentum), (p2, testParticle2.momentum), (p3, resultParticle.momentum)]
if (typeof(direction(p)) <: Incoming)
totalMom += mom
else
totalMom -= mom
end
end
@test isapprox(totalMom, zero(SFourMomentum); atol = sqrt(eps()))
end
end
@ -129,23 +190,39 @@ end
# s to output (exit node)
d_exit = insert_node!(graph, make_node(DataTask(16)), track = false)
sum_node = insert_node!(graph, make_node(ComputeTaskQED_Sum(2)), track = false)
d_s0_sum = insert_node!(graph, make_node(DataTask(16)), track = false)
d_s1_sum = insert_node!(graph, make_node(DataTask(16)), track = false)
# final s compute
s0 = insert_node!(graph, make_node(ComputeTaskQED_S2()), track = false)
s1 = insert_node!(graph, make_node(ComputeTaskQED_S2()), track = false)
# data from v0 and v1 to s0
d_v0_s0 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_v1_s0 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_v2_s1 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_v3_s1 = insert_node!(graph, make_node(DataTask(96)), track = false)
# v0 and v1 compute
v0 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
v1 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
v2 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
v3 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
# data from uPhIn, uPhOut, uElIn, uElOut
# data from uPhIn, uPhOut, uElIn, uElOut to v0 and v1
d_uPhIn_v0 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uElIn_v0 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uPhOut_v1 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uElOut_v1 = insert_node!(graph, make_node(DataTask(96)), track = false)
# data from uPhIn, uPhOut, uElIn, uElOut to v2 and v3
d_uPhOut_v2 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uElIn_v2 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uPhIn_v3 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uElOut_v3 = insert_node!(graph, make_node(DataTask(96)), track = false)
# uPhIn, uPhOut, uElIn and uElOut computes
uPhIn = insert_node!(graph, make_node(ComputeTaskQED_U()), track = false)
uPhOut = insert_node!(graph, make_node(ComputeTaskQED_U()), track = false)
@ -169,18 +246,45 @@ end
insert_edge!(graph, uElIn, d_uElIn_v0, track = false)
insert_edge!(graph, uElOut, d_uElOut_v1, track = false)
insert_edge!(graph, uPhIn, d_uPhIn_v3, track = false)
insert_edge!(graph, uPhOut, d_uPhOut_v2, track = false)
insert_edge!(graph, uElIn, d_uElIn_v2, track = false)
insert_edge!(graph, uElOut, d_uElOut_v3, track = false)
insert_edge!(graph, d_uPhIn_v0, v0, track = false)
insert_edge!(graph, d_uPhOut_v1, v1, track = false)
insert_edge!(graph, d_uElIn_v0, v0, track = false)
insert_edge!(graph, d_uElOut_v1, v1, track = false)
insert_edge!(graph, d_uPhIn_v3, v3, track = false)
insert_edge!(graph, d_uPhOut_v2, v2, track = false)
insert_edge!(graph, d_uElIn_v2, v2, track = false)
insert_edge!(graph, d_uElOut_v3, v3, track = false)
insert_edge!(graph, v0, d_v0_s0, track = false)
insert_edge!(graph, v1, d_v1_s0, track = false)
insert_edge!(graph, v2, d_v2_s1, track = false)
insert_edge!(graph, v3, d_v3_s1, track = false)
insert_edge!(graph, d_v0_s0, s0, track = false)
insert_edge!(graph, d_v1_s0, s0, track = false)
insert_edge!(graph, s0, d_exit, track = false)
insert_edge!(graph, d_v2_s1, s1, track = false)
insert_edge!(graph, d_v3_s1, s1, track = false)
insert_edge!(graph, s0, d_s0_sum, track = false)
insert_edge!(graph, s1, d_s1_sum, track = false)
insert_edge!(graph, d_s0_sum, sum_node, track = false)
insert_edge!(graph, d_s1_sum, sum_node, track = false)
insert_edge!(graph, sum_node, d_exit, track = false)
compton_function = get_compute_function(graph, process, machine)
input = gen_process_input(process)
println("Function result: $(compton_function(input))")
println("Groundtruth: $(compton_groundtruth(input))")
end