From c88898a5025bbe583f3acd492f1845727b527b91 Mon Sep 17 00:00:00 2001
From: Anton Reinhard <anton.reinhard@wandelbots.com>
Date: Mon, 25 Sep 2023 18:49:44 +0200
Subject: [PATCH] WIP

---
 Project.toml                 |   3 +
 src/code_gen/main.jl         |  82 +++++++-----
 src/models/abc/compute.jl    |   1 +
 src/models/abc/create.jl     | 236 ++++++++++++++++++++++++++---------
 src/models/abc/particle.jl   |  28 +++--
 src/task/print.jl            |   3 +-
 test/Project.toml            |   1 +
 test/unit_tests_execution.jl |  33 ++---
 8 files changed, 266 insertions(+), 121 deletions(-)

diff --git a/Project.toml b/Project.toml
index b6b3812..ff33dc5 100644
--- a/Project.toml
+++ b/Project.toml
@@ -7,10 +7,13 @@ version = "0.1.0"
 AccurateArithmetic = "22286c92-06ac-501d-9306-4abd417d9753"
 CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
 DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
+ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
 JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
 KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
 NumaAllocators = "21436f30-1b4a-4f08-87af-e26101bb5379"
+QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
+Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
 UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
 oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
 
diff --git a/src/code_gen/main.jl b/src/code_gen/main.jl
index 3bbb5a3..16e6a6f 100644
--- a/src/code_gen/main.jl
+++ b/src/code_gen/main.jl
@@ -50,13 +50,27 @@ function gen_code(graph::DAG)
     )
 end
 
-"""
-    execute(generated_code, input::Dict{ParticleType, Vector{Particle}})
+function gen_input_assignment_code(
+    inputSymbols::Dict{String, Symbol},
+    particles::Tuple{Vector{Particle}, Vector{Particle}},
+)
+    @assert !isempty(particles[1]) "Can't have 0 input particles!"
+    @assert !isempty(particles[2]) "Can't have 0 output particles!"
+    @assert length(inputSymbols) == length(particles[1]) + length(particles[2])
 
-Execute the given `generated_code` (as returned by [`gen_code`](@ref)) on the given input particles.
-"""
-function execute(generated_code, input::Dict{ParticleType, Vector{Particle}})
-    (code, inputSymbols, outputSymbol) = generated_code
+    # TODO none of this is very pretty
+    in_out_count = Dict{ParticleType, Tuple{Int, Int}}()
+    for type in types(particles[1][1])
+        in_out_count[type] = (0, 0)
+    end
+    for p in particles[1]
+        (i, o) = in_out_count[p.type]
+        in_out_count[p.type] = (i + 1, o)
+    end
+    for p in particles[2]
+        (i, o) = in_out_count[p.type]
+        in_out_count[p.type] = (i, o + 1)
+    end
 
     assignInputs = Vector{Expr}()
     for (name, symbol) in inputSymbols
@@ -70,15 +84,42 @@ function execute(generated_code, input::Dict{ParticleType, Vector{Particle}})
         end
         index = parse(Int, name[2:end])
 
+        p = nothing
+
+        condition(x) = x.type == type
+
+        if (index > in_out_count[type][1])
+            index -= in_out_count[type][1]
+            @assert index <= in_out_count[type][2] "Too few particles of type $type in input particles for this process"
+
+            p = particles[2][findall(condition, particles[2])[index]]
+        else
+            p = particles[1][findall(condition, particles[1])[index]]
+        end
+
         push!(
             assignInputs,
             Meta.parse(
-                "$(symbol) = ParticleValue(Particle($(input[type][index]).P0, $(input[type][index]).P1, $(input[type][index]).P2, $(input[type][index]).P3, $(type)), 1.0)",
+                "$(symbol) = ParticleValue(Particle($(p.momentum), $(p.type)), 1.0)",
             ),
         )
     end
 
-    assignInputs = Expr(:block, assignInputs...)
+    return Expr(:block, assignInputs...)
+end
+
+"""
+    execute(generated_code, input::Dict{ParticleType, Vector{Particle}})
+
+Execute the given `generated_code` (as returned by [`gen_code`](@ref)) on the given input particles.
+"""
+function execute(
+    generated_code,
+    input::Tuple{Vector{Particle}, Vector{Particle}},
+)
+    (code, inputSymbols, outputSymbol) = generated_code
+
+    assignInputs = gen_input_assignment_code(inputSymbols, input)
     eval(assignInputs)
     eval(code)
 
@@ -94,33 +135,16 @@ The input particles should be sorted correctly into the dictionary to their acco
 
 See also: [`gen_particles`](@ref)
 """
-function execute(graph::DAG, input::Dict{ParticleType, Vector{Particle}})
+function execute(graph::DAG, input::Tuple{Vector{Particle}, Vector{Particle}})
     (code, inputSymbols, outputSymbol) = gen_code(graph)
 
-    assignInputs = Vector{Expr}()
-    for (name, symbol) in inputSymbols
-        type = nothing
-        if startswith(name, "A")
-            type = A
-        elseif startswith(name, "B")
-            type = B
-        else
-            type = C
-        end
-        index = parse(Int, name[2:end])
+    assignInputs = gen_input_assignment_code(inputSymbols, input)
 
-        push!(
-            assignInputs,
-            Meta.parse(
-                "$(symbol) = ParticleValue(Particle($(input[type][index]).P0, $(input[type][index]).P1, $(input[type][index]).P2, $(input[type][index]).P3, $(type)), 1.0)",
-            ),
-        )
-    end
-
-    assignInputs = Expr(:block, assignInputs...)
+    println(code)
     eval(assignInputs)
     eval(code)
 
     eval(Meta.parse("result = $outputSymbol"))
+
     return result
 end
diff --git a/src/models/abc/compute.jl b/src/models/abc/compute.jl
index 399b3f0..7a7ae92 100644
--- a/src/models/abc/compute.jl
+++ b/src/models/abc/compute.jl
@@ -184,6 +184,7 @@ function get_expression(
     expr1 = nothing
     expr2 = nothing
 
+    expr0 = Meta.parse("# fused compute task $(t.first_task), $(t.second_task)")
     expr1 = get_expression(t.first_task, t.t1_inputs, t.t1_output)
     expr2 =
         get_expression(t.second_task, [t.t2_inputs..., t.t1_output], outExpr)
diff --git a/src/models/abc/create.jl b/src/models/abc/create.jl
index 865b18c..4f053c5 100644
--- a/src/models/abc/create.jl
+++ b/src/models/abc/create.jl
@@ -1,74 +1,188 @@
+using QEDbase
+using Random
+using Roots
+using ForwardDiff
 
 """
-    Particle(rng)
+    gen_particles(in::Vector{ParticleType}, out::Vector{ParticleType})
 
-Return a randomly generated particle.
-"""
-function Particle(rng, type::ParticleType)
-
-    p1 = rand(rng, Float64)
-    p2 = rand(rng, Float64)
-    p3 = rand(rng, Float64)
-    m = mass(type)
-
-    # keep the momenta of the particles on-shell
-    p4 = sqrt(p1^2 + p2^2 + p3^2 + m^2)
-
-    return Particle(p1, p2, p3, p4, type)
-end
-
-"""
-    gen_particles(n::Int)
-
-Return a Vector of `n` randomly generated [`Particle`](@ref)s.
+Return a Vector of randomly generated [`Particle`](@ref)s. `in` is the list of particles that enter the process, `out` the list of particles that exit it. Their added momenta will be equal.
 
 Note: This does not take into account the preservation of momenta required for an actual valid process!
 """
-function gen_particles(ns::Dict{ParticleType, Int})
+function gen_particles(
+    in_particles::Vector{ParticleType},
+    out_particles::Vector{ParticleType},
+)
     particles = Dict{ParticleType, Vector{Particle}}()
     rng = MersenneTwister(0)
 
-
-    if ns == Dict((A => 2), (B => 2))
-        rho = 1.0
-
-        omega = rand(rng, Float64)
-        theta = rand(rng, Float64) * π
-        phi = rand(rng, Float64) * π
-
-        particles[A] = Vector{Particle}()
-        particles[B] = Vector{Particle}()
-
-        push!(particles[A], Particle(omega, 0, 0, omega, A))
-        push!(particles[B], Particle(omega, 0, 0, -omega, B))
-        push!(
-            particles[A],
-            Particle(
-                omega,
-                rho * cos(theta) * cos(phi),
-                rho * cos(theta) * sin(phi),
-                rho * sin(theta),
-                A,
-            ),
-        )
-        push!(
-            particles[B],
-            Particle(
-                omega,
-                -rho * cos(theta) * cos(phi),
-                -rho * cos(theta) * sin(phi),
-                -rho * sin(theta),
-                B,
-            ),
-        )
-        return particles
+    mass_sum = 0
+    input_masses = Vector{Float64}()
+    for particle in in_particles
+        mass_sum += mass(particle)
+        push!(input_masses, mass(particle))
+    end
+    output_masses = Vector{Float64}()
+    for particle in out_particles
+        mass_sum += mass(particle)
+        push!(output_masses, mass(particle))
     end
 
-    for (type, n) in ns
-        particles[type] = Vector{Particle}()
-        for i in 1:n
-            push!(particles[type], Particle(rng, type))
-        end
+    # add some extra random mass to allow for some momentum
+    mass_sum += rand(rng) * (length(in_particles) + length(out_particles))
+
+
+    input_particles = Vector{Particle}()
+    initial_momenta = generate_initial_moms(mass_sum, input_masses)
+    for (mom, type) in zip(initial_momenta, in_particles)
+        push!(input_particles, Particle(mom, type))
     end
-    return particles
+
+    output_particles = Vector{Particle}()
+    final_momenta = generate_physical_massive_moms(rng, mass_sum, output_masses)
+    for (mom, type) in zip(final_momenta, out_particles)
+        push!(output_particles, Particle(mom, type))
+    end
+
+    return (input_particles, output_particles)
+end
+
+####################
+# CODE FROM HERE BORROWED FROM SOURCE: https://codebase.helmholtz.cloud/qedsandbox/QEDphasespaces.jl/
+# use qedphasespaces directly once released
+#
+# quick and dirty implementation of the RAMBO algorithm
+#
+# reference: 
+# * https://cds.cern.ch/record/164736/files/198601282.pdf
+# * https://www.sciencedirect.com/science/article/pii/0010465586901190
+####################
+
+function generate_initial_moms(ss, masses)
+    E1 = (ss^2 + masses[1]^2 - masses[2]^2) / (2 * ss)
+    E2 = (ss^2 + masses[2]^2 - masses[1]^2) / (2 * ss)
+
+    rho1 = sqrt(E1^2 - masses[1]^2)
+    rho2 = sqrt(E2^2 - masses[2]^2)
+
+    return [SFourMomentum(E1, 0, 0, rho1), SFourMomentum(E2, 0, 0, -rho2)]
+end
+
+
+Random.rand(rng::AbstractRNG, ::Random.SamplerType{SFourMomentum}) =
+    SFourMomentum(rand(rng, 4))
+Random.rand(
+    rng::AbstractRNG,
+    ::Random.SamplerType{NTuple{N, Float64}},
+) where {N} = Tuple(rand(rng, N))
+
+
+function _transform_uni_to_mom(u1, u2, u3, u4)
+    cth = 2 * u1 - 1
+    sth = sqrt(1 - cth^2)
+    phi = 2 * pi * u2
+    q0 = -log(u3 * u4)
+    qx = q0 * sth * cos(phi)
+    qy = q0 * sth * sin(phi)
+    qz = q0 * cth
+
+    return SFourMomentum(q0, qx, qy, qz)
+end
+
+function _transform_uni_to_mom!(uni_mom, dest)
+    u1, u2, u3, u4 = Tuple(uni_mom)
+    cth = 2 * u1 - 1
+    sth = sqrt(1 - cth^2)
+    phi = 2 * pi * u2
+    q0 = -log(u3 * u4)
+    qx = q0 * sth * cos(phi)
+    qy = q0 * sth * sin(phi)
+    qz = q0 * cth
+
+    return dest = SFourMomentum(q0, qx, qy, qz)
+end
+
+_transform_uni_to_mom(u1234::Tuple) = _transform_uni_to_mom(u1234...)
+_transform_uni_to_mom(u1234::SFourMomentum) =
+    _transform_uni_to_mom(Tuple(u1234))
+
+function generate_massless_moms(rng, n::Int)
+    a = Vector{SFourMomentum}(undef, n)
+    rand!(rng, a)
+    return map(_transform_uni_to_mom, a)
+end
+
+function generate_physical_massless_moms(rng, ss, n)
+    r_moms = generate_massless_moms(rng, n)
+    Q = sum(r_moms)
+    M = sqrt(Q * Q)
+    fac = -1 / M
+    Qx = getX(Q)
+    Qy = getY(Q)
+    Qz = getZ(Q)
+    bx = fac * Qx
+    by = fac * Qy
+    bz = fac * Qz
+    gamma = getT(Q) / M
+    a = 1 / (1 + gamma)
+    x = ss / M
+
+    i = 1
+    while i <= n
+        mom = r_moms[i]
+        mom0 = getT(mom)
+        mom1 = getX(mom)
+        mom2 = getY(mom)
+        mom3 = getZ(mom)
+
+        bq = bx * mom1 + by * mom2 + bz * mom3
+
+        p0 = x * (gamma * mom0 + bq)
+        px = x * (mom1 + bx * mom0 + a * bq * bx)
+        py = x * (mom2 + by * mom0 + a * bq * by)
+        pz = x * (mom3 + bz * mom0 + a * bq * bz)
+
+        r_moms[i] = SFourMomentum(p0, px, py, pz)
+        i += 1
+    end
+    return r_moms
+end
+
+function _to_be_solved(xi, masses, p0s, ss)
+    sum = 0.0
+    for (i, E) in enumerate(p0s)
+        sum += sqrt(masses[i]^2 + xi^2 * E^2)
+    end
+    return sum - ss
+end
+
+function _build_massive_momenta(xi, masses, massless_moms)
+    vec = SFourMomentum[]
+    i = 1
+    while i <= length(massless_moms)
+        massless_mom = massless_moms[i]
+        k0 = sqrt(getT(massless_mom)^2 * xi^2 + masses[i]^2)
+
+        kx = xi * getX(massless_mom)
+        ky = xi * getY(massless_mom)
+        kz = xi * getZ(massless_mom)
+
+        push!(vec, SFourMomentum(k0, kx, ky, kz))
+
+        i += 1
+    end
+    return vec
+end
+
+first_derivative(func) = x -> ForwardDiff.derivative(func, float(x))
+
+
+function generate_physical_massive_moms(rng, ss, masses; x0 = 0.1)
+    n = length(masses)
+    massless_moms = generate_physical_massless_moms(rng, ss, n)
+    energies = getT.(massless_moms)
+    f = x -> _to_be_solved(x, masses, energies, ss)
+    xi = find_zero((f, first_derivative(f)), x0, Roots.Newton())
+    return _build_massive_momenta(xi, masses, massless_moms)
 end
diff --git a/src/models/abc/particle.jl b/src/models/abc/particle.jl
index 626474e..c166836 100644
--- a/src/models/abc/particle.jl
+++ b/src/models/abc/particle.jl
@@ -1,3 +1,5 @@
+using QEDbase
+
 """
     ParticleType
 
@@ -16,16 +18,13 @@ const PARTICLE_MASSES =
 """
     Particle
 
-A struct describing a particle of the ABC-Model. It has the 4 momentum parts P0...P3 and a [`ParticleType`](@ref).
+A struct describing a particle of the ABC-Model. It has the 4 momentum of the particle and a [`ParticleType`](@ref).
 
 `sizeof(Particle())` = 40 Byte
 """
 struct Particle
     # SFourMomentum
-    P0::Float64
-    P1::Float64
-    P2::Float64
-    P3::Float64
+    momentum::SFourMomentum
 
     type::ParticleType
 end
@@ -65,6 +64,15 @@ function remaining_type(t1::ParticleType, t2::ParticleType)
     end
 end
 
+"""
+    types(::Particle)
+
+Return a Vector of the possible [`ParticleType`](@ref)s of this [`Particle`](@ref).
+"""
+function types(::Particle)
+    return [A, B, C]
+end
+
 """
     square(p::Particle)
 
@@ -73,7 +81,7 @@ Return the square of the particle's momentum as a `Float` value.
 Takes 7 effective FLOP.
 """
 function square(p::Particle)
-    return p.P0 * p.P0 - p.P1 * p.P1 - p.P2 * p.P2 - p.P3 * p.P3
+    return getMass(p.momentum)
 end
 
 """
@@ -119,13 +127,7 @@ Calculate and return a new particle from two given interacting ones at a vertex.
 Takes 4 effective FLOP.
 """
 function preserve_momentum(p1::Particle, p2::Particle)
-    p3 = Particle(
-        p1.P0 + p2.P0,
-        p1.P1 + p2.P1,
-        p1.P2 + p2.P2,
-        p1.P3 + p2.P3,
-        remaining_type(p1.type, p2.type),
-    )
+    p3 = Particle(p1.momentum + p2.momentum, remaining_type(p1.type, p2.type))
 
     return p3
 end
diff --git a/src/task/print.jl b/src/task/print.jl
index 5909b5f..da6797d 100644
--- a/src/task/print.jl
+++ b/src/task/print.jl
@@ -4,6 +4,5 @@
 Print a string representation of the fused compute task to io.
 """
 function show(io::IO, t::FusedComputeTask)
-    (T1, T2) = get_types(t)
-    return print(io, "ComputeFuse(", T1(), ", ", T2(), ")")
+    return print(io, "ComputeFuse(", t.first_task, ", ", t.second_task, ")")
 end
diff --git a/test/Project.toml b/test/Project.toml
index 7a21f89..fbcc5de 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -1,3 +1,4 @@
 [deps]
+QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
diff --git a/test/unit_tests_execution.jl b/test/unit_tests_execution.jl
index 1fb1a94..869cb81 100644
--- a/test/unit_tests_execution.jl
+++ b/test/unit_tests_execution.jl
@@ -2,25 +2,26 @@ import MetagraphOptimization.A
 import MetagraphOptimization.B
 import MetagraphOptimization.ParticleType
 
+using QEDbase
+
 include("../examples/profiling_utilities.jl")
 
 @testset "Unit Tests Execution" begin
-    particles = Dict{ParticleType, Vector{Particle}}(
-        (
-            A => [
-                Particle(0.823648, 0.0, 0.0, 0.823648, A),
-                Particle(0.823648, -0.835061, -0.474802, 0.277915, A),
-            ]
-        ),
-        (
-            B => [
-                Particle(0.823648, 0.0, 0.0, -0.823648, B),
-                Particle(0.823648, 0.835061, 0.474802, -0.277915, B),
-            ]
-        ),
-    )
+    particles = Tuple{Vector{Particle}, Vector{Particle}}((
+        [
+            Particle(SFourMomentum(0.823648, 0.0, 0.0, 0.823648), A),
+            Particle(SFourMomentum(0.823648, 0.0, 0.0, -0.823648), B),
+        ],
+        [
+            Particle(
+                SFourMomentum(0.823648, -0.835061, -0.474802, 0.277915),
+                A,
+            ),
+            Particle(SFourMomentum(0.823648, 0.835061, 0.474802, -0.277915), B),
+        ],
+    ))
 
-    expected_result = 5.5320567694746876e-5
+    expected_result = 7.594784103424603e-5
 
     @testset "AB->AB no optimization" begin
         for _ in 1:10   # test in a loop because graph layout should not change the result
@@ -41,7 +42,7 @@ include("../examples/profiling_utilities.jl")
     end
 
     @testset "AB->AB after random walk" begin
-        for _ in 1:20
+        for _ in 1:50
             graph = parse_abc(joinpath(@__DIR__, "..", "input", "AB->AB.txt"))
             random_walk!(graph, 40)