heterogeneity (#27)

Prepare things to work with heterogeneity, make things work on GPU

Reviewed-on: Rubydragon/MetagraphOptimization.jl#27
Co-authored-by: Anton Reinhard <anton.reinhard@proton.me>
Co-committed-by: Anton Reinhard <anton.reinhard@proton.me>
This commit is contained in:
2023-12-18 14:31:52 +01:00
committed by Anton Reinhard
parent c90346e948
commit 92e0eeaaef
42 changed files with 1631 additions and 238 deletions

View File

@ -4,5 +4,7 @@ QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
QEDprocesses = "46de9c38-1bb3-4547-a1ec-da24d767fdad"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

View File

@ -18,12 +18,12 @@ end
@safetestset "ABC-Model Unit Tests " begin
include("unit_tests_abcmodel.jl")
end
@safetestset "QED Feynman Diagram Generation Tests" begin
include("unit_tests_qed_diagrams.jl")
end
@safetestset "QED-Model Unit Tests " begin
include("unit_tests_qedmodel.jl")
end
@safetestset "QED Feynman Diagram Generation Tests" begin
include("unit_tests_qed_diagrams.jl")
end
@safetestset "Node Reduction Unit Tests " begin
include("node_reduction.jl")
end

View File

@ -2,6 +2,8 @@ using MetagraphOptimization
using QEDbase
using AccurateArithmetic
using Random
using UUIDs
using StaticArrays
import MetagraphOptimization.ABCParticle
import MetagraphOptimization.interaction_result
@ -27,11 +29,11 @@ function ground_truth_graph_result(input::ABCProcessInput)
constant = (1 / 137.0)^2
# calculate particle C in diagram 1
diagram1_C = ParticleC(input.inParticles[1].momentum + input.inParticles[2].momentum)
diagram2_C = ParticleC(input.inParticles[1].momentum + input.outParticles[2].momentum)
diagram1_C = ParticleC(input.inA[1].momentum + input.inB[1].momentum)
diagram2_C = ParticleC(input.inA[1].momentum + input.outB[1].momentum)
diagram1_Cp = ParticleC(input.outParticles[1].momentum + input.outParticles[2].momentum)
diagram2_Cp = ParticleC(input.outParticles[1].momentum + input.inParticles[2].momentum)
diagram1_Cp = ParticleC(input.outA[1].momentum + input.outB[1].momentum)
diagram2_Cp = ParticleC(input.outA[1].momentum + input.inB[1].momentum)
check_particle_reverse_moment(diagram1_Cp.momentum, diagram1_C.momentum)
check_particle_reverse_moment(diagram2_Cp.momentum, diagram2_C.momentum)
@ -47,7 +49,18 @@ function ground_truth_graph_result(input::ABCProcessInput)
return sum_kbn([diagram1_result, diagram2_result])
end
machine = get_machine_info()
machine = Machine(
[
MetagraphOptimization.NumaNode(
0,
1,
MetagraphOptimization.default_strategy(MetagraphOptimization.NumaNode),
-1.0,
UUIDs.uuid1(),
),
],
[-1.0;;],
)
process_2_2 = ABCProcessDescription(
Dict{Type, Int64}(ParticleA => 1, ParticleB => 1),
@ -56,14 +69,12 @@ process_2_2 = ABCProcessDescription(
particles_2_2 = ABCProcessInput(
process_2_2,
ABCParticle[
ParticleA(SFourMomentum(0.823648, 0.0, 0.0, 0.823648)),
ParticleB(SFourMomentum(0.823648, 0.0, 0.0, -0.823648)),
],
ABCParticle[
ParticleA(SFourMomentum(0.823648, -0.835061, -0.474802, 0.277915)),
ParticleB(SFourMomentum(0.823648, 0.835061, 0.474802, -0.277915)),
],
SVector{1}(ParticleA(SFourMomentum(0.823648, 0.0, 0.0, 0.823648))),
SVector{1}(ParticleB(SFourMomentum(0.823648, 0.0, 0.0, -0.823648))),
SVector{0, ParticleC}(),
SVector{1}(ParticleA(SFourMomentum(0.823648, -0.835061, -0.474802, 0.277915))),
SVector{1}(ParticleB(SFourMomentum(0.823648, 0.835061, 0.474802, -0.277915))),
SVector{0, ParticleC}(),
)
expected_result = ground_truth_graph_result(particles_2_2)

View File

@ -3,6 +3,7 @@ using QEDbase
using QEDprocesses
using StatsBase # for countmap
using Random
using UUIDs
import MetagraphOptimization.caninteract
import MetagraphOptimization.issame
@ -17,32 +18,32 @@ def_momentum = SFourMomentum(1.0, 0.0, 0.0, 0.0)
RNG = Random.default_rng()
testparticleTypes = [
PhotonStateful{Incoming},
PhotonStateful{Outgoing},
FermionStateful{Incoming},
FermionStateful{Outgoing},
AntiFermionStateful{Incoming},
AntiFermionStateful{Outgoing},
PhotonStateful{Incoming, PolX},
PhotonStateful{Outgoing, PolX},
FermionStateful{Incoming, SpinUp},
FermionStateful{Outgoing, SpinUp},
AntiFermionStateful{Incoming, SpinUp},
AntiFermionStateful{Outgoing, SpinUp},
]
testparticleTypesPropagated = [
PhotonStateful{Outgoing},
PhotonStateful{Incoming},
FermionStateful{Outgoing},
FermionStateful{Incoming},
AntiFermionStateful{Outgoing},
AntiFermionStateful{Incoming},
PhotonStateful{Outgoing, PolX},
PhotonStateful{Incoming, PolX},
FermionStateful{Outgoing, SpinUp},
FermionStateful{Incoming, SpinUp},
AntiFermionStateful{Outgoing, SpinUp},
AntiFermionStateful{Incoming, SpinUp},
]
function compton_groundtruth(input::QEDProcessInput)
# p1k1 -> p2k2
# formula: (ie)^2 (u(p2) slashed(ε1) S(p2 k1) slashed(ε2) u(p1) + u(p2) slashed(ε2) S(p1 + k1) slashed(ε1) u(p1))
p1 = input.inParticles[findfirst(x -> typeof(x) <: FermionStateful, input.inParticles)]
p2 = input.outParticles[findfirst(x -> typeof(x) <: FermionStateful, input.outParticles)]
p1 = input.inFerms[1]
p2 = input.outFerms[1]
k1 = input.inParticles[findfirst(x -> typeof(x) <: PhotonStateful, input.inParticles)]
k2 = input.outParticles[findfirst(x -> typeof(x) <: PhotonStateful, input.outParticles)]
k1 = input.inPhotons[1]
k2 = input.outPhotons[1]
u_p1 = base_state(Electron(), Incoming(), p1.momentum, spin_or_pol(p1))
u_p2 = base_state(Electron(), Outgoing(), p2.momentum, spin_or_pol(p2))
@ -117,36 +118,36 @@ end
@testset "Known processes" begin
compton_process = QEDProcessDescription(
Dict{Type, Int}(PhotonStateful{Incoming} => 1, FermionStateful{Incoming} => 1),
Dict{Type, Int}(PhotonStateful{Outgoing} => 1, FermionStateful{Outgoing} => 1),
Dict{Type, Int}(PhotonStateful{Incoming, PolX} => 1, FermionStateful{Incoming, SpinUp} => 1),
Dict{Type, Int}(PhotonStateful{Outgoing, PolX} => 1, FermionStateful{Outgoing, SpinUp} => 1),
)
@test parse_process("ke->ke", QEDModel()) == compton_process
positron_compton_process = QEDProcessDescription(
Dict{Type, Int}(PhotonStateful{Incoming} => 1, AntiFermionStateful{Incoming} => 1),
Dict{Type, Int}(PhotonStateful{Outgoing} => 1, AntiFermionStateful{Outgoing} => 1),
Dict{Type, Int}(PhotonStateful{Incoming, PolX} => 1, AntiFermionStateful{Incoming, SpinUp} => 1),
Dict{Type, Int}(PhotonStateful{Outgoing, PolX} => 1, AntiFermionStateful{Outgoing, SpinUp} => 1),
)
@test parse_process("kp->kp", QEDModel()) == positron_compton_process
trident_process = QEDProcessDescription(
Dict{Type, Int}(PhotonStateful{Incoming} => 1, FermionStateful{Incoming} => 1),
Dict{Type, Int}(FermionStateful{Outgoing} => 2, AntiFermionStateful{Outgoing} => 1),
Dict{Type, Int}(PhotonStateful{Incoming, PolX} => 1, FermionStateful{Incoming, SpinUp} => 1),
Dict{Type, Int}(FermionStateful{Outgoing, SpinUp} => 2, AntiFermionStateful{Outgoing, SpinUp} => 1),
)
@test parse_process("ke->eep", QEDModel()) == trident_process
pair_production_process = QEDProcessDescription(
Dict{Type, Int}(PhotonStateful{Incoming} => 2),
Dict{Type, Int}(FermionStateful{Outgoing} => 1, AntiFermionStateful{Outgoing} => 1),
Dict{Type, Int}(PhotonStateful{Incoming, PolX} => 2),
Dict{Type, Int}(FermionStateful{Outgoing, SpinUp} => 1, AntiFermionStateful{Outgoing, SpinUp} => 1),
)
@test parse_process("kk->pe", QEDModel()) == pair_production_process
pair_annihilation_process = QEDProcessDescription(
Dict{Type, Int}(FermionStateful{Incoming} => 1, AntiFermionStateful{Incoming} => 1),
Dict{Type, Int}(PhotonStateful{Outgoing} => 2),
Dict{Type, Int}(FermionStateful{Incoming, SpinUp} => 1, AntiFermionStateful{Incoming, SpinUp} => 1),
Dict{Type, Int}(PhotonStateful{Outgoing, PolX} => 2),
)
@test parse_process("pe->kk", QEDModel()) == pair_annihilation_process
@ -160,12 +161,24 @@ end
for i in 1:100
input = gen_process_input(process)
@test countmap(typeof.(input.inParticles)) == process.inParticles
@test countmap(typeof.(input.outParticles)) == process.outParticles
@test length(input.inFerms) == get(process.inParticles, FermionStateful{Incoming, SpinUp}, 0)
@test length(input.inAntiferms) == get(process.inParticles, AntiFermionStateful{Incoming, SpinUp}, 0)
@test length(input.inPhotons) == get(process.inParticles, PhotonStateful{Incoming, PolX}, 0)
@test length(input.outFerms) == get(process.outParticles, FermionStateful{Outgoing, SpinUp}, 0)
@test length(input.outAntiferms) == get(process.outParticles, AntiFermionStateful{Outgoing, SpinUp}, 0)
@test length(input.outPhotons) == get(process.outParticles, PhotonStateful{Outgoing, PolX}, 0)
@test isapprox(
sum(getfield.(input.inParticles, :momentum)),
sum(getfield.(input.outParticles, :momentum));
sum([
getfield.(input.inFerms, :momentum)...,
getfield.(input.inAntiferms, :momentum)...,
getfield.(input.inPhotons, :momentum)...,
]),
sum([
getfield.(input.outFerms, :momentum)...,
getfield.(input.outAntiferms, :momentum)...,
getfield.(input.outPhotons, :momentum)...,
]);
atol = sqrt(eps()),
)
end
@ -179,7 +192,18 @@ end
model = QEDModel()
process = parse_process("ke->ke", model)
machine = get_machine_info()
machine = Machine(
[
MetagraphOptimization.NumaNode(
0,
1,
MetagraphOptimization.default_strategy(MetagraphOptimization.NumaNode),
-1.0,
UUIDs.uuid1(),
),
],
[-1.0;;],
)
graph = MetagraphOptimization.DAG()
@ -289,3 +313,37 @@ end
compton_function = get_compute_function(graph_generated, process, machine)
@test isapprox(compton_function.(input), compton_groundtruth.(input))
end
@testset "Equal results after optimization" for optimizer in
[ReductionOptimizer(), RandomWalkOptimizer(MersenneTwister(0))]
@testset "Process $proc_str" for proc_str in ["ke->ke", "kp->kp", "kk->ep", "ep->kk", "ke->kke", "ke->kkke"]
model = QEDModel()
process = parse_process(proc_str, model)
machine = Machine(
[
MetagraphOptimization.NumaNode(
0,
1,
MetagraphOptimization.default_strategy(MetagraphOptimization.NumaNode),
-1.0,
UUIDs.uuid1(),
),
],
[-1.0;;],
)
graph = gen_graph(process)
compute_function = get_compute_function(graph, process, machine)
if (typeof(optimizer) <: RandomWalkOptimizer)
optimize!(optimizer, graph, 100)
elseif (typeof(optimizer) <: ReductionOptimizer)
optimize_to_fixpoint!(optimizer, graph)
end
reduced_compute_function = get_compute_function(graph, process, machine)
input = [gen_process_input(process) for _ in 1:100]
@test isapprox(compute_function.(input), reduced_compute_function.(input))
end
end