From 2b7c02c223cfc29b593f1b926af02a80b9a59118 Mon Sep 17 00:00:00 2001 From: Anton Reinhard Date: Thu, 20 Jun 2024 15:58:03 +0200 Subject: [PATCH] Add flat matrix --- src/FeynmanDiagramGenerator.jl | 4 ++++ src/diagrams/diagrams.jl | 18 ++++++++++-------- src/flat_matrix.jl | 23 +++++++++++++++++++++++ 3 files changed, 37 insertions(+), 8 deletions(-) create mode 100644 src/flat_matrix.jl diff --git a/src/FeynmanDiagramGenerator.jl b/src/FeynmanDiagramGenerator.jl index a460378..7596639 100644 --- a/src/FeynmanDiagramGenerator.jl +++ b/src/FeynmanDiagramGenerator.jl @@ -7,6 +7,8 @@ include("QEDprocesses_patch.jl") import Base.== +export FlatMatrix + export GenericQEDProcess, isphysical export AbstractTreeLevelFeynmanDiagram, FeynmanVertex, FeynmanDiagram @@ -19,6 +21,8 @@ export plane_trees, labelled_plane_trees, feynman_diagrams export can_interact export QED +include("flat_matrix.jl") + include("trees/labelled_plane_trees.jl") include("trees/trees.jl") include("trees/iterator.jl") diff --git a/src/diagrams/diagrams.jl b/src/diagrams/diagrams.jl index 159ce44..a4d653c 100644 --- a/src/diagrams/diagrams.jl +++ b/src/diagrams/diagrams.jl @@ -43,9 +43,8 @@ end # Feynman Diagram, tree-level, QED # -struct FeynmanDiagram{N,E,U,T,M} <: AbstractTreeLevelFeynmanDiagram where {N,E,U,T,M} - # TODO: flatten into one list - diagram_structure::NTuple{N,Vector{Int}} +struct FeynmanDiagram{N,E,U,T,M,FM} <: AbstractTreeLevelFeynmanDiagram where {N,E,U,T,M,FM<:FlatMatrix} + diagram_structure::FM electron_permutation::NTuple{E,Int} muon_permutation::NTuple{U,Int} @@ -66,13 +65,14 @@ struct FeynmanDiagram{N,E,U,T,M} <: AbstractTreeLevelFeynmanDiagram where {N,E,U @assert T == length(tauon_perm) N = E + U + T - return new{N,E,U,T,M}(NTuple{N,Vector{Int}}(structure), NTuple{E,Int}(elec_perm), NTuple{U,Int}(muon_perm), NTuple{T,Int}(tauon_perm)) + fm = FlatMatrix(structure) + return new{N,E,U,T,M,typeof(fm)}(fm, NTuple{E,Int}(elec_perm), NTuple{U,Int}(muon_perm), NTuple{T,Int}(tauon_perm)) end end -function virtual_particles(diagram::FeynmanDiagram) +function virtual_particles(diagram::FeynmanDiagram{N,E,U,T,M,FM}) - return NTuple{N,Tuple{QEDbase.AbstractParticleType,BitArray}}() + return NTuple{N,Tuple{QEDbase.AbstractParticleType,BitArray,BitArray}}() end function vertices(::AbstractTreeLevelFeynmanDiagram) @@ -175,8 +175,9 @@ function Base.length(it::FeynmanDiagramIterator{E,U,T,M}) where {E,U,T,M} end function Base.iterate(iter::FeynmanDiagramIterator) + f = FeynmanDiagram(iter.photon_structure, iter.e_perms[iter.e_index], iter.u_perms[iter.u_index], iter.t_perms[iter.t_index], iter.e, iter.u, iter.t, iter.m) return ( - FeynmanDiagram(iter.photon_structure, iter.e_perms[iter.e_index], iter.u_perms[iter.u_index], iter.t_perms[iter.t_index], iter.e, iter.u, iter.t, iter.m), + f, nothing ) end @@ -203,8 +204,9 @@ function Base.iterate(iter::FeynmanDiagramIterator, ::Nothing) (iter.photon_structure, _) = photon_iter_result end + f = FeynmanDiagram(iter.photon_structure, iter.e_perms[iter.e_index], iter.u_perms[iter.u_index], iter.t_perms[iter.t_index], iter.e, iter.u, iter.t, iter.m) return ( - FeynmanDiagram(iter.photon_structure, iter.e_perms[iter.e_index], iter.u_perms[iter.u_index], iter.t_perms[iter.t_index], iter.e, iter.u, iter.t, iter.m), + f, nothing ) end diff --git a/src/flat_matrix.jl b/src/flat_matrix.jl new file mode 100644 index 0000000..5fafb10 --- /dev/null +++ b/src/flat_matrix.jl @@ -0,0 +1,23 @@ +# array of arrays but with a given number of arrays (M) and given total length (N) + +struct FlatMatrix{T,N,M} + values::NTuple{N,T} + indices::NTuple{M,Int} + + function FlatMatrix(v::Vector{Vector{T}}) where {T} + M = length(v) + N = sum(length.(v)) + + values = NTuple{N,T}(vcat(v...)) + indices = ntuple(i -> sum(length.(v[1:i-1])), M) + + return new{Int,N,M}(values, indices) + end +end + +function Base.getindex(m::FlatMatrix{T,N,M}, x::Int, y::Int) where {T,N,M} + x <= M || throw(InvalidInputError("invalid indices ($x, $y) for flat matrix $m")) + (x > 0 && y > 0) || throw(InvalidInputError("invalid indices ($x, $y) for flat matrix $m")) + x == M ? m.indices[x] + y <= N : m.indices[x] + y <= m.indices[x+1] || throw(InvalidInputError("invalid indices ($x, $y) for flat matrix $m")) + return m.values[m.indices[x]+y] +end