Add momentum conservation tests; Debug result against groundtruth
This commit is contained in:
parent
aa18430d29
commit
ba0c75c8dc
@ -74,7 +74,6 @@ function gen_input_assignment_code(
|
||||
end
|
||||
|
||||
for symbol in symbols
|
||||
# TODO: how to get the "default" cpu device?
|
||||
device = entry_device(machine)
|
||||
evalExpr = eval(gen_access_expr(device, symbol))
|
||||
push!(assignInputs, Meta.parse("$(evalExpr)::ParticleValue{$type} = ParticleValue($p, one(ComplexF64))"))
|
||||
|
@ -123,7 +123,6 @@ function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invali
|
||||
pre_length1 = length(node1.parents)
|
||||
pre_length2 = length(node2.children)
|
||||
|
||||
#TODO: filter is very slow
|
||||
for i in eachindex(node1.parents)
|
||||
if (node1.parents[i] == node2)
|
||||
splice!(node1.parents, i)
|
||||
@ -252,7 +251,6 @@ function invalidate_caches!(graph::DAG, operation::NodeFusion)
|
||||
delete!(graph.possibleOperations, operation)
|
||||
|
||||
# delete the operation from all caches of nodes involved in the operation
|
||||
# TODO: filter is very slow
|
||||
for n in [1, 3]
|
||||
for i in eachindex(operation.input[n].nodeFusions)
|
||||
if operation == operation.input[n].nodeFusions[i]
|
||||
|
@ -35,6 +35,8 @@ function compute(
|
||||
p3 = QED_conserve_momentum(data1.p, data2.p)
|
||||
P3 = interaction_result(P1, P2)
|
||||
|
||||
println("Virtual particle: $(P3(p3))")
|
||||
|
||||
state = QED_vertex()
|
||||
if (typeof(data1.v) <: AdjointBiSpinor)
|
||||
state = data1.v * state
|
||||
@ -68,8 +70,15 @@ function compute(
|
||||
P1 <: Union{AntiFermionStateful, FermionStateful},
|
||||
P2 <: Union{AntiFermionStateful, FermionStateful},
|
||||
}
|
||||
println("S2: $(direction(data1.p)) $(data1.p) - $(direction(data2.p)) $(data2.p)")
|
||||
# TODO: assert that data1 and data2 are opposites
|
||||
inner = QED_inner_edge(data1.p)
|
||||
#=@assert isapprox(abs(data1.p.momentum.E), abs(data2.p.momentum.E), rtol = 0.001, atol = sqrt(eps())) "E: $(data1.p.momentum.E) vs. $(data2.p.momentum.E)"
|
||||
@assert isapprox(data1.p.momentum.px, -data2.p.momentum.px, rtol = 0.001, atol = sqrt(eps())) "px: $(data1.p.momentum.px) vs. $(data2.p.momentum.px)"
|
||||
@assert isapprox(data1.p.momentum.py, -data2.p.momentum.py, rtol = 0.001, atol = sqrt(eps())) "py: $(data1.p.momentum.py) vs. $(data2.p.momentum.py)"
|
||||
@assert isapprox(data1.p.momentum.pz, -data2.p.momentum.pz, rtol = 0.001, atol = sqrt(eps())) "pz: $(data1.p.momentum.pz) vs. $(data2.p.momentum.pz)"
|
||||
=#
|
||||
|
||||
inner = QED_inner_edge(data2.p)
|
||||
# inner edge is just a scalar, data1 and data2 are bispinor/adjointbispinnor, need to keep correct order
|
||||
if typeof(data1.v) <: BiSpinor
|
||||
return data2.v * inner * data1.v
|
||||
@ -111,7 +120,8 @@ Compute a sum over the vector. Use an algorithm that accounts for accumulated er
|
||||
Linearly many FLOP with growing data.
|
||||
"""
|
||||
function compute(::ComputeTaskQED_Sum, data::Vector{ComplexF64})::ComplexF64
|
||||
return sum_kbn(data)
|
||||
# TODO: want to use sum_kbn here but it doesn't seem to support ComplexF64, do it element-wise?
|
||||
return sum(data)
|
||||
end
|
||||
|
||||
"""
|
||||
|
@ -1,4 +1,6 @@
|
||||
|
||||
ComputeTaskQED_Sum() = ComputeTaskQED_Sum(0)
|
||||
|
||||
"""
|
||||
gen_process_input(processDescription::QEDProcessDescription)
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
using QEDprocesses
|
||||
import QEDbase.mass
|
||||
|
||||
"""
|
||||
@ -191,6 +192,10 @@ end
|
||||
@inline direction(::FermionStateful{Dir}) where {Dir <: ParticleDirection} = Dir()
|
||||
@inline direction(::AntiFermionStateful{Dir}) where {Dir <: ParticleDirection} = Dir()
|
||||
|
||||
@inline direction(::Type{PhotonStateful{Dir}}) where {Dir <: ParticleDirection} = Dir()
|
||||
@inline direction(::Type{FermionStateful{Dir}}) where {Dir <: ParticleDirection} = Dir()
|
||||
@inline direction(::Type{AntiFermionStateful{Dir}}) where {Dir <: ParticleDirection} = Dir()
|
||||
|
||||
@inline isincoming(::QEDParticle{Incoming}) = true
|
||||
@inline isincoming(::QEDParticle{Outgoing}) = false
|
||||
@inline isoutgoing(::QEDParticle{Incoming}) = false
|
||||
@ -271,8 +276,7 @@ Return the factor of a vertex in a QED feynman diagram.
|
||||
end
|
||||
|
||||
@inline function QED_inner_edge(p::QEDParticle)
|
||||
# TODO: doesn't exist yet in QEDprocesses
|
||||
return one(ComplexF64)
|
||||
return propagator(particle(p), p.momentum)
|
||||
end
|
||||
|
||||
"""
|
||||
@ -281,10 +285,23 @@ end
|
||||
Calculate and return a new particle from two given interacting ones at a vertex.
|
||||
"""
|
||||
function QED_conserve_momentum(p1::QEDParticle, p2::QEDParticle)
|
||||
#println("Conserving momentum of \n$(direction(p1)) $(p1)\n and \n$(direction(p2)) $(p2)")
|
||||
T3 = interaction_result(typeof(p1), typeof(p2))
|
||||
# TODO: probably also need to do something about the spin/pol
|
||||
p3 = T3(p1.momentum + p2.momentum)
|
||||
return p3
|
||||
p1_mom = p1.momentum
|
||||
if (typeof(direction(p1)) <: Outgoing)
|
||||
p1_mom *= -1
|
||||
end
|
||||
p2_mom = p2.momentum
|
||||
if (typeof(direction(p2)) <: Outgoing)
|
||||
p2_mom *= -1
|
||||
end
|
||||
|
||||
p3_mom = p1_mom + p2_mom
|
||||
if (typeof(direction(T3)) <: Incoming)
|
||||
return T3(-p3_mom)
|
||||
end
|
||||
return T3(p3_mom)
|
||||
end
|
||||
|
||||
"""
|
||||
|
@ -1,6 +1,7 @@
|
||||
[deps]
|
||||
AccurateArithmetic = "22286c92-06ac-501d-9306-4abd417d9753"
|
||||
QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
|
||||
QEDprocesses = "46de9c38-1bb3-4547-a1ec-da24d767fdad"
|
||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
|
||||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
|
@ -1,5 +1,5 @@
|
||||
using SafeTestsets
|
||||
|
||||
#=
|
||||
@safetestset "Utility Unit Tests" begin
|
||||
include("unit_tests_utility.jl")
|
||||
end
|
||||
@ -17,10 +17,10 @@ end
|
||||
end
|
||||
@safetestset "ABC-Model Unit Tests" begin
|
||||
include("unit_tests_abcmodel.jl")
|
||||
end
|
||||
end=#
|
||||
@safetestset "QED-Model Unit Tests" begin
|
||||
include("unit_tests_qedmodel.jl")
|
||||
end
|
||||
end#=
|
||||
@safetestset "Node Reduction Unit Tests" begin
|
||||
include("node_reduction.jl")
|
||||
end
|
||||
@ -36,3 +36,4 @@ end
|
||||
@safetestset "Known Graph Tests" begin
|
||||
include("known_graphs.jl")
|
||||
end
|
||||
=#
|
@ -1,15 +1,21 @@
|
||||
using MetagraphOptimization
|
||||
using QEDbase
|
||||
using QEDprocesses
|
||||
using StatsBase # for countmap
|
||||
using Random
|
||||
|
||||
import MetagraphOptimization.caninteract
|
||||
import MetagraphOptimization.issame
|
||||
import MetagraphOptimization.interaction_result
|
||||
import MetagraphOptimization.propagation_result
|
||||
import MetagraphOptimization.direction
|
||||
import MetagraphOptimization.spin_or_pol
|
||||
import MetagraphOptimization.QED_vertex
|
||||
|
||||
def_momentum = SFourMomentum(1.0, 0.0, 0.0, 0.0)
|
||||
|
||||
RNG = Random.default_rng()
|
||||
|
||||
testparticleTypes = [
|
||||
PhotonStateful{Incoming},
|
||||
PhotonStateful{Outgoing},
|
||||
@ -28,14 +34,69 @@ testparticleTypesPropagated = [
|
||||
AntiFermionStateful{Incoming},
|
||||
]
|
||||
|
||||
function compton_groundtruth(input::QEDProcessInput)
|
||||
# p1k1 -> p2k2
|
||||
# formula: −(ie)^2 (u(p2) slashed(ε1) S(p2 − k1) slashed(ε2) u(p1) + u(p2) slashed(ε2) S(p1 + k1) slashed(ε1) u(p1))
|
||||
|
||||
p1 = input.inParticles[findfirst(x -> typeof(x) <: FermionStateful, input.inParticles)]
|
||||
p2 = input.outParticles[findfirst(x -> typeof(x) <: FermionStateful, input.outParticles)]
|
||||
|
||||
k1 = input.inParticles[findfirst(x -> typeof(x) <: PhotonStateful, input.inParticles)]
|
||||
k2 = input.outParticles[findfirst(x -> typeof(x) <: PhotonStateful, input.outParticles)]
|
||||
|
||||
u_p1 = base_state(Electron(), Incoming(), p1.momentum, spin_or_pol(p1))
|
||||
u_p2 = base_state(Electron(), Outgoing(), p2.momentum, spin_or_pol(p2))
|
||||
|
||||
eps_slashed_1 = base_state(Photon(), Incoming(), k1.momentum, spin_or_pol(k1))
|
||||
eps_slashed_2 = base_state(Photon(), Outgoing(), k2.momentum, spin_or_pol(k2))
|
||||
|
||||
virt1_mom = p2.momentum - k1.momentum
|
||||
virt2_mom = p1.momentum + k1.momentum
|
||||
|
||||
println("Groundtruth virtual particle (p2 - k1): $(virt1_mom)")
|
||||
println("Groundtruth virtual particle (p1 + k1): $(virt2_mom)")
|
||||
|
||||
|
||||
s_p2_k1 = propagator(Electron(), virt1_mom)
|
||||
s_p1_k1 = propagator(Electron(), virt2_mom)
|
||||
|
||||
diagram1 = u_p2 * eps_slashed_1 * QED_vertex() * s_p2_k1 * QED_vertex() * eps_slashed_2 * u_p1
|
||||
diagram2 = u_p2 * eps_slashed_2 * QED_vertex() * s_p1_k1 * QED_vertex() * eps_slashed_1 * u_p1
|
||||
|
||||
return (diagram1 + diagram2)
|
||||
end
|
||||
|
||||
|
||||
@testset "Interaction Result" begin
|
||||
import MetagraphOptimization.QED_conserve_momentum
|
||||
|
||||
for p1 in testparticleTypes, p2 in testparticleTypes
|
||||
if !caninteract(p1, p2)
|
||||
@test_throws AssertionError interaction_result(p1, p2)
|
||||
else
|
||||
@test interaction_result(p1, p2) in setdiff(testparticleTypes, [p1, p2])
|
||||
@test issame(interaction_result(p1, p2), interaction_result(p2, p1))
|
||||
continue
|
||||
end
|
||||
|
||||
@test interaction_result(p1, p2) in setdiff(testparticleTypes, [p1, p2])
|
||||
@test issame(interaction_result(p1, p2), interaction_result(p2, p1))
|
||||
|
||||
testParticle1 = p1(rand(RNG, SFourMomentum))
|
||||
testParticle2 = p2(rand(RNG, SFourMomentum))
|
||||
p3 = interaction_result(p1, p2)
|
||||
|
||||
resultParticle = QED_conserve_momentum(testParticle1, testParticle2)
|
||||
|
||||
@test issame(typeof(resultParticle), interaction_result(p1, p2))
|
||||
|
||||
totalMom = zero(SFourMomentum)
|
||||
for (p, mom) in [(p1, testParticle1.momentum), (p2, testParticle2.momentum), (p3, resultParticle.momentum)]
|
||||
if (typeof(direction(p)) <: Incoming)
|
||||
totalMom += mom
|
||||
else
|
||||
totalMom -= mom
|
||||
end
|
||||
end
|
||||
|
||||
@test isapprox(totalMom, zero(SFourMomentum); atol = sqrt(eps()))
|
||||
end
|
||||
end
|
||||
|
||||
@ -129,23 +190,39 @@ end
|
||||
# s to output (exit node)
|
||||
d_exit = insert_node!(graph, make_node(DataTask(16)), track = false)
|
||||
|
||||
sum_node = insert_node!(graph, make_node(ComputeTaskQED_Sum(2)), track = false)
|
||||
|
||||
d_s0_sum = insert_node!(graph, make_node(DataTask(16)), track = false)
|
||||
d_s1_sum = insert_node!(graph, make_node(DataTask(16)), track = false)
|
||||
|
||||
# final s compute
|
||||
s0 = insert_node!(graph, make_node(ComputeTaskQED_S2()), track = false)
|
||||
s1 = insert_node!(graph, make_node(ComputeTaskQED_S2()), track = false)
|
||||
|
||||
# data from v0 and v1 to s0
|
||||
d_v0_s0 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
d_v1_s0 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
d_v2_s1 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
d_v3_s1 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
|
||||
# v0 and v1 compute
|
||||
v0 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
|
||||
v1 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
|
||||
v2 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
|
||||
v3 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
|
||||
|
||||
# data from uPhIn, uPhOut, uElIn, uElOut
|
||||
# data from uPhIn, uPhOut, uElIn, uElOut to v0 and v1
|
||||
d_uPhIn_v0 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
d_uElIn_v0 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
d_uPhOut_v1 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
d_uElOut_v1 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
|
||||
# data from uPhIn, uPhOut, uElIn, uElOut to v2 and v3
|
||||
d_uPhOut_v2 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
d_uElIn_v2 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
d_uPhIn_v3 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
d_uElOut_v3 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
|
||||
# uPhIn, uPhOut, uElIn and uElOut computes
|
||||
uPhIn = insert_node!(graph, make_node(ComputeTaskQED_U()), track = false)
|
||||
uPhOut = insert_node!(graph, make_node(ComputeTaskQED_U()), track = false)
|
||||
@ -169,18 +246,45 @@ end
|
||||
insert_edge!(graph, uElIn, d_uElIn_v0, track = false)
|
||||
insert_edge!(graph, uElOut, d_uElOut_v1, track = false)
|
||||
|
||||
insert_edge!(graph, uPhIn, d_uPhIn_v3, track = false)
|
||||
insert_edge!(graph, uPhOut, d_uPhOut_v2, track = false)
|
||||
insert_edge!(graph, uElIn, d_uElIn_v2, track = false)
|
||||
insert_edge!(graph, uElOut, d_uElOut_v3, track = false)
|
||||
|
||||
insert_edge!(graph, d_uPhIn_v0, v0, track = false)
|
||||
insert_edge!(graph, d_uPhOut_v1, v1, track = false)
|
||||
insert_edge!(graph, d_uElIn_v0, v0, track = false)
|
||||
insert_edge!(graph, d_uElOut_v1, v1, track = false)
|
||||
|
||||
insert_edge!(graph, d_uPhIn_v3, v3, track = false)
|
||||
insert_edge!(graph, d_uPhOut_v2, v2, track = false)
|
||||
insert_edge!(graph, d_uElIn_v2, v2, track = false)
|
||||
insert_edge!(graph, d_uElOut_v3, v3, track = false)
|
||||
|
||||
insert_edge!(graph, v0, d_v0_s0, track = false)
|
||||
insert_edge!(graph, v1, d_v1_s0, track = false)
|
||||
insert_edge!(graph, v2, d_v2_s1, track = false)
|
||||
insert_edge!(graph, v3, d_v3_s1, track = false)
|
||||
|
||||
insert_edge!(graph, d_v0_s0, s0, track = false)
|
||||
insert_edge!(graph, d_v1_s0, s0, track = false)
|
||||
|
||||
insert_edge!(graph, s0, d_exit, track = false)
|
||||
insert_edge!(graph, d_v2_s1, s1, track = false)
|
||||
insert_edge!(graph, d_v3_s1, s1, track = false)
|
||||
|
||||
insert_edge!(graph, s0, d_s0_sum, track = false)
|
||||
insert_edge!(graph, s1, d_s1_sum, track = false)
|
||||
|
||||
insert_edge!(graph, d_s0_sum, sum_node, track = false)
|
||||
insert_edge!(graph, d_s1_sum, sum_node, track = false)
|
||||
|
||||
insert_edge!(graph, sum_node, d_exit, track = false)
|
||||
|
||||
compton_function = get_compute_function(graph, process, machine)
|
||||
|
||||
input = gen_process_input(process)
|
||||
|
||||
println("Function result: $(compton_function(input))")
|
||||
|
||||
println("Groundtruth: $(compton_groundtruth(input))")
|
||||
end
|
||||
|
Loading…
x
Reference in New Issue
Block a user