From e8bc26b0c054041e67d9d28269a1db7715825f60 Mon Sep 17 00:00:00 2001
From: AntonReinhard <anton.reinhard@proton.me>
Date: Tue, 9 Jul 2024 16:47:47 +0200
Subject: [PATCH] Add virtual_particles implementation

---
 Project.toml                   |   2 +
 src/FeynmanDiagramGenerator.jl |   8 +-
 src/QEDprocesses_patch.jl      |  23 +++++
 src/diagrams/diagrams.jl       | 184 +++++++++++++++++++++++++++++----
 src/flat_matrix.jl             |   5 +
 5 files changed, 201 insertions(+), 21 deletions(-)

diff --git a/Project.toml b/Project.toml
index 1c94728..f4aee7b 100644
--- a/Project.toml
+++ b/Project.toml
@@ -5,6 +5,8 @@ version = "0.1.0"
 
 [deps]
 Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
+DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
 QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
 QEDcore = "35dc0263-cb5f-4c33-a114-1d7f54ab753e"
 QEDprocesses = "46de9c38-1bb3-4547-a1ec-da24d767fdad"
+Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
diff --git a/src/FeynmanDiagramGenerator.jl b/src/FeynmanDiagramGenerator.jl
index 42affeb..1d76795 100644
--- a/src/FeynmanDiagramGenerator.jl
+++ b/src/FeynmanDiagramGenerator.jl
@@ -1,8 +1,9 @@
 module FeynmanDiagramGenerator
 
-using QEDbase
-using QEDcore
-using QEDprocesses
+using Reexport
+@reexport using QEDbase
+@reexport using QEDcore
+@reexport using QEDprocesses
 
 include("QEDprocesses_patch.jl")
 
@@ -14,6 +15,7 @@ export GenericQEDProcess, isphysical
 
 export AbstractTreeLevelFeynmanDiagram, FeynmanVertex, FeynmanDiagram
 export external_particles, virtual_particles, process, vertices
+export VirtualParticle
 
 export Forest
 
diff --git a/src/QEDprocesses_patch.jl b/src/QEDprocesses_patch.jl
index 7163f47..f6e1378 100644
--- a/src/QEDprocesses_patch.jl
+++ b/src/QEDprocesses_patch.jl
@@ -5,3 +5,26 @@
 ) where {DIR<:QEDbase.ParticleDirection,PT<:QEDbase.AbstractParticleType}
     return count(x -> x isa PT, particles(proc_def, dir))
 end
+
+
+"""
+    number_particles(proc_def::AbstractProcessDefinition, dir::ParticleDirection, species::AbstractParticleType)
+
+Return the number of particles of the given direction and species in the given process definition.
+"""
+@inline function QEDbase.number_particles(
+    proc_def::AbstractProcessDefinition, dir::DIR, species::PT
+) where {DIR<:ParticleDirection,PT<:AbstractParticleType}
+    return count(x -> x isa PT, particles(proc_def, dir))
+end
+
+"""
+    number_particles(proc_def::AbstractProcessDefinition, particle::AbstractParticleStateful)
+
+Return the number of particles of the given particle's direction and species in the given process definition.
+"""
+@inline function QEDbase.number_particles(
+    proc_def::AbstractProcessDefinition, ps::AbstractParticleStateful
+)
+    return number_particles(proc_def, particle_direction(ps), particle_species(ps))
+end
diff --git a/src/diagrams/diagrams.jl b/src/diagrams/diagrams.jl
index 211a6fc..8b124fa 100644
--- a/src/diagrams/diagrams.jl
+++ b/src/diagrams/diagrams.jl
@@ -1,3 +1,4 @@
+using DataStructures
 using Combinatorics
 using QEDprocesses
 using QEDbase
@@ -80,22 +81,115 @@ end
 import Base: +
 
 # "addition" of the bool tuples
-# realistically, there should never be "colliding" 1s. if there are there is probably an error and this should be asserted
 function +(a::Tuple{NTuple{I,Bool},NTuple{O,Bool}}, b::Tuple{NTuple{I,Bool},NTuple{O,Bool}}) where {I,O}
+    # realistically, there should never be "colliding" 1s. if there are there is probably an error and this should be asserted
+    #= for (i, j) in zip(a[1], b[1]) @assert !(i && j) end
+    for (i, j) in zip(a[2], b[2]) @assert !(i && j) end =#
+
     return (ntuple(i -> a[1][i] || b[1][i], I), ntuple(i -> a[2][i] || b[2][i], O))
 end
 
+# normalize the representation
+function normalize(virtual_particle::VirtualParticle{P,S,IN_T,OUT_T}) where {P,S,IN_T,OUT_T}
+    I = length(virtual_particle.in_particle_contributions)
+    O = length(virtual_particle.out_particle_contributions)
+    data = (virtual_particle.in_particle_contributions, virtual_particle.out_particle_contributions)
+    s = sum(data[1]) + sum(data[2])
+    if s > (I + O) / 2
+        return VirtualParticle(virtual_particle.proc, virtual_particle.species, ntuple(x -> !data[1][x], I), ntuple(x -> !data[2][x], O))
+    elseif s == (I + O) / 2 && data[1][1] == false
+        return VirtualParticle(virtual_particle.proc, virtual_particle.species, ntuple(x -> !data[1][x], I), ntuple(x -> !data[2][x], O))
+    else
+        return virtual_particle
+    end
+end
+
+function _momentum_contribution(proc::AbstractProcessDefinition, dir::ParticleDirection, species::AbstractParticleType, index::Int)
+    # get index of n-th "dir species" particle in proc
+    particles_seen = 0
+    c = 0
+    for p in particles(proc, dir)
+        c += 1
+        if p == species
+            particles_seen += 1
+        end
+        if particles_seen == index
+            return (ntuple(x -> is_incoming(dir) && x == c, number_incoming_particles(proc)), ntuple(x -> is_outgoing(dir) && x == c, number_outgoing_particles(proc)))
+        end
+    end
+end
+
+function _momentum_contribution(proc::AbstractProcessDefinition, diagram::FeynmanDiagram{N,E,U,T,M,FM}, n::Int) where {N,E,U,T,M,FM}
+    if (n > 0 && n <= E)
+        # left electron n
+        electron_n = n
+        if electron_n > number_particles(proc, Incoming(), Electron())
+            # outgoing positron
+            return _momentum_contribution(proc, Outgoing(), Positron(), electron_n - number_particles(proc, Incoming(), Electron()))
+        else
+            # incoming electron
+            return _momentum_contribution(proc, Incoming(), Electron(), electron_n)
+        end
+    elseif (n > E && n <= E + U)
+        # left muon n - E
+        muon_n = n - E
+        throw(InvalidInputError("unimplemented for muons"))
+    elseif (n > E + U && n <= E + U + T)
+        # left tauon n - E - U
+        tauon_n = n - E - U
+        throw(InvalidInputError("unimplemented for tauons"))
+    elseif (n > N && n <= N + M)
+        # photon
+        photon_n = n - N
+        if photon_n > number_particles(proc, Incoming(), Photon())
+            # outgoing photon
+            return _momentum_contribution(proc, Outgoing(), Photon(), photon_n - number_particles(proc, Incoming(), Photon()))
+        else
+            # incoming photon
+            return _momentum_contribution(proc, Incoming(), Photon(), photon_n)
+        end
+    elseif (n > N + M && n <= N + M + E)
+        # right electron
+        electron_n = n - N - M
+        if electron_n > number_particles(proc, Outgoing(), Electron())
+            # incoming positron
+            return _momentum_contribution(proc, Incoming(), Positron(), electron_n - number_particles(proc, Outgoing(), Electron()))
+        else
+            # outgoing electron
+            return _momentum_contribution(proc, Outgoing(), Electron(), electron_n)
+        end
+    elseif (n > N + M + E && n <= N + M + E + U)
+        # right muon
+        muon_n = n - N - M - E
+        throw(InvalidInputError("unimplemented for muons"))
+    elseif (n > N + M + E + U && n <= N + M + E + U + T)
+        # right tauon
+        tauon_n = n - N - M - E - U
+        throw(InvalidInputError("unimplemented for tauons"))
+    else
+        # error
+        throw(InvalidInputError("invalid index given for _momentum_contribution()"))
+    end
+end
+
 function virtual_particles(proc::QEDbase.AbstractProcessDefinition, diagram::FeynmanDiagram{N,E,U,T,M,FM}) where {N,E,U,T,M,FM}
-    I = number_incoming_particles(proc)
-    O = number_outgoing_particles(proc)
+    fermion_lines = PriorityQueue{Int64,Int64}()
 
-    # map of all known particles' momentum composition
-    known_particles = Dict{Int64,Tuple{NTuple{I,Bool},NTuple{O,Bool}}}()
+    # count number of internal photons in each fermion line and make a priority queue for fermion line => number of internal photons
+    for i in 1:N
+        count = 0
+        for p in 1:length(diagram.diagram_structure, i)
+            if diagram.diagram_structure[i, p] <= N
+                # internal photon
+                count += 1
+            end
+        end
+        enqueue!(fermion_lines, i => count)
+    end
 
+    result = Vector()
 
-    # 1: insert all the external ones (won't be returned), they all have exactly one 1 in their composition
-    # TODO
-
+    internal_photon_contributions = Dict()
 
     # 2: Loop: 
     # while there are incomplete fermion lines:
@@ -103,16 +197,66 @@ function virtual_particles(proc::QEDbase.AbstractProcessDefinition, diagram::Fey
     #   walk the fermion line, assign each virtual particle the momentum composition of the previous (or initial fermion if start) "+" the connected particle
     #   when/if the unknown particle is encountered, start walking from the other side
     #   when they meet at the unknown particle, assign the unknown particle Photon and left side - right side momentum contribution
-    # TODO
+    while !isempty(fermion_lines)
+        current_line = dequeue!(fermion_lines)
 
-    # 3: minimalize the contributions, i.e., if the number of contributing particles > half of all particles, invert both vectors
-    #    if it's exactly half of all particles, think of some consistent way to break the symmetry, e.g. swap if the first particle is not contributing
-    # TODO
+        local unknown_photon_momentum = nothing
+        # walk line from the *left* (everything looks like an electron/muon/tauon)
+        species = current_line <= E ? Electron() : throw(InvalidInputError("muon/taun not implemented yet"))
+        cumulative_mom = _momentum_contribution(proc, diagram, current_line)
 
-    # 4: convert the known_particles Dict to an NTuple and remove the external particles (those with only 1 contributing momentum)
-    # TODO
+        for i in 1:length(diagram.diagram_structure, current_line)
+            binding_particle = diagram.diagram_structure[current_line, i]
+            if (binding_particle <= N) # binding_particle is an internal photon
+                if haskey(internal_photon_contributions, binding_particle)   # if the binding particle is known
+                    cumulative_mom += internal_photon_contributions[binding_particle]
+                else # if the binding particle is unknown
+                    # save so far momentum and break, add the right side momentum later
+                    unknown_photon_momentum = cumulative_mom
+                    break
+                end
+            else # binding_particle is an external photon
+                cumulative_mom += _momentum_contribution(proc, diagram, binding_particle)
+            end
+            push!(result, VirtualParticle(proc, species, cumulative_mom...))
+        end
 
-    return NTuple{?,VirtualParticle}()
+        if isnothing(unknown_photon_momentum)
+            # case where we're done (only one fermion line or last fermion line)
+            # fermion_lines always has to be empty at this point, otherwise the tree wouldn't be connected
+            @assert isempty(fermion_lines)
+            continue
+        end
+
+        # walk line from the *right* (everything looks like a positron/antimuon/antitauon)
+        species = current_line <= E ? Positron() : throw(InvalidInputError("muon/taun not implemented yet"))
+        # find right side of the line
+        right_line = diagram.electron_permutation[current_line]
+
+        cumulative_mom = _momentum_contribution(proc, diagram, right_line)
+        for i in length(diagram.diagram_structure, current_line):-1:1   # iterate from the right
+            binding_particle = diagram.diagram_structure[current_line, i]
+            if (binding_particle <= N) # binding_particle is an internal photon
+                if haskey(internal_photon_contributions, binding_particle)   # if the binding particle is known, proceed as above
+                    cumulative_mom += internal_photon_contributions[binding_particle]
+                else # if the binding particle is unknown
+                    # we have arrived at the "middle" of the line
+                    # this line will be the unknown particle for the other lines
+                    internal_photon_contributions[current_line] = cumulative_mom + unknown_photon_momentum
+                    # now we know that the fermion line that binding_particle binds to on the other end has one fewer unknown internal photons
+                    fermion_lines[binding_particle] -= 1
+                    # add the internal photon virtual particle
+                    push!(result, VirtualParticle(proc, Photon(), (cumulative_mom + unknown_photon_momentum)...))
+                    break
+                end
+            else # binding_particle is an external photon
+                cumulative_mom += _momentum_contribution(proc, diagram, binding_particle)
+            end
+            push!(result, VirtualParticle(proc, species, cumulative_mom...))
+        end
+    end
+
+    return ntuple(x -> normalize(result[x]), length(result) - 1)
 end
 
 function vertices(diagram::FeynmanDiagram{N,E,U,T,M,FM}) where {N,E,U,T,M,FM}
@@ -271,11 +415,15 @@ function feynman_diagrams(in_particles::Tuple, out_particles::Tuple)
     # TODO: do this the same way as for e when muons and tauons are a part of QED.jl
     u = 0
     t = 0
+    n = e + u + t
 
+    # the numbers in the feynman diagram go as follows:
+    # left electrons -> left muons -> left tauons -> left photons -> right photons -> right electrons -> right muons -> right tauons
+    # a "left" fermion is simply an incoming fermion or outgoing antifermion of the type, while a "left" photon is an incoming photon, and the reverse for the right ones
     f_iter = _feynman_structures(e + u + t, m)
-    e_perms = collect(permutations(Int[1:e;]))
-    u_perms = collect(permutations(Int[e+1:e+u;]))
-    t_perms = collect(permutations(Int[e+u+1:e+u+t;]))
+    e_perms = collect(permutations(Int[n+m+1:n+m+e;]))
+    u_perms = collect(permutations(Int[n+m+e+1:n+m+e+u;]))
+    t_perms = collect(permutations(Int[n+m+e+u+1:n+m+e+u+t;]))
     first_photon_structure, _ = iterate(f_iter)
 
     return FeynmanDiagramIterator(Val(e), e_perms, 1, Val(u), u_perms, 1, Val(t), t_perms, 1, Val(m), f_iter, first_photon_structure)
diff --git a/src/flat_matrix.jl b/src/flat_matrix.jl
index 5fafb10..7065ff7 100644
--- a/src/flat_matrix.jl
+++ b/src/flat_matrix.jl
@@ -21,3 +21,8 @@ function Base.getindex(m::FlatMatrix{T,N,M}, x::Int, y::Int) where {T,N,M}
     x == M ? m.indices[x] + y <= N : m.indices[x] + y <= m.indices[x+1] || throw(InvalidInputError("invalid indices ($x, $y) for flat matrix $m"))
     return m.values[m.indices[x]+y]
 end
+
+function Base.length(m::FlatMatrix{T,N,M}, x::Int) where {T,N,M}
+    (x <= M && x > 0) || throw(InvalidInputError("invalid index $x for flat matrix $m"))
+    return x == M ? N - m.indices[x] : m.indices[x+1] - m.indices[x]
+end