diff --git a/src/metagraph_impl/compute.jl b/src/metagraph_impl/compute.jl index 5f362c9..5cb0590 100644 --- a/src/metagraph_impl/compute.jl +++ b/src/metagraph_impl/compute.jl @@ -1,17 +1,32 @@ struct ComputeTask_BaseState <: AbstractComputeTask end # calculate the base state of an external particle struct ComputeTask_Propagator <: AbstractComputeTask end # calculate the propagator term of a virtual particle struct ComputeTask_Pair <: AbstractComputeTask end # from a pair of virtual particle currents, calculate the product -struct ComputeTask_CollectPairs <: AbstractComputeTask end # for a list of virtual particle current pair products and a propagator, sum and propagate +struct ComputeTask_CollectPairs <: AbstractComputeTask end # for a list of virtual particle current pair products, sum +struct ComputeTask_PropagatePairs <: AbstractComputeTask end # for the result of a CollectPairs compute task and a propagator, propagate the sum struct ComputeTask_Triple <: AbstractComputeTask end # from a triple of virtual particle currents, calculate the diagram result struct ComputeTask_CollectTriples <: AbstractComputeTask end # sum over triples results and +# import compute so we don't have to repeat it all the time +import MetagraphOptimization.compute + struct BaseStateInput{PS_T<:AbstractParticleStateful,SPIN_POL_T<:AbstractSpinOrPolarization} particle::PS_T spin_pol::SPIN_POL_T end -function MetagraphOptimization.compute(::ComputeTask_BaseState, input::BaseStateInput{PS,SPIN_POL}) where {PS,SPIN_POL} - return QEDbase.base_state(particle_species(input.particle), particle_direction(input.particle), momentum(input.particle), input.spin_pol) +@inline function compute( + ::ComputeTask_BaseState, input::BaseStateInput{PS,SPIN_POL} +) where {PS,SPIN_POL} + return Propagated( # "propagated" because it goes directly into the next pair + input.particle, + QEDbase.base_state( + particle_species(input.particle), + particle_direction(input.particle), + momentum(input.particle), + input.spin_pol, + ), + # bispinor, adjointbispinor, or lorentzvector + ) end struct PropagatorInput{VP_T<:VirtualParticle,PSP_T<:AbstractPhaseSpacePoint} @@ -19,7 +34,9 @@ struct PropagatorInput{VP_T<:VirtualParticle,PSP_T<:AbstractPhaseSpacePoint} psp::PSP_T end -function MetagraphOptimization.compute(::ComputeTask_Propagator, input::PropagatorInput{VP_T,PSP_T}) where {VP_T,PSP_T} +@inline function compute( + ::ComputeTask_Propagator, input::PropagatorInput{VP_T,PSP_T} +) where {VP_T,PSP_T} vp_mom = zero(typeof(momentum(input.psp, Incoming(), 1))) for i in eachindex(in_contributions(input.vp)) if in_contributions(input.vp)[i] @@ -33,6 +50,87 @@ function MetagraphOptimization.compute(::ComputeTask_Propagator, input::Propagat end vp_species = particle_species(input.vp) - vp_mass = mass(vp_species) return QEDbase.propagator(vp_species, vp_mom) + # diracmatrix or scalar number end + +struct Unpropagated{PARTICLE_T<:AbstractParticleType,VALUE_T} + particle::PARTICLE_T + value::VALUE_T +end + +struct Propagated{PARTICLE_T<:AbstractParticleType,VALUE_T} + particle::PARTICLE_T + value::VALUE_T +end + +# maybe add the γ matrix term here too? +@inline function compute( + ::ComputeTask_Pair, electron::Propagated{Electron,V1}, positron::Propagated{Positron,V2} +) where {V1,V2} + return Unpropagated(Photon(), positron.value * electron.value) # fermion - antifermion -> photon +end +@inline function compute( + ::ComputeTask_Pair, positron::Propagated{Positron,V1}, electron::Propagated{Electron,V2} +) where {V1,V2} + return Unpropagated(Photon(), positron.value * electron.value) # antifermion - fermion -> photon +end +@inline function compute( + ::ComputeTask_Pair, photon::Propagated{Photon,V1}, fermion::Propagated{F,V2} +) where {F<:FermionLike,V1,V2} + return Unpropagated(invert(fermion.particle), fermion.value * photon.value) # (anti-)fermion - photon -> (anti-)fermion +end +@inline function compute( + ::ComputeTask_Pair, fermion::Propagated{F,V2}, photon::Propagated{Photon,V1} +) where {F<:FermionLike,V1,V2} + return Unpropagated(invert(fermion.particle), fermion.value * photon.value) # photon - (anti-)fermion -> (anti-)fermion +end + +@inline function compute( + ::ComputeTask_PropagatePairs, left::PROP_V, right::Unpropagated{P,VAL} +) where {PROP_V,P<:AbstractParticleType,VAL} + return Propagated(right.particle, right.value * left.value) +end +@inline function compute( + ::ComputeTask_PropagatePairs, left::Unpropagated{P,VAL}, right::PROP_V +) where {PROP_V,P<:AbstractParticleType,VAL} + return Propagated(left.particle, left.value * right.value) +end + +@inline function compute( + ::ComputeTask_Triple, + photon::Propagated{Photon,V1}, + electron::Propagated{Electron,V2}, + positron::Propagated{Positron,V3}, +) where {V1,V2,V3} + return positron.value * photon.value * electron.input +end +@inline function compute( + c::ComputeTask_Triple, + photon::Propagated{Photon,V1}, + positron::Propagated{Positron,V2}, + electron::Propagated{Electron,V3}, +) where {V1,V2,V3} + return compute(c, photon, electron, positron) +end +@inline function compute( + c::ComputeTask_Triple, + f1::Propagated{F1,V1}, + f2::Propagated{F2,V2}, + photon::Propagated{Photon,V3}, +) where {V1,V2,V3,F1<:FermionLike,F2<:FermionLike} + return compute(c, photon, f1, f2) +end +@inline function compute( + c::ComputeTask_Triple, + f1::Propagated{F1,V1}, + photon::Propagated{Photon,V2}, + f2::Propagated{F2,V3}, +) where {V1,V2,V3,F1<:FermionLike,F2<:FermionLike} + return compute(c, photon, f1, f2) +end + +# this compiles in a reasonable amount of time for up to about 1e4 parameters +# use a summation algorithm with more accuracy and/or parallelization +@inline compute(::ComputeTask_CollectPairs, args::Vararg{N,T}) where {N,T} = sum(args) +@inline compute(::ComputeTask_CollectTriples, args::Vararg{N,T}) where {N,T} = sum(args)