heterogeneity (#27)
Prepare things to work with heterogeneity, make things work on GPU Reviewed-on: Rubydragon/MetagraphOptimization.jl#27 Co-authored-by: Anton Reinhard <anton.reinhard@proton.me> Co-committed-by: Anton Reinhard <anton.reinhard@proton.me>
This commit is contained in:
@ -4,5 +4,7 @@ QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
|
||||
QEDprocesses = "46de9c38-1bb3-4547-a1ec-da24d767fdad"
|
||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
|
||||
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
|
||||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
||||
|
@ -18,12 +18,12 @@ end
|
||||
@safetestset "ABC-Model Unit Tests " begin
|
||||
include("unit_tests_abcmodel.jl")
|
||||
end
|
||||
@safetestset "QED Feynman Diagram Generation Tests" begin
|
||||
include("unit_tests_qed_diagrams.jl")
|
||||
end
|
||||
@safetestset "QED-Model Unit Tests " begin
|
||||
include("unit_tests_qedmodel.jl")
|
||||
end
|
||||
@safetestset "QED Feynman Diagram Generation Tests" begin
|
||||
include("unit_tests_qed_diagrams.jl")
|
||||
end
|
||||
@safetestset "Node Reduction Unit Tests " begin
|
||||
include("node_reduction.jl")
|
||||
end
|
||||
|
@ -2,6 +2,8 @@ using MetagraphOptimization
|
||||
using QEDbase
|
||||
using AccurateArithmetic
|
||||
using Random
|
||||
using UUIDs
|
||||
using StaticArrays
|
||||
|
||||
import MetagraphOptimization.ABCParticle
|
||||
import MetagraphOptimization.interaction_result
|
||||
@ -27,11 +29,11 @@ function ground_truth_graph_result(input::ABCProcessInput)
|
||||
constant = (1 / 137.0)^2
|
||||
|
||||
# calculate particle C in diagram 1
|
||||
diagram1_C = ParticleC(input.inParticles[1].momentum + input.inParticles[2].momentum)
|
||||
diagram2_C = ParticleC(input.inParticles[1].momentum + input.outParticles[2].momentum)
|
||||
diagram1_C = ParticleC(input.inA[1].momentum + input.inB[1].momentum)
|
||||
diagram2_C = ParticleC(input.inA[1].momentum + input.outB[1].momentum)
|
||||
|
||||
diagram1_Cp = ParticleC(input.outParticles[1].momentum + input.outParticles[2].momentum)
|
||||
diagram2_Cp = ParticleC(input.outParticles[1].momentum + input.inParticles[2].momentum)
|
||||
diagram1_Cp = ParticleC(input.outA[1].momentum + input.outB[1].momentum)
|
||||
diagram2_Cp = ParticleC(input.outA[1].momentum + input.inB[1].momentum)
|
||||
|
||||
check_particle_reverse_moment(diagram1_Cp.momentum, diagram1_C.momentum)
|
||||
check_particle_reverse_moment(diagram2_Cp.momentum, diagram2_C.momentum)
|
||||
@ -47,7 +49,18 @@ function ground_truth_graph_result(input::ABCProcessInput)
|
||||
return sum_kbn([diagram1_result, diagram2_result])
|
||||
end
|
||||
|
||||
machine = get_machine_info()
|
||||
machine = Machine(
|
||||
[
|
||||
MetagraphOptimization.NumaNode(
|
||||
0,
|
||||
1,
|
||||
MetagraphOptimization.default_strategy(MetagraphOptimization.NumaNode),
|
||||
-1.0,
|
||||
UUIDs.uuid1(),
|
||||
),
|
||||
],
|
||||
[-1.0;;],
|
||||
)
|
||||
|
||||
process_2_2 = ABCProcessDescription(
|
||||
Dict{Type, Int64}(ParticleA => 1, ParticleB => 1),
|
||||
@ -56,14 +69,12 @@ process_2_2 = ABCProcessDescription(
|
||||
|
||||
particles_2_2 = ABCProcessInput(
|
||||
process_2_2,
|
||||
ABCParticle[
|
||||
ParticleA(SFourMomentum(0.823648, 0.0, 0.0, 0.823648)),
|
||||
ParticleB(SFourMomentum(0.823648, 0.0, 0.0, -0.823648)),
|
||||
],
|
||||
ABCParticle[
|
||||
ParticleA(SFourMomentum(0.823648, -0.835061, -0.474802, 0.277915)),
|
||||
ParticleB(SFourMomentum(0.823648, 0.835061, 0.474802, -0.277915)),
|
||||
],
|
||||
SVector{1}(ParticleA(SFourMomentum(0.823648, 0.0, 0.0, 0.823648))),
|
||||
SVector{1}(ParticleB(SFourMomentum(0.823648, 0.0, 0.0, -0.823648))),
|
||||
SVector{0, ParticleC}(),
|
||||
SVector{1}(ParticleA(SFourMomentum(0.823648, -0.835061, -0.474802, 0.277915))),
|
||||
SVector{1}(ParticleB(SFourMomentum(0.823648, 0.835061, 0.474802, -0.277915))),
|
||||
SVector{0, ParticleC}(),
|
||||
)
|
||||
expected_result = ground_truth_graph_result(particles_2_2)
|
||||
|
||||
|
@ -3,6 +3,7 @@ using QEDbase
|
||||
using QEDprocesses
|
||||
using StatsBase # for countmap
|
||||
using Random
|
||||
using UUIDs
|
||||
|
||||
import MetagraphOptimization.caninteract
|
||||
import MetagraphOptimization.issame
|
||||
@ -17,32 +18,32 @@ def_momentum = SFourMomentum(1.0, 0.0, 0.0, 0.0)
|
||||
RNG = Random.default_rng()
|
||||
|
||||
testparticleTypes = [
|
||||
PhotonStateful{Incoming},
|
||||
PhotonStateful{Outgoing},
|
||||
FermionStateful{Incoming},
|
||||
FermionStateful{Outgoing},
|
||||
AntiFermionStateful{Incoming},
|
||||
AntiFermionStateful{Outgoing},
|
||||
PhotonStateful{Incoming, PolX},
|
||||
PhotonStateful{Outgoing, PolX},
|
||||
FermionStateful{Incoming, SpinUp},
|
||||
FermionStateful{Outgoing, SpinUp},
|
||||
AntiFermionStateful{Incoming, SpinUp},
|
||||
AntiFermionStateful{Outgoing, SpinUp},
|
||||
]
|
||||
|
||||
testparticleTypesPropagated = [
|
||||
PhotonStateful{Outgoing},
|
||||
PhotonStateful{Incoming},
|
||||
FermionStateful{Outgoing},
|
||||
FermionStateful{Incoming},
|
||||
AntiFermionStateful{Outgoing},
|
||||
AntiFermionStateful{Incoming},
|
||||
PhotonStateful{Outgoing, PolX},
|
||||
PhotonStateful{Incoming, PolX},
|
||||
FermionStateful{Outgoing, SpinUp},
|
||||
FermionStateful{Incoming, SpinUp},
|
||||
AntiFermionStateful{Outgoing, SpinUp},
|
||||
AntiFermionStateful{Incoming, SpinUp},
|
||||
]
|
||||
|
||||
function compton_groundtruth(input::QEDProcessInput)
|
||||
# p1k1 -> p2k2
|
||||
# formula: −(ie)^2 (u(p2) slashed(ε1) S(p2 − k1) slashed(ε2) u(p1) + u(p2) slashed(ε2) S(p1 + k1) slashed(ε1) u(p1))
|
||||
|
||||
p1 = input.inParticles[findfirst(x -> typeof(x) <: FermionStateful, input.inParticles)]
|
||||
p2 = input.outParticles[findfirst(x -> typeof(x) <: FermionStateful, input.outParticles)]
|
||||
p1 = input.inFerms[1]
|
||||
p2 = input.outFerms[1]
|
||||
|
||||
k1 = input.inParticles[findfirst(x -> typeof(x) <: PhotonStateful, input.inParticles)]
|
||||
k2 = input.outParticles[findfirst(x -> typeof(x) <: PhotonStateful, input.outParticles)]
|
||||
k1 = input.inPhotons[1]
|
||||
k2 = input.outPhotons[1]
|
||||
|
||||
u_p1 = base_state(Electron(), Incoming(), p1.momentum, spin_or_pol(p1))
|
||||
u_p2 = base_state(Electron(), Outgoing(), p2.momentum, spin_or_pol(p2))
|
||||
@ -117,36 +118,36 @@ end
|
||||
|
||||
@testset "Known processes" begin
|
||||
compton_process = QEDProcessDescription(
|
||||
Dict{Type, Int}(PhotonStateful{Incoming} => 1, FermionStateful{Incoming} => 1),
|
||||
Dict{Type, Int}(PhotonStateful{Outgoing} => 1, FermionStateful{Outgoing} => 1),
|
||||
Dict{Type, Int}(PhotonStateful{Incoming, PolX} => 1, FermionStateful{Incoming, SpinUp} => 1),
|
||||
Dict{Type, Int}(PhotonStateful{Outgoing, PolX} => 1, FermionStateful{Outgoing, SpinUp} => 1),
|
||||
)
|
||||
|
||||
@test parse_process("ke->ke", QEDModel()) == compton_process
|
||||
|
||||
positron_compton_process = QEDProcessDescription(
|
||||
Dict{Type, Int}(PhotonStateful{Incoming} => 1, AntiFermionStateful{Incoming} => 1),
|
||||
Dict{Type, Int}(PhotonStateful{Outgoing} => 1, AntiFermionStateful{Outgoing} => 1),
|
||||
Dict{Type, Int}(PhotonStateful{Incoming, PolX} => 1, AntiFermionStateful{Incoming, SpinUp} => 1),
|
||||
Dict{Type, Int}(PhotonStateful{Outgoing, PolX} => 1, AntiFermionStateful{Outgoing, SpinUp} => 1),
|
||||
)
|
||||
|
||||
@test parse_process("kp->kp", QEDModel()) == positron_compton_process
|
||||
|
||||
trident_process = QEDProcessDescription(
|
||||
Dict{Type, Int}(PhotonStateful{Incoming} => 1, FermionStateful{Incoming} => 1),
|
||||
Dict{Type, Int}(FermionStateful{Outgoing} => 2, AntiFermionStateful{Outgoing} => 1),
|
||||
Dict{Type, Int}(PhotonStateful{Incoming, PolX} => 1, FermionStateful{Incoming, SpinUp} => 1),
|
||||
Dict{Type, Int}(FermionStateful{Outgoing, SpinUp} => 2, AntiFermionStateful{Outgoing, SpinUp} => 1),
|
||||
)
|
||||
|
||||
@test parse_process("ke->eep", QEDModel()) == trident_process
|
||||
|
||||
pair_production_process = QEDProcessDescription(
|
||||
Dict{Type, Int}(PhotonStateful{Incoming} => 2),
|
||||
Dict{Type, Int}(FermionStateful{Outgoing} => 1, AntiFermionStateful{Outgoing} => 1),
|
||||
Dict{Type, Int}(PhotonStateful{Incoming, PolX} => 2),
|
||||
Dict{Type, Int}(FermionStateful{Outgoing, SpinUp} => 1, AntiFermionStateful{Outgoing, SpinUp} => 1),
|
||||
)
|
||||
|
||||
@test parse_process("kk->pe", QEDModel()) == pair_production_process
|
||||
|
||||
pair_annihilation_process = QEDProcessDescription(
|
||||
Dict{Type, Int}(FermionStateful{Incoming} => 1, AntiFermionStateful{Incoming} => 1),
|
||||
Dict{Type, Int}(PhotonStateful{Outgoing} => 2),
|
||||
Dict{Type, Int}(FermionStateful{Incoming, SpinUp} => 1, AntiFermionStateful{Incoming, SpinUp} => 1),
|
||||
Dict{Type, Int}(PhotonStateful{Outgoing, PolX} => 2),
|
||||
)
|
||||
|
||||
@test parse_process("pe->kk", QEDModel()) == pair_annihilation_process
|
||||
@ -160,12 +161,24 @@ end
|
||||
|
||||
for i in 1:100
|
||||
input = gen_process_input(process)
|
||||
@test countmap(typeof.(input.inParticles)) == process.inParticles
|
||||
@test countmap(typeof.(input.outParticles)) == process.outParticles
|
||||
@test length(input.inFerms) == get(process.inParticles, FermionStateful{Incoming, SpinUp}, 0)
|
||||
@test length(input.inAntiferms) == get(process.inParticles, AntiFermionStateful{Incoming, SpinUp}, 0)
|
||||
@test length(input.inPhotons) == get(process.inParticles, PhotonStateful{Incoming, PolX}, 0)
|
||||
@test length(input.outFerms) == get(process.outParticles, FermionStateful{Outgoing, SpinUp}, 0)
|
||||
@test length(input.outAntiferms) == get(process.outParticles, AntiFermionStateful{Outgoing, SpinUp}, 0)
|
||||
@test length(input.outPhotons) == get(process.outParticles, PhotonStateful{Outgoing, PolX}, 0)
|
||||
|
||||
@test isapprox(
|
||||
sum(getfield.(input.inParticles, :momentum)),
|
||||
sum(getfield.(input.outParticles, :momentum));
|
||||
sum([
|
||||
getfield.(input.inFerms, :momentum)...,
|
||||
getfield.(input.inAntiferms, :momentum)...,
|
||||
getfield.(input.inPhotons, :momentum)...,
|
||||
]),
|
||||
sum([
|
||||
getfield.(input.outFerms, :momentum)...,
|
||||
getfield.(input.outAntiferms, :momentum)...,
|
||||
getfield.(input.outPhotons, :momentum)...,
|
||||
]);
|
||||
atol = sqrt(eps()),
|
||||
)
|
||||
end
|
||||
@ -179,7 +192,18 @@ end
|
||||
|
||||
model = QEDModel()
|
||||
process = parse_process("ke->ke", model)
|
||||
machine = get_machine_info()
|
||||
machine = Machine(
|
||||
[
|
||||
MetagraphOptimization.NumaNode(
|
||||
0,
|
||||
1,
|
||||
MetagraphOptimization.default_strategy(MetagraphOptimization.NumaNode),
|
||||
-1.0,
|
||||
UUIDs.uuid1(),
|
||||
),
|
||||
],
|
||||
[-1.0;;],
|
||||
)
|
||||
|
||||
graph = MetagraphOptimization.DAG()
|
||||
|
||||
@ -289,3 +313,37 @@ end
|
||||
compton_function = get_compute_function(graph_generated, process, machine)
|
||||
@test isapprox(compton_function.(input), compton_groundtruth.(input))
|
||||
end
|
||||
|
||||
@testset "Equal results after optimization" for optimizer in
|
||||
[ReductionOptimizer(), RandomWalkOptimizer(MersenneTwister(0))]
|
||||
@testset "Process $proc_str" for proc_str in ["ke->ke", "kp->kp", "kk->ep", "ep->kk", "ke->kke", "ke->kkke"]
|
||||
model = QEDModel()
|
||||
process = parse_process(proc_str, model)
|
||||
machine = Machine(
|
||||
[
|
||||
MetagraphOptimization.NumaNode(
|
||||
0,
|
||||
1,
|
||||
MetagraphOptimization.default_strategy(MetagraphOptimization.NumaNode),
|
||||
-1.0,
|
||||
UUIDs.uuid1(),
|
||||
),
|
||||
],
|
||||
[-1.0;;],
|
||||
)
|
||||
graph = gen_graph(process)
|
||||
|
||||
compute_function = get_compute_function(graph, process, machine)
|
||||
|
||||
if (typeof(optimizer) <: RandomWalkOptimizer)
|
||||
optimize!(optimizer, graph, 100)
|
||||
elseif (typeof(optimizer) <: ReductionOptimizer)
|
||||
optimize_to_fixpoint!(optimizer, graph)
|
||||
end
|
||||
reduced_compute_function = get_compute_function(graph, process, machine)
|
||||
|
||||
input = [gen_process_input(process) for _ in 1:100]
|
||||
|
||||
@test isapprox(compute_function.(input), reduced_compute_function.(input))
|
||||
end
|
||||
end
|
||||
|
Reference in New Issue
Block a user