Compare commits
1 Commits
remove_fus
...
901944bd8b
Author | SHA1 | Date | |
---|---|---|---|
901944bd8b |
4
.gitattributes
vendored
4
.gitattributes
vendored
@@ -1,5 +1,3 @@
|
||||
input/AB->ABBBBBBBBB.txt filter=lfs diff=lfs merge=lfs -text
|
||||
input/AB->ABBBBBBB.txt filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs
|
||||
*.gif filter=lfs diff=lfs merge=lfs
|
||||
*.jld2 filter=lfs diff=lfs merge=lfs
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
|
@@ -19,17 +19,12 @@ QEDprocesses = "46de9c38-1bb3-4547-a1ec-da24d767fdad"
|
||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
|
||||
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
|
||||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
||||
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
|
||||
|
||||
[extras]
|
||||
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
|
||||
QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
|
||||
QEDcore = "35dc0263-cb5f-4c33-a114-1d7f54ab753e"
|
||||
QEDprocesses = "46de9c38-1bb3-4547-a1ec-da24d767fdad"
|
||||
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
|
||||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
|
||||
[targets]
|
||||
test = ["SafeTestsets", "Test", "QEDbase", "QEDcore", "QEDprocesses"]
|
||||
test = ["Test"]
|
||||
|
@@ -5,7 +5,7 @@
|
||||
## Package Features
|
||||
- Read a DAG from a file
|
||||
- Analyze its properties
|
||||
- Mute the graph using the operations NodeReduction and NodeSplit
|
||||
- Mute the graph using the operations NodeFusion, NodeReduction and NodeSplit
|
||||
|
||||
## Coming Soon:
|
||||
- Add Code Generation from finished DAG
|
||||
|
@@ -1,19 +1,9 @@
|
||||
ENV["UCX_ERROR_SIGNALS"] = "SIGILL,SIGBUS,SIGFPE"
|
||||
|
||||
using MetagraphOptimization
|
||||
using QEDbase
|
||||
using QEDcore
|
||||
using QEDprocesses
|
||||
using Random
|
||||
using UUIDs
|
||||
|
||||
using CUDA
|
||||
|
||||
using NamedDims
|
||||
using CSV
|
||||
using JLD2
|
||||
using FlexiMaps
|
||||
|
||||
RNG = Random.MersenneTwister(123)
|
||||
|
||||
function mock_machine()
|
||||
@@ -31,164 +21,103 @@ function mock_machine()
|
||||
)
|
||||
end
|
||||
|
||||
function congruent_input_momenta(processDescription::GenericQEDProcess, omega::Number)
|
||||
function congruent_input(processDescription::QEDProcessDescription, omega::Number)
|
||||
# generate an input sample for given e + nk -> e' + k' process, where the nk are equal
|
||||
massSum = 0
|
||||
inputMasses = Vector{Float64}()
|
||||
for particle in incoming_particles(processDescription)
|
||||
push!(inputMasses, mass(particle))
|
||||
for (particle, n) in processDescription.inParticles
|
||||
for _ in 1:n
|
||||
massSum += mass(particle)
|
||||
push!(inputMasses, mass(particle))
|
||||
end
|
||||
end
|
||||
outputMasses = Vector{Float64}()
|
||||
for particle in outgoing_particles(processDescription)
|
||||
push!(outputMasses, mass(particle))
|
||||
for (particle, n) in processDescription.outParticles
|
||||
for _ in 1:n
|
||||
massSum += mass(particle)
|
||||
push!(outputMasses, mass(particle))
|
||||
end
|
||||
end
|
||||
|
||||
initial_momenta = [
|
||||
i == length(inputMasses) ? SFourMomentum(1, 0, 0, 0) : SFourMomentum(omega, 0, 0, omega) for
|
||||
initialMomenta = [
|
||||
i == 1 ? SFourMomentum(1, 0, 0, 0) : SFourMomentum(omega, 0, 0, omega) for
|
||||
i in 1:length(inputMasses)
|
||||
]
|
||||
|
||||
ss = sqrt(sum(initial_momenta) * sum(initial_momenta))
|
||||
final_momenta = MetagraphOptimization.generate_physical_massive_moms(RNG, ss, outputMasses)
|
||||
# add some extra random mass to allow for some momentum
|
||||
ss = sqrt(sum(initialMomenta) * sum(initialMomenta))
|
||||
|
||||
return (tuple(initial_momenta...), tuple(final_momenta...))
|
||||
end
|
||||
result = Vector{QEDProcessInput}()
|
||||
sizehint!(result, 16)
|
||||
|
||||
# theta ∈ [0, 2π] and phi ∈ [0, 2π]
|
||||
function congruent_input_momenta_scenario_2(
|
||||
processDescription::GenericQEDProcess,
|
||||
omega::Number,
|
||||
theta::Number,
|
||||
phi::Number,
|
||||
)
|
||||
# -------------
|
||||
# same as above
|
||||
|
||||
# generate an input sample for given e + nk -> e' + k' process, where the nk are equal
|
||||
inputMasses = Vector{Float64}()
|
||||
for particle in incoming_particles(processDescription)
|
||||
push!(inputMasses, mass(particle))
|
||||
end
|
||||
outputMasses = Vector{Float64}()
|
||||
for particle in outgoing_particles(processDescription)
|
||||
push!(outputMasses, mass(particle))
|
||||
end
|
||||
|
||||
initial_momenta = [
|
||||
i == length(inputMasses) ? SFourMomentum(1, 0, 0, 0) : SFourMomentum(omega, 0, 0, omega) for
|
||||
i in 1:length(inputMasses)
|
||||
]
|
||||
|
||||
ss = sqrt(sum(initial_momenta) * sum(initial_momenta))
|
||||
|
||||
# up to here
|
||||
# ----------
|
||||
|
||||
# now calculate the final_momenta from omega, cos_theta and phi
|
||||
n = number_particles(processDescription, Incoming(), Photon())
|
||||
|
||||
cos_theta = cos(theta)
|
||||
omega_prime = (n * omega) / (1 + n * omega * (1 - cos_theta))
|
||||
|
||||
k_prime =
|
||||
omega_prime * SFourMomentum(1, sqrt(1 - cos_theta^2) * cos(phi), sqrt(1 - cos_theta^2) * sin(phi), cos_theta)
|
||||
p_prime = sum(initial_momenta) - k_prime
|
||||
|
||||
final_momenta = (k_prime, p_prime)
|
||||
|
||||
return (tuple(initial_momenta...), tuple(final_momenta...))
|
||||
|
||||
end
|
||||
|
||||
function build_psp(processDescription::GenericQEDProcess, momenta)
|
||||
return PhaseSpacePoint(
|
||||
processDescription,
|
||||
PerturbativeQED(),
|
||||
PhasespaceDefinition(SphericalCoordinateSystem(), ElectronRestFrame()),
|
||||
momenta[1],
|
||||
momenta[2],
|
||||
spin_pol_combinations = Iterators.product(
|
||||
[SpinUp, SpinDown], [SpinUp, SpinDown], [PolX, PolY], [PolX, PolY]
|
||||
)
|
||||
end
|
||||
for (in_spin, out_spin, in_pol, out_pol) in spin_pol_combinations
|
||||
|
||||
# hack to fix stacksize for threading
|
||||
with_stacksize(f, n) = fetch(schedule(Task(f, n)))
|
||||
# get the electron first, then the n photons
|
||||
particles = Vector{QEDParticle}()
|
||||
|
||||
# scenario 2
|
||||
N = 1024 # thetas
|
||||
M = 1024 # phis
|
||||
K = 64 # omegas
|
||||
|
||||
thetas = collect(LinRange(0, 2π, N))
|
||||
phis = collect(LinRange(0, 2π, M))
|
||||
omegas = collect(maprange(log, 2e-2, 2e-7, K))
|
||||
|
||||
for photons in 1:5
|
||||
# temp process to generate momenta
|
||||
println("Generating $(K*N*M) inputs for $photons photons (Scenario 2 grid walk)...")
|
||||
temp_process = parse_process("k"^photons * "e->ke", QEDModel(), PolX(), SpinUp(), PolX(), SpinUp())
|
||||
|
||||
input_momenta =
|
||||
Array{typeof(congruent_input_momenta_scenario_2(temp_process, omegas[1], thetas[1], phis[1]))}(undef, (K, N, M))
|
||||
|
||||
Threads.@threads for k in 1:K
|
||||
Threads.@threads for i in 1:N
|
||||
Threads.@threads for j in 1:M
|
||||
input_momenta[k, i, j] = congruent_input_momenta_scenario_2(temp_process, omegas[k], thetas[i], phis[j])
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
cu_results = CuArray{Float64}(undef, size(input_momenta))
|
||||
fill!(cu_results, 0.0)
|
||||
|
||||
i = 1
|
||||
for (in_pol, in_spin, out_pol, out_spin) in
|
||||
Iterators.product([PolX(), PolY()], [SpinUp(), SpinDown()], [PolX(), PolY()], [SpinUp(), SpinDown()])
|
||||
|
||||
print(
|
||||
"[$i/16] Calculating for spin/pol config: $in_pol, $in_spin -> $out_pol, $out_spin... Preparing inputs... ",
|
||||
)
|
||||
process = parse_process("k"^photons * "e->ke", QEDModel(), in_pol, in_spin, out_pol, out_spin)
|
||||
|
||||
inputs = Array{typeof(build_psp(process, input_momenta[1, 1, 1]))}(undef, (K, N, M))
|
||||
#println("input_momenta: $input_momenta")
|
||||
Threads.@threads for k in 1:K
|
||||
Threads.@threads for i in 1:N
|
||||
Threads.@threads for j in 1:M
|
||||
inputs[k, i, j] = build_psp(process, input_momenta[k, i, j])
|
||||
for (particle, n) in processDescription.inParticles
|
||||
if particle <: FermionStateful
|
||||
mom = initialMomenta[1]
|
||||
push!(particles, particle(mom, in_spin()))
|
||||
elseif particle <: PhotonStateful
|
||||
for i in 1:n
|
||||
mom = initialMomenta[i + 1]
|
||||
push!(particles, particle(mom, in_pol()))
|
||||
end
|
||||
else
|
||||
@assert false
|
||||
end
|
||||
end
|
||||
cu_inputs = CuArray(inputs)
|
||||
|
||||
print("Preparing graph... ")
|
||||
graph = gen_graph(process)
|
||||
optimize_to_fixpoint!(ReductionOptimizer(), graph)
|
||||
print("Preparing function... ")
|
||||
kernel! = get_cuda_kernel(graph, process, mock_machine())
|
||||
#func = get_compute_function(graph, process, mock_machine())
|
||||
final_momenta = MetagraphOptimization.generate_physical_massive_moms(
|
||||
RNG, ss, outputMasses
|
||||
)
|
||||
index = 1
|
||||
for (particle, n) in processDescription.outParticles
|
||||
for _ in 1:n
|
||||
if particle <: FermionStateful
|
||||
push!(particles, particle(final_momenta[index], out_spin()))
|
||||
elseif particle <: PhotonStateful
|
||||
push!(particles, particle(final_momenta[index], out_pol()))
|
||||
end
|
||||
index += 1
|
||||
end
|
||||
end
|
||||
|
||||
print("Calculating... ")
|
||||
ts = 32
|
||||
bs = Int64(length(cu_inputs) / 32)
|
||||
inFerms = MetagraphOptimization._svector_from_type(
|
||||
processDescription, FermionStateful{Incoming,in_spin}, particles
|
||||
)
|
||||
outFerms = MetagraphOptimization._svector_from_type(
|
||||
processDescription, FermionStateful{Outgoing,out_spin}, particles
|
||||
)
|
||||
inAntiferms = MetagraphOptimization._svector_from_type(
|
||||
processDescription, AntiFermionStateful{Incoming,in_spin}, particles
|
||||
)
|
||||
outAntiferms = MetagraphOptimization._svector_from_type(
|
||||
processDescription, AntiFermionStateful{Outgoing,out_spin}, particles
|
||||
)
|
||||
inPhotons = MetagraphOptimization._svector_from_type(
|
||||
processDescription, PhotonStateful{Incoming,in_pol}, particles
|
||||
)
|
||||
outPhotons = MetagraphOptimization._svector_from_type(
|
||||
processDescription, PhotonStateful{Outgoing,out_pol}, particles
|
||||
)
|
||||
|
||||
outputs = CuArray{ComplexF64}(undef, size(cu_inputs))
|
||||
processInput = QEDProcessInput(
|
||||
processDescription,
|
||||
inFerms,
|
||||
outFerms,
|
||||
inAntiferms,
|
||||
outAntiferms,
|
||||
inPhotons,
|
||||
outPhotons,
|
||||
)
|
||||
|
||||
@cuda threads = ts blocks = bs always_inline = true kernel!(cu_inputs, outputs, length(cu_inputs))
|
||||
CUDA.device_synchronize()
|
||||
cu_results += abs2.(outputs)
|
||||
|
||||
println("Done.")
|
||||
i += 1
|
||||
push!(result, processInput)
|
||||
end
|
||||
|
||||
println("Writing results")
|
||||
|
||||
out_ph_moms = getindex.(getindex.(input_momenta, 2), 1)
|
||||
out_el_moms = getindex.(getindex.(input_momenta, 2), 2)
|
||||
|
||||
results = NamedDimsArray{(:omegas, :thetas, :phis)}(Array(cu_results))
|
||||
println("Named results array: $(typeof(results))")
|
||||
|
||||
@save "$(photons)_congruent_photons_grid.jld2" omegas thetas phis results
|
||||
return result
|
||||
end
|
||||
|
60
examples/plot_chain.jl
Normal file
60
examples/plot_chain.jl
Normal file
@@ -0,0 +1,60 @@
|
||||
using MetagraphOptimization
|
||||
using Plots
|
||||
using Random
|
||||
|
||||
function gen_plot(filepath)
|
||||
name = basename(filepath)
|
||||
name, _ = splitext(name)
|
||||
|
||||
filepath = joinpath(@__DIR__, "../input/", filepath)
|
||||
if !isfile(filepath)
|
||||
println("File ", filepath, " does not exist, skipping")
|
||||
return
|
||||
end
|
||||
|
||||
g = parse_dag(filepath, ABCModel())
|
||||
|
||||
Random.seed!(1)
|
||||
|
||||
println("Random Walking... ")
|
||||
|
||||
x = Vector{Float64}()
|
||||
y = Vector{Float64}()
|
||||
|
||||
for i in 1:30
|
||||
print("\r", i)
|
||||
# push
|
||||
opt = get_operations(g)
|
||||
|
||||
# choose one of fuse/split/reduce
|
||||
option = rand(1:3)
|
||||
if option == 1 && !isempty(opt.nodeFusions)
|
||||
push_operation!(g, rand(collect(opt.nodeFusions)))
|
||||
println("NF")
|
||||
elseif option == 2 && !isempty(opt.nodeReductions)
|
||||
push_operation!(g, rand(collect(opt.nodeReductions)))
|
||||
println("NR")
|
||||
elseif option == 3 && !isempty(opt.nodeSplits)
|
||||
push_operation!(g, rand(collect(opt.nodeSplits)))
|
||||
println("NS")
|
||||
else
|
||||
i = i - 1
|
||||
end
|
||||
|
||||
props = get_properties(g)
|
||||
push!(x, props.data)
|
||||
push!(y, props.computeEffort)
|
||||
end
|
||||
|
||||
println("\rDone.")
|
||||
|
||||
plot([x[1], x[2]], [y[1], y[2]], linestyle = :solid, linewidth = 1, color = :red, legend = false)
|
||||
# Create lines connecting the reference point to each data point
|
||||
for i in 3:length(x)
|
||||
plot!([x[i - 1], x[i]], [y[i - 1], y[i]], linestyle = :solid, linewidth = 1, color = :red)
|
||||
end
|
||||
|
||||
return gui()
|
||||
end
|
||||
|
||||
gen_plot("AB->ABBB.txt")
|
96
examples/plot_star.jl
Normal file
96
examples/plot_star.jl
Normal file
@@ -0,0 +1,96 @@
|
||||
using MetagraphOptimization
|
||||
using Plots
|
||||
using Random
|
||||
|
||||
function gen_plot(filepath)
|
||||
name = basename(filepath)
|
||||
name, _ = splitext(name)
|
||||
|
||||
filepath = joinpath(@__DIR__, "../input/", filepath)
|
||||
if !isfile(filepath)
|
||||
println("File ", filepath, " does not exist, skipping")
|
||||
return
|
||||
end
|
||||
|
||||
g = parse_dag(filepath, ABCModel())
|
||||
|
||||
Random.seed!(1)
|
||||
|
||||
println("Random Walking... ")
|
||||
|
||||
for i in 1:30
|
||||
print("\r", i)
|
||||
# push
|
||||
opt = get_operations(g)
|
||||
|
||||
# choose one of fuse/split/reduce
|
||||
option = rand(1:3)
|
||||
if option == 1 && !isempty(opt.nodeFusions)
|
||||
push_operation!(g, rand(collect(opt.nodeFusions)))
|
||||
println("NF")
|
||||
elseif option == 2 && !isempty(opt.nodeReductions)
|
||||
push_operation!(g, rand(collect(opt.nodeReductions)))
|
||||
println("NR")
|
||||
elseif option == 3 && !isempty(opt.nodeSplits)
|
||||
push_operation!(g, rand(collect(opt.nodeSplits)))
|
||||
println("NS")
|
||||
else
|
||||
i = i - 1
|
||||
end
|
||||
end
|
||||
|
||||
println("\rDone.")
|
||||
|
||||
|
||||
|
||||
|
||||
props = get_properties(g)
|
||||
x0 = props.data
|
||||
y0 = props.computeEffort
|
||||
|
||||
x = Vector{Float64}()
|
||||
y = Vector{Float64}()
|
||||
names = Vector{String}()
|
||||
|
||||
opt = get_operations(g)
|
||||
for op in opt.nodeFusions
|
||||
push_operation!(g, op)
|
||||
props = get_properties(g)
|
||||
push!(x, props.data)
|
||||
push!(y, props.computeEffort)
|
||||
pop_operation!(g)
|
||||
|
||||
push!(names, "NF: (" * string(props.data) * ", " * string(props.computeEffort) * ")")
|
||||
end
|
||||
for op in opt.nodeReductions
|
||||
push_operation!(g, op)
|
||||
props = get_properties(g)
|
||||
push!(x, props.data)
|
||||
push!(y, props.computeEffort)
|
||||
pop_operation!(g)
|
||||
|
||||
push!(names, "NR: (" * string(props.data) * ", " * string(props.computeEffort) * ")")
|
||||
end
|
||||
for op in opt.nodeSplits
|
||||
push_operation!(g, op)
|
||||
props = get_properties(g)
|
||||
push!(x, props.data)
|
||||
push!(y, props.computeEffort)
|
||||
pop_operation!(g)
|
||||
|
||||
push!(names, "NS: (" * string(props.data) * ", " * string(props.computeEffort) * ")")
|
||||
end
|
||||
|
||||
plot([x0, x[1]], [y0, y[1]], linestyle = :solid, linewidth = 1, color = :red, legend = false)
|
||||
# Create lines connecting the reference point to each data point
|
||||
for i in 2:length(x)
|
||||
plot!([x0, x[i]], [y0, y[i]], linestyle = :solid, linewidth = 1, color = :red)
|
||||
end
|
||||
#scatter!(x, y, label=names)
|
||||
|
||||
print(names)
|
||||
|
||||
return gui()
|
||||
end
|
||||
|
||||
gen_plot("AB->ABBB.txt")
|
BIN
images/contour_plot_congruent_in_photons.gif
(Stored with Git LFS)
BIN
images/contour_plot_congruent_in_photons.gif
(Stored with Git LFS)
Binary file not shown.
@@ -1,5 +1,5 @@
|
||||
# Optimizer Plots
|
||||
|
||||
Plots of FusionOptimizer (deprecated), ReductionOptimizer, SplitOptimizer, RandomWalkOptimizer, and GreedyOptimizer, executed on a system with 32 threads and an A30 GPU.
|
||||
Plots of FusionOptimizer, ReductionOptimizer, SplitOptimizer, RandomWalkOptimizer, and GreedyOptimizer, executed on a system with 32 threads and an A30 GPU.
|
||||
|
||||
Benchmarked using `notebooks/optimizers.ipynb`.
|
||||
|
@@ -413,7 +413,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Julia 1.10.4",
|
||||
"display_name": "Julia 1.10.2",
|
||||
"language": "julia",
|
||||
"name": "julia-1.10"
|
||||
},
|
||||
@@ -421,7 +421,7 @@
|
||||
"file_extension": ".jl",
|
||||
"mimetype": "application/julia",
|
||||
"name": "julia",
|
||||
"version": "1.10.4"
|
||||
"version": "1.10.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
File diff suppressed because one or more lines are too long
@@ -54,6 +54,8 @@
|
||||
"cu_inputs = CuArray(inputs)\n",
|
||||
"optimizer = RandomWalkOptimizer(MersenneTwister(0))# GreedyOptimizer(GlobalMetricEstimator())\n",
|
||||
"\n",
|
||||
"#done: split, reduce, fuse, greedy\n",
|
||||
"\n",
|
||||
"process_str_short = \"qed_k3\"\n",
|
||||
"optim_str = \"Random Walk Optimization\"\n",
|
||||
"optim_str_short=\"random\"\n",
|
||||
|
BIN
results/1_congruent_photons_grid.jld2
(Stored with Git LFS)
BIN
results/1_congruent_photons_grid.jld2
(Stored with Git LFS)
Binary file not shown.
BIN
results/2_congruent_photons_grid.jld2
(Stored with Git LFS)
BIN
results/2_congruent_photons_grid.jld2
(Stored with Git LFS)
Binary file not shown.
BIN
results/3_congruent_photons_grid.jld2
(Stored with Git LFS)
BIN
results/3_congruent_photons_grid.jld2
(Stored with Git LFS)
Binary file not shown.
BIN
results/4_congruent_photons_grid.jld2
(Stored with Git LFS)
BIN
results/4_congruent_photons_grid.jld2
(Stored with Git LFS)
Binary file not shown.
BIN
results/5_congruent_photons_grid.jld2
(Stored with Git LFS)
BIN
results/5_congruent_photons_grid.jld2
(Stored with Git LFS)
Binary file not shown.
@@ -19,6 +19,7 @@ export AbstractTask
|
||||
export AbstractComputeTask
|
||||
export AbstractDataTask
|
||||
export DataTask
|
||||
export FusedComputeTask
|
||||
export PossibleOperations
|
||||
export GraphProperties
|
||||
|
||||
@@ -43,6 +44,7 @@ export is_valid, is_scheduled
|
||||
# graph operation related
|
||||
export Operation
|
||||
export AppliedOperation
|
||||
export NodeFusion
|
||||
export NodeReduction
|
||||
export NodeSplit
|
||||
export push_operation!
|
||||
@@ -64,7 +66,8 @@ export ComputeTaskABC_Sum
|
||||
|
||||
# QED model
|
||||
export FeynmanDiagram, FeynmanVertex, FeynmanTie, FeynmanParticle
|
||||
export GenericQEDProcess, QEDModel
|
||||
export PhotonStateful, FermionStateful, AntiFermionStateful
|
||||
export QEDParticle, QEDProcessDescription, QEDProcessInput, QEDModel
|
||||
export ComputeTaskQED_P
|
||||
export ComputeTaskQED_S1
|
||||
export ComputeTaskQED_S2
|
||||
@@ -86,7 +89,7 @@ export GlobalMetricEstimator, CDCost
|
||||
|
||||
# optimization
|
||||
export AbstractOptimizer, GreedyOptimizer, RandomWalkOptimizer
|
||||
export ReductionOptimizer, SplitOptimizer
|
||||
export ReductionOptimizer, SplitOptimizer, FusionOptimizer
|
||||
export optimize_step!, optimize!
|
||||
export fixpoint_reached, optimize_to_fixpoint!
|
||||
|
||||
@@ -98,6 +101,9 @@ export ==, in, show, isempty, delete!, length
|
||||
|
||||
export bytes_to_human_readable
|
||||
|
||||
# TODO: this is probably not good
|
||||
import QEDprocesses.compute
|
||||
|
||||
import Base.length
|
||||
import Base.show
|
||||
import Base.==
|
||||
@@ -162,30 +168,28 @@ include("optimization/interface.jl")
|
||||
include("optimization/greedy.jl")
|
||||
include("optimization/random_walk.jl")
|
||||
include("optimization/reduce.jl")
|
||||
include("optimization/fuse.jl")
|
||||
include("optimization/split.jl")
|
||||
|
||||
include("models/interface.jl")
|
||||
include("models/print.jl")
|
||||
|
||||
include("models/physics_models/interface.jl")
|
||||
include("models/abc/types.jl")
|
||||
include("models/abc/particle.jl")
|
||||
include("models/abc/compute.jl")
|
||||
include("models/abc/create.jl")
|
||||
include("models/abc/properties.jl")
|
||||
include("models/abc/parse.jl")
|
||||
include("models/abc/print.jl")
|
||||
|
||||
include("models/physics_models/abc/types.jl")
|
||||
include("models/physics_models/abc/particle.jl")
|
||||
include("models/physics_models/abc/compute.jl")
|
||||
include("models/physics_models/abc/create.jl")
|
||||
include("models/physics_models/abc/properties.jl")
|
||||
include("models/physics_models/abc/parse.jl")
|
||||
include("models/physics_models/abc/print.jl")
|
||||
|
||||
include("models/physics_models/qed/utility.jl")
|
||||
include("models/physics_models/qed/types.jl")
|
||||
include("models/physics_models/qed/particle.jl")
|
||||
include("models/physics_models/qed/diagrams.jl")
|
||||
include("models/physics_models/qed/compute.jl")
|
||||
include("models/physics_models/qed/create.jl")
|
||||
include("models/physics_models/qed/properties.jl")
|
||||
include("models/physics_models/qed/parse.jl")
|
||||
include("models/physics_models/qed/print.jl")
|
||||
include("models/qed/types.jl")
|
||||
include("models/qed/particle.jl")
|
||||
include("models/qed/diagrams.jl")
|
||||
include("models/qed/compute.jl")
|
||||
include("models/qed/create.jl")
|
||||
include("models/qed/properties.jl")
|
||||
include("models/qed/parse.jl")
|
||||
include("models/qed/print.jl")
|
||||
|
||||
include("devices/measure.jl")
|
||||
include("devices/detect.jl")
|
||||
@@ -194,7 +198,7 @@ include("devices/impl.jl")
|
||||
include("devices/numa/impl.jl")
|
||||
include("devices/cuda/impl.jl")
|
||||
include("devices/rocm/impl.jl")
|
||||
#include("devices/oneapi/impl.jl")
|
||||
include("devices/oneapi/impl.jl")
|
||||
|
||||
include("scheduler/interface.jl")
|
||||
include("scheduler/greedy.jl")
|
||||
|
@@ -1,70 +1,72 @@
|
||||
"""
|
||||
get_compute_function(graph::DAG, instance, machine::Machine)
|
||||
get_compute_function(graph::DAG, process::AbstractProcessDescription, machine::Machine)
|
||||
|
||||
Return a function of signature `compute_<id>(input::input_type(instance))`, which will return the result of the DAG computation on the given input.
|
||||
Return a function of signature `compute_<id>(input::AbstractProcessInput)`, which will return the result of the DAG computation on the given input.
|
||||
"""
|
||||
function get_compute_function(graph::DAG, instance, machine::Machine)
|
||||
tape = gen_tape(graph, instance, machine)
|
||||
function get_compute_function(graph::DAG, process::AbstractProcessDescription, machine::Machine)
|
||||
tape = gen_tape(graph, process, machine)
|
||||
|
||||
initCaches = Expr(:block, tape.initCachesCode...)
|
||||
assignInputs = Expr(:block, tape.inputAssignCode...)
|
||||
assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...)
|
||||
code = Expr(:block, expr_from_fc.(tape.computeCode)...)
|
||||
|
||||
functionId = to_var_name(UUIDs.uuid1(rng[1]))
|
||||
resSym = eval(gen_access_expr(entry_device(tape.machine), tape.outputSymbol))
|
||||
expr = Meta.parse(
|
||||
"function compute_$(functionId)(data_input::$(input_type(instance))) $(initCaches); $(assignInputs); $code; return $resSym; end",
|
||||
"function compute_$(functionId)(data_input::AbstractProcessInput) $(initCaches); $(assignInputs); $code; return $resSym; end",
|
||||
)
|
||||
|
||||
return expr
|
||||
func = eval(expr)
|
||||
|
||||
return func
|
||||
end
|
||||
|
||||
"""
|
||||
get_cuda_kernel(graph::DAG, instance, machine::Machine)
|
||||
get_cuda_kernel(graph::DAG, process::AbstractProcessDescription, machine::Machine)
|
||||
|
||||
Return a function of signature `compute_<id>(input::CuVector, output::CuVector, n::Int64)`, which will return the result of the DAG computation of the input on the given output variable.
|
||||
"""
|
||||
function get_cuda_kernel(graph::DAG, instance, machine::Machine)
|
||||
tape = gen_tape(graph, instance, machine)
|
||||
function get_cuda_kernel(graph::DAG, process::AbstractProcessDescription, machine::Machine)
|
||||
tape = gen_tape(graph, process, machine)
|
||||
|
||||
initCaches = Expr(:block, tape.initCachesCode...)
|
||||
assignInputs = Expr(:block, tape.inputAssignCode...)
|
||||
assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...)
|
||||
code = Expr(:block, expr_from_fc.(tape.computeCode)...)
|
||||
|
||||
functionId = to_var_name(UUIDs.uuid1(rng[1]))
|
||||
resSym = eval(gen_access_expr(entry_device(tape.machine), tape.outputSymbol))
|
||||
expr = Meta.parse(
|
||||
"function compute_$(functionId)(input_vector, output_vector, n::Int64)
|
||||
id = (blockIdx().x - 1) * blockDim().x + threadIdx().x
|
||||
if (id > n)
|
||||
return
|
||||
end
|
||||
@inline data_input = input_vector[id]
|
||||
$(initCaches)
|
||||
$(assignInputs)
|
||||
$code
|
||||
@inline output_vector[id] = $resSym
|
||||
return nothing
|
||||
end"
|
||||
)
|
||||
expr = Meta.parse("function compute_$(functionId)(input_vector, output_vector, n::Int64)
|
||||
id = (blockIdx().x - 1) * blockDim().x + threadIdx().x
|
||||
if (id > n)
|
||||
return
|
||||
end
|
||||
@inline data_input = input_vector[id]
|
||||
$(initCaches)
|
||||
$(assignInputs)
|
||||
$code
|
||||
@inline output_vector[id] = $resSym
|
||||
return nothing
|
||||
end")
|
||||
|
||||
return expr
|
||||
func = eval(expr)
|
||||
|
||||
return func
|
||||
end
|
||||
|
||||
"""
|
||||
execute(graph::DAG, instance, machine::Machine, input)
|
||||
execute(graph::DAG, process::AbstractProcessDescription, machine::Machine, input::AbstractProcessInput)
|
||||
|
||||
Execute the code of the given `graph` on the given input values.
|
||||
Execute the code of the given `graph` on the given input particles.
|
||||
|
||||
This is essentially shorthand for
|
||||
```julia
|
||||
tape = gen_tape(graph, instance, machine)
|
||||
tape = gen_tape(graph, process, machine)
|
||||
return execute_tape(tape, input)
|
||||
```
|
||||
|
||||
See also: [`parse_dag`](@ref), [`parse_process`](@ref), [`gen_process_input`](@ref)
|
||||
"""
|
||||
function execute(graph::DAG, instance, machine::Machine, input)
|
||||
tape = gen_tape(graph, instance, machine)
|
||||
function execute(graph::DAG, process::AbstractProcessDescription, machine::Machine, input::AbstractProcessInput)
|
||||
tape = gen_tape(graph, process, machine)
|
||||
return execute_tape(tape, input)
|
||||
end
|
||||
|
@@ -1,11 +1,10 @@
|
||||
# TODO: do this with macros
|
||||
function call_fc(fc::FunctionCall{VectorT, 0}, cache::Dict{Symbol, Any}) where {VectorT <: SVector{1}}
|
||||
cache[fc.return_symbol] = fc.func(cache[fc.arguments[1]])
|
||||
return nothing
|
||||
end
|
||||
|
||||
function call_fc(fc::FunctionCall{VectorT, 1}, cache::Dict{Symbol, Any}) where {VectorT <: SVector{1}}
|
||||
cache[fc.return_symbol] = fc.func(fc.value_arguments[1], cache[fc.arguments[1]])
|
||||
cache[fc.return_symbol] = fc.func(fc.additional_arguments[1], cache[fc.arguments[1]])
|
||||
return nothing
|
||||
end
|
||||
|
||||
@@ -15,12 +14,12 @@ function call_fc(fc::FunctionCall{VectorT, 0}, cache::Dict{Symbol, Any}) where {
|
||||
end
|
||||
|
||||
function call_fc(fc::FunctionCall{VectorT, 1}, cache::Dict{Symbol, Any}) where {VectorT <: SVector{2}}
|
||||
cache[fc.return_symbol] = fc.func(fc.value_arguments[1], cache[fc.arguments[1]], cache[fc.arguments[2]])
|
||||
cache[fc.return_symbol] = fc.func(fc.additional_arguments[1], cache[fc.arguments[1]], cache[fc.arguments[2]])
|
||||
return nothing
|
||||
end
|
||||
|
||||
function call_fc(fc::FunctionCall{VectorT, 1}, cache::Dict{Symbol, Any}) where {VectorT}
|
||||
cache[fc.return_symbol] = fc.func(fc.value_arguments[1], getindex.(Ref(cache), fc.arguments)...)
|
||||
cache[fc.return_symbol] = fc.func(fc.additional_arguments[1], getindex.(Ref(cache), fc.arguments)...)
|
||||
return nothing
|
||||
end
|
||||
|
||||
@@ -32,7 +31,7 @@ Execute the given [`FunctionCall`](@ref) on the dictionary.
|
||||
Several more specialized versions of this function exist to reduce vector unrolling work for common cases.
|
||||
"""
|
||||
function call_fc(fc::FunctionCall{VectorT, M}, cache::Dict{Symbol, Any}) where {VectorT, M}
|
||||
cache[fc.return_symbol] = fc.func(fc.value_arguments..., getindex.(Ref(cache), fc.arguments)...)
|
||||
cache[fc.return_symbol] = fc.func(fc.additional_arguments..., getindex.(Ref(cache), fc.arguments)...)
|
||||
return nothing
|
||||
end
|
||||
|
||||
@@ -48,8 +47,12 @@ end
|
||||
For a given function call, return an expression evaluating it.
|
||||
"""
|
||||
function expr_from_fc(fc::FunctionCall{VectorT, M}) where {VectorT, M}
|
||||
func_call =
|
||||
Expr(:call, Symbol(fc.func), fc.value_arguments..., eval.(gen_access_expr.(Ref(fc.device), fc.arguments))...)
|
||||
func_call = Expr(
|
||||
:call,
|
||||
Symbol(fc.func),
|
||||
fc.additional_arguments...,
|
||||
eval.(gen_access_expr.(Ref(fc.device), fc.arguments))...,
|
||||
)
|
||||
|
||||
expr = :($(eval(gen_access_expr(fc.device, fc.return_symbol))) = $func_call)
|
||||
return expr
|
||||
@@ -70,32 +73,51 @@ function gen_cache_init_code(machine::Machine)
|
||||
return initializeCaches
|
||||
end
|
||||
|
||||
"""
|
||||
part_from_x(type::Type, index::Int, x::AbstractProcessInput)
|
||||
|
||||
Return the [`ParticleValue`](@ref) of the given type of particle with the given `index` from the given process input.
|
||||
|
||||
Function is wrapped into a [`FunctionCall`](@ref) in [`gen_input_assignment_code`](@ref).
|
||||
"""
|
||||
part_from_x(type::Type, index::Int, x::AbstractProcessInput) =
|
||||
ParticleValue{type, ComplexF64}(get_particle(x, type, index), one(ComplexF64))
|
||||
|
||||
"""
|
||||
gen_input_assignment_code(
|
||||
inputSymbols::Dict{String, Vector{Symbol}},
|
||||
instance::AbstractProblemInstance,
|
||||
processDescription::AbstractProcessDescription,
|
||||
machine::Machine,
|
||||
problemInputSymbol::Symbol = :data_input,
|
||||
processInputSymbol::Symbol = :input,
|
||||
)
|
||||
|
||||
Return a `Vector{Expr}` doing the input assignments from the given `problemInputSymbol` onto the `inputSymbols`.
|
||||
Return a `Vector{Expr}` doing the input assignments from the given `processInputSymbol` onto the `inputSymbols`.
|
||||
"""
|
||||
function gen_input_assignment_code(
|
||||
inputSymbols::Dict{String, Vector{Symbol}},
|
||||
instance,
|
||||
processDescription::AbstractProcessDescription,
|
||||
machine::Machine,
|
||||
problemInputSymbol::Symbol = :data_input,
|
||||
processInputSymbol::Symbol = :input,
|
||||
)
|
||||
assignInputs = Vector{Expr}()
|
||||
@assert length(inputSymbols) >=
|
||||
sum(values(in_particles(processDescription))) + sum(values(out_particles(processDescription))) "Number of input Symbols is smaller than the number of particles in the process description"
|
||||
|
||||
assignInputs = Vector{FunctionCall}()
|
||||
for (name, symbols) in inputSymbols
|
||||
(type, index) = type_index_from_name(model(processDescription), name)
|
||||
# make a function for this, since we can't use anonymous functions in the FunctionCall
|
||||
|
||||
for symbol in symbols
|
||||
device = entry_device(machine)
|
||||
push!(
|
||||
assignInputs,
|
||||
Meta.parse(
|
||||
"$(eval(gen_access_expr(device, symbol))) = $(input_expr(instance, name, problemInputSymbol))",
|
||||
FunctionCall(
|
||||
# x is the process input
|
||||
part_from_x,
|
||||
SVector{1, Symbol}(processInputSymbol),
|
||||
SVector{2, Any}(type, index),
|
||||
symbol,
|
||||
device,
|
||||
),
|
||||
)
|
||||
end
|
||||
@@ -105,14 +127,14 @@ function gen_input_assignment_code(
|
||||
end
|
||||
|
||||
"""
|
||||
gen_tape(graph::DAG, instance::AbstractProblemInstance, machine::Machine, scheduler::AbstractScheduler = GreedyScheduler())
|
||||
gen_tape(graph::DAG, process::AbstractProcessDescription, machine::Machine)
|
||||
|
||||
Generate the code for a given graph. The return value is a [`Tape`](@ref).
|
||||
|
||||
See also: [`execute`](@ref), [`execute_tape`](@ref)
|
||||
"""
|
||||
function gen_tape(graph::DAG, instance, machine::Machine, scheduler::AbstractScheduler = GreedyScheduler())
|
||||
schedule = schedule_dag(scheduler, graph, machine)
|
||||
function gen_tape(graph::DAG, process::AbstractProcessDescription, machine::Machine)
|
||||
schedule = schedule_dag(GreedyScheduler(), graph, machine)
|
||||
|
||||
# get inSymbols
|
||||
inputSyms = Dict{String, Vector{Symbol}}()
|
||||
@@ -128,24 +150,23 @@ function gen_tape(graph::DAG, instance, machine::Machine, scheduler::AbstractSch
|
||||
outSym = Symbol(to_var_name(get_exit_node(graph).id))
|
||||
|
||||
initCaches = gen_cache_init_code(machine)
|
||||
assignInputs = gen_input_assignment_code(inputSyms, instance, machine, :data_input)
|
||||
assignInputs = gen_input_assignment_code(inputSyms, process, machine, :input)
|
||||
|
||||
return Tape{input_type(instance)}(initCaches, assignInputs, schedule, inputSyms, outSym, Dict(), instance, machine)
|
||||
return Tape(initCaches, assignInputs, schedule, inputSyms, outSym, Dict(), process, machine)
|
||||
end
|
||||
|
||||
"""
|
||||
execute_tape(tape::Tape, input::Input) where {Input}
|
||||
execute_tape(tape::Tape, input::AbstractProcessInput)
|
||||
|
||||
Execute the given tape with the given input.
|
||||
|
||||
For implementation reasons, this disregards the set [`CacheStrategy`](@ref) of the devices and always uses a dictionary.
|
||||
"""
|
||||
function execute_tape(tape::Tape, input)
|
||||
function execute_tape(tape::Tape, input::AbstractProcessInput)
|
||||
cache = Dict{Symbol, Any}()
|
||||
cache[:data_input] = input
|
||||
cache[:input] = input
|
||||
# simply execute all the code snippets here
|
||||
@assert typeof(input) == input_type(tape.instance)
|
||||
# TODO: `@assert` that input fits the tape.instance
|
||||
# TODO: `@assert` that process input fits the tape.process
|
||||
for expr in tape.initCachesCode
|
||||
@eval $expr
|
||||
end
|
||||
|
@@ -1,21 +1,19 @@
|
||||
|
||||
"""
|
||||
Tape{INPUT}
|
||||
Tape
|
||||
|
||||
TODO: update docs
|
||||
- `INPUT` the input type of the problem instance
|
||||
|
||||
- `code::Vector{Expr}`: The julia expression containing the code for the whole graph.
|
||||
- `inputSymbols::Dict{String, Vector{Symbol}}`: A dictionary of symbols mapping the names of the input nodes of the graph to the symbols their inputs should be provided on.
|
||||
- `outputSymbol::Symbol`: The symbol of the final calculated value
|
||||
"""
|
||||
struct Tape{INPUT}
|
||||
struct Tape
|
||||
initCachesCode::Vector{Expr}
|
||||
inputAssignCode::Vector{Expr}
|
||||
inputAssignCode::Vector{FunctionCall}
|
||||
computeCode::Vector{FunctionCall}
|
||||
inputSymbols::Dict{String, Vector{Symbol}}
|
||||
outputSymbol::Symbol
|
||||
cache::Dict{Symbol, Any}
|
||||
instance::Any
|
||||
process::AbstractProcessDescription
|
||||
machine::Machine
|
||||
end
|
||||
|
@@ -10,7 +10,8 @@ Representation of a [`DAG`](@ref)'s cost as estimated by the [`GlobalMetricEstim
|
||||
|
||||
|
||||
!!! note
|
||||
Note that the `computeIntensity` doesn't necessarily make sense in the context of only operation costs.
|
||||
Note that the `computeIntensity` doesn't necessarily make sense in the context of only operation costs.
|
||||
For example, for node fusions this will always be 0, since the computeEffort is zero.
|
||||
It will still work as intended when adding/subtracting to/from a `graph_cost` estimate.
|
||||
"""
|
||||
const CDCost = NamedTuple{(:data, :computeEffort, :computeIntensity), Tuple{Float64, Float64, Float64}}
|
||||
@@ -54,6 +55,10 @@ function graph_cost(estimator::GlobalMetricEstimator, graph::DAG)
|
||||
)::CDCost
|
||||
end
|
||||
|
||||
function operation_effect(estimator::GlobalMetricEstimator, graph::DAG, operation::NodeFusion)
|
||||
return (data = -data(operation.input[2].task), computeEffort = 0.0, computeIntensity = 0.0)::CDCost
|
||||
end
|
||||
|
||||
function operation_effect(estimator::GlobalMetricEstimator, graph::DAG, operation::NodeReduction)
|
||||
s = length(operation.input) - 1
|
||||
return (
|
||||
|
@@ -169,6 +169,66 @@ function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invali
|
||||
return nothing
|
||||
end
|
||||
|
||||
function replace_children!(task::FusedComputeTask, before, after)
|
||||
replacedIn1 = length(findall(x -> x == before, task.t1_inputs))
|
||||
replacedIn2 = length(findall(x -> x == before, task.t2_inputs))
|
||||
|
||||
#@assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
|
||||
|
||||
replace!(task.t1_inputs, before => after)
|
||||
replace!(task.t2_inputs, before => after)
|
||||
|
||||
# recursively descend down the tree, but only in the tasks where we're replacing things
|
||||
if replacedIn1 > 0
|
||||
replace_children!(task.first_task, before, after)
|
||||
end
|
||||
if replacedIn2 > 0
|
||||
replace_children!(task.second_task, before, after)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function replace_children!(task::AbstractTask, before, after)
|
||||
return nothing
|
||||
end
|
||||
|
||||
function update_child!(graph::DAG, n::Node, child_before::Symbol, child_after::Symbol; track = true)
|
||||
# only need to update fused compute tasks
|
||||
if !(typeof(task(n)) <: FusedComputeTask)
|
||||
return nothing
|
||||
end
|
||||
|
||||
taskBefore = copy(task(n))
|
||||
|
||||
#=if !((child_before in task(n).t1_inputs) || (child_before in task(n).t2_inputs))
|
||||
println("------------------ Nothing to replace!! ------------------")
|
||||
child_ids = Vector{String}()
|
||||
for child in children(n)
|
||||
push!(child_ids, "$(child.id)")
|
||||
end
|
||||
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
|
||||
@assert false
|
||||
end=#
|
||||
|
||||
replace_children!(task(n), child_before, child_after)
|
||||
|
||||
#=if !((child_after in task(n).t1_inputs) || (child_after in task(n).t2_inputs))
|
||||
println("------------------ Did not replace anything!! ------------------")
|
||||
child_ids = Vector{String}()
|
||||
for child in children(n)
|
||||
push!(child_ids, "$(child.id)")
|
||||
end
|
||||
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
|
||||
@assert false
|
||||
end=#
|
||||
|
||||
# keep track
|
||||
if (track)
|
||||
push!(graph.diff.updatedChildren, (n, taskBefore))
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
get_snapshot_diff(graph::DAG)
|
||||
|
||||
@@ -180,6 +240,31 @@ function get_snapshot_diff(graph::DAG)
|
||||
return swapfield!(graph, :diff, Diff())
|
||||
end
|
||||
|
||||
"""
|
||||
invalidate_caches!(graph::DAG, operation::NodeFusion)
|
||||
|
||||
Invalidate the operation caches for a given [`NodeFusion`](@ref).
|
||||
|
||||
This deletes the operation from the graph's possible operations and from the involved nodes' own operation caches.
|
||||
"""
|
||||
function invalidate_caches!(graph::DAG, operation::NodeFusion)
|
||||
delete!(graph.possibleOperations, operation)
|
||||
|
||||
# delete the operation from all caches of nodes involved in the operation
|
||||
for n in [1, 3]
|
||||
for i in eachindex(operation.input[n].nodeFusions)
|
||||
if operation == operation.input[n].nodeFusions[i]
|
||||
splice!(operation.input[n].nodeFusions, i)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
operation.input[2].nodeFusion = missing
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
invalidate_caches!(graph::DAG, operation::NodeReduction)
|
||||
|
||||
@@ -226,6 +311,9 @@ function invalidate_operation_caches!(graph::DAG, node::ComputeTaskNode)
|
||||
if !ismissing(node.nodeSplit)
|
||||
invalidate_caches!(graph, node.nodeSplit)
|
||||
end
|
||||
while !isempty(node.nodeFusions)
|
||||
invalidate_caches!(graph, pop!(node.nodeFusions))
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
@@ -241,5 +329,8 @@ function invalidate_operation_caches!(graph::DAG, node::DataTaskNode)
|
||||
if !ismissing(node.nodeSplit)
|
||||
invalidate_caches!(graph, node.nodeSplit)
|
||||
end
|
||||
if !ismissing(node.nodeFusion)
|
||||
invalidate_caches!(graph, node.nodeFusion)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
@@ -7,6 +7,7 @@ A struct storing all possible operations on a [`DAG`](@ref).
|
||||
To get the [`PossibleOperations`](@ref) on a [`DAG`](@ref), use [`get_operations`](@ref).
|
||||
"""
|
||||
mutable struct PossibleOperations
|
||||
nodeFusions::Set{NodeFusion}
|
||||
nodeReductions::Set{NodeReduction}
|
||||
nodeSplits::Set{NodeSplit}
|
||||
end
|
||||
@@ -51,7 +52,7 @@ end
|
||||
Construct and return an empty [`PossibleOperations`](@ref) object.
|
||||
"""
|
||||
function PossibleOperations()
|
||||
return PossibleOperations(Set{NodeReduction}(), Set{NodeSplit}())
|
||||
return PossibleOperations(Set{NodeFusion}(), Set{NodeReduction}(), Set{NodeSplit}())
|
||||
end
|
||||
|
||||
"""
|
||||
|
@@ -40,11 +40,19 @@ function is_valid(graph::DAG)
|
||||
for ns in graph.possibleOperations.nodeSplits
|
||||
@assert is_valid(graph, ns)
|
||||
end
|
||||
for nf in graph.possibleOperations.nodeFusions
|
||||
@assert is_valid(graph, nf)
|
||||
end
|
||||
|
||||
for node in graph.dirtyNodes
|
||||
@assert node in graph "Dirty Node is not part of the graph!"
|
||||
@assert ismissing(node.nodeReduction) "Dirty Node has a NodeReduction!"
|
||||
@assert ismissing(node.nodeSplit) "Dirty Node has a NodeSplit!"
|
||||
if (typeof(node) <: DataTaskNode)
|
||||
@assert ismissing(node.nodeFusion) "Dirty DataTaskNode has a Node Fusion!"
|
||||
elseif (typeof(node) <: ComputeTaskNode)
|
||||
@assert isempty(node.nodeFusions) "Dirty ComputeTaskNode has Node Fusions!"
|
||||
end
|
||||
end
|
||||
|
||||
@assert is_connected(graph) "Graph is not connected!"
|
||||
|
@@ -1,17 +1,6 @@
|
||||
using AccurateArithmetic
|
||||
using StaticArrays
|
||||
|
||||
function input_expr(instance::ABCProcessDescription, name::String, psp_symbol::Symbol)
|
||||
(type, index) = type_index_from_name(ABCModel(), name)
|
||||
|
||||
return Meta.parse(
|
||||
"ABCParticleValue(
|
||||
$type(momentum($psp_symbol, $(construction_string(particle_direction(type))), $(construction_string(particle_species(type))), Val($index))),
|
||||
0.0im,
|
||||
)",
|
||||
)
|
||||
end
|
||||
|
||||
"""
|
||||
compute(::ComputeTaskABC_P, data::ABCParticleValue)
|
||||
|
||||
@@ -19,7 +8,7 @@ Return the particle and value as is.
|
||||
|
||||
0 FLOP.
|
||||
"""
|
||||
function compute(::ComputeTaskABC_P, data::ABCParticleValue{P})::ABCParticleValue{P} where {P}
|
||||
function compute(::ComputeTaskABC_P, data::ABCParticleValue{P})::ABCParticleValue{P} where {P <: ABCParticle}
|
||||
return data
|
||||
end
|
||||
|
||||
@@ -30,7 +19,7 @@ Compute an outer edge. Return the particle value with the same particle and the
|
||||
|
||||
1 FLOP.
|
||||
"""
|
||||
function compute(::ComputeTaskABC_U, data::ABCParticleValue{P})::ABCParticleValue{P} where {P}
|
||||
function compute(::ComputeTaskABC_U, data::ABCParticleValue{P})::ABCParticleValue{P} where {P <: ABCParticle}
|
||||
return ABCParticleValue{P}(data.p, data.v * ABC_outer_edge(data.p))
|
||||
end
|
||||
|
||||
@@ -45,7 +34,7 @@ function compute(
|
||||
::ComputeTaskABC_V,
|
||||
data1::ABCParticleValue{P1},
|
||||
data2::ABCParticleValue{P2},
|
||||
)::ABCParticleValue where {P1, P2}
|
||||
)::ABCParticleValue where {P1 <: ABCParticle, P2 <: ABCParticle}
|
||||
p3 = ABC_conserve_momentum(data1.p, data2.p)
|
||||
dataOut = ABCParticleValue{typeof(p3)}(p3, data1.v * ABC_vertex() * data2.v)
|
||||
return dataOut
|
||||
@@ -60,7 +49,11 @@ For valid inputs, both input particles should have the same momenta at this poin
|
||||
|
||||
12 FLOP.
|
||||
"""
|
||||
function compute(::ComputeTaskABC_S2, data1::ParticleValue{P}, data2::ParticleValue{P})::Float64 where {P}
|
||||
function compute(
|
||||
::ComputeTaskABC_S2,
|
||||
data1::ParticleValue{P},
|
||||
data2::ParticleValue{P},
|
||||
)::Float64 where {P <: ABCParticle}
|
||||
#=
|
||||
@assert isapprox(abs(data1.p.momentum.E), abs(data2.p.momentum.E), rtol = 0.001, atol = sqrt(eps())) "E: $(data1.p.momentum.E) vs. $(data2.p.momentum.E)"
|
||||
@assert isapprox(data1.p.momentum.px, -data2.p.momentum.px, rtol = 0.001, atol = sqrt(eps())) "px: $(data1.p.momentum.px) vs. $(data2.p.momentum.px)"
|
||||
@@ -92,6 +85,12 @@ Linearly many FLOP with growing data.
|
||||
"""
|
||||
function compute(::ComputeTaskABC_Sum, data...)::Float64
|
||||
return sum_kbn([data...])
|
||||
|
||||
#=s = 0.0im
|
||||
for d in data
|
||||
s += d
|
||||
end
|
||||
return s=#
|
||||
end
|
||||
|
||||
function compute(::ComputeTaskABC_Sum, data::AbstractArray)::Float64
|
@@ -14,28 +14,34 @@ struct ABCModel <: AbstractPhysicsModel end
|
||||
|
||||
Base type for all particles in the [`ABCModel`](@ref).
|
||||
"""
|
||||
abstract type ABCParticle <: AbstractParticleType end
|
||||
abstract type ABCParticle <: AbstractParticle end
|
||||
|
||||
"""
|
||||
ParticleA <: ABCParticle
|
||||
|
||||
An 'A' particle in the ABC Model.
|
||||
"""
|
||||
struct ParticleA <: ABCParticle end
|
||||
struct ParticleA <: ABCParticle
|
||||
momentum::SFourMomentum
|
||||
end
|
||||
|
||||
"""
|
||||
ParticleB <: ABCParticle
|
||||
|
||||
A 'B' particle in the ABC Model.
|
||||
"""
|
||||
struct ParticleB <: ABCParticle end
|
||||
struct ParticleB <: ABCParticle
|
||||
momentum::SFourMomentum
|
||||
end
|
||||
|
||||
"""
|
||||
ParticleC <: ABCParticle
|
||||
|
||||
A 'C' particle in the ABC Model.
|
||||
"""
|
||||
struct ParticleC <: ABCParticle end
|
||||
struct ParticleC <: ABCParticle
|
||||
momentum::SFourMomentum
|
||||
end
|
||||
|
||||
"""
|
||||
ABCProcessDescription <: AbstractProcessDescription
|
||||
@@ -66,7 +72,7 @@ struct ABCProcessInput{N1, N2, N3, N4, N5, N6} <: AbstractProcessInput
|
||||
outC::SVector{N6, ParticleC}
|
||||
end
|
||||
|
||||
ABCParticleValue{ParticleType} = ParticleValue{ParticleType, ComplexF64}
|
||||
ABCParticleValue{ParticleType <: ABCParticle} = ParticleValue{ParticleType, ComplexF64}
|
||||
|
||||
"""
|
||||
mass(t::Type{T}) where {T <: ABCParticle}
|
||||
@@ -103,46 +109,39 @@ end
|
||||
Return a Vector of the possible types of particle in the [`ABCModel`](@ref).
|
||||
"""
|
||||
function types(::ABCModel)
|
||||
return [
|
||||
ParticleStateful{Incoming, ParticleA, SFourMomentum},
|
||||
ParticleStateful{Incoming, ParticleB, SFourMomentum},
|
||||
ParticleStateful{Incoming, ParticleC, SFourMomentum},
|
||||
ParticleStateful{Outgoing, ParticleA, SFourMomentum},
|
||||
ParticleStateful{Outgoing, ParticleB, SFourMomentum},
|
||||
ParticleStateful{Outgoing, ParticleC, SFourMomentum},
|
||||
]
|
||||
return [ParticleA, ParticleB, ParticleC]
|
||||
end
|
||||
|
||||
"""
|
||||
square(p::AbstractParticleStateful{Dir, ABCParticle})
|
||||
square(p::ABCParticle)
|
||||
|
||||
Return the square of the particle's momentum as a `Float` value.
|
||||
|
||||
Takes 7 effective FLOP.
|
||||
"""
|
||||
function square(p::AbstractParticleStateful{D, ABCParticle}) where {D}
|
||||
return getMass2(momentum(p))
|
||||
function square(p::ABCParticle)
|
||||
return getMass2(p.momentum)
|
||||
end
|
||||
|
||||
"""
|
||||
ABC_inner_edge(p::AbstractParticleStateful{Dir, ABCParticle})
|
||||
ABC_inner_edge(p::ABCParticle)
|
||||
|
||||
Return the factor of the inner edge with the given (virtual) particle.
|
||||
|
||||
Takes 10 effective FLOP. (3 here + 7 in square(p))
|
||||
"""
|
||||
function ABC_inner_edge(p::AbstractParticleStateful{D, ABCParticle}) where {D}
|
||||
return 1.0 / (square(p) - mass(particle(p))^2)
|
||||
function ABC_inner_edge(p::ABCParticle)
|
||||
return 1.0 / (square(p) - mass(p)^2)
|
||||
end
|
||||
|
||||
"""
|
||||
ABC_outer_edge(p::AbstractParticleStateful{Dir, ABCParticle})
|
||||
ABC_outer_edge(p::ABCParticle)
|
||||
|
||||
Return the factor of the outer edge with the given (real) particle.
|
||||
|
||||
Takes 0 effective FLOP.
|
||||
"""
|
||||
function ABC_outer_edge(::AbstractParticleStateful{D, ABCParticle}) where {D}
|
||||
function ABC_outer_edge(p::ABCParticle)
|
||||
return 1.0
|
||||
end
|
||||
|
||||
@@ -180,26 +179,17 @@ model(::ABCProcessDescription) = ABCModel()
|
||||
model(::ABCProcessInput) = ABCModel()
|
||||
|
||||
function type_index_from_name(::ABCModel, name::String)
|
||||
if startswith(name, "Ai")
|
||||
return (ParticleStateful{Incoming, ParticleA, SFourMomentum}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "Ao")
|
||||
return (ParticleStateful{Outgoing, ParticleA, SFourMomentum}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "Bi")
|
||||
return (ParticleStateful{Incoming, ParticleB, SFourMomentum}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "Bo")
|
||||
return (ParticleStateful{Outgoing, ParticleB, SFourMomentum}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "Ci")
|
||||
return (ParticleStateful{Incoming, ParticleC, SFourMomentum}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "Co")
|
||||
return (ParticleStateful{Outgoing, ParticleC, SFourMomentum}, parse(Int, name[3:end]))
|
||||
if startswith(name, "A")
|
||||
return (ParticleA, parse(Int, name[2:end]))
|
||||
elseif startswith(name, "B")
|
||||
return (ParticleB, parse(Int, name[2:end]))
|
||||
elseif startswith(name, "C")
|
||||
return (ParticleC, parse(Int, name[2:end]))
|
||||
else
|
||||
throw("Invalid name for a particle in the ABC model")
|
||||
end
|
||||
end
|
||||
|
||||
function String(::Type{PS}) where {DIR, P <: ABCParticle, PS <: AbstractParticleStateful{DIR, P}}
|
||||
return String(P)
|
||||
end
|
||||
function String(::Type{ParticleA})
|
||||
return "A"
|
||||
end
|
@@ -1,46 +1,120 @@
|
||||
import QEDbase.mass
|
||||
import QEDbase.AbstractParticle
|
||||
|
||||
"""
|
||||
AbstractModel
|
||||
AbstractPhysicsModel
|
||||
|
||||
Base type for all models. From this, [`AbstractProblemInstance`](@ref)s can be constructed.
|
||||
|
||||
See also: [`problem_instance`](@ref)
|
||||
Base type for a model, e.g. ABC-Model or QED. This is used to dispatch many functions.
|
||||
"""
|
||||
abstract type AbstractModel end
|
||||
abstract type AbstractPhysicsModel end
|
||||
|
||||
"""
|
||||
problem_instance(::AbstractModel, ::Vararg)
|
||||
ParticleValue{ParticleType <: AbstractParticle}
|
||||
|
||||
Interface function that must be implemented for any implementation of [`AbstractModel`](@ref). This function should return a specific [`AbstractProblemInstance`](@ref) given some parameters.
|
||||
A struct describing a particle during a calculation of a Feynman Diagram, together with the value that's being calculated. `AbstractParticle` is the type from the QEDbase package.
|
||||
|
||||
`sizeof(ParticleValue())` = 48 Byte
|
||||
"""
|
||||
function problem_instance end
|
||||
struct ParticleValue{ParticleType <: AbstractParticle, ValueType}
|
||||
p::ParticleType
|
||||
v::ValueType
|
||||
end
|
||||
|
||||
"""
|
||||
AbstractProblemInstance
|
||||
AbstractProcessDescription
|
||||
|
||||
Base type for problem instances. An object of this type of a corresponding [`AbstractModel`](@ref) should uniquely identify a problem instance of that model.
|
||||
Base type for process descriptions. An object of this type of a corresponding [`AbstractPhysicsModel`](@ref) should uniquely identify a process in that model.
|
||||
|
||||
See also: [`parse_process`](@ref)
|
||||
"""
|
||||
abstract type AbstractProblemInstance end
|
||||
abstract type AbstractProcessDescription end
|
||||
|
||||
"""
|
||||
input_type(problem::AbstractProblemInstance)
|
||||
AbstractProcessInput
|
||||
|
||||
Return the fully specified input type for a specific [`AbstractProblemInstance`](@ref).
|
||||
Base type for process inputs. An object of this type contains the input values (e.g. momenta) of the particles in a process.
|
||||
|
||||
See also: [`gen_process_input`](@ref)
|
||||
"""
|
||||
function input_type end
|
||||
abstract type AbstractProcessInput end
|
||||
|
||||
"""
|
||||
graph(::AbstractProblemInstance)
|
||||
interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: AbstractParticle, T2 <: AbstractParticle}
|
||||
|
||||
Generate the [`DAG`](@ref) for the given [`AbstractProblemInstance`](@ref). Every entry node (see [`get_entry_nodes`](@ref)) to the graph must have a name set. Implement [`input_expr`](@ref) to return a valid expression for each of those names.
|
||||
Interface function that must be implemented for every subtype of [`AbstractParticle`](@ref), returning the result particle type when the two given particles interact.
|
||||
"""
|
||||
function graph end
|
||||
function interaction_result end
|
||||
|
||||
"""
|
||||
input_expr(instance::AbstractProblemInstance, name::String, input_symbol::Symbol)
|
||||
types(::AbstractPhysicsModel)
|
||||
|
||||
For the given [`AbstractProblemInstance`](@ref), the entry node name, and the symbol of the problem input (where a variable of type `input_type(...)` will exist), return an `Expr` that gets that specific input value from the input symbol.
|
||||
Interface function that must be implemented for every subtype of [`AbstractPhysicsModel`](@ref), returning a `Vector` of the available particle types in the model.
|
||||
"""
|
||||
function input_expr end
|
||||
function types end
|
||||
|
||||
"""
|
||||
in_particles(::AbstractProcessDescription)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessDescription`](@ref).
|
||||
Returns a `<: Dict{Type{AbstractParticle}, Int}` object, representing the number of incoming particles for the process per particle type.
|
||||
|
||||
|
||||
in_particles(::AbstractProcessInput)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
|
||||
Returns a `<: Vector{AbstractParticle}` object with the values of all incoming particles for the corresponding `ProcessDescription`.
|
||||
"""
|
||||
function in_particles end
|
||||
|
||||
"""
|
||||
out_particles(::AbstractProcessDescription)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessDescription`](@ref).
|
||||
Returns a `<: Dict{Type{AbstractParticle}, Int}` object, representing the number of outgoing particles for the process per particle type.
|
||||
|
||||
|
||||
out_particles(::AbstractProcessInput)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
|
||||
Returns a `<: Vector{AbstractParticle}` object with the values of all outgoing particles for the corresponding `ProcessDescription`.
|
||||
"""
|
||||
function out_particles end
|
||||
|
||||
"""
|
||||
get_particle(::AbstractProcessInput, t::Type, n::Int)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
|
||||
Returns the `n`th particle of type `t`.
|
||||
"""
|
||||
function get_particle end
|
||||
|
||||
"""
|
||||
parse_process(::AbstractString, ::AbstractPhysicsModel)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractPhysicsModel`](@ref).
|
||||
Returns a `ProcessDescription` object.
|
||||
"""
|
||||
function parse_process end
|
||||
|
||||
"""
|
||||
gen_process_input(::AbstractProcessDescription)
|
||||
|
||||
Interface function that must be implemented for every specific [`AbstractProcessDescription`](@ref).
|
||||
Returns a randomly generated and valid corresponding `ProcessInput`.
|
||||
"""
|
||||
function gen_process_input end
|
||||
|
||||
"""
|
||||
model(::AbstractProcessDescription)
|
||||
model(::AbstarctProcessInput)
|
||||
|
||||
Return the model of this process description or input.
|
||||
"""
|
||||
function model end
|
||||
|
||||
"""
|
||||
type_from_name(model::AbstractModel, name::String)
|
||||
|
||||
For a name of a particle in the given [`AbstractModel`](@ref), return the particle's [`Type`] and index as a tuple. The input string can be expetced to be of the form \"<name><index>\".
|
||||
"""
|
||||
function type_index_from_name end
|
||||
|
@@ -1,3 +0,0 @@
|
||||
## Deprecation Warning
|
||||
|
||||
These models are deprecated and should not be used anymore. They will be dropped entirely soon.
|
@@ -1,142 +0,0 @@
|
||||
import QEDbase.AbstractParticle
|
||||
|
||||
"""
|
||||
AbstractPhysicsModel
|
||||
|
||||
Base type for a model, e.g. ABC-Model or QED. This is used to dispatch many functions.
|
||||
"""
|
||||
abstract type AbstractPhysicsModel <: AbstractModel end
|
||||
|
||||
"""
|
||||
ParticleValue{ParticleType <: AbstractParticleStateful}
|
||||
|
||||
A struct describing a particle during a calculation of a Feynman Diagram, together with the value that's being calculated. `AbstractParticleStateful` is the type from the QEDbase package.
|
||||
|
||||
`sizeof(ParticleValue())` = 48 Byte
|
||||
"""
|
||||
struct ParticleValue{ParticleType <: AbstractParticleStateful, ValueType}
|
||||
p::ParticleType
|
||||
v::ValueType
|
||||
end
|
||||
|
||||
"""
|
||||
TBW
|
||||
|
||||
particle value + spin/pol info, only used on the external legs (u tasks)
|
||||
"""
|
||||
struct ParticleValueSP{ParticleType <: AbstractParticleStateful, SP <: AbstractSpinOrPolarization, ValueType}
|
||||
p::ParticleType
|
||||
v::ValueType
|
||||
sp::SP
|
||||
end
|
||||
|
||||
"""
|
||||
AbstractProcessDescription <: AbstractProblemInstance
|
||||
|
||||
Base type for particle scattering process descriptions. An object of this type of a corresponding [`AbstractPhysicsModel`](@ref) should uniquely identify a scattering process in that model.
|
||||
|
||||
See also: [`parse_process`](@ref), [`AbstractProblemInstance`](@ref)
|
||||
"""
|
||||
abstract type AbstractProcessDescription end
|
||||
|
||||
|
||||
#TODO: i don't think giving this a base type is a good idea, the input type should just be returned of some function, allowing anything as an input type
|
||||
"""
|
||||
AbstractProcessInput
|
||||
|
||||
Base type for process inputs. An object of this type contains the input values (e.g. momenta) of the particles in a process.
|
||||
|
||||
See also: [`gen_process_input`](@ref)
|
||||
"""
|
||||
abstract type AbstractProcessInput end
|
||||
|
||||
"""
|
||||
interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: AbstractParticle, T2 <: AbstractParticle}
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractParticle`](@ref), returning the result particle type when the two given particles interact.
|
||||
"""
|
||||
function interaction_result end
|
||||
|
||||
"""
|
||||
types(::AbstractPhysicsModel)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractPhysicsModel`](@ref), returning a `Vector` of the available particle types in the model.
|
||||
"""
|
||||
function types end
|
||||
|
||||
"""
|
||||
in_particles(::AbstractProcessDescription)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessDescription`](@ref).
|
||||
Returns a `<: Dict{Type{AbstractParticle}, Int}` object, representing the number of incoming particles for the process per particle type.
|
||||
|
||||
|
||||
in_particles(::AbstractProcessInput)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
|
||||
Returns a `<: Vector{AbstractParticle}` object with the values of all incoming particles for the corresponding `ProcessDescription`.
|
||||
"""
|
||||
function in_particles end
|
||||
|
||||
"""
|
||||
out_particles(::AbstractProcessDescription)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessDescription`](@ref).
|
||||
Returns a `<: Dict{Type{AbstractParticle}, Int}` object, representing the number of outgoing particles for the process per particle type.
|
||||
|
||||
|
||||
out_particles(::AbstractProcessInput)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
|
||||
Returns a `<: Vector{AbstractParticle}` object with the values of all outgoing particles for the corresponding `ProcessDescription`.
|
||||
"""
|
||||
function out_particles end
|
||||
|
||||
"""
|
||||
get_particle(::AbstractProcessInput, t::Type, n::Int)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
|
||||
Returns the `n`th particle of type `t`.
|
||||
"""
|
||||
function get_particle end
|
||||
|
||||
"""
|
||||
parse_process(::AbstractString, ::AbstractPhysicsModel)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractPhysicsModel`](@ref).
|
||||
Returns a `ProcessDescription` object.
|
||||
"""
|
||||
function parse_process end
|
||||
|
||||
"""
|
||||
gen_process_input(::AbstractProcessDescription)
|
||||
|
||||
Interface function that must be implemented for every specific [`AbstractProcessDescription`](@ref).
|
||||
Returns a randomly generated and valid corresponding `ProcessInput`.
|
||||
"""
|
||||
function gen_process_input end
|
||||
|
||||
"""
|
||||
model(::AbstractProcessDescription)
|
||||
model(::AbstractProcessInput)
|
||||
|
||||
Return the model of this process description or input.
|
||||
"""
|
||||
function model end
|
||||
|
||||
"""
|
||||
type_from_name(model::AbstractModel, name::String)
|
||||
|
||||
For a name of a particle in the given [`AbstractModel`](@ref), return the particle's [`Type`] and index as a tuple. The input string can be expetced to be of the form \"<name><index>\".
|
||||
"""
|
||||
function type_index_from_name end
|
||||
|
||||
"""
|
||||
part_from_x(type::Type, index::Int, x::AbstractProcessInput)
|
||||
|
||||
Return the [`ParticleValue`](@ref) of the given type of particle with the given `index` from the given process input.
|
||||
|
||||
Function is wrapped into a [`FunctionCall`](@ref) in [`gen_input_assignment_code`](@ref).
|
||||
"""
|
||||
part_from_x(type::Type, index::Int, x::AbstractProcessInput) =
|
||||
ParticleValue{type, ComplexF64}(get_particle(x, type, index), one(ComplexF64))
|
@@ -1,160 +0,0 @@
|
||||
|
||||
ComputeTaskQED_Sum() = ComputeTaskQED_Sum(0)
|
||||
|
||||
function _svector_from_type(processDescription::GenericQEDProcess, type, particles)
|
||||
if haskey(in_particles(processDescription), type)
|
||||
return SVector{in_particles(processDescription)[type], type}(filter(x -> typeof(x) <: type, particles))
|
||||
end
|
||||
if haskey(out_particles(processDescription), type)
|
||||
return SVector{out_particles(processDescription)[type], type}(filter(x -> typeof(x) <: type, particles))
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
gen_process_input(processDescription::GenericQEDProcess)
|
||||
|
||||
Return a ProcessInput of randomly generated [`QEDParticle`](@ref)s from a [`GenericQEDProcess`](@ref). The process description can be created manually or parsed from a string using [`parse_process`](@ref).
|
||||
|
||||
Note: This uses RAMBO to create a valid process with conservation of momentum and energy.
|
||||
"""
|
||||
function gen_process_input(processDescription::GenericQEDProcess)
|
||||
massSum = 0
|
||||
inputMasses = Vector{Float64}()
|
||||
for particle in incoming_particles(processDescription)
|
||||
massSum += mass(particle)
|
||||
push!(inputMasses, mass(particle))
|
||||
end
|
||||
outputMasses = Vector{Float64}()
|
||||
for particle in outgoing_particles(processDescription)
|
||||
massSum += mass(particle)
|
||||
push!(outputMasses, mass(particle))
|
||||
end
|
||||
|
||||
# add some extra random mass to allow for some momentum
|
||||
massSum += rand(rng[threadid()]) * (length(inputMasses) + length(outputMasses))
|
||||
|
||||
initial_momenta = generate_initial_moms(massSum, inputMasses)
|
||||
final_momenta = generate_physical_massive_moms(rng[threadid()], massSum, outputMasses)
|
||||
|
||||
processInput = PhaseSpacePoint(
|
||||
processDescription,
|
||||
PerturbativeQED(),
|
||||
PhasespaceDefinition(SphericalCoordinateSystem(), ElectronRestFrame()),
|
||||
tuple(initial_momenta...),
|
||||
tuple(final_momenta...),
|
||||
)
|
||||
|
||||
return processInput
|
||||
end
|
||||
|
||||
"""
|
||||
gen_graph(process_description::GenericQEDProcess)
|
||||
|
||||
For a given [`GenericQEDProcess`](@ref), return the [`DAG`](@ref) that computes it.
|
||||
"""
|
||||
function gen_graph(process_description::GenericQEDProcess)
|
||||
initial_diagram = FeynmanDiagram(process_description)
|
||||
diagrams = gen_diagrams(initial_diagram)
|
||||
|
||||
graph = DAG()
|
||||
|
||||
COMPLEX_SIZE = sizeof(ComplexF64)
|
||||
PARTICLE_VALUE_SIZE = 96.0
|
||||
|
||||
# TODO: Not all diagram outputs should always be summed at the end, if they differ by fermion exchange they need to be diffed
|
||||
# Should not matter for n-Photon Compton processes though
|
||||
sum_node = insert_node!(graph, make_node(ComputeTaskQED_Sum(0)); track = false, invalidate_cache = false)
|
||||
global_data_out = insert_node!(graph, make_node(DataTask(COMPLEX_SIZE)); track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, sum_node, global_data_out; track = false, invalidate_cache = false)
|
||||
|
||||
# remember the data out nodes for connection
|
||||
dataOutNodes = Dict()
|
||||
|
||||
for particle in initial_diagram.particles
|
||||
# generate data in and U tasks
|
||||
data_in = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE), String(particle));
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
) # read particle data node
|
||||
compute_u = insert_node!(graph, make_node(ComputeTaskQED_U()); track = false, invalidate_cache = false) # compute U node
|
||||
data_out =
|
||||
insert_node!(graph, make_node(DataTask(PARTICLE_VALUE_SIZE)); track = false, invalidate_cache = false) # transfer data out from u (one ABCParticleValue object)
|
||||
|
||||
insert_edge!(graph, data_in, compute_u; track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, compute_u, data_out; track = false, invalidate_cache = false)
|
||||
|
||||
# remember the data_out node for future edges
|
||||
dataOutNodes[String(particle)] = data_out
|
||||
end
|
||||
|
||||
# TODO: this should be parallelizable somewhat easily
|
||||
for diagram in diagrams
|
||||
tie = diagram.tie[]
|
||||
|
||||
# handle the vertices
|
||||
for vertices in diagram.vertices
|
||||
for vertex in vertices
|
||||
data_in1 = dataOutNodes[String(vertex.in1)]
|
||||
data_in2 = dataOutNodes[String(vertex.in2)]
|
||||
|
||||
compute_V = insert_node!(graph, make_node(ComputeTaskQED_V()); track = false, invalidate_cache = false) # compute vertex
|
||||
|
||||
insert_edge!(graph, data_in1, compute_V; track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, data_in2, compute_V; track = false, invalidate_cache = false)
|
||||
|
||||
data_V_out = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE));
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
|
||||
insert_edge!(graph, compute_V, data_V_out; track = false, invalidate_cache = false)
|
||||
|
||||
if (vertex.out == tie.in1 || vertex.out == tie.in2)
|
||||
# out particle is part of the tie -> there will be an S2 task with it later, don't make S1 task
|
||||
dataOutNodes[String(vertex.out)] = data_V_out
|
||||
continue
|
||||
end
|
||||
|
||||
# otherwise, add S1 task
|
||||
compute_S1 =
|
||||
insert_node!(graph, make_node(ComputeTaskQED_S1()); track = false, invalidate_cache = false) # compute propagator
|
||||
|
||||
insert_edge!(graph, data_V_out, compute_S1; track = false, invalidate_cache = false)
|
||||
|
||||
data_S1_out = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE));
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
|
||||
insert_edge!(graph, compute_S1, data_S1_out; track = false, invalidate_cache = false)
|
||||
|
||||
# overrides potentially different nodes from previous diagrams, which is intentional
|
||||
dataOutNodes[String(vertex.out)] = data_S1_out
|
||||
end
|
||||
end
|
||||
|
||||
# handle the tie
|
||||
data_in1 = dataOutNodes[String(tie.in1)]
|
||||
data_in2 = dataOutNodes[String(tie.in2)]
|
||||
|
||||
compute_S2 = insert_node!(graph, make_node(ComputeTaskQED_S2()); track = false, invalidate_cache = false)
|
||||
|
||||
data_S2 = insert_node!(graph, make_node(DataTask(PARTICLE_VALUE_SIZE)); track = false, invalidate_cache = false)
|
||||
|
||||
insert_edge!(graph, data_in1, compute_S2; track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, data_in2, compute_S2; track = false, invalidate_cache = false)
|
||||
|
||||
insert_edge!(graph, compute_S2, data_S2; track = false, invalidate_cache = false)
|
||||
|
||||
insert_edge!(graph, data_S2, sum_node; track = false, invalidate_cache = false)
|
||||
add_child!(task(sum_node))
|
||||
end
|
||||
|
||||
return graph
|
||||
end
|
@@ -1,45 +0,0 @@
|
||||
"""
|
||||
parse_process(string::AbstractString, model::QEDModel)
|
||||
|
||||
Parse a string representation of a process, such as "ke->ke" into the corresponding [`QEDProcessDescription`](@ref).
|
||||
"""
|
||||
function parse_process(
|
||||
str::AbstractString,
|
||||
model::QEDModel,
|
||||
inphpol::AbstractDefinitePolarization = PolX(),
|
||||
inelspin::AbstractDefiniteSpin = SpinUp(),
|
||||
outphpol::AbstractDefinitePolarization = PolX(),
|
||||
outelspin::AbstractDefiniteSpin = SpinUp(),
|
||||
)
|
||||
if !(contains(str, "->"))
|
||||
throw("Did not find -> while parsing process \"$str\"")
|
||||
end
|
||||
|
||||
(in_str, out_str) = split(str, "->")
|
||||
|
||||
if (isempty(in_str) || isempty(out_str))
|
||||
throw("Process (\"$str\") input or output part is empty!")
|
||||
end
|
||||
|
||||
in_particles = Vector{AbstractParticleType}()
|
||||
out_particles = Vector{AbstractParticleType}()
|
||||
|
||||
for (particle_vector, s) in ((in_particles, in_str), (out_particles, out_str))
|
||||
for c in s
|
||||
if c == 'e'
|
||||
push!(particle_vector, Electron())
|
||||
elseif c == 'p'
|
||||
push!(particle_vector, Positron())
|
||||
elseif c == 'k'
|
||||
push!(particle_vector, Photon())
|
||||
else
|
||||
throw("Encountered unknown characters in the process \"$str\"")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
in_spin_pols = tuple([is_boson(in_particles[i]) ? inphpol : inelspin for i in eachindex(in_particles)]...)
|
||||
out_spin_pols = tuple([is_boson(out_particles[i]) ? outphpol : outelspin for i in eachindex(out_particles)]...)
|
||||
|
||||
return GenericQEDProcess(tuple(in_particles...), tuple(out_particles...), in_spin_pols, out_spin_pols)
|
||||
end
|
@@ -1,305 +0,0 @@
|
||||
using QEDprocesses
|
||||
using StaticArrays
|
||||
|
||||
const e = sqrt(4π / 137)
|
||||
|
||||
QEDbase.is_incoming(::Type{<:ParticleStateful{Incoming}}) = true
|
||||
QEDbase.is_outgoing(::Type{<:ParticleStateful{Outgoing}}) = true
|
||||
QEDbase.is_incoming(::Type{<:ParticleStateful{Outgoing}}) = false
|
||||
QEDbase.is_outgoing(::Type{<:ParticleStateful{Incoming}}) = false
|
||||
|
||||
QEDbase.particle_direction(::Type{<:ParticleStateful{DIR}}) where {DIR <: ParticleDirection} = DIR()
|
||||
QEDbase.particle_species(
|
||||
::Type{<:ParticleStateful{DIR, SPECIES}},
|
||||
) where {DIR <: ParticleDirection, SPECIES <: AbstractParticleType} = SPECIES()
|
||||
|
||||
function spin_or_pol(
|
||||
process::GenericQEDProcess,
|
||||
type::Type{ParticleStateful{DIR, SPECIES, EL}},
|
||||
n::Int,
|
||||
) where {DIR <: ParticleDirection, SPECIES <: AbstractParticleType, EL <: AbstractFourMomentum}
|
||||
i = 0
|
||||
c = n
|
||||
for p in particles(process, DIR())
|
||||
i += 1
|
||||
if p == SPECIES()
|
||||
c -= 1
|
||||
end
|
||||
if c == 0
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if c != 0 || n <= 0
|
||||
throw(InvalidInputError("could not get $n-th spin/pol of $(DIR()) $species, does not exist"))
|
||||
end
|
||||
|
||||
if DIR <: Incoming
|
||||
return process.incoming_spins_pols[i]
|
||||
elseif DIR <: Outgoing
|
||||
return process.outgoing_spins_pols[i]
|
||||
else
|
||||
throw(InvalidInputError("unknown direction $(DIR()) given"))
|
||||
end
|
||||
end
|
||||
|
||||
function input_type(p::GenericQEDProcess)
|
||||
in_t = QEDcore._assemble_tuple_type(incoming_particles(p), Incoming(), SFourMomentum)
|
||||
out_t = QEDcore._assemble_tuple_type(outgoing_particles(p), Outgoing(), SFourMomentum)
|
||||
return PhaseSpacePoint{
|
||||
typeof(p),
|
||||
PerturbativeQED,
|
||||
PhasespaceDefinition{SphericalCoordinateSystem, ElectronRestFrame},
|
||||
Tuple{in_t...},
|
||||
Tuple{out_t...},
|
||||
}
|
||||
end
|
||||
|
||||
ValueType = Union{BiSpinor, AdjointBiSpinor, DiracMatrix, SLorentzVector{Float64}, ComplexF64}
|
||||
|
||||
"""
|
||||
interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: QEDParticle, T2 <: QEDParticle}
|
||||
|
||||
For two given particle types that can interact, return the third.
|
||||
"""
|
||||
function interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: ParticleStateful, T2 <: ParticleStateful}
|
||||
@assert false "Invalid interaction between particles of types $t1 and $t2"
|
||||
end
|
||||
|
||||
interaction_result(
|
||||
::Type{ParticleStateful{Incoming, Electron, EL}},
|
||||
::Type{ParticleStateful{Outgoing, Electron, EL}},
|
||||
) where {EL <: AbstractFourMomentum} = ParticleStateful{Incoming, Photon, EL}
|
||||
interaction_result(
|
||||
::Type{ParticleStateful{Incoming, Electron, EL}},
|
||||
::Type{ParticleStateful{Incoming, Positron, EL}},
|
||||
) where {EL <: AbstractFourMomentum} = ParticleStateful{Incoming, Photon, EL}
|
||||
interaction_result(
|
||||
::Type{ParticleStateful{Incoming, Electron, EL}},
|
||||
::Type{ParticleStateful{DIR, Photon, EL}},
|
||||
) where {EL <: AbstractFourMomentum, DIR <: ParticleDirection} = ParticleStateful{Outgoing, Electron, EL}
|
||||
|
||||
interaction_result(
|
||||
::Type{ParticleStateful{Outgoing, Electron, EL}},
|
||||
::Type{ParticleStateful{Incoming, Electron, EL}},
|
||||
) where {EL <: AbstractFourMomentum} = ParticleStateful{Incoming, Photon, EL}
|
||||
interaction_result(
|
||||
::Type{ParticleStateful{Outgoing, Electron, EL}},
|
||||
::Type{ParticleStateful{Outgoing, Positron, EL}},
|
||||
) where {EL <: AbstractFourMomentum} = ParticleStateful{Incoming, Photon, EL}
|
||||
interaction_result(
|
||||
::Type{ParticleStateful{Outgoing, Electron, EL}},
|
||||
::Type{ParticleStateful{DIR, Photon, EL}},
|
||||
) where {EL <: AbstractFourMomentum, DIR <: ParticleDirection} = ParticleStateful{Incoming, Electron, EL}
|
||||
|
||||
# antifermion mirror
|
||||
interaction_result(
|
||||
::Type{ParticleStateful{Incoming, Positron, EL}},
|
||||
t2::Type{<:ParticleStateful},
|
||||
) where {EL <: AbstractFourMomentum} = interaction_result(ParticleStateful{Outgoing, Electron, EL}, t2)
|
||||
interaction_result(
|
||||
::Type{ParticleStateful{Outgoing, Positron, EL}},
|
||||
t2::Type{<:ParticleStateful},
|
||||
) where {EL <: AbstractFourMomentum} = interaction_result(ParticleStateful{Incoming, Electron, EL}, t2)
|
||||
|
||||
# photon commutativity
|
||||
interaction_result(
|
||||
t1::Type{ParticleStateful{DIR, Photon, EL}},
|
||||
t2::Type{<:ParticleStateful},
|
||||
) where {EL <: AbstractFourMomentum, DIR <: ParticleDirection} = interaction_result(t2, t1)
|
||||
|
||||
# but prevent stack overflow
|
||||
function interaction_result(
|
||||
t1::Type{ParticleStateful{DIR1, Photon, EL}},
|
||||
t2::Type{ParticleStateful{DIR2, Photon, EL}},
|
||||
) where {DIR1 <: ParticleDirection, DIR2 <: ParticleDirection, EL <: AbstractFourMomentum}
|
||||
@assert false "Invalid interaction between particles of types $t1 and $t2"
|
||||
end
|
||||
|
||||
"""
|
||||
propagation_result(t1::Type{T}) where {T <: QEDParticle}
|
||||
|
||||
Return the type of the inverted direction. E.g.
|
||||
"""
|
||||
propagation_result(::Type{ParticleStateful{Incoming, Electron, EL}}) where {EL <: AbstractFourMomentum} =
|
||||
ParticleStateful{Outgoing, Electron, EL}
|
||||
propagation_result(::Type{ParticleStateful{Outgoing, Electron, EL}}) where {EL <: AbstractFourMomentum} =
|
||||
ParticleStateful{Incoming, Electron, EL}
|
||||
propagation_result(::Type{ParticleStateful{Incoming, Positron, EL}}) where {EL <: AbstractFourMomentum} =
|
||||
ParticleStateful{Outgoing, Positron, EL}
|
||||
propagation_result(::Type{ParticleStateful{Outgoing, Positron, EL}}) where {EL <: AbstractFourMomentum} =
|
||||
ParticleStateful{Incoming, Positron, EL}
|
||||
propagation_result(::Type{ParticleStateful{Incoming, Photon, EL}}) where {EL <: AbstractFourMomentum} =
|
||||
ParticleStateful{Outgoing, Photon, EL}
|
||||
propagation_result(::Type{ParticleStateful{Outgoing, Photon, EL}}) where {EL <: AbstractFourMomentum} =
|
||||
ParticleStateful{Incoming, Photon, EL}
|
||||
|
||||
"""
|
||||
types(::QEDModel)
|
||||
|
||||
Return a Vector of the possible types of particle in the [`QEDModel`](@ref).
|
||||
"""
|
||||
function types(::QEDModel)
|
||||
return [
|
||||
ParticleStateful{Incoming, Photon, SFourMomentum},
|
||||
ParticleStateful{Outgoing, Photon, SFourMomentum},
|
||||
ParticleStateful{Incoming, Electron, SFourMomentum},
|
||||
ParticleStateful{Outgoing, Electron, SFourMomentum},
|
||||
ParticleStateful{Incoming, Positron, SFourMomentum},
|
||||
ParticleStateful{Outgoing, Positron, SFourMomentum},
|
||||
]
|
||||
end
|
||||
|
||||
# type piracy?
|
||||
String(::Type{Incoming}) = "Incoming"
|
||||
String(::Type{Outgoing}) = "Outgoing"
|
||||
|
||||
String(::Type{PolX}) = "polx"
|
||||
String(::Type{PolY}) = "poly"
|
||||
|
||||
String(::Type{SpinUp}) = "spinup"
|
||||
String(::Type{SpinDown}) = "spindown"
|
||||
|
||||
String(::Incoming) = "i"
|
||||
String(::Outgoing) = "o"
|
||||
|
||||
function String(::Type{<:ParticleStateful{DIR, Photon}}) where {DIR <: ParticleDirection}
|
||||
return "k"
|
||||
end
|
||||
function String(::Type{<:ParticleStateful{DIR, Electron}}) where {DIR <: ParticleDirection}
|
||||
return "e"
|
||||
end
|
||||
function String(::Type{<:ParticleStateful{DIR, Positron}}) where {DIR <: ParticleDirection}
|
||||
return "p"
|
||||
end
|
||||
|
||||
"""
|
||||
caninteract(T1::Type{<:ParticleStateful}, T2::Type{<:ParticleStateful})
|
||||
|
||||
For two given `ParticleStateful` types, return whether they can interact at a vertex. This is equivalent to `!issame(T1, T2)`.
|
||||
|
||||
See also: [`issame`](@ref) and [`interaction_result`](@ref)
|
||||
"""
|
||||
function caninteract(
|
||||
T1::Type{<:ParticleStateful{D1, S1}},
|
||||
T2::Type{<:ParticleStateful{D2, S2}},
|
||||
) where {D1 <: ParticleDirection, S1 <: AbstractParticleType, D2 <: ParticleDirection, S2 <: AbstractParticleType}
|
||||
if (T1 == T2)
|
||||
return false
|
||||
end
|
||||
if (S1 == Photon && S2 == Photon)
|
||||
return false
|
||||
end
|
||||
|
||||
for (P1, P2) in [(T1, T2), (T2, T1)]
|
||||
if (P1 <: ParticleStateful{Incoming, Electron} && P2 <: ParticleStateful{Outgoing, Positron})
|
||||
return false
|
||||
end
|
||||
if (P1 <: ParticleStateful{Outgoing, Electron} && P2 <: ParticleStateful{Incoming, Positron})
|
||||
return false
|
||||
end
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function type_index_from_name(::QEDModel, name::String)
|
||||
if startswith(name, "ki")
|
||||
return (ParticleStateful{Incoming, Photon, SFourMomentum}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "ko")
|
||||
return (ParticleStateful{Outgoing, Photon, SFourMomentum}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "ei")
|
||||
return (ParticleStateful{Incoming, Electron, SFourMomentum}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "eo")
|
||||
return (ParticleStateful{Outgoing, Electron, SFourMomentum}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "pi")
|
||||
return (ParticleStateful{Incoming, Positron, SFourMomentum}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "po")
|
||||
return (ParticleStateful{Outgoing, Positron, SFourMomentum}, parse(Int, name[3:end]))
|
||||
else
|
||||
throw("Invalid name for a particle in the QED model")
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
issame(T1::Type{<:ParticleStateful}, T2::Type{<:ParticleStateful})
|
||||
|
||||
For two given `ParticleStateful` types, return whether they are equivalent for the purpose of a Feynman Diagram. That means e.g. an `Incoming` `AntiFermion` is the same as an `Outgoing` `Fermion`. This is equivalent to `!caninteract(T1, T2)`.
|
||||
|
||||
See also: [`caninteract`](@ref) and [`interaction_result`](@ref)
|
||||
"""
|
||||
function issame(T1::Type{<:ParticleStateful}, T2::Type{<:ParticleStateful})
|
||||
return !caninteract(T1, T2)
|
||||
end
|
||||
|
||||
"""
|
||||
QED_vertex()
|
||||
|
||||
Return the factor of a vertex in a QED feynman diagram.
|
||||
"""
|
||||
@inline function QED_vertex()::SLorentzVector{DiracMatrix}
|
||||
# Peskin-Schroeder notation
|
||||
return -1im * e * gamma()
|
||||
end
|
||||
|
||||
@inline function QED_inner_edge(p::ParticleStateful)
|
||||
return propagator(particle_species(p), momentum(p))
|
||||
end
|
||||
|
||||
"""
|
||||
QED_conserve_momentum(p1::ParticleStateful, p2::ParticleStateful)
|
||||
|
||||
Calculate and return a new particle from two given interacting ones at a vertex.
|
||||
"""
|
||||
function QED_conserve_momentum(
|
||||
p1::P1,
|
||||
p2::P2,
|
||||
) where {
|
||||
Dir1 <: ParticleDirection,
|
||||
Dir2 <: ParticleDirection,
|
||||
Species1 <: AbstractParticleType,
|
||||
Species2 <: AbstractParticleType,
|
||||
P1 <: ParticleStateful{Dir1, Species1},
|
||||
P2 <: ParticleStateful{Dir2, Species2},
|
||||
}
|
||||
P3 = interaction_result(P1, P2)
|
||||
p1_mom = momentum(p1)
|
||||
if (Dir1 <: Outgoing)
|
||||
p1_mom *= -1
|
||||
end
|
||||
p2_mom = momentum(p2)
|
||||
if (Dir2 <: Outgoing)
|
||||
p2_mom *= -1
|
||||
end
|
||||
|
||||
p3_mom = p1_mom + p2_mom
|
||||
if (particle_direction(P3) isa Incoming)
|
||||
return ParticleStateful(particle_direction(P3), particle_species(P3), -p3_mom)
|
||||
end
|
||||
return ParticleStateful(particle_direction(P3), particle_species(P3), p3_mom)
|
||||
end
|
||||
|
||||
"""
|
||||
model(::AbstractProcessDescription)
|
||||
|
||||
Return the model of this process description.
|
||||
"""
|
||||
model(::GenericQEDProcess) = QEDModel()
|
||||
model(::PhaseSpacePoint) = QEDModel()
|
||||
|
||||
function get_particle(
|
||||
input::PhaseSpacePoint,
|
||||
t::Type{ParticleStateful{DIR, SPECIES}},
|
||||
n::Int,
|
||||
) where {DIR <: ParticleDirection, SPECIES <: AbstractParticleType}
|
||||
i = 0
|
||||
for p in particles(input, DIR())
|
||||
if p isa t
|
||||
i += 1
|
||||
if i == n
|
||||
return p
|
||||
end
|
||||
end
|
||||
end
|
||||
@assert false "Invalid type given"
|
||||
end
|
@@ -1,14 +0,0 @@
|
||||
using QEDprocesses
|
||||
|
||||
# add type overload for number_particles function
|
||||
@inline function QEDprocesses.number_particles(
|
||||
proc_def::QEDbase.AbstractProcessDefinition,
|
||||
::Type{PS},
|
||||
) where {
|
||||
DIR <: QEDbase.ParticleDirection,
|
||||
PT <: QEDbase.AbstractParticleType,
|
||||
EL <: AbstractFourMomentum,
|
||||
PS <: ParticleStateful{DIR, PT, EL},
|
||||
}
|
||||
return QEDprocesses.number_particles(proc_def, DIR(), PT())
|
||||
end
|
@@ -0,0 +1,10 @@
|
||||
|
||||
"""
|
||||
show(io::IO, particleValue::ParticleValue)
|
||||
|
||||
Pretty print a [`ParticleValue`](@ref), no newlines.
|
||||
"""
|
||||
function show(io::IO, particleValue::ParticleValue)
|
||||
print(io, "($(particleValue.p), value: $(particleValue.v))")
|
||||
return nothing
|
||||
end
|
||||
|
@@ -1,56 +1,41 @@
|
||||
using StaticArrays
|
||||
|
||||
construction_string(::Incoming) = "Incoming()"
|
||||
construction_string(::Outgoing) = "Outgoing()"
|
||||
"""
|
||||
compute(::ComputeTaskQED_P, data::QEDParticleValue)
|
||||
|
||||
construction_string(::Electron) = "Electron()"
|
||||
construction_string(::Positron) = "Positron()"
|
||||
construction_string(::Photon) = "Photon()"
|
||||
|
||||
construction_string(::PolX) = "PolX()"
|
||||
construction_string(::PolY) = "PolY()"
|
||||
construction_string(::SpinUp) = "SpinUp()"
|
||||
construction_string(::SpinDown) = "SpinDown()"
|
||||
|
||||
function input_expr(instance::GenericQEDProcess, name::String, psp_symbol::Symbol)
|
||||
(type, index) = type_index_from_name(QEDModel(), name)
|
||||
|
||||
return Meta.parse(
|
||||
"ParticleValueSP(
|
||||
$type(momentum($psp_symbol, $(construction_string(particle_direction(type))), $(construction_string(particle_species(type))), Val($index))),
|
||||
0.0im,
|
||||
$(construction_string(spin_or_pol(instance, type, index))),
|
||||
)",
|
||||
)
|
||||
Return the particle as is and initialize the Value.
|
||||
"""
|
||||
function compute(::ComputeTaskQED_P, data::QEDParticleValue{P}) where {P<:QEDParticle}
|
||||
# TODO do we actually need this for anything?
|
||||
return ParticleValue{P,DiracMatrix}(data.p, one(DiracMatrix))
|
||||
end
|
||||
|
||||
"""
|
||||
compute(::ComputeTaskQED_U, data::ParticleValueSP)
|
||||
compute(::ComputeTaskQED_U, data::QEDParticleValue)
|
||||
|
||||
Compute an outer edge. Return the particle value with the same particle and the value multiplied by an outer_edge factor.
|
||||
"""
|
||||
function compute(
|
||||
::ComputeTaskQED_U,
|
||||
data::ParticleValueSP{P, SP, V},
|
||||
) where {P <: ParticleStateful, V <: ValueType, SP <: AbstractSpinOrPolarization}
|
||||
::ComputeTaskQED_U, data::PV
|
||||
) where {P<:QEDParticle,PV<:QEDParticleValue{P}}
|
||||
part::P = data.p
|
||||
state = base_state(particle_species(part), particle_direction(part), momentum(part), SP())
|
||||
return ParticleValue{P, typeof(state)}(
|
||||
state = base_state(particle(part), direction(part), momentum(part), spin_or_pol(part))
|
||||
return ParticleValue{P,typeof(state)}(
|
||||
data.p,
|
||||
state, # will return a SLorentzVector{ComplexF64}, BiSpinor or AdjointBiSpinor
|
||||
)
|
||||
end
|
||||
|
||||
"""
|
||||
compute(::ComputeTaskQED_V, data1::ParticleValue, data2::ParticleValue)
|
||||
compute(::ComputeTaskQED_V, data1::QEDParticleValue, data2::QEDParticleValue)
|
||||
|
||||
Compute a vertex. Preserve momentum and particle types (e + gamma->p etc.) to create resulting particle, multiply values together and times a vertex factor.
|
||||
"""
|
||||
function compute(
|
||||
::ComputeTaskQED_V,
|
||||
data1::ParticleValue{P1, V1},
|
||||
data2::ParticleValue{P2, V2},
|
||||
) where {P1 <: ParticleStateful, P2 <: ParticleStateful, V1 <: ValueType, V2 <: ValueType}
|
||||
::ComputeTaskQED_V, data1::PV1, data2::PV2
|
||||
) where {
|
||||
P1<:QEDParticle,P2<:QEDParticle,PV1<:QEDParticleValue{P1},PV2<:QEDParticleValue{P2}
|
||||
}
|
||||
p3 = QED_conserve_momentum(data1.p, data2.p)
|
||||
P3 = interaction_result(P1, P2)
|
||||
state = QED_vertex()
|
||||
@@ -65,12 +50,12 @@ function compute(
|
||||
state = state * data2.v
|
||||
end
|
||||
|
||||
dataOut = ParticleValue{P3, typeof(state)}(P3(momentum(p3)), state)
|
||||
dataOut = ParticleValue{P3,typeof(state)}(P3(momentum(p3)), state)
|
||||
return dataOut
|
||||
end
|
||||
|
||||
"""
|
||||
compute(::ComputeTaskQED_S2, data1::ParticleValue, data2::ParticleValue)
|
||||
compute(::ComputeTaskQED_S2, data1::QEDParticleValue, data2::QEDParticleValue)
|
||||
|
||||
Compute a final inner edge (2 input particles, no output particle).
|
||||
|
||||
@@ -79,19 +64,10 @@ For valid inputs, both input particles should have the same momenta at this poin
|
||||
12 FLOP.
|
||||
"""
|
||||
function compute(
|
||||
::ComputeTaskQED_S2,
|
||||
data1::ParticleValue{P1, V1},
|
||||
data2::ParticleValue{P2, V2},
|
||||
::ComputeTaskQED_S2, data1::ParticleValue{P1}, data2::ParticleValue{P2}
|
||||
) where {
|
||||
D1 <: ParticleDirection,
|
||||
D2 <: ParticleDirection,
|
||||
S1 <: Union{Electron, Positron},
|
||||
S2 <: Union{Electron, Positron},
|
||||
V1 <: ValueType,
|
||||
V2 <: ValueType,
|
||||
EL <: AbstractFourMomentum,
|
||||
P1 <: ParticleStateful{D1, S1, EL},
|
||||
P2 <: ParticleStateful{D2, S2, EL},
|
||||
P1<:Union{AntiFermionStateful,FermionStateful},
|
||||
P2<:Union{AntiFermionStateful,FermionStateful},
|
||||
}
|
||||
#@assert isapprox(data1.p.momentum, data2.p.momentum, rtol = sqrt(eps()), atol = sqrt(eps())) "$(data1.p.momentum) vs. $(data2.p.momentum)"
|
||||
|
||||
@@ -106,10 +82,8 @@ function compute(
|
||||
end
|
||||
|
||||
function compute(
|
||||
::ComputeTaskQED_S2,
|
||||
data1::ParticleValue{ParticleStateful{D1, Photon}, V1},
|
||||
data2::ParticleValue{ParticleStateful{D2, Photon}, V2},
|
||||
) where {D1 <: ParticleDirection, D2 <: ParticleDirection, V1 <: ValueType, V2 <: ValueType}
|
||||
::ComputeTaskQED_S2, data1::ParticleValue{P1}, data2::ParticleValue{P2}
|
||||
) where {P1<:PhotonStateful,P2<:PhotonStateful}
|
||||
# TODO: assert that data1 and data2 are opposites
|
||||
inner = QED_inner_edge(data1.p)
|
||||
# inner edge is just a scalar, data1 and data2 are photon states that are just Complex numbers here
|
||||
@@ -117,11 +91,11 @@ function compute(
|
||||
end
|
||||
|
||||
"""
|
||||
compute(::ComputeTaskQED_S1, data::ParticleValue)
|
||||
compute(::ComputeTaskQED_S1, data::QEDParticleValue)
|
||||
|
||||
Compute inner edge (1 input particle, 1 output particle).
|
||||
"""
|
||||
function compute(::ComputeTaskQED_S1, data::ParticleValue{P, V}) where {P <: ParticleStateful, V <: ValueType}
|
||||
function compute(::ComputeTaskQED_S1, data::QEDParticleValue{P}) where {P<:QEDParticle}
|
||||
newP = propagation_result(P)
|
||||
new_p = newP(momentum(data.p))
|
||||
# inner edge is just a scalar, can multiply from either side
|
253
src/models/qed/create.jl
Normal file
253
src/models/qed/create.jl
Normal file
@@ -0,0 +1,253 @@
|
||||
|
||||
ComputeTaskQED_Sum() = ComputeTaskQED_Sum(0)
|
||||
|
||||
function _svector_from_type(
|
||||
processDescription::QEDProcessDescription, type::Type{T}, particles
|
||||
) where {DIR<:ParticleDirection,T<:QEDParticle{DIR}}
|
||||
if DIR <: Incoming
|
||||
l = 0
|
||||
for (k, v) in in_particles(processDescription)
|
||||
if T <: k
|
||||
l = v
|
||||
break
|
||||
end
|
||||
end
|
||||
return SVector{l,T}(filter(x -> typeof(x) <: T, particles))
|
||||
elseif DIR <: Outgoing
|
||||
l = 0
|
||||
for (k, v) in out_particles(processDescription)
|
||||
if T <: k
|
||||
l = v
|
||||
break
|
||||
end
|
||||
end
|
||||
return SVector{l,T}(filter(x -> typeof(x) <: T, particles))
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
gen_process_input(processDescription::QEDProcessDescription)
|
||||
|
||||
Return a ProcessInput of randomly generated [`QEDParticle`](@ref)s from a [`QEDProcessDescription`](@ref). The process description can be created manually or parsed from a string using [`parse_process`](@ref).
|
||||
|
||||
Note: This uses RAMBO to create a valid process with conservation of momentum and energy.
|
||||
"""
|
||||
function gen_process_input(processDescription::QEDProcessDescription)
|
||||
massSum = 0
|
||||
inputMasses = Vector{Float64}()
|
||||
for (particle, n) in processDescription.inParticles
|
||||
for _ in 1:n
|
||||
massSum += mass(particle)
|
||||
push!(inputMasses, mass(particle))
|
||||
end
|
||||
end
|
||||
outputMasses = Vector{Float64}()
|
||||
for (particle, n) in processDescription.outParticles
|
||||
for _ in 1:n
|
||||
massSum += mass(particle)
|
||||
push!(outputMasses, mass(particle))
|
||||
end
|
||||
end
|
||||
|
||||
# add some extra random mass to allow for some momentum
|
||||
massSum += rand(rng[threadid()]) * (length(inputMasses) + length(outputMasses))
|
||||
|
||||
particles = Vector{QEDParticle}()
|
||||
initialMomenta = generate_initial_moms(massSum, inputMasses)
|
||||
index = 1
|
||||
for (particle, n) in processDescription.inParticles
|
||||
for _ in 1:n
|
||||
mom = initialMomenta[index]
|
||||
push!(particles, particle(mom))
|
||||
index += 1
|
||||
end
|
||||
end
|
||||
|
||||
final_momenta = generate_physical_massive_moms(rng[threadid()], massSum, outputMasses)
|
||||
index = 1
|
||||
for (particle, n) in processDescription.outParticles
|
||||
for _ in 1:n
|
||||
push!(particles, particle(final_momenta[index]))
|
||||
index += 1
|
||||
end
|
||||
end
|
||||
|
||||
inFerms = _svector_from_type(
|
||||
processDescription, FermionStateful{Incoming,SpinUp}, particles
|
||||
)
|
||||
outFerms = _svector_from_type(
|
||||
processDescription, FermionStateful{Outgoing,SpinUp}, particles
|
||||
)
|
||||
inAntiferms = _svector_from_type(
|
||||
processDescription, AntiFermionStateful{Incoming,SpinUp}, particles
|
||||
)
|
||||
outAntiferms = _svector_from_type(
|
||||
processDescription, AntiFermionStateful{Outgoing,SpinUp}, particles
|
||||
)
|
||||
inPhotons = _svector_from_type(
|
||||
processDescription, PhotonStateful{Incoming,PolX}, particles
|
||||
)
|
||||
outPhotons = _svector_from_type(
|
||||
processDescription, PhotonStateful{Outgoing,PolX}, particles
|
||||
)
|
||||
|
||||
processInput = QEDProcessInput(
|
||||
processDescription,
|
||||
inFerms,
|
||||
outFerms,
|
||||
inAntiferms,
|
||||
outAntiferms,
|
||||
inPhotons,
|
||||
outPhotons,
|
||||
)
|
||||
|
||||
return processInput
|
||||
end
|
||||
|
||||
"""
|
||||
gen_graph(process_description::QEDProcessDescription)
|
||||
|
||||
For a given [`QEDProcessDescription`](@ref), return the [`DAG`](@ref) that computes it.
|
||||
"""
|
||||
function gen_graph(process_description::QEDProcessDescription)
|
||||
initial_diagram = FeynmanDiagram(process_description)
|
||||
diagrams = gen_diagrams(initial_diagram)
|
||||
|
||||
graph = DAG()
|
||||
|
||||
COMPLEX_SIZE = sizeof(ComplexF64)
|
||||
PARTICLE_VALUE_SIZE = 96.0
|
||||
|
||||
# TODO: Not all diagram outputs should always be summed at the end, if they differ by fermion exchange they need to be diffed
|
||||
# Should not matter for n-Photon Compton processes though
|
||||
sum_node = insert_node!(
|
||||
graph, make_node(ComputeTaskQED_Sum(0)); track=false, invalidate_cache=false
|
||||
)
|
||||
global_data_out = insert_node!(
|
||||
graph, make_node(DataTask(COMPLEX_SIZE)); track=false, invalidate_cache=false
|
||||
)
|
||||
insert_edge!(graph, sum_node, global_data_out; track=false, invalidate_cache=false)
|
||||
|
||||
# remember the data out nodes for connection
|
||||
dataOutNodes = Dict()
|
||||
|
||||
for particle in initial_diagram.particles
|
||||
# generate data in and U tasks
|
||||
data_in = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE), String(particle));
|
||||
track=false,
|
||||
invalidate_cache=false,
|
||||
) # read particle data node
|
||||
compute_u = insert_node!(
|
||||
graph, make_node(ComputeTaskQED_U()); track=false, invalidate_cache=false
|
||||
) # compute U node
|
||||
data_out = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE));
|
||||
track=false,
|
||||
invalidate_cache=false,
|
||||
) # transfer data out from u (one ABCParticleValue object)
|
||||
|
||||
insert_edge!(graph, data_in, compute_u; track=false, invalidate_cache=false)
|
||||
insert_edge!(graph, compute_u, data_out; track=false, invalidate_cache=false)
|
||||
|
||||
# remember the data_out node for future edges
|
||||
dataOutNodes[String(particle)] = data_out
|
||||
end
|
||||
|
||||
# TODO: this should be parallelizable somewhat easily
|
||||
for diagram in diagrams
|
||||
tie = diagram.tie[]
|
||||
|
||||
# handle the vertices
|
||||
for vertices in diagram.vertices
|
||||
for vertex in vertices
|
||||
data_in1 = dataOutNodes[String(vertex.in1)]
|
||||
data_in2 = dataOutNodes[String(vertex.in2)]
|
||||
|
||||
compute_V = insert_node!(
|
||||
graph,
|
||||
make_node(ComputeTaskQED_V());
|
||||
track=false,
|
||||
invalidate_cache=false,
|
||||
) # compute vertex
|
||||
|
||||
insert_edge!(
|
||||
graph, data_in1, compute_V; track=false, invalidate_cache=false
|
||||
)
|
||||
insert_edge!(
|
||||
graph, data_in2, compute_V; track=false, invalidate_cache=false
|
||||
)
|
||||
|
||||
data_V_out = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE));
|
||||
track=false,
|
||||
invalidate_cache=false,
|
||||
)
|
||||
|
||||
insert_edge!(
|
||||
graph, compute_V, data_V_out; track=false, invalidate_cache=false
|
||||
)
|
||||
|
||||
if (vertex.out == tie.in1 || vertex.out == tie.in2)
|
||||
# out particle is part of the tie -> there will be an S2 task with it later, don't make S1 task
|
||||
dataOutNodes[String(vertex.out)] = data_V_out
|
||||
continue
|
||||
end
|
||||
|
||||
# otherwise, add S1 task
|
||||
compute_S1 = insert_node!(
|
||||
graph,
|
||||
make_node(ComputeTaskQED_S1());
|
||||
track=false,
|
||||
invalidate_cache=false,
|
||||
) # compute propagator
|
||||
|
||||
insert_edge!(
|
||||
graph, data_V_out, compute_S1; track=false, invalidate_cache=false
|
||||
)
|
||||
|
||||
data_S1_out = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE));
|
||||
track=false,
|
||||
invalidate_cache=false,
|
||||
)
|
||||
|
||||
insert_edge!(
|
||||
graph, compute_S1, data_S1_out; track=false, invalidate_cache=false
|
||||
)
|
||||
|
||||
# overrides potentially different nodes from previous diagrams, which is intentional
|
||||
dataOutNodes[String(vertex.out)] = data_S1_out
|
||||
end
|
||||
end
|
||||
|
||||
# handle the tie
|
||||
data_in1 = dataOutNodes[String(tie.in1)]
|
||||
data_in2 = dataOutNodes[String(tie.in2)]
|
||||
|
||||
compute_S2 = insert_node!(
|
||||
graph, make_node(ComputeTaskQED_S2()); track=false, invalidate_cache=false
|
||||
)
|
||||
|
||||
data_S2 = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE));
|
||||
track=false,
|
||||
invalidate_cache=false,
|
||||
)
|
||||
|
||||
insert_edge!(graph, data_in1, compute_S2; track=false, invalidate_cache=false)
|
||||
insert_edge!(graph, data_in2, compute_S2; track=false, invalidate_cache=false)
|
||||
|
||||
insert_edge!(graph, compute_S2, data_S2; track=false, invalidate_cache=false)
|
||||
|
||||
insert_edge!(graph, data_S2, sum_node; track=false, invalidate_cache=false)
|
||||
add_child!(task(sum_node))
|
||||
end
|
||||
|
||||
return graph
|
||||
end
|
@@ -8,10 +8,10 @@ import Base.show
|
||||
"""
|
||||
FeynmanParticle
|
||||
|
||||
Representation of a particle for use in [`FeynmanDiagram`](@ref)s. Consist of the `ParticleStateful` type and an id.
|
||||
Representation of a particle for use in [`FeynmanDiagram`](@ref)s. Consist of the [`QEDParticle`](@ref) type and an id.
|
||||
"""
|
||||
struct FeynmanParticle
|
||||
particle::Type{<:ParticleStateful}
|
||||
particle::Type{<:QEDParticle}
|
||||
id::Int
|
||||
end
|
||||
|
||||
@@ -45,27 +45,37 @@ The [`FeynmanTie`](@ref) represents the final inner edge of the diagram.
|
||||
"""
|
||||
struct FeynmanDiagram
|
||||
vertices::Vector{Set{FeynmanVertex}}
|
||||
tie::Ref{Union{FeynmanTie, Missing}}
|
||||
tie::Ref{Union{FeynmanTie,Missing}}
|
||||
particles::Vector{FeynmanParticle}
|
||||
type_ids::Dict{Type, Int64} # lut for number of used ids for a particle type
|
||||
type_ids::Dict{Type,Int64} # lut for number of used ids for a particle type
|
||||
end
|
||||
|
||||
"""
|
||||
FeynmanDiagram(pd::GenericQEDProcess)
|
||||
FeynmanDiagram(pd::QEDProcessDescription)
|
||||
|
||||
Create an initial [`FeynmanDiagram`](@ref) with only its initial particles set and no vertices or ties.
|
||||
|
||||
Use [`gen_diagrams`](@ref) to generate all possible diagrams from this one.
|
||||
"""
|
||||
function FeynmanDiagram(pd::GenericQEDProcess)
|
||||
function FeynmanDiagram(pd::QEDProcessDescription)
|
||||
parts = Vector{FeynmanParticle}()
|
||||
|
||||
ids = Dict{Type, Int64}()
|
||||
for type in types(model(pd))
|
||||
for i in 1:number_particles(pd, type)
|
||||
for (type, n) in pd.inParticles
|
||||
for i in 1:n
|
||||
push!(parts, FeynmanParticle(type, i))
|
||||
end
|
||||
ids[type] = number_particles(pd, type)
|
||||
end
|
||||
for (type, n) in pd.outParticles
|
||||
for i in 1:n
|
||||
push!(parts, FeynmanParticle(type, i))
|
||||
end
|
||||
end
|
||||
ids = Dict{Type,Int64}()
|
||||
for t in types(QEDModel())
|
||||
if (isincoming(t))
|
||||
ids[t] = get(pd.inParticles, t, 0)
|
||||
else
|
||||
ids[t] = get(pd.outParticles, t, 0)
|
||||
end
|
||||
end
|
||||
|
||||
return FeynmanDiagram([], missing, parts, ids)
|
||||
@@ -73,13 +83,17 @@ end
|
||||
|
||||
function particle_after_tie(p::FeynmanParticle, t::FeynmanTie)
|
||||
if p == t.in1 || p == t.in2
|
||||
return FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, -1) # placeholder particle and id for tied particles
|
||||
return FeynmanParticle(FermionStateful{Incoming,SpinUp}, -1) # placeholder particle and id for tied particles
|
||||
end
|
||||
return p
|
||||
end
|
||||
|
||||
function vertex_after_tie(v::FeynmanVertex, t::FeynmanTie)
|
||||
return FeynmanVertex(particle_after_tie(v.in1, t), particle_after_tie(v.in2, t), particle_after_tie(v.out, t))
|
||||
return FeynmanVertex(
|
||||
particle_after_tie(v.in1, t),
|
||||
particle_after_tie(v.in2, t),
|
||||
particle_after_tie(v.out, t),
|
||||
)
|
||||
end
|
||||
|
||||
function vertex_after_tie(v::FeynmanVertex, t::Missing)
|
||||
@@ -94,7 +108,9 @@ function vertex_set_after_tie(vs::Set{FeynmanVertex}, t::Missing)
|
||||
return vs
|
||||
end
|
||||
|
||||
function vertex_set_after_tie(vs::Set{FeynmanVertex}, t1::Union{FeynmanTie, Missing}, t2::Union{FeynmanTie, Missing})
|
||||
function vertex_set_after_tie(
|
||||
vs::Set{FeynmanVertex}, t1::Union{FeynmanTie,Missing}, t2::Union{FeynmanTie,Missing}
|
||||
)
|
||||
return Set{FeynmanVertex}(vertex_after_tie(vertex_after_tie(v, t1), t2) for v in vs)
|
||||
end
|
||||
|
||||
@@ -104,7 +120,7 @@ end
|
||||
Return a string representation of the [`FeynmanParticle`](@ref) in a format that is readable by [`type_index_from_name`](@ref).
|
||||
"""
|
||||
function String(p::FeynmanParticle)
|
||||
return "$(String(p.particle))$(String(particle_direction(p.particle)))$(p.id)"
|
||||
return "$(String(p.particle))$(String(direction(p.particle)))$(p.id)"
|
||||
end
|
||||
|
||||
function hash(v::FeynmanVertex)
|
||||
@@ -128,7 +144,8 @@ function ==(t1::FeynmanTie, t2::FeynmanTie)
|
||||
end
|
||||
|
||||
function ==(d1::FeynmanDiagram, d2::FeynmanDiagram)
|
||||
if (!ismissing(d1.tie[]) && ismissing(d2.tie[])) || (ismissing(d1.tie[]) && !ismissing(d2.tie[]))
|
||||
if (!ismissing(d1.tie[]) && ismissing(d2.tie[])) ||
|
||||
(ismissing(d1.tie[]) && !ismissing(d2.tie[]))
|
||||
return false
|
||||
end
|
||||
if d1.particles != d2.particles
|
||||
@@ -140,7 +157,8 @@ function ==(d1::FeynmanDiagram, d2::FeynmanDiagram)
|
||||
|
||||
# TODO can i prove that this works?
|
||||
for (v1, v2) in zip(d1.vertices, d2.vertices)
|
||||
if vertex_set_after_tie(v1, d1.tie[], d2.tie[]) != vertex_set_after_tie(v2, d1.tie[], d2.tie[])
|
||||
if vertex_set_after_tie(v1, d1.tie[], d2.tie[]) !=
|
||||
vertex_set_after_tie(v2, d1.tie[], d2.tie[])
|
||||
return false
|
||||
end
|
||||
end
|
||||
@@ -153,15 +171,17 @@ function ==(d1::FeynmanDiagram, d2::FeynmanDiagram)
|
||||
end
|
||||
|
||||
function copy(fd::FeynmanDiagram)
|
||||
return FeynmanDiagram(deepcopy(fd.vertices), copy(fd.tie[]), deepcopy(fd.particles), copy(fd.type_ids))
|
||||
return FeynmanDiagram(
|
||||
deepcopy(fd.vertices), copy(fd.tie[]), deepcopy(fd.particles), copy(fd.type_ids)
|
||||
)
|
||||
end
|
||||
|
||||
"""
|
||||
id_for_type(d::FeynmanDiagram, t::Type{<:ParticleStateful})
|
||||
id_for_type(d::FeynmanDiagram, t::Type{<:QEDParticle})
|
||||
|
||||
Return the highest id of any particle of the given type in the diagram + 1.
|
||||
"""
|
||||
function id_for_type(d::FeynmanDiagram, t::Type{<:ParticleStateful})
|
||||
function id_for_type(d::FeynmanDiagram, t::Type{<:QEDParticle})
|
||||
return d.type_ids[t] + 1
|
||||
end
|
||||
|
||||
@@ -220,7 +240,7 @@ end
|
||||
|
||||
Return a vector of the particles after applying the vertices and tie of the diagram up to the given level. If no level is given, apply all. The tie comes last and is its own "level".
|
||||
"""
|
||||
function get_particles(fd::FeynmanDiagram, level::Int = -1)
|
||||
function get_particles(fd::FeynmanDiagram, level::Int=-1)
|
||||
if level == -1
|
||||
level = length(fd.vertices) + 1
|
||||
end
|
||||
@@ -356,8 +376,14 @@ function possible_vertices(fd::FeynmanDiagram)
|
||||
p1 = particles[i]
|
||||
p2 = particles[j]
|
||||
if (caninteract(p1.particle, p2.particle))
|
||||
interaction_res = propagation_result(interaction_result(p1.particle, p2.particle))
|
||||
v = FeynmanVertex(p1, p2, FeynmanParticle(interaction_res, id_for_type(fd, interaction_res)))
|
||||
interaction_res = propagation_result(
|
||||
interaction_result(p1.particle, p2.particle)
|
||||
)
|
||||
v = FeynmanVertex(
|
||||
p1,
|
||||
p2,
|
||||
FeynmanParticle(interaction_res, id_for_type(fd, interaction_res)),
|
||||
)
|
||||
#@assert !(v.out in particles) "$v is in $fd"
|
||||
if !can_apply_vertex(fully_generated_particles, v)
|
||||
continue
|
||||
@@ -430,19 +456,18 @@ function remove_duplicates(compare_set::Set{FeynmanDiagram})
|
||||
return result
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
is_compton(fd::FeynmanDiagram)
|
||||
|
||||
Returns true iff the given feynman diagram is an (empty) diagram of a compton process like ke->k^ne
|
||||
"""
|
||||
function is_compton(fd::FeynmanDiagram)
|
||||
return fd.type_ids[ParticleStateful{Incoming, Electron, SFourMomentum}] == 1 &&
|
||||
fd.type_ids[ParticleStateful{Outgoing, Electron, SFourMomentum}] == 1 &&
|
||||
fd.type_ids[ParticleStateful{Incoming, Positron, SFourMomentum}] == 0 &&
|
||||
fd.type_ids[ParticleStateful{Outgoing, Positron, SFourMomentum}] == 0 &&
|
||||
fd.type_ids[ParticleStateful{Incoming, Photon, SFourMomentum}] >= 1 &&
|
||||
fd.type_ids[ParticleStateful{Outgoing, Photon, SFourMomentum}] >= 1
|
||||
return fd.type_ids[FermionStateful{Incoming}] == 1 &&
|
||||
fd.type_ids[FermionStateful{Outgoing}] == 1 &&
|
||||
fd.type_ids[AntiFermionStateful{Incoming}] == 0 &&
|
||||
fd.type_ids[AntiFermionStateful{Outgoing}] == 0 &&
|
||||
fd.type_ids[PhotonStateful{Incoming}] >= 1 &&
|
||||
fd.type_ids[PhotonStateful{Outgoing}] >= 1
|
||||
end
|
||||
|
||||
"""
|
||||
@@ -452,19 +477,19 @@ Helper function for [`gen_compton_diagrams`](@Ref). Generates a single diagram f
|
||||
"""
|
||||
function gen_compton_diagram_from_order(order::Vector{Int}, inFerm, outFerm, n::Int, m::Int)
|
||||
photons = vcat(
|
||||
[FeynmanParticle(ParticleStateful{Incoming, Photon, SFourMomentum}, i) for i in 1:n],
|
||||
[FeynmanParticle(ParticleStateful{Outgoing, Photon, SFourMomentum}, i) for i in 1:m],
|
||||
[FeynmanParticle(PhotonStateful{Incoming,PolX}, i) for i in 1:n],
|
||||
[FeynmanParticle(PhotonStateful{Outgoing,PolX}, i) for i in 1:m],
|
||||
)
|
||||
|
||||
new_diagram = FeynmanDiagram(
|
||||
[],
|
||||
missing,
|
||||
[inFerm, outFerm, photons...],
|
||||
Dict{Type, Int64}(
|
||||
ParticleStateful{Incoming, Electron, SFourMomentum} => 1,
|
||||
ParticleStateful{Outgoing, Electron, SFourMomentum} => 1,
|
||||
ParticleStateful{Incoming, Photon, SFourMomentum} => n,
|
||||
ParticleStateful{Outgoing, Photon, SFourMomentum} => m,
|
||||
Dict{Type,Int64}(
|
||||
FermionStateful{Incoming,SpinUp} => 1,
|
||||
FermionStateful{Outgoing,SpinUp} => 1,
|
||||
PhotonStateful{Incoming,PolX} => n,
|
||||
PhotonStateful{Outgoing,PolX} => m,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -476,9 +501,9 @@ function gen_compton_diagram_from_order(order::Vector{Int}, inFerm, outFerm, n::
|
||||
while left_index <= right_index
|
||||
# left side
|
||||
v_left = FeynmanVertex(
|
||||
FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, iterations),
|
||||
FeynmanParticle(FermionStateful{Incoming,SpinUp}, iterations),
|
||||
photons[order[left_index]],
|
||||
FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, iterations + 1),
|
||||
FeynmanParticle(FermionStateful{Incoming,SpinUp}, iterations + 1),
|
||||
)
|
||||
left_index += 1
|
||||
add_vertex!(new_diagram, v_left)
|
||||
@@ -489,9 +514,9 @@ function gen_compton_diagram_from_order(order::Vector{Int}, inFerm, outFerm, n::
|
||||
|
||||
# right side
|
||||
v_right = FeynmanVertex(
|
||||
FeynmanParticle(ParticleStateful{Outgoing, Electron, SFourMomentum}, iterations),
|
||||
FeynmanParticle(FermionStateful{Outgoing,SpinUp}, iterations),
|
||||
photons[order[right_index]],
|
||||
FeynmanParticle(ParticleStateful{Outgoing, Electron, SFourMomentum}, iterations + 1),
|
||||
FeynmanParticle(FermionStateful{Outgoing,SpinUp}, iterations + 1),
|
||||
)
|
||||
right_index -= 1
|
||||
add_vertex!(new_diagram, v_right)
|
||||
@@ -509,21 +534,23 @@ end
|
||||
|
||||
Helper function for [`gen_compton_diagrams`](@Ref). Generates a single diagram for the given order and n input and m output photons.
|
||||
"""
|
||||
function gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, outFerm, n::Int, m::Int)
|
||||
function gen_compton_diagram_from_order_one_side(
|
||||
order::Vector{Int}, inFerm, outFerm, n::Int, m::Int
|
||||
)
|
||||
photons = vcat(
|
||||
[FeynmanParticle(ParticleStateful{Incoming, Photon, SFourMomentum}, i) for i in 1:n],
|
||||
[FeynmanParticle(ParticleStateful{Outgoing, Photon, SFourMomentum}, i) for i in 1:m],
|
||||
[FeynmanParticle(PhotonStateful{Incoming,PolX}, i) for i in 1:n],
|
||||
[FeynmanParticle(PhotonStateful{Outgoing,PolX}, i) for i in 1:m],
|
||||
)
|
||||
|
||||
new_diagram = FeynmanDiagram(
|
||||
[],
|
||||
missing,
|
||||
[inFerm, outFerm, photons...],
|
||||
Dict{Type, Int64}(
|
||||
ParticleStateful{Incoming, Electron, SFourMomentum} => 1,
|
||||
ParticleStateful{Outgoing, Electron, SFourMomentum} => 1,
|
||||
ParticleStateful{Incoming, Photon, SFourMomentum} => n,
|
||||
ParticleStateful{Outgoing, Photon, SFourMomentum} => m,
|
||||
Dict{Type,Int64}(
|
||||
FermionStateful{Incoming,SpinUp} => 1,
|
||||
FermionStateful{Outgoing,SpinUp} => 1,
|
||||
PhotonStateful{Incoming,PolX} => n,
|
||||
PhotonStateful{Outgoing,PolX} => m,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -535,9 +562,9 @@ function gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, out
|
||||
while left_index <= right_index
|
||||
# left side
|
||||
v_left = FeynmanVertex(
|
||||
FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, iterations),
|
||||
FeynmanParticle(FermionStateful{Incoming,SpinUp}, iterations),
|
||||
photons[order[left_index]],
|
||||
FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, iterations + 1),
|
||||
FeynmanParticle(FermionStateful{Incoming,SpinUp}, iterations + 1),
|
||||
)
|
||||
left_index += 1
|
||||
add_vertex!(new_diagram, v_left)
|
||||
@@ -550,9 +577,9 @@ function gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, out
|
||||
if (iterations == 1)
|
||||
# right side
|
||||
v_right = FeynmanVertex(
|
||||
FeynmanParticle(ParticleStateful{Outgoing, Electron, SFourMomentum}, iterations),
|
||||
FeynmanParticle(FermionStateful{Outgoing,SpinUp}, iterations),
|
||||
photons[order[right_index]],
|
||||
FeynmanParticle(ParticleStateful{Outgoing, Electron, SFourMomentum}, iterations + 1),
|
||||
FeynmanParticle(FermionStateful{Outgoing,SpinUp}, iterations + 1),
|
||||
)
|
||||
right_index -= 1
|
||||
add_vertex!(new_diagram, v_right)
|
||||
@@ -573,14 +600,17 @@ end
|
||||
Special case diagram generation for Compton processes, i.e., processes of the form k^ne->k^me
|
||||
"""
|
||||
function gen_compton_diagrams(n::Int, m::Int)
|
||||
inFerm = FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, 1)
|
||||
outFerm = FeynmanParticle(ParticleStateful{Outgoing, Electron, SFourMomentum}, 1)
|
||||
inFerm = FeynmanParticle(FermionStateful{Incoming,SpinUp}, 1)
|
||||
outFerm = FeynmanParticle(FermionStateful{Outgoing,SpinUp}, 1)
|
||||
|
||||
perms = [permutations([i for i in 1:(n + m)])...]
|
||||
|
||||
diagrams = [Vector{FeynmanDiagram}() for i in 1:nthreads()]
|
||||
@threads for order in perms
|
||||
push!(diagrams[threadid()], gen_compton_diagram_from_order(order, inFerm, outFerm, n, m))
|
||||
push!(
|
||||
diagrams[threadid()],
|
||||
gen_compton_diagram_from_order(order, inFerm, outFerm, n, m),
|
||||
)
|
||||
end
|
||||
|
||||
return vcat(diagrams...)
|
||||
@@ -592,14 +622,17 @@ end
|
||||
Special case diagram generation for Compton processes, i.e., processes of the form k^ne->k^me, but generating from one end, yielding larger diagrams
|
||||
"""
|
||||
function gen_compton_diagrams_one_side(n::Int, m::Int)
|
||||
inFerm = FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, 1)
|
||||
outFerm = FeynmanParticle(ParticleStateful{Outgoing, Electron, SFourMomentum}, 1)
|
||||
inFerm = FeynmanParticle(FermionStateful{Incoming,SpinUp}, 1)
|
||||
outFerm = FeynmanParticle(FermionStateful{Outgoing,SpinUp}, 1)
|
||||
|
||||
perms = [permutations([i for i in 1:(n + m)])...]
|
||||
|
||||
diagrams = [Vector{FeynmanDiagram}() for i in 1:nthreads()]
|
||||
@threads for order in perms
|
||||
push!(diagrams[threadid()], gen_compton_diagram_from_order_one_side(order, inFerm, outFerm, n, m))
|
||||
push!(
|
||||
diagrams[threadid()],
|
||||
gen_compton_diagram_from_order_one_side(order, inFerm, outFerm, n, m),
|
||||
)
|
||||
end
|
||||
|
||||
return vcat(diagrams...)
|
||||
@@ -612,15 +645,11 @@ From a given feynman diagram in its initial state, e.g. when created through the
|
||||
"""
|
||||
function gen_diagrams(fd::FeynmanDiagram)
|
||||
if is_compton(fd)
|
||||
return gen_compton_diagrams(
|
||||
fd.type_ids[ParticleStateful{Incoming, Photon, SFourMomentum}],
|
||||
fd.type_ids[ParticleStateful{Outgoing, Photon, SFourMomentum}],
|
||||
return gen_compton_diagrams_one_side(
|
||||
fd.type_ids[PhotonStateful{Incoming}], fd.type_ids[PhotonStateful{Outgoing}]
|
||||
)
|
||||
end
|
||||
|
||||
throw(error("Unimplemented for non-compton!"))
|
||||
|
||||
#=
|
||||
working = Set{FeynmanDiagram}()
|
||||
results = Set{FeynmanDiagram}()
|
||||
|
||||
@@ -659,5 +688,4 @@ function gen_diagrams(fd::FeynmanDiagram)
|
||||
end
|
||||
|
||||
return remove_duplicates(results)
|
||||
=#
|
||||
end
|
44
src/models/qed/parse.jl
Normal file
44
src/models/qed/parse.jl
Normal file
@@ -0,0 +1,44 @@
|
||||
|
||||
"""
|
||||
parse_process(string::AbstractString, model::QEDModel)
|
||||
|
||||
Parse a string representation of a process, such as "ke->ke" into the corresponding [`QEDProcessDescription`](@ref).
|
||||
"""
|
||||
function parse_process(str::AbstractString, model::QEDModel)
|
||||
inParticles = Dict{Type, Int}()
|
||||
outParticles = Dict{Type, Int}()
|
||||
|
||||
if !(contains(str, "->"))
|
||||
throw("Did not find -> while parsing process \"$str\"")
|
||||
end
|
||||
|
||||
(inStr, outStr) = split(str, "->")
|
||||
|
||||
if (isempty(inStr) || isempty(outStr))
|
||||
throw("Process (\"$str\") input or output part is empty!")
|
||||
end
|
||||
|
||||
for t in types(model)
|
||||
if (isincoming(t))
|
||||
inCount = count(x -> x == String(t)[1], inStr)
|
||||
|
||||
if inCount != 0
|
||||
inParticles[t] = inCount
|
||||
end
|
||||
end
|
||||
if (isoutgoing(t))
|
||||
outCount = count(x -> x == String(t)[1], outStr)
|
||||
if outCount != 0
|
||||
outParticles[t] = outCount
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if length(inStr) != sum(values(inParticles))
|
||||
throw("Encountered unknown characters in the input part of process \"$str\"")
|
||||
elseif length(outStr) != sum(values(outParticles))
|
||||
throw("Encountered unknown characters in the output part of process \"$str\"")
|
||||
end
|
||||
|
||||
return QEDProcessDescription(inParticles, outParticles)
|
||||
end
|
532
src/models/qed/particle.jl
Normal file
532
src/models/qed/particle.jl
Normal file
@@ -0,0 +1,532 @@
|
||||
using QEDprocesses
|
||||
using StaticArrays
|
||||
import QEDbase.mass
|
||||
|
||||
# TODO check
|
||||
const e = sqrt(4π / 137)
|
||||
|
||||
"""
|
||||
QEDModel <: AbstractPhysicsModel
|
||||
|
||||
Singleton definition for identification of the QED-Model.
|
||||
"""
|
||||
struct QEDModel <: AbstractPhysicsModel end
|
||||
|
||||
"""
|
||||
QEDParticle
|
||||
|
||||
Base type for all particles in the [`QEDModel`](@ref).
|
||||
|
||||
Its template parameter specifies the particle's direction.
|
||||
|
||||
The concrete types contain singletons of the types that they are, like `Photon` and `Electron` from QEDbase, and their state descriptions.
|
||||
"""
|
||||
abstract type QEDParticle{Direction<:ParticleDirection} <: AbstractParticle end
|
||||
|
||||
"""
|
||||
QEDProcessDescription <: AbstractProcessDescription
|
||||
|
||||
A description of a process in the QED-Model. Contains the input and output particles.
|
||||
|
||||
See also: [`in_particles`](@ref), [`out_particles`](@ref), [`parse_process`](@ref)
|
||||
"""
|
||||
struct QEDProcessDescription <: AbstractProcessDescription
|
||||
inParticles::Dict{Type{<:QEDParticle{Incoming}},Int}
|
||||
outParticles::Dict{Type{<:QEDParticle{Outgoing}},Int}
|
||||
end
|
||||
|
||||
QEDParticleValue{ParticleType<:QEDParticle} = Union{
|
||||
ParticleValue{ParticleType,BiSpinor},
|
||||
ParticleValue{ParticleType,AdjointBiSpinor},
|
||||
ParticleValue{ParticleType,DiracMatrix},
|
||||
ParticleValue{ParticleType,SLorentzVector{Float64}},
|
||||
ParticleValue{ParticleType,ComplexF64},
|
||||
}
|
||||
|
||||
"""
|
||||
PhotonStateful <: QEDParticle
|
||||
|
||||
A photon of the [`QEDModel`](@ref) with its state.
|
||||
"""
|
||||
struct PhotonStateful{Direction<:ParticleDirection,Pol<:AbstractDefinitePolarization} <:
|
||||
QEDParticle{Direction}
|
||||
momentum::SFourMomentum
|
||||
|
||||
function PhotonStateful{Direction,Pol}(
|
||||
mom::SFourMomentum
|
||||
) where {Direction<:ParticleDirection,Pol<:AbstractDefinitePolarization}
|
||||
return new{Direction,Pol}(mom)
|
||||
end
|
||||
|
||||
function PhotonStateful{Direction}(
|
||||
mom::SFourMomentum
|
||||
) where {Direction<:ParticleDirection}
|
||||
return new{Direction,AbstractDefinitePolarization}(mom)
|
||||
end
|
||||
|
||||
function PhotonStateful{Direction}(
|
||||
mom::SFourMomentum, pol::AbstractDefinitePolarization
|
||||
) where {Direction<:ParticleDirection}
|
||||
return new{Direction,typeof(pol)}(mom)
|
||||
end
|
||||
|
||||
function PhotonStateful{Dir,Pol}(ph::PhotonStateful) where {Dir,Pol}
|
||||
return new{Dir,Pol}(ph.momentum)
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
FermionStateful <: QEDParticle
|
||||
|
||||
A fermion of the [`QEDModel`](@ref) with its state.
|
||||
"""
|
||||
struct FermionStateful{Direction<:ParticleDirection,Spin<:AbstractDefiniteSpin} <:
|
||||
QEDParticle{Direction}
|
||||
momentum::SFourMomentum
|
||||
# TODO: mass for electron/muon/tauon representation?
|
||||
|
||||
function FermionStateful{Direction,Spin}(
|
||||
mom::SFourMomentum
|
||||
) where {Direction<:ParticleDirection,Spin<:AbstractDefiniteSpin}
|
||||
return new{Direction,Spin}(mom)
|
||||
end
|
||||
|
||||
function FermionStateful{Direction}(
|
||||
mom::SFourMomentum
|
||||
) where {Direction<:ParticleDirection}
|
||||
return new{Direction,AbstractDefiniteSpin}(mom)
|
||||
end
|
||||
|
||||
function FermionStateful{Direction}(
|
||||
mom::SFourMomentum, spin::AbstractDefiniteSpin
|
||||
) where {Direction<:ParticleDirection}
|
||||
return new{Direction,typeof(spin)}(mom)
|
||||
end
|
||||
|
||||
function FermionStateful{Dir,Spin}(f::FermionStateful) where {Dir,Spin}
|
||||
return new{Dir,Spin}(f.momentum)
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
AntiFermionStateful <: QEDParticle
|
||||
|
||||
An anti-fermion of the [`QEDModel`](@ref) with its state.
|
||||
"""
|
||||
struct AntiFermionStateful{Direction<:ParticleDirection,Spin<:AbstractDefiniteSpin} <:
|
||||
QEDParticle{Direction}
|
||||
momentum::SFourMomentum
|
||||
# TODO: mass for electron/muon/tauon representation?
|
||||
|
||||
function AntiFermionStateful{Direction,Spin}(
|
||||
mom::SFourMomentum
|
||||
) where {Direction<:ParticleDirection,Spin<:AbstractDefiniteSpin}
|
||||
return new{Direction,Spin}(mom)
|
||||
end
|
||||
|
||||
function AntiFermionStateful{Direction}(
|
||||
mom::SFourMomentum
|
||||
) where {Direction<:ParticleDirection}
|
||||
return new{Direction,AbstractDefiniteSpin}(mom)
|
||||
end
|
||||
|
||||
function AntiFermionStateful{Direction}(
|
||||
mom::SFourMomentum, spin::AbstractDefiniteSpin
|
||||
) where {Direction<:ParticleDirection}
|
||||
return new{Direction,typeof(spin)}(mom)
|
||||
end
|
||||
|
||||
function AntiFermionStateful{Dir,Spin}(f::AntiFermionStateful) where {Dir,Spin}
|
||||
return new{Dir,Spin}(f.momentum)
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: QEDParticle, T2 <: QEDParticle}
|
||||
|
||||
For two given particle types that can interact, return the third.
|
||||
"""
|
||||
function interaction_result(
|
||||
t1::Type{T1}, t2::Type{T2}
|
||||
) where {T1<:QEDParticle,T2<:QEDParticle}
|
||||
@assert false "Invalid interaction between particles of types $t1 and $t2"
|
||||
end
|
||||
|
||||
function interaction_result(
|
||||
::Type{FermionStateful{Incoming,Spin1}}, ::Type{FermionStateful{Outgoing,Spin2}}
|
||||
) where {Spin1,Spin2}
|
||||
return PhotonStateful{Incoming,PolX}
|
||||
end
|
||||
function interaction_result(
|
||||
::Type{FermionStateful{Incoming,Spin1}}, ::Type{AntiFermionStateful{Incoming,Spin2}}
|
||||
) where {Spin1,Spin2}
|
||||
return PhotonStateful{Incoming,PolX}
|
||||
end
|
||||
function interaction_result(
|
||||
::Type{FermionStateful{Incoming,Spin1}}, ::Type{<:PhotonStateful}
|
||||
) where {Spin1}
|
||||
return FermionStateful{Outgoing,SpinUp}
|
||||
end
|
||||
|
||||
function interaction_result(
|
||||
::Type{FermionStateful{Outgoing,Spin1}}, ::Type{FermionStateful{Incoming,Spin2}}
|
||||
) where {Spin1,Spin2}
|
||||
return PhotonStateful{Incoming,PolX}
|
||||
end
|
||||
function interaction_result(
|
||||
::Type{FermionStateful{Outgoing,Spin1}}, ::Type{AntiFermionStateful{Outgoing,Spin2}}
|
||||
) where {Spin1,Spin2}
|
||||
return PhotonStateful{Incoming,PolX}
|
||||
end
|
||||
function interaction_result(
|
||||
::Type{FermionStateful{Outgoing,Spin1}}, ::Type{<:PhotonStateful}
|
||||
) where {Spin1}
|
||||
return FermionStateful{Incoming,SpinUp}
|
||||
end
|
||||
|
||||
# antifermion mirror
|
||||
function interaction_result(
|
||||
::Type{AntiFermionStateful{Incoming,Spin}}, t2::Type{<:QEDParticle}
|
||||
) where {Spin}
|
||||
return interaction_result(FermionStateful{Outgoing,Spin}, t2)
|
||||
end
|
||||
function interaction_result(
|
||||
::Type{AntiFermionStateful{Outgoing,Spin}}, t2::Type{<:QEDParticle}
|
||||
) where {Spin}
|
||||
return interaction_result(FermionStateful{Incoming,Spin}, t2)
|
||||
end
|
||||
|
||||
# photon commutativity
|
||||
function interaction_result(t1::Type{<:PhotonStateful}, t2::Type{<:QEDParticle})
|
||||
return interaction_result(t2, t1)
|
||||
end
|
||||
|
||||
# but prevent stack overflow
|
||||
function interaction_result(t1::Type{<:PhotonStateful}, t2::Type{<:PhotonStateful})
|
||||
@assert false "Invalid interaction between particles of types $t1 and $t2"
|
||||
end
|
||||
|
||||
"""
|
||||
propagation_result(t1::Type{T}) where {T <: QEDParticle}
|
||||
|
||||
Return the type of the inverted direction. E.g.
|
||||
"""
|
||||
propagation_result(
|
||||
::Type{FermionStateful{Incoming,Spin}}
|
||||
) where {Spin<:AbstractDefiniteSpin} = FermionStateful{Outgoing,Spin}
|
||||
function propagation_result(
|
||||
::Type{FermionStateful{Outgoing,Spin}}
|
||||
) where {Spin<:AbstractDefiniteSpin}
|
||||
return FermionStateful{Incoming,Spin}
|
||||
end
|
||||
function propagation_result(
|
||||
::Type{AntiFermionStateful{Incoming,Spin}}
|
||||
) where {Spin<:AbstractDefiniteSpin}
|
||||
return AntiFermionStateful{Outgoing,Spin}
|
||||
end
|
||||
function propagation_result(
|
||||
::Type{AntiFermionStateful{Outgoing,Spin}}
|
||||
) where {Spin<:AbstractDefiniteSpin}
|
||||
return AntiFermionStateful{Incoming,Spin}
|
||||
end
|
||||
function propagation_result(
|
||||
::Type{PhotonStateful{Incoming,Pol}}
|
||||
) where {Pol<:AbstractDefinitePolarization}
|
||||
return PhotonStateful{Outgoing,Pol}
|
||||
end
|
||||
function propagation_result(
|
||||
::Type{PhotonStateful{Outgoing,Pol}}
|
||||
) where {Pol<:AbstractDefinitePolarization}
|
||||
return PhotonStateful{Incoming,Pol}
|
||||
end
|
||||
|
||||
"""
|
||||
types(::QEDModel)
|
||||
|
||||
Return a Vector of the possible types of particle in the [`QEDModel`](@ref).
|
||||
"""
|
||||
function types(::QEDModel)
|
||||
return [
|
||||
PhotonStateful{Incoming},
|
||||
PhotonStateful{Outgoing},
|
||||
FermionStateful{Incoming},
|
||||
FermionStateful{Outgoing},
|
||||
AntiFermionStateful{Incoming},
|
||||
AntiFermionStateful{Outgoing},
|
||||
]
|
||||
end
|
||||
|
||||
# type piracy?
|
||||
String(::Type{Incoming}) = "Incoming"
|
||||
String(::Type{Outgoing}) = "Outgoing"
|
||||
|
||||
String(::Type{PolX}) = "polx"
|
||||
String(::Type{PolY}) = "poly"
|
||||
|
||||
String(::Type{SpinUp}) = "spinup"
|
||||
String(::Type{SpinDown}) = "spindown"
|
||||
|
||||
String(::Incoming) = "i"
|
||||
String(::Outgoing) = "o"
|
||||
|
||||
function String(::Type{<:PhotonStateful})
|
||||
return "k"
|
||||
end
|
||||
function String(::Type{<:FermionStateful})
|
||||
return "e"
|
||||
end
|
||||
function String(::Type{<:AntiFermionStateful})
|
||||
return "p"
|
||||
end
|
||||
|
||||
function unique_name(::Type{PhotonStateful{Dir,Pol}}) where {Dir,Pol}
|
||||
return String(PhotonStateful) * String(Dir) * String(Pol)
|
||||
end
|
||||
function unique_name(::Type{FermionStateful{Dir,Spin}}) where {Dir,Spin}
|
||||
return String(FermionStateful) * String(Dir) * String(Spin)
|
||||
end
|
||||
function unique_name(::Type{AntiFermionStateful{Dir,Spin}}) where {Dir,Spin}
|
||||
return String(AntiFermionStateful) * String(Dir) * String(Spin)
|
||||
end
|
||||
|
||||
@inline particle(::PhotonStateful) = Photon()
|
||||
@inline particle(::FermionStateful) = Electron()
|
||||
@inline particle(::AntiFermionStateful) = Positron()
|
||||
|
||||
@inline particle(::Type{<:PhotonStateful}) = Photon()
|
||||
@inline particle(::Type{<:FermionStateful}) = Electron()
|
||||
@inline particle(::Type{<:AntiFermionStateful}) = Positron()
|
||||
|
||||
@inline momentum(p::PhotonStateful)::SFourMomentum = p.momentum
|
||||
@inline momentum(p::FermionStateful)::SFourMomentum = p.momentum
|
||||
@inline momentum(p::AntiFermionStateful)::SFourMomentum = p.momentum
|
||||
|
||||
@inline spin_or_pol(
|
||||
p::PhotonStateful{Dir,Pol}
|
||||
) where {Dir,Pol<:AbstractDefinitePolarization} = Pol()
|
||||
@inline spin_or_pol(p::FermionStateful{Dir,Spin}) where {Dir,Spin<:AbstractDefiniteSpin} =
|
||||
Spin()
|
||||
@inline spin_or_pol(
|
||||
p::AntiFermionStateful{Dir,Spin}
|
||||
) where {Dir,Spin<:AbstractDefiniteSpin} = Spin()
|
||||
|
||||
@inline direction(
|
||||
::Type{P}
|
||||
) where {
|
||||
P<:Union{
|
||||
FermionStateful{Incoming},AntiFermionStateful{Incoming},PhotonStateful{Incoming}
|
||||
},
|
||||
} = Incoming()
|
||||
@inline direction(
|
||||
::Type{P}
|
||||
) where {
|
||||
P<:Union{
|
||||
FermionStateful{Outgoing},AntiFermionStateful{Outgoing},PhotonStateful{Outgoing}
|
||||
},
|
||||
} = Outgoing()
|
||||
|
||||
@inline direction(
|
||||
::P
|
||||
) where {
|
||||
P<:Union{
|
||||
FermionStateful{Incoming},AntiFermionStateful{Incoming},PhotonStateful{Incoming}
|
||||
},
|
||||
} = Incoming()
|
||||
@inline direction(
|
||||
::P
|
||||
) where {
|
||||
P<:Union{
|
||||
FermionStateful{Outgoing},AntiFermionStateful{Outgoing},PhotonStateful{Outgoing}
|
||||
},
|
||||
} = Outgoing()
|
||||
|
||||
@inline isincoming(::QEDParticle{Incoming}) = true
|
||||
@inline isincoming(::QEDParticle{Outgoing}) = false
|
||||
@inline isoutgoing(::QEDParticle{Incoming}) = false
|
||||
@inline isoutgoing(::QEDParticle{Outgoing}) = true
|
||||
|
||||
@inline isincoming(::Type{<:QEDParticle{Incoming}}) = true
|
||||
@inline isincoming(::Type{<:QEDParticle{Outgoing}}) = false
|
||||
@inline isoutgoing(::Type{<:QEDParticle{Incoming}}) = false
|
||||
@inline isoutgoing(::Type{<:QEDParticle{Outgoing}}) = true
|
||||
|
||||
@inline mass(::Type{<:FermionStateful}) = 1.0
|
||||
@inline mass(::Type{<:AntiFermionStateful}) = 1.0
|
||||
@inline mass(::Type{<:PhotonStateful}) = 0.0
|
||||
|
||||
@inline invert_momentum(p::FermionStateful{Dir,Spin}) where {Dir,Spin} =
|
||||
FermionStateful{Dir,Spin}(-p.momentum, p.spin)
|
||||
@inline invert_momentum(p::AntiFermionStateful{Dir,Spin}) where {Dir,Spin} =
|
||||
AntiFermionStateful{Dir,Spin}(-p.momentum, p.spin)
|
||||
@inline invert_momentum(k::PhotonStateful{Dir,Spin}) where {Dir,Spin} =
|
||||
PhotonStateful{Dir,Spin}(-k.momentum, k.polarization)
|
||||
|
||||
"""
|
||||
caninteract(T1::Type{<:QEDParticle}, T2::Type{<:QEDParticle})
|
||||
|
||||
For two given [`QEDParticle`](@ref) types, return whether they can interact at a vertex. This is equivalent to `!issame(T1, T2)`.
|
||||
|
||||
See also: [`issame`](@ref) and [`interaction_result`](@ref)
|
||||
"""
|
||||
function caninteract(T1::Type{<:QEDParticle}, T2::Type{<:QEDParticle})
|
||||
if (T1 == T2)
|
||||
return false
|
||||
end
|
||||
if (T1 <: PhotonStateful && T2 <: PhotonStateful)
|
||||
return false
|
||||
end
|
||||
|
||||
for (P1, P2) in [(T1, T2), (T2, T1)]
|
||||
if (P1 <: FermionStateful{Incoming} && P2 <: AntiFermionStateful{Outgoing})
|
||||
return false
|
||||
end
|
||||
if (P1 <: FermionStateful{Outgoing} && P2 <: AntiFermionStateful{Incoming})
|
||||
return false
|
||||
end
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function type_index_from_name(::QEDModel, name::String)
|
||||
if startswith(name, "ki")
|
||||
return (PhotonStateful{Incoming,PolX}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "ko")
|
||||
return (PhotonStateful{Outgoing,PolX}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "ei")
|
||||
return (FermionStateful{Incoming,SpinUp}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "eo")
|
||||
return (FermionStateful{Outgoing,SpinUp}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "pi")
|
||||
return (AntiFermionStateful{Incoming,SpinUp}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "po")
|
||||
return (AntiFermionStateful{Outgoing,SpinUp}, parse(Int, name[3:end]))
|
||||
else
|
||||
throw("Invalid name for a particle in the QED model")
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
issame(T1::Type{<:QEDParticle}, T2::Type{<:QEDParticle})
|
||||
|
||||
For two given [`QEDParticle`](@ref) types, return whether they are equivalent for the purpose of a Feynman Diagram. That means e.g. an `Incoming` `AntiFermion` is the same as an `Outgoing` `Fermion`. This is equivalent to `!caninteract(T1, T2)`.
|
||||
|
||||
See also: [`caninteract`](@ref) and [`interaction_result`](@ref)
|
||||
"""
|
||||
function issame(T1::Type{<:QEDParticle}, T2::Type{<:QEDParticle})
|
||||
return !caninteract(T1, T2)
|
||||
end
|
||||
|
||||
"""
|
||||
QED_vertex()
|
||||
|
||||
Return the factor of a vertex in a QED feynman diagram.
|
||||
"""
|
||||
@inline function QED_vertex()::SLorentzVector{DiracMatrix}
|
||||
# Peskin-Schroeder notation
|
||||
return -1im * e * gamma()
|
||||
end
|
||||
|
||||
@inline function QED_inner_edge(p::QEDParticle)
|
||||
return QEDbase.propagator(particle(p), p.momentum)
|
||||
end
|
||||
|
||||
"""
|
||||
QED_conserve_momentum(p1::QEDParticle, p2::QEDParticle)
|
||||
|
||||
Calculate and return a new particle from two given interacting ones at a vertex.
|
||||
"""
|
||||
function QED_conserve_momentum(
|
||||
p1::P1, p2::P2
|
||||
) where {
|
||||
Dir1<:ParticleDirection,
|
||||
Dir2<:ParticleDirection,
|
||||
SpinPol1<:AbstractSpinOrPolarization,
|
||||
SpinPol2<:AbstractSpinOrPolarization,
|
||||
P1<:Union{
|
||||
FermionStateful{Dir1,SpinPol1},
|
||||
AntiFermionStateful{Dir1,SpinPol1},
|
||||
PhotonStateful{Dir1,SpinPol1},
|
||||
},
|
||||
P2<:Union{
|
||||
FermionStateful{Dir2,SpinPol2},
|
||||
AntiFermionStateful{Dir2,SpinPol2},
|
||||
PhotonStateful{Dir2,SpinPol2},
|
||||
},
|
||||
}
|
||||
P3 = interaction_result(P1, P2)
|
||||
p1_mom = p1.momentum
|
||||
if (Dir1 <: Outgoing)
|
||||
p1_mom *= -1
|
||||
end
|
||||
p2_mom = p2.momentum
|
||||
if (Dir2 <: Outgoing)
|
||||
p2_mom *= -1
|
||||
end
|
||||
|
||||
p3_mom = p1_mom + p2_mom
|
||||
if (typeof(direction(P3)) <: Incoming)
|
||||
return P3(-p3_mom)
|
||||
end
|
||||
return P3(p3_mom)
|
||||
end
|
||||
|
||||
"""
|
||||
QEDProcessInput <: AbstractProcessInput
|
||||
|
||||
Input for a QED Process. Contains the [`QEDProcessDescription`](@ref) of the process it is an input for, and the values of the in and out particles.
|
||||
|
||||
See also: [`gen_process_input`](@ref)
|
||||
"""
|
||||
struct QEDProcessInput{N1,N2,N3,N4,N5,N6,S1,S2,S3,S4,P1,P2} <: AbstractProcessInput
|
||||
process::QEDProcessDescription
|
||||
inFerms::SVector{N1,FermionStateful{Incoming,S1}}
|
||||
outFerms::SVector{N2,FermionStateful{Outgoing,S2}}
|
||||
inAntiferms::SVector{N3,AntiFermionStateful{Incoming,S3}}
|
||||
outAntiferms::SVector{N4,AntiFermionStateful{Outgoing,S4}}
|
||||
inPhotons::SVector{N5,PhotonStateful{Incoming,P1}}
|
||||
outPhotons::SVector{N6,PhotonStateful{Outgoing,P2}}
|
||||
end
|
||||
|
||||
"""
|
||||
model(::AbstractProcessDescription)
|
||||
|
||||
Return the model of this process description.
|
||||
"""
|
||||
model(::QEDProcessDescription) = QEDModel()
|
||||
model(::QEDProcessInput) = QEDModel()
|
||||
|
||||
function copy(process::QEDProcessDescription)
|
||||
return QEDProcessDescription(copy(process.inParticles), copy(process.outParticles))
|
||||
end
|
||||
|
||||
function ==(p1::QEDProcessDescription, p2::QEDProcessDescription)
|
||||
return p1.inParticles == p2.inParticles && p1.outParticles == p2.outParticles
|
||||
end
|
||||
|
||||
function in_particles(process::QEDProcessDescription)
|
||||
return process.inParticles
|
||||
end
|
||||
|
||||
function out_particles(process::QEDProcessDescription)
|
||||
return process.outParticles
|
||||
end
|
||||
|
||||
function get_particle(
|
||||
input::QEDProcessInput, t::Type{Particle}, n::Int
|
||||
)::Particle where {Particle}
|
||||
if (t <: FermionStateful{Incoming})
|
||||
return input.inFerms[n]
|
||||
elseif (t <: FermionStateful{Outgoing})
|
||||
return input.outFerms[n]
|
||||
elseif (t <: AntiFermionStateful{Incoming})
|
||||
return input.inAntiferms[n]
|
||||
elseif (t <: AntiFermionStateful{Outgoing})
|
||||
return input.outAntiferms[n]
|
||||
elseif (t <: PhotonStateful{Incoming})
|
||||
return input.inPhotons[n]
|
||||
elseif (t <: PhotonStateful{Outgoing})
|
||||
return input.outPhotons[n]
|
||||
end
|
||||
@assert false "Invalid type given"
|
||||
end
|
@@ -1,9 +1,8 @@
|
||||
|
||||
#=
|
||||
"""
|
||||
show(io::IO, process::GenericQEDProcess)
|
||||
show(io::IO, process::QEDProcessDescription)
|
||||
|
||||
Pretty print an [`GenericQEDProcess`](@ref) (no newlines).
|
||||
Pretty print an [`QEDProcessDescription`](@ref) (no newlines).
|
||||
|
||||
```jldoctest
|
||||
julia> using MetagraphOptimization
|
||||
@@ -15,7 +14,7 @@ julia> print(parse_process("kk->ep", QEDModel()))
|
||||
QED Process: 'kk->ep'
|
||||
```
|
||||
"""
|
||||
function show(io::IO, process::GenericQEDProcess)
|
||||
function show(io::IO, process::QEDProcessDescription)
|
||||
# types() gives the types in order (QED) instead of random like keys() would
|
||||
print(io, "QED Process: \'")
|
||||
for type in types(QEDModel())
|
||||
@@ -35,7 +34,7 @@ end
|
||||
|
||||
|
||||
"""
|
||||
String(process::GenericQEDProcess)
|
||||
String(process::QEDProcessDescription)
|
||||
|
||||
Create a short string suitable as a filename or similar, describing the given process.
|
||||
|
||||
@@ -65,9 +64,7 @@ function String(process::QEDProcessDescription)
|
||||
end
|
||||
return str
|
||||
end
|
||||
=#
|
||||
|
||||
#=
|
||||
"""
|
||||
show(io::IO, processInput::QEDProcessInput)
|
||||
|
||||
@@ -95,9 +92,7 @@ function show(io::IO, processInput::QEDProcessInput)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
=#
|
||||
|
||||
#=
|
||||
"""
|
||||
show(io::IO, particle::T) where {T <: QEDParticle}
|
||||
|
||||
@@ -107,14 +102,13 @@ function show(io::IO, particle::T) where {T <: QEDParticle}
|
||||
print(io, "$(String(typeof(particle))): $(particle.momentum)")
|
||||
return nothing
|
||||
end
|
||||
=#
|
||||
|
||||
"""
|
||||
show(io::IO, particle::FeynmanParticle)
|
||||
|
||||
Pretty print a [`FeynmanParticle`](@ref) (no newlines).
|
||||
"""
|
||||
show(io::IO, p::FeynmanParticle) = print(io, "$(String(p.particle))_$(String(particle_direction(p.particle)))_$(p.id)")
|
||||
show(io::IO, p::FeynmanParticle) = print(io, "$(String(p.particle))_$(String(direction(p.particle)))_$(p.id)")
|
||||
|
||||
"""
|
||||
show(io::IO, particle::FeynmanVertex)
|
@@ -1,10 +1,3 @@
|
||||
"""
|
||||
QEDModel <: AbstractPhysicsModel
|
||||
|
||||
Singleton definition for identification of the QED-Model.
|
||||
"""
|
||||
struct QEDModel <: AbstractPhysicsModel end
|
||||
|
||||
"""
|
||||
ComputeTaskQED_S1 <: AbstractComputeTask
|
||||
|
@@ -1,6 +1,6 @@
|
||||
|
||||
DataTaskNode(t::AbstractDataTask, name = "") =
|
||||
DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, name)
|
||||
DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, missing, name)
|
||||
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
|
||||
t, # task
|
||||
Vector{Node}(), # parents
|
||||
@@ -8,6 +8,7 @@ ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
|
||||
UUIDs.uuid1(rng[threadid()]), # id
|
||||
missing, # node reduction
|
||||
missing, # node split
|
||||
Vector{NodeFusion}(), # node fusions
|
||||
missing, # device
|
||||
)
|
||||
|
||||
|
@@ -30,6 +30,7 @@ Any node that transfers data and does no computation.
|
||||
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
|
||||
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.\\
|
||||
`.name`: The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\
|
||||
"""
|
||||
mutable struct DataTaskNode{TaskType <: AbstractDataTask} <: Node
|
||||
@@ -50,6 +51,9 @@ mutable struct DataTaskNode{TaskType <: AbstractDataTask} <: Node
|
||||
# the NodeSplit involving this node, if it exists
|
||||
nodeSplit::Union{Operation, Missing}
|
||||
|
||||
# the node fusion involving this node, if it exists
|
||||
nodeFusion::Union{Operation, Missing}
|
||||
|
||||
# for input nodes we need a name for the node to distinguish between them
|
||||
name::String
|
||||
end
|
||||
@@ -66,6 +70,7 @@ Any node that computes a result from inputs using an [`AbstractComputeTask`](@re
|
||||
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
|
||||
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeFusions`: A vector of this node's [`NodeFusion`](@ref)s. For a `ComputeTaskNode` there can be any number of these, unlike the [`DataTaskNode`](@ref)s.\\
|
||||
`.device`: The Device this node has been scheduled on by a [`Scheduler`](@ref).
|
||||
"""
|
||||
mutable struct ComputeTaskNode{TaskType <: AbstractComputeTask} <: Node
|
||||
@@ -77,6 +82,9 @@ mutable struct ComputeTaskNode{TaskType <: AbstractComputeTask} <: Node
|
||||
nodeReduction::Union{Operation, Missing}
|
||||
nodeSplit::Union{Operation, Missing}
|
||||
|
||||
# for ComputeTasks there can be multiple fusions, unlike the DataTasks
|
||||
nodeFusions::Vector{<:Operation}
|
||||
|
||||
# the device this node is assigned to execute on
|
||||
device::Union{AbstractDevice, Missing}
|
||||
end
|
||||
|
@@ -29,6 +29,17 @@ function is_valid_node(graph::DAG, node::Node)
|
||||
@assert is_valid(graph, node.nodeSplit)
|
||||
end=#
|
||||
|
||||
if !(typeof(task(node)) <: FusedComputeTask)
|
||||
# the remaining checks are only necessary for fused compute tasks
|
||||
return true
|
||||
end
|
||||
|
||||
# every child must be in some input of the task
|
||||
for child in node.children
|
||||
str = Symbol(to_var_name(child.id))
|
||||
@assert (str in task(node).t1_inputs) || (str in task(node).t2_inputs) "$str was not in any of the tasks' inputs\nt1_inputs: $(task(node).t1_inputs)\nt2_inputs: $(task(node).t2_inputs)"
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
@@ -42,6 +53,9 @@ This also calls [`is_valid_node(graph::DAG, node::Node)`](@ref).
|
||||
function is_valid(graph::DAG, node::ComputeTaskNode)
|
||||
@assert is_valid_node(graph, node)
|
||||
|
||||
#=for nf in node.nodeFusions
|
||||
@assert is_valid(graph, nf)
|
||||
end=#
|
||||
return true
|
||||
end
|
||||
|
||||
@@ -55,5 +69,8 @@ This also calls [`is_valid_node(graph::DAG, node::Node)`](@ref).
|
||||
function is_valid(graph::DAG, node::DataTaskNode)
|
||||
@assert is_valid_node(graph, node)
|
||||
|
||||
#=if !ismissing(node.nodeFusion)
|
||||
@assert is_valid(graph, node.nodeFusion)
|
||||
end=#
|
||||
return true
|
||||
end
|
||||
|
@@ -26,6 +26,21 @@ function apply_operation!(graph::DAG, operation::Operation)
|
||||
return error("Unknown operation type!")
|
||||
end
|
||||
|
||||
"""
|
||||
apply_operation!(graph::DAG, operation::NodeFusion)
|
||||
|
||||
Apply the given [`NodeFusion`](@ref) to the graph. Generic wrapper around [`node_fusion!`](@ref).
|
||||
|
||||
Return an [`AppliedNodeFusion`](@ref) object generated from the graph's [`Diff`](@ref).
|
||||
"""
|
||||
function apply_operation!(graph::DAG, operation::NodeFusion)
|
||||
diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3])
|
||||
|
||||
graph.properties += GraphProperties(diff)
|
||||
|
||||
return AppliedNodeFusion(operation, diff)
|
||||
end
|
||||
|
||||
"""
|
||||
apply_operation!(graph::DAG, operation::NodeReduction)
|
||||
|
||||
@@ -65,10 +80,20 @@ function revert_operation!(graph::DAG, operation::AppliedOperation)
|
||||
return error("Unknown operation type!")
|
||||
end
|
||||
|
||||
"""
|
||||
revert_operation!(graph::DAG, operation::AppliedNodeFusion)
|
||||
|
||||
Revert the applied node fusion on the graph. Return the original [`NodeFusion`](@ref) operation.
|
||||
"""
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeFusion)
|
||||
revert_diff!(graph, operation.diff)
|
||||
return operation.operation
|
||||
end
|
||||
|
||||
"""
|
||||
revert_operation!(graph::DAG, operation::AppliedNodeReduction)
|
||||
|
||||
Revert the applied node reduction on the graph. Return the original [`NodeReduction`](@ref) operation.
|
||||
Revert the applied node fusion on the graph. Return the original [`NodeReduction`](@ref) operation.
|
||||
"""
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeReduction)
|
||||
revert_diff!(graph, operation.diff)
|
||||
@@ -78,7 +103,7 @@ end
|
||||
"""
|
||||
revert_operation!(graph::DAG, operation::AppliedNodeSplit)
|
||||
|
||||
Revert the applied node split on the graph. Return the original [`NodeSplit`](@ref) operation.
|
||||
Revert the applied node fusion on the graph. Return the original [`NodeSplit`](@ref) operation.
|
||||
"""
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeSplit)
|
||||
revert_diff!(graph, operation.diff)
|
||||
@@ -107,11 +132,88 @@ function revert_diff!(graph::DAG, diff::Diff)
|
||||
insert_edge!(graph, edge.edge[1], edge.edge[2], track = false)
|
||||
end
|
||||
|
||||
for (node, t) in diff.updatedChildren
|
||||
# node must be fused compute task at this point
|
||||
@assert typeof(task(node)) <: FusedComputeTask
|
||||
|
||||
node.task = t
|
||||
end
|
||||
|
||||
graph.properties -= GraphProperties(diff)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
|
||||
Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph.
|
||||
|
||||
For details see [`NodeFusion`](@ref).
|
||||
"""
|
||||
function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
@assert is_valid_node_fusion_input(graph, n1, n2, n3)
|
||||
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
# save children and parents
|
||||
n1Children = copy(children(n1))
|
||||
n3Parents = copy(parents(n3))
|
||||
|
||||
n1Task = copy(task(n1))
|
||||
n3Task = copy(task(n3))
|
||||
|
||||
# assemble the input node vectors of n1 and n3 to save into the FusedComputeTask
|
||||
n1Inputs = Vector{Symbol}()
|
||||
for child in n1Children
|
||||
push!(n1Inputs, Symbol(to_var_name(child.id)))
|
||||
end
|
||||
|
||||
# remove the edges and nodes that will be replaced by the fused node
|
||||
remove_edge!(graph, n1, n2)
|
||||
remove_edge!(graph, n2, n3)
|
||||
remove_node!(graph, n1)
|
||||
remove_node!(graph, n2)
|
||||
|
||||
# get n3's children now so it automatically excludes n2
|
||||
n3Children = copy(children(n3))
|
||||
|
||||
n3Inputs = Vector{Symbol}()
|
||||
for child in n3Children
|
||||
push!(n3Inputs, Symbol(to_var_name(child.id)))
|
||||
end
|
||||
|
||||
remove_node!(graph, n3)
|
||||
|
||||
# create new node with the fused compute task
|
||||
newNode = ComputeTaskNode(FusedComputeTask(n1Task, n3Task, n1Inputs, Symbol(to_var_name(n2.id)), n3Inputs))
|
||||
insert_node!(graph, newNode)
|
||||
|
||||
for child in n1Children
|
||||
remove_edge!(graph, child, n1)
|
||||
insert_edge!(graph, child, newNode)
|
||||
end
|
||||
|
||||
for child in n3Children
|
||||
remove_edge!(graph, child, n3)
|
||||
if !(child in n1Children)
|
||||
insert_edge!(graph, child, newNode)
|
||||
end
|
||||
end
|
||||
|
||||
for parent in n3Parents
|
||||
remove_edge!(graph, n3, parent)
|
||||
insert_edge!(graph, newNode, parent)
|
||||
|
||||
# important! update the parent node's child names in case they are fused compute tasks
|
||||
# needed for compute generation so the fused compute task can correctly match inputs to its component tasks
|
||||
update_child!(graph, parent, Symbol(to_var_name(n3.id)), Symbol(to_var_name(newNode.id)))
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
end
|
||||
|
||||
"""
|
||||
node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
|
||||
@@ -163,6 +265,7 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
|
||||
# this has to be done for all parents, even the ones of n1 because they can be duplicate
|
||||
prevChild = newParentsChildNames[parent]
|
||||
update_child!(graph, parent, prevChild, Symbol(to_var_name(n1.id)))
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
@@ -204,6 +307,8 @@ function node_split!(
|
||||
for child in n1Children
|
||||
insert_edge!(graph, child, nCopy)
|
||||
end
|
||||
|
||||
update_child!(graph, parent, Symbol(to_var_name(n1.id)), Symbol(to_var_name(nCopy.id)))
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
|
@@ -1,5 +1,60 @@
|
||||
# These are functions for "cleaning" nodes, i.e. regenerating the possible operations for a node
|
||||
|
||||
"""
|
||||
find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
|
||||
Find node fusions involving the given data node. The function pushes the found [`NodeFusion`](@ref) (if any) everywhere it needs to be and returns nothing.
|
||||
|
||||
Does nothing if the node already has a node fusion set. Since it's a data node, only one node fusion can be possible with it.
|
||||
"""
|
||||
function find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
# if there is already a fusion here, skip to avoid duplicates
|
||||
if !ismissing(node.nodeFusion)
|
||||
return nothing
|
||||
end
|
||||
|
||||
if length(parents(node)) != 1 || length(children(node)) != 1
|
||||
return nothing
|
||||
end
|
||||
|
||||
child_node = first(children(node))
|
||||
parent_node = first(parents(node))
|
||||
|
||||
if !(child_node in graph) || !(parent_node in graph)
|
||||
error("Parents/Children that are not in the graph!!!")
|
||||
end
|
||||
|
||||
if length(parents(child_node)) != 1
|
||||
return nothing
|
||||
end
|
||||
|
||||
nf = NodeFusion((child_node, node, parent_node))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(child_node.nodeFusions, nf)
|
||||
node.nodeFusion = nf
|
||||
push!(parent_node.nodeFusions, nf)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
find_fusions!(graph::DAG, node::ComputeTaskNode)
|
||||
|
||||
Find node fusions involving the given compute node. The function pushes the found [`NodeFusion`](@ref)s (if any) everywhere they need to be and returns nothing.
|
||||
"""
|
||||
function find_fusions!(graph::DAG, node::ComputeTaskNode)
|
||||
# just find fusions in neighbouring DataTaskNodes
|
||||
for child in children(node)
|
||||
find_fusions!(graph, child)
|
||||
end
|
||||
|
||||
for parent in parents(node)
|
||||
find_fusions!(graph, parent)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
find_reductions!(graph::DAG, node::Node)
|
||||
|
||||
@@ -66,7 +121,7 @@ end
|
||||
"""
|
||||
clean_node!(graph::DAG, node::Node)
|
||||
|
||||
Sort this node's parent and child sets, then find reductions and splits involving it. Needs to be called after the node was changed in some way.
|
||||
Sort this node's parent and child sets, then find fusions, reductions and splits involving it. Needs to be called after the node was changed in some way.
|
||||
"""
|
||||
function clean_node!(
|
||||
graph::DAG,
|
||||
@@ -74,6 +129,7 @@ function clean_node!(
|
||||
) where {TaskType <: AbstractTask}
|
||||
sort_node!(node)
|
||||
|
||||
find_fusions!(graph, node)
|
||||
find_reductions!(graph, node)
|
||||
find_splits!(graph, node)
|
||||
|
||||
|
@@ -2,6 +2,26 @@
|
||||
|
||||
using Base.Threads
|
||||
|
||||
"""
|
||||
insert_operation!(nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock})
|
||||
|
||||
Insert the given node fusion into its input nodes' operation caches. For the compute nodes, locking via the given `locks` is employed to have safe multi-threading. For a large set of nodes, contention on the locks should be very small.
|
||||
"""
|
||||
function insert_operation!(nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock})
|
||||
n1 = nf.input[1]
|
||||
n2 = nf.input[2]
|
||||
n3 = nf.input[3]
|
||||
|
||||
lock(locks[n1]) do
|
||||
return push!(nf.input[1].nodeFusions, nf)
|
||||
end
|
||||
n2.nodeFusion = nf
|
||||
lock(locks[n3]) do
|
||||
return push!(nf.input[3].nodeFusions, nf)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
insert_operation!(nf::NodeReduction)
|
||||
|
||||
@@ -52,6 +72,41 @@ function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Ve
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
nf_insertion!(graph::DAG, operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}})
|
||||
|
||||
Insert the node fusions into the graph and the nodes' caches. Employs multithreading for speedup.
|
||||
"""
|
||||
function nf_insertion!(graph::DAG, operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}})
|
||||
total_len = 0
|
||||
for vec in nodeFusions
|
||||
total_len += length(vec)
|
||||
end
|
||||
sizehint!(operations.nodeFusions, total_len)
|
||||
|
||||
t = @task for vec in nodeFusions
|
||||
union!(operations.nodeFusions, Set(vec))
|
||||
end
|
||||
schedule(t)
|
||||
|
||||
locks = Dict{ComputeTaskNode, SpinLock}()
|
||||
for n in graph.nodes
|
||||
if (typeof(n) <: ComputeTaskNode)
|
||||
locks[n] = SpinLock()
|
||||
end
|
||||
end
|
||||
|
||||
@threads for vec in nodeFusions
|
||||
for op in vec
|
||||
insert_operation!(op, locks)
|
||||
end
|
||||
end
|
||||
|
||||
wait(t)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplits}})
|
||||
|
||||
@@ -88,6 +143,7 @@ Generate all possible operations on the graph. Used initially when the graph is
|
||||
Safely inserts all the found operations into the graph and its nodes.
|
||||
"""
|
||||
function generate_operations(graph::DAG)
|
||||
generatedFusions = [Vector{NodeFusion}() for _ in 1:nthreads()]
|
||||
generatedReductions = [Vector{NodeReduction}() for _ in 1:nthreads()]
|
||||
generatedSplits = [Vector{NodeSplit}() for _ in 1:nthreads()]
|
||||
|
||||
@@ -143,6 +199,31 @@ function generate_operations(graph::DAG)
|
||||
# remove duplicates
|
||||
nr_task = @spawn nr_insertion!(graph.possibleOperations, generatedReductions)
|
||||
|
||||
# --- find possible node fusions ---
|
||||
@threads for node in nodeArray
|
||||
if (typeof(node) <: DataTaskNode)
|
||||
if length(parents(node)) != 1
|
||||
# data node can only have a single parent
|
||||
continue
|
||||
end
|
||||
parent_node = first(parents(node))
|
||||
|
||||
if length(children(node)) != 1
|
||||
# this node is an entry node or has multiple children which should not be possible
|
||||
continue
|
||||
end
|
||||
child_node = first(children(node))
|
||||
if (length(parents(child_node)) != 1)
|
||||
continue
|
||||
end
|
||||
|
||||
push!(generatedFusions[threadid()], NodeFusion((child_node, node, parent_node)))
|
||||
end
|
||||
end
|
||||
|
||||
# launch thread for node fusion insertion
|
||||
nf_task = @spawn nf_insertion!(graph, graph.possibleOperations, generatedFusions)
|
||||
|
||||
# find possible node splits
|
||||
@threads for node in nodeArray
|
||||
if (can_split(node))
|
||||
@@ -156,6 +237,7 @@ function generate_operations(graph::DAG)
|
||||
empty!(graph.dirtyNodes)
|
||||
|
||||
wait(nr_task)
|
||||
wait(nf_task)
|
||||
wait(ns_task)
|
||||
|
||||
return nothing
|
||||
|
@@ -2,7 +2,8 @@ import Base.iterate
|
||||
|
||||
const _POSSIBLE_OPERATIONS_FIELDS = fieldnames(PossibleOperations)
|
||||
|
||||
_POIteratorStateType = NamedTuple{(:result, :state), Tuple{Union{NodeReduction, NodeSplit}, Tuple{Symbol, Int64}}}
|
||||
_POIteratorStateType =
|
||||
NamedTuple{(:result, :state), Tuple{Union{NodeFusion, NodeReduction, NodeSplit}, Tuple{Symbol, Int64}}}
|
||||
|
||||
@inline function iterate(possibleOperations::PossibleOperations)::Union{Nothing, _POIteratorStateType}
|
||||
for fieldname in _POSSIBLE_OPERATIONS_FIELDS
|
||||
|
@@ -4,6 +4,11 @@
|
||||
Print a string representation of the set of possible operations to io.
|
||||
"""
|
||||
function show(io::IO, ops::PossibleOperations)
|
||||
print(io, length(ops.nodeFusions))
|
||||
println(io, " Node Fusions: ")
|
||||
for nf in ops.nodeFusions
|
||||
println(io, " - ", nf)
|
||||
end
|
||||
print(io, length(ops.nodeReductions))
|
||||
println(io, " Node Reductions: ")
|
||||
for nr in ops.nodeReductions
|
||||
@@ -37,3 +42,17 @@ function show(io::IO, op::NodeSplit)
|
||||
print(io, "NS: ")
|
||||
return print(io, task(op.input))
|
||||
end
|
||||
|
||||
"""
|
||||
show(io::IO, op::NodeFusion)
|
||||
|
||||
Print a string representation of the node fusion to io.
|
||||
"""
|
||||
function show(io::IO, op::NodeFusion)
|
||||
print(io, "NF: ")
|
||||
print(io, task(op.input[1]))
|
||||
print(io, "->")
|
||||
print(io, task(op.input[2]))
|
||||
print(io, "->")
|
||||
return print(io, task(op.input[3]))
|
||||
end
|
||||
|
@@ -20,6 +20,45 @@ See also: [`revert_operation!`](@ref).
|
||||
"""
|
||||
abstract type AppliedOperation end
|
||||
|
||||
"""
|
||||
NodeFusion <: Operation
|
||||
|
||||
The NodeFusion operation. Represents the fusing of a chain of compute node -> data node -> compute node.
|
||||
|
||||
After the node fusion is applied, the graph has 2 fewer nodes and edges, and a new [`FusedComputeTask`](@ref) with the two input compute nodes as parts.
|
||||
|
||||
# Requirements for successful application
|
||||
|
||||
A chain of (n1, n2, n3) can be fused if:
|
||||
- All nodes are in the graph.
|
||||
- (n1, n2) is an edge in the graph.
|
||||
- (n2, n3) is an edge in the graph.
|
||||
- n2 has exactly one parent (n3) and exactly one child (n1).
|
||||
- n1 has exactly one parent (n2).
|
||||
|
||||
[`is_valid_node_fusion_input`](@ref) can be used to `@assert` these requirements.
|
||||
|
||||
See also: [`can_fuse`](@ref)
|
||||
"""
|
||||
struct NodeFusion{TaskType1 <: AbstractComputeTask, TaskType2 <: AbstractDataTask, TaskType3 <: AbstractComputeTask} <:
|
||||
Operation
|
||||
input::Tuple{ComputeTaskNode{TaskType1}, DataTaskNode{TaskType2}, ComputeTaskNode{TaskType3}}
|
||||
end
|
||||
|
||||
"""
|
||||
AppliedNodeFusion <: AppliedOperation
|
||||
|
||||
The applied version of the [`NodeFusion`](@ref).
|
||||
"""
|
||||
struct AppliedNodeFusion{
|
||||
TaskType1 <: AbstractComputeTask,
|
||||
TaskType2 <: AbstractDataTask,
|
||||
TaskType3 <: AbstractComputeTask,
|
||||
} <: AppliedOperation
|
||||
operation::NodeFusion{TaskType1, TaskType2, TaskType3}
|
||||
diff::Diff
|
||||
end
|
||||
|
||||
"""
|
||||
NodeReduction <: Operation
|
||||
|
||||
|
@@ -4,7 +4,7 @@
|
||||
Return whether `operations` is empty, i.e. all of its fields are empty.
|
||||
"""
|
||||
function isempty(operations::PossibleOperations)
|
||||
return isempty(operations.nodeReductions) && isempty(operations.nodeSplits)
|
||||
return isempty(operations.nodeFusions) && isempty(operations.nodeReductions) && isempty(operations.nodeSplits)
|
||||
end
|
||||
|
||||
"""
|
||||
@@ -13,7 +13,21 @@ end
|
||||
Return a named tuple with the number of each of the operation types as a named tuple. The fields are named the same as the [`PossibleOperations`](@ref)'.
|
||||
"""
|
||||
function length(operations::PossibleOperations)
|
||||
return (nodeReductions = length(operations.nodeReductions), nodeSplits = length(operations.nodeSplits))
|
||||
return (
|
||||
nodeFusions = length(operations.nodeFusions),
|
||||
nodeReductions = length(operations.nodeReductions),
|
||||
nodeSplits = length(operations.nodeSplits),
|
||||
)
|
||||
end
|
||||
|
||||
"""
|
||||
delete!(operations::PossibleOperations, op::NodeFusion)
|
||||
|
||||
Delete the given node fusion from the possible operations.
|
||||
"""
|
||||
function delete!(operations::PossibleOperations, op::NodeFusion)
|
||||
delete!(operations.nodeFusions, op)
|
||||
return operations
|
||||
end
|
||||
|
||||
"""
|
||||
@@ -36,6 +50,24 @@ function delete!(operations::PossibleOperations, op::NodeSplit)
|
||||
return operations
|
||||
end
|
||||
|
||||
"""
|
||||
can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
|
||||
Return whether the given nodes can be fused. See [`NodeFusion`](@ref) for the requirements.
|
||||
"""
|
||||
function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
if !is_child(n1, n2) || !is_child(n2, n3)
|
||||
# the checks are redundant but maybe a good sanity check
|
||||
return false
|
||||
end
|
||||
|
||||
if length(parents(n2)) != 1 || length(children(n2)) != 1 || length(parents(n1)) != 1
|
||||
return false
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
can_reduce(n1::Node, n2::Node)
|
||||
|
||||
@@ -104,6 +136,23 @@ function ==(op1::Operation, op2::Operation)
|
||||
return false
|
||||
end
|
||||
|
||||
"""
|
||||
==(op1::NodeFusion, op2::NodeFusion)
|
||||
|
||||
Equality comparison between two node fusions. Two node fusions are considered equal if they have the same inputs.
|
||||
"""
|
||||
function ==(
|
||||
op1::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2},
|
||||
op2::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2},
|
||||
) where {
|
||||
ComputeTaskType1 <: AbstractComputeTask,
|
||||
DataTaskType <: AbstractDataTask,
|
||||
ComputeTaskType2 <: AbstractComputeTask,
|
||||
}
|
||||
# there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same
|
||||
return op1.input[2] == op2.input[2]
|
||||
end
|
||||
|
||||
"""
|
||||
==(op1::NodeReduction, op2::NodeReduction)
|
||||
|
||||
|
@@ -2,6 +2,43 @@
|
||||
# should be called with @assert
|
||||
# the functions throw their own errors though, to still have helpful error messages
|
||||
|
||||
"""
|
||||
is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
|
||||
Assert for a gven node fusion input whether the nodes can be fused. For the requirements of a node fusion see [`NodeFusion`](@ref).
|
||||
|
||||
Intended for use with `@assert` or `@test`.
|
||||
"""
|
||||
function is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
if !(n1 in graph) || !(n2 in graph) || !(n3 in graph)
|
||||
throw(AssertionError("[Node Fusion] The given nodes are not part of the given graph"))
|
||||
end
|
||||
|
||||
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
|
||||
throw(
|
||||
AssertionError(
|
||||
"[Node Fusion] The given nodes are not connected by edges which is required for node fusion",
|
||||
),
|
||||
)
|
||||
end
|
||||
|
||||
if length(n2.parents) > 1
|
||||
throw(AssertionError("[Node Fusion] The given data node has more than one parent"))
|
||||
end
|
||||
if length(n2.children) > 1
|
||||
throw(AssertionError("[Node Fusion] The given data node has more than one child"))
|
||||
end
|
||||
if length(n1.parents) > 1
|
||||
throw(AssertionError("[Node Fusion] The given n1 has more than one parent"))
|
||||
end
|
||||
|
||||
@assert is_valid(graph, n1)
|
||||
@assert is_valid(graph, n2)
|
||||
@assert is_valid(graph, n3)
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node})
|
||||
|
||||
@@ -94,3 +131,16 @@ function is_valid(graph::DAG, ns::NodeSplit)
|
||||
#@assert ns in graph.possibleOperations.nodeSplits "NodeSplit is not part of the graph's possible operations!"
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
is_valid(graph::DAG, nr::NodeFusion)
|
||||
|
||||
Assert for a given [`NodeFusion`](@ref) whether it is a valid operation in the graph.
|
||||
|
||||
Intended for use with `@assert` or `@test`.
|
||||
"""
|
||||
function is_valid(graph::DAG, nf::NodeFusion)
|
||||
@assert is_valid_node_fusion_input(graph, nf.input[1], nf.input[2], nf.input[3])
|
||||
#@assert nf in graph.possibleOperations.nodeFusions "NodeFusion is not part of the graph's possible operations!"
|
||||
return true
|
||||
end
|
||||
|
36
src/optimization/fuse.jl
Normal file
36
src/optimization/fuse.jl
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
FusionOptimizer
|
||||
|
||||
An optimizer that simply applies an available [`NodeFusion`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeFusion`](@ref)s in the graph.
|
||||
|
||||
See also: [`SplitOptimizer`](@ref), [`ReductionOptimizer`](@ref)
|
||||
"""
|
||||
struct FusionOptimizer <: AbstractOptimizer end
|
||||
|
||||
function optimize_step!(optimizer::FusionOptimizer, graph::DAG)
|
||||
# generate all options
|
||||
operations = get_operations(graph)
|
||||
if fixpoint_reached(optimizer, graph)
|
||||
return false
|
||||
end
|
||||
|
||||
push_operation!(graph, first(operations.nodeFusions))
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function fixpoint_reached(optimizer::FusionOptimizer, graph::DAG)
|
||||
operations = get_operations(graph)
|
||||
return isempty(operations.nodeFusions)
|
||||
end
|
||||
|
||||
function optimize_to_fixpoint!(optimizer::FusionOptimizer, graph::DAG)
|
||||
while !fixpoint_reached(optimizer, graph)
|
||||
optimize_step!(optimizer, graph)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function String(::FusionOptimizer)
|
||||
return "fusion_optimizer"
|
||||
end
|
@@ -26,12 +26,16 @@ function optimize_step!(optimizer::RandomWalkOptimizer, graph::DAG)
|
||||
if rand(r, Bool)
|
||||
# push
|
||||
|
||||
# choose one of split/reduce
|
||||
option = rand(r, 1:2)
|
||||
if option == 1 && !isempty(operations.nodeReductions)
|
||||
# choose one of fuse/split/reduce
|
||||
# TODO refactor fusions so they actually work
|
||||
option = rand(r, 2:3)
|
||||
if option == 1 && !isempty(operations.nodeFusions)
|
||||
push_operation!(graph, rand(r, collect(operations.nodeFusions)))
|
||||
return true
|
||||
elseif option == 2 && !isempty(operations.nodeReductions)
|
||||
push_operation!(graph, rand(r, collect(operations.nodeReductions)))
|
||||
return true
|
||||
elseif option == 2 && !isempty(operations.nodeSplits)
|
||||
elseif option == 3 && !isempty(operations.nodeSplits)
|
||||
push_operation!(graph, rand(r, collect(operations.nodeSplits)))
|
||||
return true
|
||||
end
|
||||
|
@@ -3,7 +3,7 @@
|
||||
|
||||
An optimizer that simply applies an available [`NodeReduction`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeReduction`](@ref)s in the graph.
|
||||
|
||||
See also: [`SplitOptimizer`](@ref)
|
||||
See also: [`FusionOptimizer`](@ref), [`SplitOptimizer`](@ref)
|
||||
"""
|
||||
struct ReductionOptimizer <: AbstractOptimizer end
|
||||
|
||||
|
@@ -3,7 +3,7 @@
|
||||
|
||||
An optimizer that simply applies an available [`NodeSplit`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeSplit`](@ref)s in the graph.
|
||||
|
||||
See also: [`ReductionOptimizer`](@ref)
|
||||
See also: [`FusionOptimizer`](@ref), [`ReductionOptimizer`](@ref)
|
||||
"""
|
||||
struct SplitOptimizer <: AbstractOptimizer end
|
||||
|
||||
|
@@ -4,7 +4,7 @@
|
||||
|
||||
A greedy implementation of a scheduler, creating a topological ordering of nodes and naively balancing them onto the different devices.
|
||||
"""
|
||||
struct GreedyScheduler <: AbstractScheduler end
|
||||
struct GreedyScheduler end
|
||||
|
||||
function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine)
|
||||
nodeQueue = PriorityQueue{Node, Int}()
|
||||
|
@@ -1,10 +1,10 @@
|
||||
|
||||
"""
|
||||
AbstractScheduler
|
||||
Scheduler
|
||||
|
||||
Abstract base type for scheduler implementations. The scheduler is used to assign each node to a device and create a topological ordering of tasks.
|
||||
"""
|
||||
abstract type AbstractScheduler end
|
||||
abstract type Scheduler end
|
||||
|
||||
"""
|
||||
schedule_dag(::Scheduler, ::DAG, ::Machine)
|
||||
|
@@ -5,11 +5,10 @@ using StaticArrays
|
||||
|
||||
Type representing a function call with `N` parameters. Contains the function to call, argument symbols, the return symbol and the device to execute on.
|
||||
"""
|
||||
struct FunctionCall{VectorType <: AbstractVector, N}
|
||||
struct FunctionCall{VectorType <: AbstractVector, M}
|
||||
func::Function
|
||||
# TODO: this should be a tuple
|
||||
value_arguments::SVector{N, Any} # value arguments for the function call, will be prepended to the other arguments
|
||||
arguments::VectorType # symbols of the inputs to the function call
|
||||
arguments::VectorType
|
||||
additional_arguments::SVector{M, Any} # additional arguments (as values) for the function call, will be prepended to the other arguments
|
||||
return_symbol::Symbol
|
||||
device::AbstractDevice
|
||||
end
|
||||
|
@@ -1,20 +1,38 @@
|
||||
using StaticArrays
|
||||
|
||||
"""
|
||||
compute(t::FusedComputeTask, data)
|
||||
|
||||
Compute a [`FusedComputeTask`](@ref). This simply asserts false and should not be called. Fused Compute Tasks generate their expressions directly through the other tasks instead.
|
||||
"""
|
||||
function compute(t::FusedComputeTask, data...)
|
||||
inter = compute(t.first_task)
|
||||
return compute(t.second_task, inter, data2...)
|
||||
end
|
||||
|
||||
"""
|
||||
get_function_call(n::Node)
|
||||
get_function_call(t::AbstractTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol)
|
||||
|
||||
For a node or a task together with necessary information, return a vector of [`FunctionCall`](@ref)s for the computation of the node or task.
|
||||
|
||||
For ordinary compute or data tasks the vector will contain exactly one element.
|
||||
For ordinary compute or data tasks the vector will contain exactly one element, for a [`FusedComputeTask`](@ref) there can be any number of tasks greater 1.
|
||||
"""
|
||||
function get_function_call(t::FusedComputeTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol)
|
||||
# sort out the symbols to the correct tasks
|
||||
return [
|
||||
get_function_call(t.first_task, device, t.t1_inputs, t.t1_output)...,
|
||||
get_function_call(t.second_task, device, [t.t2_inputs..., t.t1_output], outSymbol)...,
|
||||
]
|
||||
end
|
||||
|
||||
function get_function_call(
|
||||
t::CompTask,
|
||||
device::AbstractDevice,
|
||||
inSymbols::AbstractVector,
|
||||
outSymbol::Symbol,
|
||||
) where {CompTask <: AbstractComputeTask}
|
||||
return [FunctionCall(compute, SVector{1, Any}(t), inSymbols, outSymbol, device)]
|
||||
return [FunctionCall(compute, inSymbols, SVector{1, Any}(t), outSymbol, device)]
|
||||
end
|
||||
|
||||
function get_function_call(node::ComputeTaskNode)
|
||||
@@ -46,8 +64,8 @@ function get_function_call(node::DataTaskNode)
|
||||
return [
|
||||
FunctionCall(
|
||||
unpack_identity,
|
||||
SVector{0, Any}(),
|
||||
SVector{1, Symbol}(Symbol(to_var_name(first(children(node)).id))),
|
||||
SVector{0, Any}(),
|
||||
Symbol(to_var_name(node.id)),
|
||||
first(children(node)).device,
|
||||
),
|
||||
@@ -59,8 +77,8 @@ function get_init_function_call(node::DataTaskNode, device::AbstractDevice)
|
||||
|
||||
return FunctionCall(
|
||||
unpack_identity,
|
||||
SVector{0, Any}(),
|
||||
SVector{1, Symbol}(Symbol("$(to_var_name(node.id))_in")),
|
||||
SVector{0, Any}(),
|
||||
Symbol(to_var_name(node.id)),
|
||||
device,
|
||||
)
|
||||
|
@@ -11,3 +11,22 @@ copy(t::AbstractDataTask) = error("Need to implement copying for your data tasks
|
||||
Return a copy of the given compute task.
|
||||
"""
|
||||
copy(t::AbstractComputeTask) = typeof(t)()
|
||||
|
||||
"""
|
||||
copy(t::FusedComputeTask)
|
||||
|
||||
Return a copy of th egiven [`FusedComputeTask`](@ref).
|
||||
"""
|
||||
function copy(t::FusedComputeTask)
|
||||
return FusedComputeTask(copy(t.first_task), copy(t.second_task), copy(t.t1_inputs), t.t1_output, copy(t.t2_inputs))
|
||||
end
|
||||
|
||||
function FusedComputeTask(
|
||||
T1::Type{<:AbstractComputeTask},
|
||||
T2::Type{<:AbstractComputeTask},
|
||||
t1_inputs::Vector{String},
|
||||
t1_output::String,
|
||||
t2_inputs::Vector{String},
|
||||
)
|
||||
return FusedComputeTask(T1(), T2(), t1_inputs, t1_output, t2_inputs)
|
||||
end
|
||||
|
@@ -3,21 +3,28 @@
|
||||
|
||||
Fallback implementation of the compute function of a compute task, throwing an error.
|
||||
"""
|
||||
function compute end
|
||||
function compute(t::AbstractTask, data...)
|
||||
return error("Need to implement compute()")
|
||||
end
|
||||
|
||||
"""
|
||||
compute_effort(t::AbstractTask)
|
||||
|
||||
Fallback implementation of the compute effort of a task, throwing an error.
|
||||
"""
|
||||
function compute_effort end
|
||||
function compute_effort(t::AbstractTask)::Float64
|
||||
# default implementation using compute
|
||||
return error("Need to implement compute_effort()")
|
||||
end
|
||||
|
||||
"""
|
||||
data(t::AbstractTask)
|
||||
|
||||
Fallback implementation of the data of a task, throwing an error.
|
||||
"""
|
||||
function data end
|
||||
function data(t::AbstractTask)::Float64
|
||||
return error("Need to implement data()")
|
||||
end
|
||||
|
||||
"""
|
||||
compute_effort(t::AbstractDataTask)
|
||||
@@ -47,9 +54,34 @@ Return the number of children of a data task (always 1).
|
||||
"""
|
||||
children(::DataTask) = 1
|
||||
|
||||
"""
|
||||
children(t::FusedComputeTask)
|
||||
|
||||
Return the number of children of a FusedComputeTask.
|
||||
"""
|
||||
function children(t::FusedComputeTask)
|
||||
return length(union(Set(t.t1_inputs), Set(t.t2_inputs)))
|
||||
end
|
||||
|
||||
"""
|
||||
data(t::AbstractComputeTask)
|
||||
|
||||
Return the data of a compute task, always zero, regardless of the specific task.
|
||||
"""
|
||||
data(t::AbstractComputeTask)::Float64 = 0.0
|
||||
|
||||
"""
|
||||
compute_effort(t::FusedComputeTask)
|
||||
|
||||
Return the compute effort of a fused compute task.
|
||||
"""
|
||||
function compute_effort(t::FusedComputeTask)::Float64
|
||||
return compute_effort(t.first_task) + compute_effort(t.second_task)
|
||||
end
|
||||
|
||||
"""
|
||||
get_types(::FusedComputeTask{T1, T2})
|
||||
|
||||
Return a tuple of a the fused compute task's components' types.
|
||||
"""
|
||||
get_types(t::FusedComputeTask) = (typeof(t.first_task), typeof(t.second_task))
|
||||
|
@@ -27,3 +27,21 @@ Task representing a specific data transfer.
|
||||
struct DataTask <: AbstractDataTask
|
||||
data::Float64
|
||||
end
|
||||
|
||||
"""
|
||||
FusedComputeTask{T1 <: AbstractComputeTask, T2 <: AbstractComputeTask} <: AbstractComputeTask
|
||||
|
||||
A fused compute task made up of the computation of first `T1` and then `T2`.
|
||||
|
||||
Also see: [`get_types`](@ref).
|
||||
"""
|
||||
struct FusedComputeTask <: AbstractComputeTask
|
||||
first_task::AbstractComputeTask
|
||||
second_task::AbstractComputeTask
|
||||
# the names of the inputs for T1
|
||||
t1_inputs::Vector{Symbol}
|
||||
# output name of T1
|
||||
t1_output::Symbol
|
||||
# t2_inputs doesn't include the output of t1, that's implicit
|
||||
t2_inputs::Vector{Symbol}
|
||||
end
|
||||
|
@@ -1,6 +1,3 @@
|
||||
using Roots
|
||||
using ForwardDiff
|
||||
|
||||
"""
|
||||
noop()
|
||||
|
||||
@@ -74,6 +71,9 @@ function mem(graph::DAG)
|
||||
size += sizeof(graph.operationsToApply)
|
||||
|
||||
size += sizeof(graph.possibleOperations)
|
||||
for op in graph.possibleOperations.nodeFusions
|
||||
size += mem(op)
|
||||
end
|
||||
for op in graph.possibleOperations.nodeReductions
|
||||
size += mem(op)
|
||||
end
|
||||
@@ -249,6 +249,7 @@ end
|
||||
|
||||
first_derivative(func) = x -> ForwardDiff.derivative(func, float(x))
|
||||
|
||||
|
||||
function generate_physical_massive_moms(rng, ss, masses; x0 = 0.1)
|
||||
n = length(masses)
|
||||
massless_moms = generate_physical_massless_moms(rng, ss, n)
|
||||
|
10
test/Project.toml
Normal file
10
test/Project.toml
Normal file
@@ -0,0 +1,10 @@
|
||||
[deps]
|
||||
AccurateArithmetic = "22286c92-06ac-501d-9306-4abd417d9753"
|
||||
QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
|
||||
QEDprocesses = "46de9c38-1bb3-4547-a1ec-da24d767fdad"
|
||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
|
||||
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
|
||||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
@@ -3,15 +3,48 @@ using Random
|
||||
|
||||
RNG = Random.MersenneTwister(321)
|
||||
|
||||
function test_known_graph(name::String, n)
|
||||
function test_known_graph(name::String, n, fusion_test = true)
|
||||
@testset "Test $name Graph ($n)" begin
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "$name.txt"), ABCModel())
|
||||
props = get_properties(graph)
|
||||
|
||||
if (fusion_test)
|
||||
test_node_fusion(graph)
|
||||
end
|
||||
test_random_walk(RNG, graph, n)
|
||||
end
|
||||
end
|
||||
|
||||
function test_node_fusion(g::DAG)
|
||||
@testset "Test Node Fusion" begin
|
||||
props = get_properties(g)
|
||||
|
||||
options = get_operations(g)
|
||||
|
||||
nodes_number = length(g.nodes)
|
||||
data = props.data
|
||||
compute_effort = props.computeEffort
|
||||
|
||||
while !isempty(options.nodeFusions)
|
||||
fusion = first(options.nodeFusions)
|
||||
|
||||
@test typeof(fusion) <: NodeFusion
|
||||
|
||||
push_operation!(g, fusion)
|
||||
|
||||
props = get_properties(g)
|
||||
@test props.data < data
|
||||
@test props.computeEffort == compute_effort
|
||||
|
||||
nodes_number = length(g.nodes)
|
||||
data = props.data
|
||||
compute_effort = props.computeEffort
|
||||
|
||||
options = get_operations(g)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function test_random_walk(RNG, g::DAG, n::Int64)
|
||||
@testset "Test Random Walk ($n)" begin
|
||||
# the purpose here is to do "random" operations and reverse them again and validate that the graph stays the same and doesn't diverge
|
||||
@@ -27,11 +60,13 @@ function test_random_walk(RNG, g::DAG, n::Int64)
|
||||
# push
|
||||
opt = get_operations(g)
|
||||
|
||||
# choose one of split/reduce
|
||||
option = rand(RNG, 1:2)
|
||||
if option == 1 && !isempty(opt.nodeReductions)
|
||||
# choose one of fuse/split/reduce
|
||||
option = rand(RNG, 1:3)
|
||||
if option == 1 && !isempty(opt.nodeFusions)
|
||||
push_operation!(g, rand(RNG, collect(opt.nodeFusions)))
|
||||
elseif option == 2 && !isempty(opt.nodeReductions)
|
||||
push_operation!(g, rand(RNG, collect(opt.nodeReductions)))
|
||||
elseif option == 2 && !isempty(opt.nodeSplits)
|
||||
elseif option == 3 && !isempty(opt.nodeSplits)
|
||||
push_operation!(g, rand(RNG, collect(opt.nodeSplits)))
|
||||
else
|
||||
i = i - 1
|
||||
@@ -56,4 +91,4 @@ end
|
||||
|
||||
test_known_graph("AB->AB", 10000)
|
||||
test_known_graph("AB->ABBB", 10000)
|
||||
test_known_graph("AB->ABBBBB", 1000)
|
||||
test_known_graph("AB->ABBBBB", 1000, false)
|
||||
|
@@ -61,14 +61,17 @@ insert_edge!(graph, CD, C1C, track = false)
|
||||
|
||||
opt = get_operations(graph)
|
||||
|
||||
@test length(opt) == (nodeReductions = 1, nodeSplits = 1)
|
||||
@test length(opt) == (nodeFusions = 6, nodeReductions = 1, nodeSplits = 1)
|
||||
|
||||
#println("Initial State:\n", opt)
|
||||
|
||||
nr = first(opt.nodeReductions)
|
||||
@test Set(nr.input) == Set([B1C_1, B1C_2])
|
||||
push_operation!(graph, nr)
|
||||
opt = get_operations(graph)
|
||||
|
||||
@test length(opt) == (nodeReductions = 1, nodeSplits = 1)
|
||||
@test length(opt) == (nodeFusions = 4, nodeReductions = 1, nodeSplits = 1)
|
||||
#println("After 1 Node Reduction:\n", opt)
|
||||
|
||||
nr = first(opt.nodeReductions)
|
||||
@test Set(nr.input) == Set([B1D_1, B1D_2])
|
||||
@@ -77,16 +80,19 @@ opt = get_operations(graph)
|
||||
|
||||
@test is_valid(graph)
|
||||
|
||||
@test length(opt) == (nodeReductions = 0, nodeSplits = 1)
|
||||
@test length(opt) == (nodeFusions = 4, nodeReductions = 0, nodeSplits = 1)
|
||||
#println("After 2 Node Reductions:\n", opt)
|
||||
|
||||
pop_operation!(graph)
|
||||
|
||||
opt = get_operations(graph)
|
||||
@test length(opt) == (nodeReductions = 1, nodeSplits = 1)
|
||||
@test length(opt) == (nodeFusions = 4, nodeReductions = 1, nodeSplits = 1)
|
||||
#println("After reverting the second Node Reduction:\n", opt)
|
||||
|
||||
reset_graph!(graph)
|
||||
|
||||
opt = get_operations(graph)
|
||||
@test length(opt) == (nodeReductions = 1, nodeSplits = 1)
|
||||
@test length(opt) == (nodeFusions = 6, nodeReductions = 1, nodeSplits = 1)
|
||||
#println("After reverting to the initial state:\n", opt)
|
||||
|
||||
@test is_valid(graph)
|
||||
|
@@ -1,5 +1,5 @@
|
||||
using SafeTestsets
|
||||
#=
|
||||
|
||||
@safetestset "Utility Unit Tests " begin
|
||||
include("unit_tests_utility.jl")
|
||||
end
|
||||
@@ -30,7 +30,6 @@ end
|
||||
@safetestset "Graph Unit Tests " begin
|
||||
include("unit_tests_graph.jl")
|
||||
end
|
||||
=#
|
||||
@safetestset "Execution Unit Tests " begin
|
||||
include("unit_tests_execution.jl")
|
||||
end
|
||||
|
@@ -1,5 +1,5 @@
|
||||
using MetagraphOptimization
|
||||
using QEDcore
|
||||
using QEDbase
|
||||
|
||||
import MetagraphOptimization.interaction_result
|
||||
|
||||
|
@@ -1,5 +1,16 @@
|
||||
using MetagraphOptimization
|
||||
|
||||
function test_op_specific(estimator, graph, nf::NodeFusion)
|
||||
estimate = operation_effect(estimator, graph, nf)
|
||||
data_reduce = data(nf.input[2].task)
|
||||
|
||||
@test isapprox(estimate.data, -data_reduce)
|
||||
@test isapprox(estimate.computeEffort, 0; atol = eps(Float64))
|
||||
@test isapprox(estimate.computeIntensity, 0; atol = eps(Float64))
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function test_op_specific(estimator, graph, nr::NodeReduction)
|
||||
estimate = operation_effect(estimator, graph, nr)
|
||||
|
||||
@@ -63,9 +74,13 @@ end
|
||||
|
||||
@testset "Operation Cost" begin
|
||||
ops = get_operations(graph)
|
||||
nfs = copy(ops.nodeFusions)
|
||||
nrs = copy(ops.nodeReductions)
|
||||
nss = copy(ops.nodeSplits)
|
||||
|
||||
for nf in nfs
|
||||
test_op(estimator, graph, nf)
|
||||
end
|
||||
for nr in nrs
|
||||
test_op(estimator, graph, nr)
|
||||
end
|
||||
|
@@ -1,5 +1,5 @@
|
||||
using MetagraphOptimization
|
||||
using QEDcore
|
||||
using QEDbase
|
||||
using AccurateArithmetic
|
||||
using Random
|
||||
using UUIDs
|
||||
@@ -63,14 +63,8 @@ machine = Machine(
|
||||
)
|
||||
|
||||
process_2_2 = ABCProcessDescription(
|
||||
Dict{Type, Int64}(
|
||||
ParticleStateful{Incoming, ParticleA, SFourMomentum} => 1,
|
||||
ParticleStateful{Incoming, ParticleB, SFourMomentum} => 1,
|
||||
),
|
||||
Dict{Type, Int64}(
|
||||
ParticleStateful{Outgoing, ParticleA, SFourMomentum} => 1,
|
||||
ParticleStateful{Outgoing, ParticleB, SFourMomentum} => 1,
|
||||
),
|
||||
Dict{Type, Int64}(ParticleA => 1, ParticleB => 1),
|
||||
Dict{Type, Int64}(ParticleA => 1, ParticleB => 1),
|
||||
)
|
||||
|
||||
particles_2_2 = ABCProcessInput(
|
||||
@@ -112,14 +106,8 @@ end
|
||||
end
|
||||
|
||||
process_2_4 = ABCProcessDescription(
|
||||
Dict{Type, Int64}(
|
||||
ParticleStateful{Incoming, ParticleA, SFourMomentum} => 1,
|
||||
ParticleStateful{Incoming, ParticleB, SFourMomentum} => 1,
|
||||
),
|
||||
Dict{Type, Int64}(
|
||||
ParticleStateful{Outgoing, ParticleA, SFourMomentum} => 1,
|
||||
ParticleStateful{Outgoing, ParticleB, SFourMomentum} => 3,
|
||||
),
|
||||
Dict{Type, Int64}(ParticleA => 1, ParticleB => 1),
|
||||
Dict{Type, Int64}(ParticleA => 1, ParticleB => 3),
|
||||
)
|
||||
particles_2_4 = gen_process_input(process_2_4)
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"), ABCModel())
|
||||
@@ -148,6 +136,105 @@ TODO: fix precision(?) issues
|
||||
end
|
||||
=#
|
||||
|
||||
@testset "AB->AB large sum fusion" begin
|
||||
for _ in 1:20
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
|
||||
|
||||
# push a fusion with the sum node
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[3].task, ComputeTaskABC_Sum)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
# push two more fusions with the fused node
|
||||
for _ in 1:15
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[3].task, FusedComputeTask)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# try execute
|
||||
@test is_valid(graph)
|
||||
expected_result = ground_truth_graph_result(particles_2_2)
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
@testset "AB->AB large sum fusion" begin
|
||||
for _ in 1:20
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
|
||||
|
||||
# push a fusion with the sum node
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[3].task, ComputeTaskABC_Sum)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
# push two more fusions with the fused node
|
||||
for _ in 1:15
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[3].task, FusedComputeTask)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# try execute
|
||||
@test is_valid(graph)
|
||||
expected_result = ground_truth_graph_result(particles_2_2)
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
|
||||
end
|
||||
end
|
||||
|
||||
@testset "AB->AB fusion edge case" begin
|
||||
for _ in 1:20
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
|
||||
|
||||
# push two fusions with ComputeTaskABC_V
|
||||
for _ in 1:2
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[1].task, ComputeTaskABC_V)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# push fusions until the end
|
||||
cont = true
|
||||
while cont
|
||||
cont = false
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[1].task, FusedComputeTask)
|
||||
push_operation!(graph, fusion)
|
||||
cont = true
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# try execute
|
||||
@test is_valid(graph)
|
||||
expected_result = ground_truth_graph_result(particles_2_2)
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
|
||||
end
|
||||
end
|
||||
|
||||
@testset "$(process) after random walk" for process in ["ke->ke", "ke->kke", "ke->kkke"]
|
||||
process = parse_process("ke->kkke", QEDModel())
|
||||
inputs = [gen_process_input(process) for _ in 1:100]
|
||||
|
@@ -13,7 +13,7 @@ graph = MetagraphOptimization.DAG()
|
||||
@test length(graph.operationsToApply) == 0
|
||||
@test length(graph.dirtyNodes) == 0
|
||||
@test length(graph.diff) == (addedNodes = 0, removedNodes = 0, addedEdges = 0, removedEdges = 0)
|
||||
@test length(get_operations(graph)) == (nodeReductions = 0, nodeSplits = 0)
|
||||
@test length(get_operations(graph)) == (nodeFusions = 0, nodeReductions = 0, nodeSplits = 0)
|
||||
|
||||
# s to output (exit node)
|
||||
d_exit = insert_node!(graph, make_node(DataTask(10)), track = false)
|
||||
@@ -133,10 +133,13 @@ insert_edge!(graph, s0, d_exit, track = false)
|
||||
@test length(siblings(s0)) == 1
|
||||
|
||||
operations = get_operations(graph)
|
||||
@test length(operations) == (nodeReductions = 0, nodeSplits = 0)
|
||||
@test length(operations) == (nodeFusions = 10, nodeReductions = 0, nodeSplits = 0)
|
||||
@test length(graph.dirtyNodes) == 0
|
||||
|
||||
@test sum(length(operations)) == 10
|
||||
|
||||
@test operations == get_operations(graph)
|
||||
nf = first(operations.nodeFusions)
|
||||
|
||||
properties = get_properties(graph)
|
||||
@test properties.computeEffort == 28
|
||||
@@ -145,19 +148,54 @@ properties = get_properties(graph)
|
||||
@test properties.noNodes == 26
|
||||
@test properties.noEdges == 25
|
||||
|
||||
push_operation!(graph, nf)
|
||||
# **does not immediately apply the operation**
|
||||
|
||||
@test length(graph.nodes) == 26
|
||||
@test length(graph.appliedOperations) == 0
|
||||
@test length(graph.operationsToApply) == 1
|
||||
@test first(graph.operationsToApply) == nf
|
||||
@test length(graph.dirtyNodes) == 0
|
||||
@test length(graph.diff) == (addedNodes = 0, removedNodes = 0, addedEdges = 0, removedEdges = 0)
|
||||
|
||||
# this applies pending operations
|
||||
properties = get_properties(graph)
|
||||
|
||||
@test length(graph.nodes) == 24
|
||||
@test length(graph.appliedOperations) == 1
|
||||
@test length(graph.operationsToApply) == 0
|
||||
@test length(graph.dirtyNodes) != 0
|
||||
@test properties.noNodes == 24
|
||||
@test properties.noEdges == 23
|
||||
@test properties.computeEffort == 28
|
||||
@test properties.data < 62
|
||||
@test properties.computeIntensity > 28 / 62
|
||||
|
||||
operations = get_operations(graph)
|
||||
@test length(graph.dirtyNodes) == 0
|
||||
|
||||
@test length(operations) == (nodeReductions = 0, nodeSplits = 0)
|
||||
@test length(operations) == (nodeFusions = 9, nodeReductions = 0, nodeSplits = 0)
|
||||
@test !isempty(operations)
|
||||
|
||||
possibleNF = 9
|
||||
while !isempty(operations.nodeFusions)
|
||||
push_operation!(graph, first(operations.nodeFusions))
|
||||
global operations = get_operations(graph)
|
||||
global possibleNF = possibleNF - 1
|
||||
@test length(operations) == (nodeFusions = possibleNF, nodeReductions = 0, nodeSplits = 0)
|
||||
end
|
||||
|
||||
@test isempty(operations)
|
||||
|
||||
@test length(operations) == (nodeFusions = 0, nodeReductions = 0, nodeSplits = 0)
|
||||
@test length(graph.dirtyNodes) == 0
|
||||
@test length(graph.nodes) == 26
|
||||
@test length(graph.appliedOperations) == 0
|
||||
@test length(graph.nodes) == 6
|
||||
@test length(graph.appliedOperations) == 10
|
||||
@test length(graph.operationsToApply) == 0
|
||||
|
||||
reset_graph!(graph)
|
||||
|
||||
@test length(graph.dirtyNodes) == 0
|
||||
@test length(graph.dirtyNodes) == 26
|
||||
@test length(graph.nodes) == 26
|
||||
@test length(graph.appliedOperations) == 0
|
||||
@test length(graph.operationsToApply) == 0
|
||||
@@ -170,6 +208,6 @@ properties = get_properties(graph)
|
||||
@test properties.computeIntensity ≈ 28 / 62
|
||||
|
||||
operations = get_operations(graph)
|
||||
@test length(operations) == (nodeReductions = 0, nodeSplits = 0)
|
||||
@test length(operations) == (nodeFusions = 10, nodeReductions = 0, nodeSplits = 0)
|
||||
|
||||
@test is_valid(graph)
|
||||
|
@@ -6,7 +6,8 @@ RNG = Random.MersenneTwister(0)
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"), ABCModel())
|
||||
|
||||
# create the optimizers
|
||||
FIXPOINT_OPTIMIZERS = [GreedyOptimizer(GlobalMetricEstimator()), ReductionOptimizer(), SplitOptimizer()]
|
||||
FIXPOINT_OPTIMIZERS =
|
||||
[GreedyOptimizer(GlobalMetricEstimator()), ReductionOptimizer(), SplitOptimizer(), FusionOptimizer()]
|
||||
NO_FIXPOINT_OPTIMIZERS = [RandomWalkOptimizer(RNG)]
|
||||
|
||||
@testset "Optimizer $optimizer" for optimizer in vcat(NO_FIXPOINT_OPTIMIZERS, FIXPOINT_OPTIMIZERS)
|
||||
|
@@ -1,8 +1,7 @@
|
||||
using MetagraphOptimization
|
||||
|
||||
using QEDcore
|
||||
|
||||
import MetagraphOptimization.gen_diagrams
|
||||
import MetagraphOptimization.isincoming
|
||||
import MetagraphOptimization.types
|
||||
|
||||
|
||||
@@ -10,12 +9,25 @@ model = QEDModel()
|
||||
compton = ("Compton Scattering", parse_process("ke->ke", model), 2)
|
||||
compton_3 = ("3-Photon Compton Scattering", parse_process("kkke->ke", QEDModel()), 24)
|
||||
compton_4 = ("4-Photon Compton Scattering", parse_process("kkkke->ke", QEDModel()), 120)
|
||||
bhabha = ("Bhabha Scattering", parse_process("ep->ep", model), 2)
|
||||
moller = ("Møller Scattering", parse_process("ee->ee", model), 2)
|
||||
pair_production = ("Pair production", parse_process("kk->ep", model), 2)
|
||||
pair_annihilation = ("Pair annihilation", parse_process("ep->kk", model), 2)
|
||||
trident = ("Trident", parse_process("ke->epe", model), 8)
|
||||
|
||||
@testset "Known Processes" begin
|
||||
@testset "$name" for (name, process, n) in [compton, compton_3, compton_4]
|
||||
@testset "$name" for (name, process, n) in
|
||||
[compton, bhabha, moller, pair_production, pair_annihilation, trident, compton_3, compton_4]
|
||||
initial_diagram = FeynmanDiagram(process)
|
||||
n_particles = number_incoming_particles(process) + number_outgoing_particles(process)
|
||||
|
||||
n_particles = 0
|
||||
for type in types(model)
|
||||
if (isincoming(type))
|
||||
n_particles += get(process.inParticles, type, 0)
|
||||
else
|
||||
n_particles += get(process.outParticles, type, 0)
|
||||
end
|
||||
end
|
||||
@test n_particles == length(initial_diagram.particles)
|
||||
@test ismissing(initial_diagram.tie[])
|
||||
@test isempty(initial_diagram.vertices)
|
||||
|
@@ -1,6 +1,5 @@
|
||||
using MetagraphOptimization
|
||||
using QEDbase
|
||||
using QEDcore
|
||||
using QEDprocesses
|
||||
using StatsBase # for countmap
|
||||
using Random
|
||||
@@ -10,6 +9,8 @@ import MetagraphOptimization.caninteract
|
||||
import MetagraphOptimization.issame
|
||||
import MetagraphOptimization.interaction_result
|
||||
import MetagraphOptimization.propagation_result
|
||||
import MetagraphOptimization.direction
|
||||
import MetagraphOptimization.spin_or_pol
|
||||
import MetagraphOptimization.QED_vertex
|
||||
|
||||
def_momentum = SFourMomentum(1.0, 0.0, 0.0, 0.0)
|
||||
@@ -17,32 +18,32 @@ def_momentum = SFourMomentum(1.0, 0.0, 0.0, 0.0)
|
||||
RNG = Random.MersenneTwister(0)
|
||||
|
||||
testparticleTypes = [
|
||||
ParticleStateful{Incoming, Photon, SFourMomentum},
|
||||
ParticleStateful{Outgoing, Photon, SFourMomentum},
|
||||
ParticleStateful{Incoming, Electron, SFourMomentum},
|
||||
ParticleStateful{Outgoing, Electron, SFourMomentum},
|
||||
ParticleStateful{Incoming, Positron, SFourMomentum},
|
||||
ParticleStateful{Outgoing, Positron, SFourMomentum},
|
||||
PhotonStateful{Incoming,PolX},
|
||||
PhotonStateful{Outgoing,PolX},
|
||||
FermionStateful{Incoming,SpinUp},
|
||||
FermionStateful{Outgoing,SpinUp},
|
||||
AntiFermionStateful{Incoming,SpinUp},
|
||||
AntiFermionStateful{Outgoing,SpinUp},
|
||||
]
|
||||
|
||||
testparticleTypesPropagated = [
|
||||
ParticleStateful{Outgoing, Photon, SFourMomentum},
|
||||
ParticleStateful{Incoming, Photon, SFourMomentum},
|
||||
ParticleStateful{Outgoing, Electron, SFourMomentum},
|
||||
ParticleStateful{Incoming, Electron, SFourMomentum},
|
||||
ParticleStateful{Outgoing, Positron, SFourMomentum},
|
||||
ParticleStateful{Incoming, Positron, SFourMomentum},
|
||||
PhotonStateful{Outgoing,PolX},
|
||||
PhotonStateful{Incoming,PolX},
|
||||
FermionStateful{Outgoing,SpinUp},
|
||||
FermionStateful{Incoming,SpinUp},
|
||||
AntiFermionStateful{Outgoing,SpinUp},
|
||||
AntiFermionStateful{Incoming,SpinUp},
|
||||
]
|
||||
|
||||
function compton_groundtruth(input::PhaseSpacePoint)
|
||||
function compton_groundtruth(input::QEDProcessInput)
|
||||
# p1k1 -> p2k2
|
||||
# formula: −(ie)^2 (u(p2) slashed(ε1) S(p2 − k1) slashed(ε2) u(p1) + u(p2) slashed(ε2) S(p1 + k1) slashed(ε1) u(p1))
|
||||
|
||||
p1 = momentum(psp, Incoming(), 2)
|
||||
p2 = momentum(psp, Outgoing(), 2)
|
||||
p1 = input.inFerms[1]
|
||||
p2 = input.outFerms[1]
|
||||
|
||||
k1 = momentum(psp, Incoming(), 1)
|
||||
k2 = momentum(psp, Outgoing(), 1)
|
||||
k1 = input.inPhotons[1]
|
||||
k2 = input.outPhotons[1]
|
||||
|
||||
u_p1 = base_state(Electron(), Incoming(), p1.momentum, spin_or_pol(p1))
|
||||
u_p2 = base_state(Electron(), Outgoing(), p2.momentum, spin_or_pol(p2))
|
||||
@@ -86,46 +87,95 @@ end
|
||||
@test issame(typeof(resultParticle), interaction_result(p1, p2))
|
||||
|
||||
totalMom = zero(SFourMomentum)
|
||||
for (p, mom) in [(p1, momentum(testParticle1)), (p2, momentum(testParticle2)), (p3, momentum(resultParticle))]
|
||||
if (typeof(particle_direction(p)) <: Incoming)
|
||||
for (p, mom) in [
|
||||
(p1, testParticle1.momentum),
|
||||
(p2, testParticle2.momentum),
|
||||
(p3, resultParticle.momentum),
|
||||
]
|
||||
if (typeof(direction(p)) <: Incoming)
|
||||
totalMom += mom
|
||||
else
|
||||
totalMom -= mom
|
||||
end
|
||||
end
|
||||
|
||||
@test isapprox(totalMom, zero(SFourMomentum); atol = sqrt(eps()))
|
||||
@test isapprox(totalMom, zero(SFourMomentum); atol=sqrt(eps()))
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Propagation Result" begin
|
||||
for (p, propResult) in zip(testparticleTypes, testparticleTypesPropagated)
|
||||
@test issame(propagation_result(p), propResult)
|
||||
@test particle_direction(propagation_result(p)(def_momentum)) != particle_direction(p(def_momentum))
|
||||
@test direction(propagation_result(p)(def_momentum)) != direction(p(def_momentum))
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Parse Process" begin
|
||||
@testset "Order invariance" begin
|
||||
@test parse_process("ke->ke", QEDModel()) == parse_process("ek->ke", QEDModel())
|
||||
@test parse_process("ke->ke", QEDModel()) == parse_process("ek->ek", QEDModel())
|
||||
@test parse_process("ke->ke", QEDModel()) == parse_process("ke->ek", QEDModel())
|
||||
|
||||
@test parse_process("kkke->eep", QEDModel()) ==
|
||||
parse_process("kkek->epe", QEDModel())
|
||||
end
|
||||
|
||||
@testset "Known processes" begin
|
||||
proc = parse_process("ke->ke", QEDModel())
|
||||
@test incoming_particles(proc) == (Photon(), Electron())
|
||||
@test outgoing_particles(proc) == (Photon(), Electron())
|
||||
compton_process = QEDProcessDescription(
|
||||
Dict{Type,Int}(
|
||||
PhotonStateful{Incoming,PolX} => 1, FermionStateful{Incoming,SpinUp} => 1
|
||||
),
|
||||
Dict{Type,Int}(
|
||||
PhotonStateful{Outgoing,PolX} => 1, FermionStateful{Outgoing,SpinUp} => 1
|
||||
),
|
||||
)
|
||||
|
||||
proc = parse_process("kp->kp", QEDModel())
|
||||
@test incoming_particles(proc) == (Photon(), Positron())
|
||||
@test outgoing_particles(proc) == (Photon(), Positron())
|
||||
@test parse_process("ke->ke", QEDModel()) == compton_process
|
||||
|
||||
proc = parse_process("ke->eep", QEDModel())
|
||||
@test incoming_particles(proc) == (Photon(), Electron())
|
||||
@test outgoing_particles(proc) == (Electron(), Electron(), Positron())
|
||||
positron_compton_process = QEDProcessDescription(
|
||||
Dict{Type,Int}(
|
||||
PhotonStateful{Incoming,PolX} => 1,
|
||||
AntiFermionStateful{Incoming,SpinUp} => 1,
|
||||
),
|
||||
Dict{Type,Int}(
|
||||
PhotonStateful{Outgoing,PolX} => 1,
|
||||
AntiFermionStateful{Outgoing,SpinUp} => 1,
|
||||
),
|
||||
)
|
||||
|
||||
proc = parse_process("kk->pe", QEDModel())
|
||||
@test incoming_particles(proc) == (Photon(), Photon())
|
||||
@test outgoing_particles(proc) == (Positron(), Electron())
|
||||
@test parse_process("kp->kp", QEDModel()) == positron_compton_process
|
||||
|
||||
proc = parse_process("pe->kk", QEDModel())
|
||||
@test incoming_particles(proc) == (Positron(), Electron())
|
||||
@test outgoing_particles(proc) == (Photon(), Photon())
|
||||
trident_process = QEDProcessDescription(
|
||||
Dict{Type,Int}(
|
||||
PhotonStateful{Incoming,PolX} => 1, FermionStateful{Incoming,SpinUp} => 1
|
||||
),
|
||||
Dict{Type,Int}(
|
||||
FermionStateful{Outgoing,SpinUp} => 2,
|
||||
AntiFermionStateful{Outgoing,SpinUp} => 1,
|
||||
),
|
||||
)
|
||||
|
||||
@test parse_process("ke->eep", QEDModel()) == trident_process
|
||||
|
||||
pair_production_process = QEDProcessDescription(
|
||||
Dict{Type,Int}(PhotonStateful{Incoming,PolX} => 2),
|
||||
Dict{Type,Int}(
|
||||
FermionStateful{Outgoing,SpinUp} => 1,
|
||||
AntiFermionStateful{Outgoing,SpinUp} => 1,
|
||||
),
|
||||
)
|
||||
|
||||
@test parse_process("kk->pe", QEDModel()) == pair_production_process
|
||||
|
||||
pair_annihilation_process = QEDProcessDescription(
|
||||
Dict{Type,Int}(
|
||||
FermionStateful{Incoming,SpinUp} => 1,
|
||||
AntiFermionStateful{Incoming,SpinUp} => 1,
|
||||
),
|
||||
Dict{Type,Int}(PhotonStateful{Outgoing,PolX} => 2),
|
||||
)
|
||||
|
||||
@test parse_process("pe->kk", QEDModel()) == pair_annihilation_process
|
||||
end
|
||||
end
|
||||
|
||||
@@ -136,12 +186,36 @@ end
|
||||
|
||||
for i in 1:100
|
||||
input = gen_process_input(process)
|
||||
@test isapprox(sum(momenta(input, Incoming())), sum(momenta(input, Outgoing())); atol = sqrt(eps()))
|
||||
@test length(input.inFerms) ==
|
||||
get(process.inParticles, FermionStateful{Incoming,SpinUp}, 0)
|
||||
@test length(input.inAntiferms) ==
|
||||
get(process.inParticles, AntiFermionStateful{Incoming,SpinUp}, 0)
|
||||
@test length(input.inPhotons) ==
|
||||
get(process.inParticles, PhotonStateful{Incoming,PolX}, 0)
|
||||
@test length(input.outFerms) ==
|
||||
get(process.outParticles, FermionStateful{Outgoing,SpinUp}, 0)
|
||||
@test length(input.outAntiferms) ==
|
||||
get(process.outParticles, AntiFermionStateful{Outgoing,SpinUp}, 0)
|
||||
@test length(input.outPhotons) ==
|
||||
get(process.outParticles, PhotonStateful{Outgoing,PolX}, 0)
|
||||
|
||||
@test isapprox(
|
||||
sum([
|
||||
getfield.(input.inFerms, :momentum)...,
|
||||
getfield.(input.inAntiferms, :momentum)...,
|
||||
getfield.(input.inPhotons, :momentum)...,
|
||||
]),
|
||||
sum([
|
||||
getfield.(input.outFerms, :momentum)...,
|
||||
getfield.(input.outAntiferms, :momentum)...,
|
||||
getfield.(input.outPhotons, :momentum)...,
|
||||
]);
|
||||
atol=sqrt(eps()),
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
#=
|
||||
@testset "Compton" begin
|
||||
import MetagraphOptimization.insert_node!
|
||||
import MetagraphOptimization.insert_edge!
|
||||
@@ -307,4 +381,3 @@ end
|
||||
@test isapprox(compute_function.(input), reduced_compute_function.(input))
|
||||
end
|
||||
end
|
||||
=#
|
Reference in New Issue
Block a user