heterogeneity (#27)

Prepare things to work with heterogeneity, make things work on GPU

Reviewed-on: Rubydragon/MetagraphOptimization.jl#27
Co-authored-by: Anton Reinhard <anton.reinhard@proton.me>
Co-committed-by: Anton Reinhard <anton.reinhard@proton.me>
This commit is contained in:
2023-12-18 14:31:52 +01:00
committed by Anton Reinhard
parent c90346e948
commit 92e0eeaaef
42 changed files with 1631 additions and 238 deletions

View File

@ -62,21 +62,12 @@ function gen_input_assignment_code(
assignInputs = Vector{Expr}()
for (name, symbols) in inputSymbols
(type, index) = type_index_from_name(model(processDescription), name)
p = nothing
if (index > get(in_particles(processDescription), type, 0))
index -= get(in_particles(processDescription), type, 0)
@assert index <= out_particles(processDescription)[type] "Too few particles of type $type in input particles for this process"
p = "filter(x -> typeof(x) <: $type, out_particles($(processInputSymbol)))[$(index)]"
else
p = "filter(x -> typeof(x) <: $type, in_particles($(processInputSymbol)))[$(index)]"
end
p = "get_particle($(processInputSymbol), $(type), $(index))"
for symbol in symbols
device = entry_device(machine)
evalExpr = eval(gen_access_expr(device, symbol))
push!(assignInputs, Meta.parse("$(evalExpr)::ParticleValue{$type} = ParticleValue($p, one(ComplexF64))"))
push!(assignInputs, Meta.parse("$(evalExpr) = ParticleValue{$type, ComplexF64}($p, one(ComplexF64))"))
end
end
@ -111,10 +102,12 @@ end
Execute the code of the given `graph` on the given input particles.
This is essentially shorthand for
```julia
compute_graph = get_compute_function(graph, process)
result = compute_graph(particles)
```
```julia
compute_graph = get_compute_function(graph, process)
result = compute_graph(particles)
```
If an exception occurs during the execution of the generated code, it will be printed for investigation.
See also: [`parse_dag`](@ref), [`parse_process`](@ref), [`gen_process_input`](@ref)
"""
@ -135,6 +128,8 @@ function execute(graph::DAG, process::AbstractProcessDescription, machine::Machi
result = 0
try
result = @eval $func($input)
#functionStr = string(expr)
#println("Function:\n$functionStr")
catch e
println("Error while evaluating: $e")

View File

@ -75,3 +75,7 @@ function operation_effect(estimator::GlobalMetricEstimator, graph::DAG, operatio
ce::Float64 = s * compute_effort(task(operation.input))
return (data = d, computeEffort = ce, computeIntensity = ce / d)::CDCost
end
function String(::GlobalMetricEstimator)
return "global_metric"
end

View File

@ -1,4 +1,5 @@
using AccurateArithmetic
using StaticArrays
"""
compute(::ComputeTaskABC_P, data::ABCParticleValue)
@ -75,14 +76,14 @@ function compute(::ComputeTaskABC_S1, data::ABCParticleValue{P})::ABCParticleVal
end
"""
compute(::ComputeTaskABC_Sum, data::Vector{Float64})
compute(::ComputeTaskABC_Sum, data::StaticVector)
Compute a sum over the vector. Use an algorithm that accounts for accumulated errors in long sums with potentially large differences in magnitude of the summands.
Linearly many FLOP with growing data.
"""
function compute(::ComputeTaskABC_Sum, data::Vector{Float64})::Float64
return sum_kbn(data)
function compute(::ComputeTaskABC_Sum, data::StaticVector)::Float64
return sum(data)
end
"""
@ -159,5 +160,7 @@ function get_expression(::ComputeTaskABC_Sum, device::AbstractDevice, inExprs::V
in = eval.(inExprs)
out = eval(outExpr)
return Meta.parse("$out = compute(ComputeTaskABC_Sum(), [$(unroll_symbol_vector(in))])")
return Meta.parse(
"$out = compute(ComputeTaskABC_Sum(), SVector{$(length(inExprs)), Float64}($(unroll_symbol_vector(in))))",
)
end

View File

@ -5,6 +5,20 @@ using ForwardDiff
ComputeTaskABC_Sum() = ComputeTaskABC_Sum(0)
function _svector_from_type_in(processDescription::ABCProcessDescription, type, particles)
if haskey(in_particles(processDescription), type)
return SVector{in_particles(processDescription)[type], type}(filter(x -> typeof(x) <: type, particles))
end
return SVector{0, type}()
end
function _svector_from_type_out(processDescription::ABCProcessDescription, type, particles)
if haskey(out_particles(processDescription), type)
return SVector{out_particles(processDescription)[type], type}(filter(x -> typeof(x) <: type, particles))
end
return SVector{0, type}()
end
"""
gen_process_input(processDescription::ABCProcessDescription)
@ -58,7 +72,15 @@ function gen_process_input(processDescription::ABCProcessDescription)
end
end
processInput = ABCProcessInput(processDescription, inputParticles, outputParticles)
inA = _svector_from_type_in(processDescription, ParticleA, inputParticles)
inB = _svector_from_type_in(processDescription, ParticleB, inputParticles)
inC = _svector_from_type_in(processDescription, ParticleC, inputParticles)
outA = _svector_from_type_out(processDescription, ParticleA, outputParticles)
outB = _svector_from_type_out(processDescription, ParticleB, outputParticles)
outC = _svector_from_type_out(processDescription, ParticleC, outputParticles)
processInput = ABCProcessInput(processDescription, inA, inB, inC, outA, outB, outC)
return return processInput
end

View File

@ -1,3 +1,5 @@
using StaticArrays
import QEDbase.mass
"""
@ -60,27 +62,30 @@ Input for a ABC Process. Contains the [`ABCProcessDescription`](@ref) of the pro
See also: [`gen_process_input`](@ref)
"""
struct ABCProcessInput <: AbstractProcessInput
struct ABCProcessInput{N1, N2, N3, N4, N5, N6} <: AbstractProcessInput
process::ABCProcessDescription
inParticles::Vector{ABCParticle}
outParticles::Vector{ABCParticle}
inA::SVector{N1, ParticleA}
inB::SVector{N2, ParticleB}
inC::SVector{N3, ParticleC}
outA::SVector{N4, ParticleA}
outB::SVector{N5, ParticleB}
outC::SVector{N6, ParticleC}
end
ABCParticleValue{ParticleType <: ABCParticle} = ParticleValue{ParticleType, ComplexF64}
"""
PARTICLE_MASSES
A constant dictionary containing the masses of the different [`ABCParticle`](@ref)s.
"""
const PARTICLE_MASSES = Dict{Type, Float64}(ParticleA => 1.0, ParticleB => 1.0, ParticleC => 0.0)
"""
mass(t::Type{T}) where {T <: ABCParticle}
Return the mass (at rest) of the given particle type.
"""
mass(t::Type{T}) where {T <: ABCParticle} = PARTICLE_MASSES[t]
mass(::ParticleA) = 1.0
mass(::ParticleB) = 1.0
mass(::ParticleC) = 0.0
mass(::Type{ParticleA}) = 1.0
mass(::Type{ParticleB}) = 1.0
mass(::Type{ParticleC}) = 0.0
"""
interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: ABCParticle, T2 <: ABCParticle}
@ -126,7 +131,7 @@ Return the factor of the inner edge with the given (virtual) particle.
Takes 10 effective FLOP. (3 here + 7 in square(p))
"""
function ABC_inner_edge(p::ABCParticle)
return 1.0 / (square(p) - mass(typeof(p)) * mass(typeof(p)))
return 1.0 / (square(p) - mass(p)^2)
end
"""
@ -166,6 +171,10 @@ function ABC_conserve_momentum(p1::ABCParticle, p2::ABCParticle)
return p3
end
function copy(process::ABCProcessDescription)
return ABCProcessDescription(copy(process.inParticles), copy(process.outParticles))
end
model(::ABCProcessDescription) = ABCModel()
model(::ABCProcessInput) = ABCModel()
@ -195,14 +204,29 @@ function in_particles(process::ABCProcessDescription)
return process.inParticles
end
function in_particles(input::ABCProcessInput)
return input.inParticles
end
function out_particles(process::ABCProcessDescription)
return process.outParticles
end
function out_particles(input::ABCProcessInput)
return input.outParticles
function get_particle(input::ABCProcessInput, t::Type{Particle}, n::Int)::Particle where {Particle}
if (t <: ParticleA)
if (n > length(input.inA))
return input.outA[n - length(input.inA)]
else
return input.inA[n]
end
elseif (t <: ParticleB)
if (n > length(input.inB))
return input.outB[n - length(input.inB)]
else
return input.inB[n]
end
elseif (t <: ParticleC)
if (n > length(input.inC))
return input.outC[n - length(input.inC)]
else
return input.inC[n]
end
end
@assert false "Invalid type given"
end

View File

@ -36,15 +36,26 @@ Pretty print an [`ABCProcessInput`](@ref) (with newlines).
"""
function show(io::IO, processInput::ABCProcessInput)
println(io, "Input for $(processInput.process):")
println(io, " $(length(processInput.inParticles)) Incoming particles:")
for particle in processInput.inParticles
println(io, " $particle")
println(io, "Incoming particles:")
if !isempty(processInput.inA)
println(io, " $(processInput.inA)")
end
println(io, " $(length(processInput.outParticles)) Outgoing Particles:")
for particle in processInput.outParticles
println(io, " $particle")
if !isempty(processInput.inB)
println(io, " $(processInput.inB)")
end
if !isempty(processInput.inC)
println(io, " $(processInput.inC)")
end
println(io, "Outgoing particles:")
if !isempty(processInput.outA)
println(io, " $(processInput.outA)")
end
if !isempty(processInput.outB)
println(io, " $(processInput.outB)")
end
if !isempty(processInput.outC)
println(io, " $(processInput.outC)")
end
return nothing
end
"""

View File

@ -80,6 +80,14 @@ Returns a `<: Vector{AbstractParticle}` object with the values of all outgoing p
"""
function out_particles end
"""
get_particle(::AbstractProcessInput, t::Type, n::Int)
Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
Returns the `n`th particle of type `t`.
"""
function get_particle end
"""
parse_process(::AbstractString, ::AbstractPhysicsModel)

View File

@ -1,12 +1,13 @@
using StaticArrays
"""
compute(::ComputeTaskQED_P, data::QEDParticleValue)
Return the particle as is and initialize the Value.
"""
function compute(::ComputeTaskQED_P, data::QEDParticleValue{P})::QEDParticleValue{P} where {P <: QEDParticle}
function compute(::ComputeTaskQED_P, data::QEDParticleValue{P}) where {P <: QEDParticle}
# TODO do we actually need this for anything?
return QEDParticleValue{P}(data.p, one(DiracMatrix))
return ParticleValue{P, DiracMatrix}(data.p, one(DiracMatrix))
end
"""
@ -15,7 +16,8 @@ end
Compute an outer edge. Return the particle value with the same particle and the value multiplied by an outer_edge factor.
"""
function compute(::ComputeTaskQED_U, data::PV) where {P <: QEDParticle, PV <: QEDParticleValue{P}}
state = base_state(particle(data.p), direction(data.p), momentum(data.p), spin_or_pol(data.p))
part::P = data.p
state = base_state(particle(part), direction(part), momentum(part), spin_or_pol(part))
return ParticleValue{P, typeof(state)}(
data.p,
state, # will return a SLorentzVector{ComplexF64}, BiSpinor or AdjointBiSpinor
@ -34,7 +36,6 @@ function compute(
) where {P1 <: QEDParticle, P2 <: QEDParticle, PV1 <: QEDParticleValue{P1}, PV2 <: QEDParticleValue{P2}}
p3 = QED_conserve_momentum(data1.p, data2.p)
P3 = interaction_result(P1, P2)
state = QED_vertex()
if (typeof(data1.v) <: AdjointBiSpinor)
state = data1.v * state
@ -47,7 +48,7 @@ function compute(
state = state * data2.v
end
dataOut = ParticleValue{P3, typeof(state)}(P3(p3), state)
dataOut = ParticleValue{P3, typeof(state)}(P3(momentum(p3)), state)
return dataOut
end
@ -64,13 +65,10 @@ function compute(
::ComputeTaskQED_S2,
data1::ParticleValue{P1},
data2::ParticleValue{P2},
)::ComplexF64 where {
P1 <: Union{AntiFermionStateful, FermionStateful},
P2 <: Union{AntiFermionStateful, FermionStateful},
}
@assert isapprox(data1.p.momentum, data2.p.momentum, rtol = sqrt(eps()), atol = sqrt(eps())) "$(data1.p.momentum) vs. $(data2.p.momentum)"
) where {P1 <: Union{AntiFermionStateful, FermionStateful}, P2 <: Union{AntiFermionStateful, FermionStateful}}
#@assert isapprox(data1.p.momentum, data2.p.momentum, rtol = sqrt(eps()), atol = sqrt(eps())) "$(data1.p.momentum) vs. $(data2.p.momentum)"
inner = QED_inner_edge(propagation_result(P1)(data1.p))
inner = QED_inner_edge(propagation_result(P1)(momentum(data1.p)))
# inner edge is just a "scalar", data1 and data2 are bispinor/adjointbispinnor, need to keep correct order
if typeof(data1.v) <: BiSpinor
@ -80,12 +78,11 @@ function compute(
end
end
# TODO: S2 when the particles are photons?
function compute(
::ComputeTaskQED_S2,
data1::ParticleValue{P1},
data2::ParticleValue{P2},
)::ComplexF64 where {P1 <: PhotonStateful, P2 <: PhotonStateful}
) where {P1 <: PhotonStateful, P2 <: PhotonStateful}
# TODO: assert that data1 and data2 are opposites
inner = QED_inner_edge(data1.p)
# inner edge is just a scalar, data1 and data2 are photon states that are just Complex numbers here
@ -97,9 +94,9 @@ end
Compute inner edge (1 input particle, 1 output particle).
"""
function compute(::ComputeTaskQED_S1, data::QEDParticleValue{P})::QEDParticleValue where {P <: QEDParticle}
function compute(::ComputeTaskQED_S1, data::QEDParticleValue{P}) where {P <: QEDParticle}
newP = propagation_result(P)
new_p = newP(data.p)
new_p = newP(momentum(data.p))
# inner edge is just a scalar, can multiply from either side
if typeof(data.v) <: BiSpinor
return ParticleValue(new_p, QED_inner_edge(new_p) * data.v)
@ -109,13 +106,13 @@ function compute(::ComputeTaskQED_S1, data::QEDParticleValue{P})::QEDParticleVal
end
"""
compute(::ComputeTaskQED_Sum, data::Vector{ComplexF64})
compute(::ComputeTaskQED_Sum, data::StaticVector)
Compute a sum over the vector. Use an algorithm that accounts for accumulated errors in long sums with potentially large differences in magnitude of the summands.
Linearly many FLOP with growing data.
"""
function compute(::ComputeTaskQED_Sum, data::Vector{ComplexF64})::ComplexF64
function compute(::ComputeTaskQED_Sum, data::StaticVector)::ComplexF64
# TODO: want to use sum_kbn here but it doesn't seem to support ComplexF64, do it element-wise?
return sum(data)
end
@ -194,5 +191,7 @@ function get_expression(::ComputeTaskQED_Sum, device::AbstractDevice, inExprs::V
in = eval.(inExprs)
out = eval(outExpr)
return Meta.parse("$out = compute(ComputeTaskQED_Sum(), [$(unroll_symbol_vector(in))])")
return Meta.parse(
"$out = compute(ComputeTaskQED_Sum(), SVector{$(length(inExprs)), ComplexF64}($(unroll_symbol_vector(in))))",
)
end

View File

@ -1,6 +1,16 @@
ComputeTaskQED_Sum() = ComputeTaskQED_Sum(0)
function _svector_from_type(processDescription::QEDProcessDescription, type, particles)
if haskey(in_particles(processDescription), type)
return SVector{in_particles(processDescription)[type], type}(filter(x -> typeof(x) <: type, particles))
end
if haskey(out_particles(processDescription), type)
return SVector{out_particles(processDescription)[type], type}(filter(x -> typeof(x) <: type, particles))
end
return SVector{0, type}()
end
"""
gen_process_input(processDescription::QEDProcessDescription)
@ -29,30 +39,37 @@ function gen_process_input(processDescription::QEDProcessDescription)
massSum += rand(rng[threadid()]) * (length(inputMasses) + length(outputMasses))
inputParticles = Vector{QEDParticle}()
particles = Vector{QEDParticle}()
initialMomenta = generate_initial_moms(massSum, inputMasses)
index = 1
for (particle, n) in processDescription.inParticles
for _ in 1:n
mom = initialMomenta[index]
push!(inputParticles, particle(mom))
push!(particles, particle(mom))
index += 1
end
end
outputParticles = Vector{QEDParticle}()
final_momenta = generate_physical_massive_moms(rng[threadid()], massSum, outputMasses)
index = 1
for (particle, n) in processDescription.outParticles
for _ in 1:n
push!(outputParticles, particle(final_momenta[index]))
push!(particles, particle(final_momenta[index]))
index += 1
end
end
processInput = QEDProcessInput(processDescription, inputParticles, outputParticles)
inFerms = _svector_from_type(processDescription, FermionStateful{Incoming, SpinUp}, particles)
outFerms = _svector_from_type(processDescription, FermionStateful{Outgoing, SpinUp}, particles)
inAntiferms = _svector_from_type(processDescription, AntiFermionStateful{Incoming, SpinUp}, particles)
outAntiferms = _svector_from_type(processDescription, AntiFermionStateful{Outgoing, SpinUp}, particles)
inPhotons = _svector_from_type(processDescription, PhotonStateful{Incoming, PolX}, particles)
outPhotons = _svector_from_type(processDescription, PhotonStateful{Outgoing, PolX}, particles)
return return processInput
processInput =
QEDProcessInput(processDescription, inFerms, outFerms, inAntiferms, outAntiferms, inPhotons, outPhotons)
return processInput
end
"""

View File

@ -82,7 +82,7 @@ end
function particle_after_tie(p::FeynmanParticle, t::FeynmanTie)
if p == t.in1 || p == t.in2
return FeynmanParticle(FermionStateful{Incoming}, -1) # placeholder particle and id for tied particles
return FeynmanParticle(FermionStateful{Incoming, SpinUp}, -1) # placeholder particle and id for tied particles
end
return p
end

View File

@ -1,4 +1,5 @@
using QEDprocesses
using StaticArrays
import QEDbase.mass
# TODO check
@ -34,19 +35,6 @@ struct QEDProcessDescription <: AbstractProcessDescription
outParticles::Dict{Type{<:QEDParticle{Outgoing}}, Int}
end
"""
QEDProcessInput <: AbstractProcessInput
Input for a QED Process. Contains the [`QEDProcessDescription`](@ref) of the process it is an input for, and the values of the in and out particles.
See also: [`gen_process_input`](@ref)
"""
struct QEDProcessInput <: AbstractProcessInput
process::QEDProcessDescription
inParticles::Vector{QEDParticle}
outParticles::Vector{QEDParticle}
end
QEDParticleValue{ParticleType <: QEDParticle} = Union{
ParticleValue{ParticleType, BiSpinor},
ParticleValue{ParticleType, AdjointBiSpinor},
@ -60,51 +48,44 @@ QEDParticleValue{ParticleType <: QEDParticle} = Union{
A photon of the [`QEDModel`](@ref) with its state.
"""
struct PhotonStateful{Direction <: ParticleDirection} <: QEDParticle{Direction}
struct PhotonStateful{Direction <: ParticleDirection, Pol <: AbstractDefinitePolarization} <: QEDParticle{Direction}
momentum::SFourMomentum
# this will maybe change to the full polarization vector? or do i need both
polarization::AbstractDefinitePolarization
end
PhotonStateful{Direction}(mom::SFourMomentum) where {Direction <: ParticleDirection} =
PhotonStateful{Direction}(mom, PolX()) # TODO: make allpol possible
PhotonStateful{Direction, PolX}(mom)
PhotonStateful{Dir1}(ph::PhotonStateful{Dir2}) where {Dir1 <: ParticleDirection, Dir2 <: ParticleDirection} =
PhotonStateful{Dir1}(ph.momentum, ph.polarization)
PhotonStateful{Dir, Pol}(ph::PhotonStateful) where {Dir, Pol} = PhotonStateful{Dir, Pol}(ph.momentum)
"""
FermionStateful <: QEDParticle
A fermion of the [`QEDModel`](@ref) with its state.
"""
struct FermionStateful{Direction <: ParticleDirection} <: QEDParticle{Direction}
struct FermionStateful{Direction <: ParticleDirection, Spin <: AbstractDefiniteSpin} <: QEDParticle{Direction}
momentum::SFourMomentum
spin::AbstractDefiniteSpin
# TODO: mass for electron/muon/tauon representation?
end
FermionStateful{Direction}(mom::SFourMomentum) where {Direction <: ParticleDirection} =
FermionStateful{Direction}(mom, SpinUp()) # TODO: make allspin possible
FermionStateful{Direction, SpinUp}(mom)
FermionStateful{Dir1}(f::FermionStateful{Dir2}) where {Dir1 <: ParticleDirection, Dir2 <: ParticleDirection} =
FermionStateful{Dir1}(f.momentum, f.spin)
FermionStateful{Dir, Spin}(f::FermionStateful) where {Dir, Spin} = FermionStateful{Dir, Spin}(f.momentum)
"""
AntiFermionStateful <: QEDParticle
An anti-fermion of the [`QEDModel`](@ref) with its state.
"""
struct AntiFermionStateful{Direction <: ParticleDirection} <: QEDParticle{Direction}
struct AntiFermionStateful{Direction <: ParticleDirection, Spin <: AbstractDefiniteSpin} <: QEDParticle{Direction}
momentum::SFourMomentum
spin::AbstractDefiniteSpin
# TODO: mass for electron/muon/tauon representation?
end
AntiFermionStateful{Direction}(mom::SFourMomentum) where {Direction <: ParticleDirection} =
AntiFermionStateful{Direction}(mom, SpinUp()) # TODO: make allspin possible
AntiFermionStateful{Direction, SpinUp}(mom)
AntiFermionStateful{Dir1}(f::AntiFermionStateful{Dir2}) where {Dir1 <: ParticleDirection, Dir2 <: ParticleDirection} =
AntiFermionStateful{Dir1}(f.momentum, f.spin)
AntiFermionStateful{Dir, Spin}(f::AntiFermionStateful) where {Dir, Spin} = AntiFermionStateful{Dir, Spin}(f.momentum)
"""
interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: QEDParticle, T2 <: QEDParticle}
@ -115,19 +96,33 @@ function interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: QEDParticle
@assert false "Invalid interaction between particles of types $t1 and $t2"
end
interaction_result(::Type{FermionStateful{Incoming}}, ::Type{FermionStateful{Outgoing}}) = PhotonStateful{Incoming}
interaction_result(::Type{FermionStateful{Incoming}}, ::Type{AntiFermionStateful{Incoming}}) = PhotonStateful{Incoming}
interaction_result(::Type{FermionStateful{Incoming}}, ::Type{<:PhotonStateful}) = FermionStateful{Outgoing}
interaction_result(
::Type{FermionStateful{Incoming, Spin1}},
::Type{FermionStateful{Outgoing, Spin2}},
) where {Spin1, Spin2} = PhotonStateful{Incoming, PolX}
interaction_result(
::Type{FermionStateful{Incoming, Spin1}},
::Type{AntiFermionStateful{Incoming, Spin2}},
) where {Spin1, Spin2} = PhotonStateful{Incoming, PolX}
interaction_result(::Type{FermionStateful{Incoming, Spin1}}, ::Type{<:PhotonStateful}) where {Spin1} =
FermionStateful{Outgoing, SpinUp}
interaction_result(::Type{FermionStateful{Outgoing}}, ::Type{FermionStateful{Incoming}}) = PhotonStateful{Incoming}
interaction_result(::Type{FermionStateful{Outgoing}}, ::Type{AntiFermionStateful{Outgoing}}) = PhotonStateful{Incoming}
interaction_result(::Type{FermionStateful{Outgoing}}, ::Type{<:PhotonStateful}) = FermionStateful{Incoming}
interaction_result(
::Type{FermionStateful{Outgoing, Spin1}},
::Type{FermionStateful{Incoming, Spin2}},
) where {Spin1, Spin2} = PhotonStateful{Incoming, PolX}
interaction_result(
::Type{FermionStateful{Outgoing, Spin1}},
::Type{AntiFermionStateful{Outgoing, Spin2}},
) where {Spin1, Spin2} = PhotonStateful{Incoming, PolX}
interaction_result(::Type{FermionStateful{Outgoing, Spin1}}, ::Type{<:PhotonStateful}) where {Spin1} =
FermionStateful{Incoming, SpinUp}
# antifermion mirror
interaction_result(::Type{AntiFermionStateful{Incoming}}, t2::Type{<:QEDParticle}) =
interaction_result(FermionStateful{Outgoing}, t2)
interaction_result(::Type{AntiFermionStateful{Outgoing}}, t2::Type{<:QEDParticle}) =
interaction_result(FermionStateful{Incoming}, t2)
interaction_result(::Type{AntiFermionStateful{Incoming, Spin}}, t2::Type{<:QEDParticle}) where {Spin} =
interaction_result(FermionStateful{Outgoing, Spin}, t2)
interaction_result(::Type{AntiFermionStateful{Outgoing, Spin}}, t2::Type{<:QEDParticle}) where {Spin} =
interaction_result(FermionStateful{Incoming, Spin}, t2)
# photon commutativity
interaction_result(t1::Type{<:PhotonStateful}, t2::Type{<:QEDParticle}) = interaction_result(t2, t1)
@ -142,12 +137,18 @@ end
Return the type of the inverted direction. E.g.
"""
propagation_result(::Type{FermionStateful{Incoming}}) = FermionStateful{Outgoing}
propagation_result(::Type{FermionStateful{Outgoing}}) = FermionStateful{Incoming}
propagation_result(::Type{AntiFermionStateful{Incoming}}) = AntiFermionStateful{Outgoing}
propagation_result(::Type{AntiFermionStateful{Outgoing}}) = AntiFermionStateful{Incoming}
propagation_result(::Type{PhotonStateful{Incoming}}) = PhotonStateful{Outgoing}
propagation_result(::Type{PhotonStateful{Outgoing}}) = PhotonStateful{Incoming}
propagation_result(::Type{FermionStateful{Incoming, Spin}}) where {Spin <: AbstractDefiniteSpin} =
FermionStateful{Outgoing, Spin}
propagation_result(::Type{FermionStateful{Outgoing, Spin}}) where {Spin <: AbstractDefiniteSpin} =
FermionStateful{Incoming, Spin}
propagation_result(::Type{AntiFermionStateful{Incoming, Spin}}) where {Spin <: AbstractDefiniteSpin} =
AntiFermionStateful{Outgoing, Spin}
propagation_result(::Type{AntiFermionStateful{Outgoing, Spin}}) where {Spin <: AbstractDefiniteSpin} =
AntiFermionStateful{Incoming, Spin}
propagation_result(::Type{PhotonStateful{Incoming, Pol}}) where {Pol <: AbstractDefinitePolarization} =
PhotonStateful{Outgoing, Pol}
propagation_result(::Type{PhotonStateful{Outgoing, Pol}}) where {Pol <: AbstractDefinitePolarization} =
PhotonStateful{Incoming, Pol}
"""
types(::QEDModel)
@ -156,12 +157,12 @@ Return a Vector of the possible types of particle in the [`QEDModel`](@ref).
"""
function types(::QEDModel)
return [
PhotonStateful{Incoming},
PhotonStateful{Outgoing},
FermionStateful{Incoming},
FermionStateful{Outgoing},
AntiFermionStateful{Incoming},
AntiFermionStateful{Outgoing},
PhotonStateful{Incoming, PolX},
PhotonStateful{Outgoing, PolX},
FermionStateful{Incoming, SpinUp},
FermionStateful{Outgoing, SpinUp},
AntiFermionStateful{Incoming, SpinUp},
AntiFermionStateful{Outgoing, SpinUp},
]
end
@ -190,17 +191,23 @@ end
@inline momentum(p::FermionStateful)::SFourMomentum = p.momentum
@inline momentum(p::AntiFermionStateful)::SFourMomentum = p.momentum
@inline spin_or_pol(p::PhotonStateful)::AbstractPolarization = p.polarization
@inline spin_or_pol(p::FermionStateful)::AbstractSpin = p.spin
@inline spin_or_pol(p::AntiFermionStateful)::AbstractSpin = p.spin
@inline spin_or_pol(p::PhotonStateful{Dir, Pol}) where {Dir, Pol <: AbstractDefinitePolarization} = Pol()
@inline spin_or_pol(p::FermionStateful{Dir, Spin}) where {Dir, Spin <: AbstractDefiniteSpin} = Spin()
@inline spin_or_pol(p::AntiFermionStateful{Dir, Spin}) where {Dir, Spin <: AbstractDefiniteSpin} = Spin()
@inline direction(::PhotonStateful{Dir}) where {Dir <: ParticleDirection} = Dir()
@inline direction(::FermionStateful{Dir}) where {Dir <: ParticleDirection} = Dir()
@inline direction(::AntiFermionStateful{Dir}) where {Dir <: ParticleDirection} = Dir()
@inline direction(
::Type{P},
) where {P <: Union{FermionStateful{Incoming}, AntiFermionStateful{Incoming}, PhotonStateful{Incoming}}} = Incoming()
@inline direction(
::Type{P},
) where {P <: Union{FermionStateful{Outgoing}, AntiFermionStateful{Outgoing}, PhotonStateful{Outgoing}}} = Outgoing()
@inline direction(::Type{PhotonStateful{Dir}}) where {Dir <: ParticleDirection} = Dir()
@inline direction(::Type{FermionStateful{Dir}}) where {Dir <: ParticleDirection} = Dir()
@inline direction(::Type{AntiFermionStateful{Dir}}) where {Dir <: ParticleDirection} = Dir()
@inline direction(
::P,
) where {P <: Union{FermionStateful{Incoming}, AntiFermionStateful{Incoming}, PhotonStateful{Incoming}}} = Incoming()
@inline direction(
::P,
) where {P <: Union{FermionStateful{Outgoing}, AntiFermionStateful{Outgoing}, PhotonStateful{Outgoing}}} = Outgoing()
@inline isincoming(::QEDParticle{Incoming}) = true
@inline isincoming(::QEDParticle{Outgoing}) = false
@ -216,12 +223,12 @@ end
@inline mass(::Type{<:AntiFermionStateful}) = 1.0
@inline mass(::Type{<:PhotonStateful}) = 0.0
@inline invert_momentum(p::FermionStateful{Dir}) where {Dir <: ParticleDirection} =
FermionStateful{Dir}(-p.momentum, p.spin)
@inline invert_momentum(p::AntiFermionStateful{Dir}) where {Dir <: ParticleDirection} =
AntiFermionStateful{Dir}(-p.momentum, p.spin)
@inline invert_momentum(k::PhotonStateful{Dir}) where {Dir <: ParticleDirection} =
PhotonStateful{Dir}(-k.momentum, k.polarization)
@inline invert_momentum(p::FermionStateful{Dir, Spin}) where {Dir, Spin} =
FermionStateful{Dir, Spin}(-p.momentum, p.spin)
@inline invert_momentum(p::AntiFermionStateful{Dir, Spin}) where {Dir, Spin} =
AntiFermionStateful{Dir, Spin}(-p.momentum, p.spin)
@inline invert_momentum(k::PhotonStateful{Dir, Spin}) where {Dir, Spin} =
PhotonStateful{Dir, Spin}(-k.momentum, k.polarization)
"""
@ -240,10 +247,10 @@ function caninteract(T1::Type{<:QEDParticle}, T2::Type{<:QEDParticle})
end
for (P1, P2) in [(T1, T2), (T2, T1)]
if (P1 == FermionStateful{Incoming} && P2 == AntiFermionStateful{Outgoing})
if (P1 <: FermionStateful{Incoming} && P2 <: AntiFermionStateful{Outgoing})
return false
end
if (P1 == FermionStateful{Outgoing} && P2 == AntiFermionStateful{Incoming})
if (P1 <: FermionStateful{Outgoing} && P2 <: AntiFermionStateful{Incoming})
return false
end
end
@ -253,17 +260,17 @@ end
function type_index_from_name(::QEDModel, name::String)
if startswith(name, "ki")
return (PhotonStateful{Incoming}, parse(Int, name[3:end]))
return (PhotonStateful{Incoming, PolX}, parse(Int, name[3:end]))
elseif startswith(name, "ko")
return (PhotonStateful{Outgoing}, parse(Int, name[3:end]))
return (PhotonStateful{Outgoing, PolX}, parse(Int, name[3:end]))
elseif startswith(name, "ei")
return (FermionStateful{Incoming}, parse(Int, name[3:end]))
return (FermionStateful{Incoming, SpinUp}, parse(Int, name[3:end]))
elseif startswith(name, "eo")
return (FermionStateful{Outgoing}, parse(Int, name[3:end]))
return (FermionStateful{Outgoing, SpinUp}, parse(Int, name[3:end]))
elseif startswith(name, "pi")
return (AntiFermionStateful{Incoming}, parse(Int, name[3:end]))
return (AntiFermionStateful{Incoming, SpinUp}, parse(Int, name[3:end]))
elseif startswith(name, "po")
return (AntiFermionStateful{Outgoing}, parse(Int, name[3:end]))
return (AntiFermionStateful{Outgoing, SpinUp}, parse(Int, name[3:end]))
else
throw("Invalid name for a particle in the QED model")
end
@ -291,8 +298,7 @@ Return the factor of a vertex in a QED feynman diagram.
end
@inline function QED_inner_edge(p::QEDParticle)
pos_mom = p.momentum
return propagator(particle(p), pos_mom)
return propagator(particle(p), p.momentum)
end
"""
@ -300,24 +306,49 @@ end
Calculate and return a new particle from two given interacting ones at a vertex.
"""
function QED_conserve_momentum(p1::QEDParticle, p2::QEDParticle)
#println("Conserving momentum of \n$(direction(p1)) $(p1)\n and \n$(direction(p2)) $(p2)")
T3 = interaction_result(typeof(p1), typeof(p2))
# TODO: probably also need to do something about the spin/pol
function QED_conserve_momentum(
p1::P1,
p2::P2,
) where {
Dir1 <: ParticleDirection,
Dir2 <: ParticleDirection,
SpinPol1 <: AbstractSpinOrPolarization,
SpinPol2 <: AbstractSpinOrPolarization,
P1 <: Union{FermionStateful{Dir1, SpinPol1}, AntiFermionStateful{Dir1, SpinPol1}, PhotonStateful{Dir1, SpinPol1}},
P2 <: Union{FermionStateful{Dir2, SpinPol2}, AntiFermionStateful{Dir2, SpinPol2}, PhotonStateful{Dir2, SpinPol2}},
}
P3 = interaction_result(P1, P2)
p1_mom = p1.momentum
if (typeof(direction(p1)) <: Outgoing)
if (Dir1 <: Outgoing)
p1_mom *= -1
end
p2_mom = p2.momentum
if (typeof(direction(p2)) <: Outgoing)
if (Dir2 <: Outgoing)
p2_mom *= -1
end
p3_mom = p1_mom + p2_mom
if (typeof(direction(T3)) <: Incoming)
return T3(-p3_mom)
if (typeof(direction(P3)) <: Incoming)
return P3(-p3_mom)
end
return T3(p3_mom)
return P3(p3_mom)
end
"""
QEDProcessInput <: AbstractProcessInput
Input for a QED Process. Contains the [`QEDProcessDescription`](@ref) of the process it is an input for, and the values of the in and out particles.
See also: [`gen_process_input`](@ref)
"""
struct QEDProcessInput{N1, N2, N3, N4, N5, N6} <: AbstractProcessInput
process::QEDProcessDescription
inFerms::SVector{N1, FermionStateful{Incoming, SpinUp}}
outFerms::SVector{N2, FermionStateful{Outgoing, SpinUp}}
inAntiferms::SVector{N3, AntiFermionStateful{Incoming, SpinUp}}
outAntiferms::SVector{N4, AntiFermionStateful{Outgoing, SpinUp}}
inPhotons::SVector{N5, PhotonStateful{Incoming, PolX}}
outPhotons::SVector{N6, PhotonStateful{Outgoing, PolX}}
end
"""
@ -328,6 +359,10 @@ Return the model of this process description.
model(::QEDProcessDescription) = QEDModel()
model(::QEDProcessInput) = QEDModel()
function copy(process::QEDProcessDescription)
return QEDProcessDescription(copy(process.inParticles), copy(process.outParticles))
end
==(p1::QEDProcessDescription, p2::QEDProcessDescription) =
p1.inParticles == p2.inParticles && p1.outParticles == p2.outParticles
@ -335,14 +370,23 @@ function in_particles(process::QEDProcessDescription)
return process.inParticles
end
function in_particles(input::QEDProcessInput)
return input.inParticles
end
function out_particles(process::QEDProcessDescription)
return process.outParticles
end
function out_particles(input::QEDProcessInput)
return input.outParticles
function get_particle(input::QEDProcessInput, t::Type{Particle}, n::Int)::Particle where {Particle}
if (t <: FermionStateful{Incoming})
return input.inFerms[n]
elseif (t <: FermionStateful{Outgoing})
return input.outFerms[n]
elseif (t <: AntiFermionStateful{Incoming})
return input.inAntiferms[n]
elseif (t <: AntiFermionStateful{Outgoing})
return input.outAntiferms[n]
elseif (t <: PhotonStateful{Incoming})
return input.inPhotons[n]
elseif (t <: PhotonStateful{Outgoing})
return input.outPhotons[n]
end
@assert false "Invalid type given"
end

View File

@ -32,20 +32,63 @@ function show(io::IO, process::QEDProcessDescription)
return nothing
end
"""
String(process::QEDProcessDescription)
Create a short string suitable as a filename or similar, describing the given process.
```jldoctest
julia> using MetagraphOptimization
julia> String(parse_process("ke->ke", QEDModel()))
qed_ke-ke
julia> print(parse_process("kk->ep", QEDModel()))
qed_kk-ep
```
"""
function String(process::QEDProcessDescription)
# types() gives the types in order (QED) instead of random like keys() would
str = "qed_"
for type in types(QEDModel())
for _ in 1:get(process.inParticles, type, 0)
str = str * String(type)
end
end
str = str * "-"
for type in types(QEDModel())
for _ in 1:get(process.outParticles, type, 0)
str = str * String(type)
end
end
return str
end
"""
show(io::IO, processInput::QEDProcessInput)
Pretty print an [`QEDProcessInput`](@ref) (with newlines).
Pretty print a [`QEDProcessInput`](@ref) (with newlines).
"""
function show(io::IO, processInput::QEDProcessInput)
println(io, "Input for $(processInput.process):")
println(io, " $(length(processInput.inParticles)) Incoming particles:")
for particle in processInput.inParticles
println(io, " $particle")
if !isempty(processInput.inFerms)
println(io, " $(processInput.inFerms)")
end
println(io, " $(length(processInput.outParticles)) Outgoing Particles:")
for particle in processInput.outParticles
println(io, " $particle")
if !isempty(processInput.outFerms)
println(io, " $(processInput.outFerms)")
end
if !isempty(processInput.inAntiferms)
println(io, " $(processInput.inAntiferms)")
end
if !isempty(processInput.outAntiferms)
println(io, " $(processInput.outAntiferms)")
end
if !isempty(processInput.inPhotons)
println(io, " $(processInput.inPhotons)")
end
if !isempty(processInput.outPhotons)
println(io, " $(processInput.outPhotons)")
end
return nothing
end
@ -53,7 +96,7 @@ end
"""
show(io::IO, particle::T) where {T <: QEDParticle}
Pretty print an [`QEDParticle`](@ref) (no newlines).
Pretty print a [`QEDParticle`](@ref) (no newlines).
"""
function show(io::IO, particle::T) where {T <: QEDParticle}
print(io, "$(String(typeof(particle))): $(particle.momentum)")

View File

@ -3,7 +3,7 @@ using UUIDs
using Base.Threads
# TODO: reliably find out how many threads we're running with (nthreads() returns 1 when precompiling :/)
rng = [Random.MersenneTwister(0) for _ in 1:32]
rng = [Random.MersenneTwister(0) for _ in 1:64]
"""
Node

View File

@ -71,3 +71,7 @@ function optimize_to_fixpoint!(optimizer::GreedyOptimizer, graph::DAG)
end
return nothing
end
function String(optimizer::GreedyOptimizer)
return "greedy_optimizer_$(optimizer.estimator)"
end

View File

@ -47,3 +47,7 @@ function optimize_step!(optimizer::RandomWalkOptimizer, graph::DAG)
end
end
end
function String(::RandomWalkOptimizer)
return "random_walker"
end

View File

@ -28,3 +28,7 @@ function optimize_to_fixpoint!(optimizer::ReductionOptimizer, graph::DAG)
end
return nothing
end
function String(::ReductionOptimizer)
return "reduction_optimizer"
end