Fix congruent ph script
Some checks failed
MetagraphOptimization_CI / test (push) Failing after 1m35s
MetagraphOptimization_CI / docs (push) Failing after 1m39s

This commit is contained in:
Anton Reinhard 2024-07-04 17:03:18 +02:00
parent 55501c15c8
commit 92f534f6bf

View File

@ -1,8 +1,10 @@
using MetagraphOptimization
using QEDbase
using QEDcore
using QEDprocesses
using Random
using UUIDs
using CSV
RNG = Random.MersenneTwister(123)
@ -21,103 +23,79 @@ function mock_machine()
)
end
function congruent_input(processDescription::QEDProcessDescription, omega::Number)
function congruent_input_momenta(processDescription::GenericQEDProcess, omega::Number)
# generate an input sample for given e + nk -> e' + k' process, where the nk are equal
massSum = 0
inputMasses = Vector{Float64}()
for (particle, n) in processDescription.inParticles
for _ in 1:n
massSum += mass(particle)
push!(inputMasses, mass(particle))
end
for particle in incoming_particles(processDescription)
massSum += mass(particle)
push!(inputMasses, mass(particle))
end
outputMasses = Vector{Float64}()
for (particle, n) in processDescription.outParticles
for _ in 1:n
massSum += mass(particle)
push!(outputMasses, mass(particle))
end
for particle in outgoing_particles(processDescription)
massSum += mass(particle)
push!(outputMasses, mass(particle))
end
initialMomenta = [
i == 1 ? SFourMomentum(1, 0, 0, 0) : SFourMomentum(omega, 0, 0, omega) for
i in 1:length(inputMasses)
]
initial_momenta =
[i == 1 ? SFourMomentum(1, 0, 0, 0) : SFourMomentum(omega, 0, 0, omega) for i in 1:length(inputMasses)]
# add some extra random mass to allow for some momentum
ss = sqrt(sum(initialMomenta) * sum(initialMomenta))
ss = sqrt(sum(initial_momenta) * sum(initial_momenta))
final_momenta = MetagraphOptimization.generate_physical_massive_moms(RNG, ss, outputMasses)
result = Vector{QEDProcessInput}()
sizehint!(result, 16)
spin_pol_combinations = Iterators.product(
[SpinUp, SpinDown], [SpinUp, SpinDown], [PolX, PolY], [PolX, PolY]
)
for (in_spin, out_spin, in_pol, out_pol) in spin_pol_combinations
# get the electron first, then the n photons
particles = Vector{QEDParticle}()
for (particle, n) in processDescription.inParticles
if particle <: FermionStateful
mom = initialMomenta[1]
push!(particles, particle(mom, in_spin()))
elseif particle <: PhotonStateful
for i in 1:n
mom = initialMomenta[i + 1]
push!(particles, particle(mom, in_pol()))
end
else
@assert false
end
end
final_momenta = MetagraphOptimization.generate_physical_massive_moms(
RNG, ss, outputMasses
)
index = 1
for (particle, n) in processDescription.outParticles
for _ in 1:n
if particle <: FermionStateful
push!(particles, particle(final_momenta[index], out_spin()))
elseif particle <: PhotonStateful
push!(particles, particle(final_momenta[index], out_pol()))
end
index += 1
end
end
inFerms = MetagraphOptimization._svector_from_type(
processDescription, FermionStateful{Incoming,in_spin}, particles
)
outFerms = MetagraphOptimization._svector_from_type(
processDescription, FermionStateful{Outgoing,out_spin}, particles
)
inAntiferms = MetagraphOptimization._svector_from_type(
processDescription, AntiFermionStateful{Incoming,in_spin}, particles
)
outAntiferms = MetagraphOptimization._svector_from_type(
processDescription, AntiFermionStateful{Outgoing,out_spin}, particles
)
inPhotons = MetagraphOptimization._svector_from_type(
processDescription, PhotonStateful{Incoming,in_pol}, particles
)
outPhotons = MetagraphOptimization._svector_from_type(
processDescription, PhotonStateful{Outgoing,out_pol}, particles
)
processInput = QEDProcessInput(
processDescription,
inFerms,
outFerms,
inAntiferms,
outAntiferms,
inPhotons,
outPhotons,
)
push!(result, processInput)
end
return result
return (tuple(initial_momenta...), tuple(final_momenta...))
end
function build_psp(processDescription::GenericQEDProcess, momenta)
return PhaseSpacePoint(
processDescription,
PerturbativeQED(),
PhasespaceDefinition(SphericalCoordinateSystem(), ElectronRestFrame()),
momenta[1],
momenta[2],
)
end
n = 1000000
photons = 4
omega = 2e-3
# n is the number of incoming photons
# omega is the number
println("Generating $n inputs for $photons photons, omega=$omega...")
# temp process to generate momenta
temp_process = parse_process("k"^photons * "e->ke", QEDModel(), PolX(), SpinUp(), PolX(), SpinUp())
input_momenta = [congruent_input_momenta(temp_process, omega) for _ in 1:n]
results = [0.0im for _ in 1:n]
for (in_pol, in_spin, out_pol, out_spin) in
Iterators.product([PolX(), PolY()], [SpinUp(), SpinDown()], [PolX(), PolY()], [SpinUp(), SpinDown()])
print("Calculating for spin/pol config: $in_pol, $in_spin -> $out_pol, $out_spin... Preparing inputs... ")
process = parse_process("k"^photons * "e->ke", QEDModel(), in_pol, in_spin, out_pol, out_spin)
inputs = build_psp.(Ref(process), input_momenta)
print("Preparing graph... ")
# prepare function
graph = gen_graph(process)
optimize_to_fixpoint!(ReductionOptimizer(), graph)
print("Preparing function... ")
func = get_compute_function(graph, process, mock_machine())
print("Calculating... ")
for i in 1:n
results[i] += func(inputs[i])^2
end
println("Done.")
end
println("Writing results")
open("$(photons)_congruent_photons_omega_$(omega).csv", "w") do f
println(f, "out_photon_momentum;out_electron_momentum;result")
for (momentum, result) in Iterators.zip(input_momenta, results)
println(f, "$(momentum[2][1]);$(momentum[2][2]);$(result)")
end
end