Add compute functions

This commit is contained in:
Rubydragon 2024-07-16 12:42:50 +02:00
parent 17e180dd26
commit 9859c3b928

View File

@ -1,17 +1,32 @@
struct ComputeTask_BaseState <: AbstractComputeTask end # calculate the base state of an external particle struct ComputeTask_BaseState <: AbstractComputeTask end # calculate the base state of an external particle
struct ComputeTask_Propagator <: AbstractComputeTask end # calculate the propagator term of a virtual particle struct ComputeTask_Propagator <: AbstractComputeTask end # calculate the propagator term of a virtual particle
struct ComputeTask_Pair <: AbstractComputeTask end # from a pair of virtual particle currents, calculate the product struct ComputeTask_Pair <: AbstractComputeTask end # from a pair of virtual particle currents, calculate the product
struct ComputeTask_CollectPairs <: AbstractComputeTask end # for a list of virtual particle current pair products and a propagator, sum and propagate struct ComputeTask_CollectPairs <: AbstractComputeTask end # for a list of virtual particle current pair products, sum
struct ComputeTask_PropagatePairs <: AbstractComputeTask end # for the result of a CollectPairs compute task and a propagator, propagate the sum
struct ComputeTask_Triple <: AbstractComputeTask end # from a triple of virtual particle currents, calculate the diagram result struct ComputeTask_Triple <: AbstractComputeTask end # from a triple of virtual particle currents, calculate the diagram result
struct ComputeTask_CollectTriples <: AbstractComputeTask end # sum over triples results and struct ComputeTask_CollectTriples <: AbstractComputeTask end # sum over triples results and
# import compute so we don't have to repeat it all the time
import MetagraphOptimization.compute
struct BaseStateInput{PS_T<:AbstractParticleStateful,SPIN_POL_T<:AbstractSpinOrPolarization} struct BaseStateInput{PS_T<:AbstractParticleStateful,SPIN_POL_T<:AbstractSpinOrPolarization}
particle::PS_T particle::PS_T
spin_pol::SPIN_POL_T spin_pol::SPIN_POL_T
end end
function MetagraphOptimization.compute(::ComputeTask_BaseState, input::BaseStateInput{PS,SPIN_POL}) where {PS,SPIN_POL} @inline function compute(
return QEDbase.base_state(particle_species(input.particle), particle_direction(input.particle), momentum(input.particle), input.spin_pol) ::ComputeTask_BaseState, input::BaseStateInput{PS,SPIN_POL}
) where {PS,SPIN_POL}
return Propagated( # "propagated" because it goes directly into the next pair
input.particle,
QEDbase.base_state(
particle_species(input.particle),
particle_direction(input.particle),
momentum(input.particle),
input.spin_pol,
),
# bispinor, adjointbispinor, or lorentzvector
)
end end
struct PropagatorInput{VP_T<:VirtualParticle,PSP_T<:AbstractPhaseSpacePoint} struct PropagatorInput{VP_T<:VirtualParticle,PSP_T<:AbstractPhaseSpacePoint}
@ -19,7 +34,9 @@ struct PropagatorInput{VP_T<:VirtualParticle,PSP_T<:AbstractPhaseSpacePoint}
psp::PSP_T psp::PSP_T
end end
function MetagraphOptimization.compute(::ComputeTask_Propagator, input::PropagatorInput{VP_T,PSP_T}) where {VP_T,PSP_T} @inline function compute(
::ComputeTask_Propagator, input::PropagatorInput{VP_T,PSP_T}
) where {VP_T,PSP_T}
vp_mom = zero(typeof(momentum(input.psp, Incoming(), 1))) vp_mom = zero(typeof(momentum(input.psp, Incoming(), 1)))
for i in eachindex(in_contributions(input.vp)) for i in eachindex(in_contributions(input.vp))
if in_contributions(input.vp)[i] if in_contributions(input.vp)[i]
@ -33,6 +50,87 @@ function MetagraphOptimization.compute(::ComputeTask_Propagator, input::Propagat
end end
vp_species = particle_species(input.vp) vp_species = particle_species(input.vp)
vp_mass = mass(vp_species)
return QEDbase.propagator(vp_species, vp_mom) return QEDbase.propagator(vp_species, vp_mom)
# diracmatrix or scalar number
end end
struct Unpropagated{PARTICLE_T<:AbstractParticleType,VALUE_T}
particle::PARTICLE_T
value::VALUE_T
end
struct Propagated{PARTICLE_T<:AbstractParticleType,VALUE_T}
particle::PARTICLE_T
value::VALUE_T
end
# maybe add the γ matrix term here too?
@inline function compute(
::ComputeTask_Pair, electron::Propagated{Electron,V1}, positron::Propagated{Positron,V2}
) where {V1,V2}
return Unpropagated(Photon(), positron.value * electron.value) # fermion - antifermion -> photon
end
@inline function compute(
::ComputeTask_Pair, positron::Propagated{Positron,V1}, electron::Propagated{Electron,V2}
) where {V1,V2}
return Unpropagated(Photon(), positron.value * electron.value) # antifermion - fermion -> photon
end
@inline function compute(
::ComputeTask_Pair, photon::Propagated{Photon,V1}, fermion::Propagated{F,V2}
) where {F<:FermionLike,V1,V2}
return Unpropagated(invert(fermion.particle), fermion.value * photon.value) # (anti-)fermion - photon -> (anti-)fermion
end
@inline function compute(
::ComputeTask_Pair, fermion::Propagated{F,V2}, photon::Propagated{Photon,V1}
) where {F<:FermionLike,V1,V2}
return Unpropagated(invert(fermion.particle), fermion.value * photon.value) # photon - (anti-)fermion -> (anti-)fermion
end
@inline function compute(
::ComputeTask_PropagatePairs, left::PROP_V, right::Unpropagated{P,VAL}
) where {PROP_V,P<:AbstractParticleType,VAL}
return Propagated(right.particle, right.value * left.value)
end
@inline function compute(
::ComputeTask_PropagatePairs, left::Unpropagated{P,VAL}, right::PROP_V
) where {PROP_V,P<:AbstractParticleType,VAL}
return Propagated(left.particle, left.value * right.value)
end
@inline function compute(
::ComputeTask_Triple,
photon::Propagated{Photon,V1},
electron::Propagated{Electron,V2},
positron::Propagated{Positron,V3},
) where {V1,V2,V3}
return positron.value * photon.value * electron.input
end
@inline function compute(
c::ComputeTask_Triple,
photon::Propagated{Photon,V1},
positron::Propagated{Positron,V2},
electron::Propagated{Electron,V3},
) where {V1,V2,V3}
return compute(c, photon, electron, positron)
end
@inline function compute(
c::ComputeTask_Triple,
f1::Propagated{F1,V1},
f2::Propagated{F2,V2},
photon::Propagated{Photon,V3},
) where {V1,V2,V3,F1<:FermionLike,F2<:FermionLike}
return compute(c, photon, f1, f2)
end
@inline function compute(
c::ComputeTask_Triple,
f1::Propagated{F1,V1},
photon::Propagated{Photon,V2},
f2::Propagated{F2,V3},
) where {V1,V2,V3,F1<:FermionLike,F2<:FermionLike}
return compute(c, photon, f1, f2)
end
# this compiles in a reasonable amount of time for up to about 1e4 parameters
# use a summation algorithm with more accuracy and/or parallelization
@inline compute(::ComputeTask_CollectPairs, args::Vararg{N,T}) where {N,T} = sum(args)
@inline compute(::ComputeTask_CollectTriples, args::Vararg{N,T}) where {N,T} = sum(args)