Compare commits
No commits in common. "remove_fusion" and "main" have entirely different histories.
remove_fus
...
main
4
.gitattributes
vendored
4
.gitattributes
vendored
@ -1,5 +1,3 @@
|
||||
input/AB->ABBBBBBBBB.txt filter=lfs diff=lfs merge=lfs -text
|
||||
input/AB->ABBBBBBB.txt filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs
|
||||
*.gif filter=lfs diff=lfs merge=lfs
|
||||
*.jld2 filter=lfs diff=lfs merge=lfs
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
|
10
Project.toml
10
Project.toml
@ -14,22 +14,16 @@ JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
|
||||
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
|
||||
NumaAllocators = "21436f30-1b4a-4f08-87af-e26101bb5379"
|
||||
QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
|
||||
QEDcore = "35dc0263-cb5f-4c33-a114-1d7f54ab753e"
|
||||
QEDprocesses = "46de9c38-1bb3-4547-a1ec-da24d767fdad"
|
||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
|
||||
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
|
||||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
||||
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
|
||||
|
||||
[extras]
|
||||
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
|
||||
QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
|
||||
QEDcore = "35dc0263-cb5f-4c33-a114-1d7f54ab753e"
|
||||
QEDprocesses = "46de9c38-1bb3-4547-a1ec-da24d767fdad"
|
||||
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
|
||||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
|
||||
[targets]
|
||||
test = ["SafeTestsets", "Test", "QEDbase", "QEDcore", "QEDprocesses"]
|
||||
test = ["Test"]
|
||||
|
@ -5,7 +5,7 @@
|
||||
## Package Features
|
||||
- Read a DAG from a file
|
||||
- Analyze its properties
|
||||
- Mute the graph using the operations NodeReduction and NodeSplit
|
||||
- Mute the graph using the operations NodeFusion, NodeReduction and NodeSplit
|
||||
|
||||
## Coming Soon:
|
||||
- Add Code Generation from finished DAG
|
||||
|
@ -1,194 +0,0 @@
|
||||
ENV["UCX_ERROR_SIGNALS"] = "SIGILL,SIGBUS,SIGFPE"
|
||||
|
||||
using MetagraphOptimization
|
||||
using QEDbase
|
||||
using QEDcore
|
||||
using QEDprocesses
|
||||
using Random
|
||||
using UUIDs
|
||||
|
||||
using CUDA
|
||||
|
||||
using NamedDims
|
||||
using CSV
|
||||
using JLD2
|
||||
using FlexiMaps
|
||||
|
||||
RNG = Random.MersenneTwister(123)
|
||||
|
||||
function mock_machine()
|
||||
return Machine(
|
||||
[
|
||||
MetagraphOptimization.NumaNode(
|
||||
0,
|
||||
1,
|
||||
MetagraphOptimization.default_strategy(MetagraphOptimization.NumaNode),
|
||||
-1.0,
|
||||
UUIDs.uuid1(),
|
||||
),
|
||||
],
|
||||
[-1.0;;],
|
||||
)
|
||||
end
|
||||
|
||||
function congruent_input_momenta(processDescription::GenericQEDProcess, omega::Number)
|
||||
# generate an input sample for given e + nk -> e' + k' process, where the nk are equal
|
||||
inputMasses = Vector{Float64}()
|
||||
for particle in incoming_particles(processDescription)
|
||||
push!(inputMasses, mass(particle))
|
||||
end
|
||||
outputMasses = Vector{Float64}()
|
||||
for particle in outgoing_particles(processDescription)
|
||||
push!(outputMasses, mass(particle))
|
||||
end
|
||||
|
||||
initial_momenta = [
|
||||
i == length(inputMasses) ? SFourMomentum(1, 0, 0, 0) : SFourMomentum(omega, 0, 0, omega) for
|
||||
i in 1:length(inputMasses)
|
||||
]
|
||||
|
||||
ss = sqrt(sum(initial_momenta) * sum(initial_momenta))
|
||||
final_momenta = MetagraphOptimization.generate_physical_massive_moms(RNG, ss, outputMasses)
|
||||
|
||||
return (tuple(initial_momenta...), tuple(final_momenta...))
|
||||
end
|
||||
|
||||
# theta ∈ [0, 2π] and phi ∈ [0, 2π]
|
||||
function congruent_input_momenta_scenario_2(
|
||||
processDescription::GenericQEDProcess,
|
||||
omega::Number,
|
||||
theta::Number,
|
||||
phi::Number,
|
||||
)
|
||||
# -------------
|
||||
# same as above
|
||||
|
||||
# generate an input sample for given e + nk -> e' + k' process, where the nk are equal
|
||||
inputMasses = Vector{Float64}()
|
||||
for particle in incoming_particles(processDescription)
|
||||
push!(inputMasses, mass(particle))
|
||||
end
|
||||
outputMasses = Vector{Float64}()
|
||||
for particle in outgoing_particles(processDescription)
|
||||
push!(outputMasses, mass(particle))
|
||||
end
|
||||
|
||||
initial_momenta = [
|
||||
i == length(inputMasses) ? SFourMomentum(1, 0, 0, 0) : SFourMomentum(omega, 0, 0, omega) for
|
||||
i in 1:length(inputMasses)
|
||||
]
|
||||
|
||||
ss = sqrt(sum(initial_momenta) * sum(initial_momenta))
|
||||
|
||||
# up to here
|
||||
# ----------
|
||||
|
||||
# now calculate the final_momenta from omega, cos_theta and phi
|
||||
n = number_particles(processDescription, Incoming(), Photon())
|
||||
|
||||
cos_theta = cos(theta)
|
||||
omega_prime = (n * omega) / (1 + n * omega * (1 - cos_theta))
|
||||
|
||||
k_prime =
|
||||
omega_prime * SFourMomentum(1, sqrt(1 - cos_theta^2) * cos(phi), sqrt(1 - cos_theta^2) * sin(phi), cos_theta)
|
||||
p_prime = sum(initial_momenta) - k_prime
|
||||
|
||||
final_momenta = (k_prime, p_prime)
|
||||
|
||||
return (tuple(initial_momenta...), tuple(final_momenta...))
|
||||
|
||||
end
|
||||
|
||||
function build_psp(processDescription::GenericQEDProcess, momenta)
|
||||
return PhaseSpacePoint(
|
||||
processDescription,
|
||||
PerturbativeQED(),
|
||||
PhasespaceDefinition(SphericalCoordinateSystem(), ElectronRestFrame()),
|
||||
momenta[1],
|
||||
momenta[2],
|
||||
)
|
||||
end
|
||||
|
||||
# hack to fix stacksize for threading
|
||||
with_stacksize(f, n) = fetch(schedule(Task(f, n)))
|
||||
|
||||
# scenario 2
|
||||
N = 1024 # thetas
|
||||
M = 1024 # phis
|
||||
K = 64 # omegas
|
||||
|
||||
thetas = collect(LinRange(0, 2π, N))
|
||||
phis = collect(LinRange(0, 2π, M))
|
||||
omegas = collect(maprange(log, 2e-2, 2e-7, K))
|
||||
|
||||
for photons in 1:5
|
||||
# temp process to generate momenta
|
||||
println("Generating $(K*N*M) inputs for $photons photons (Scenario 2 grid walk)...")
|
||||
temp_process = parse_process("k"^photons * "e->ke", QEDModel(), PolX(), SpinUp(), PolX(), SpinUp())
|
||||
|
||||
input_momenta =
|
||||
Array{typeof(congruent_input_momenta_scenario_2(temp_process, omegas[1], thetas[1], phis[1]))}(undef, (K, N, M))
|
||||
|
||||
Threads.@threads for k in 1:K
|
||||
Threads.@threads for i in 1:N
|
||||
Threads.@threads for j in 1:M
|
||||
input_momenta[k, i, j] = congruent_input_momenta_scenario_2(temp_process, omegas[k], thetas[i], phis[j])
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
cu_results = CuArray{Float64}(undef, size(input_momenta))
|
||||
fill!(cu_results, 0.0)
|
||||
|
||||
i = 1
|
||||
for (in_pol, in_spin, out_pol, out_spin) in
|
||||
Iterators.product([PolX(), PolY()], [SpinUp(), SpinDown()], [PolX(), PolY()], [SpinUp(), SpinDown()])
|
||||
|
||||
print(
|
||||
"[$i/16] Calculating for spin/pol config: $in_pol, $in_spin -> $out_pol, $out_spin... Preparing inputs... ",
|
||||
)
|
||||
process = parse_process("k"^photons * "e->ke", QEDModel(), in_pol, in_spin, out_pol, out_spin)
|
||||
|
||||
inputs = Array{typeof(build_psp(process, input_momenta[1, 1, 1]))}(undef, (K, N, M))
|
||||
#println("input_momenta: $input_momenta")
|
||||
Threads.@threads for k in 1:K
|
||||
Threads.@threads for i in 1:N
|
||||
Threads.@threads for j in 1:M
|
||||
inputs[k, i, j] = build_psp(process, input_momenta[k, i, j])
|
||||
end
|
||||
end
|
||||
end
|
||||
cu_inputs = CuArray(inputs)
|
||||
|
||||
print("Preparing graph... ")
|
||||
graph = gen_graph(process)
|
||||
optimize_to_fixpoint!(ReductionOptimizer(), graph)
|
||||
print("Preparing function... ")
|
||||
kernel! = get_cuda_kernel(graph, process, mock_machine())
|
||||
#func = get_compute_function(graph, process, mock_machine())
|
||||
|
||||
print("Calculating... ")
|
||||
ts = 32
|
||||
bs = Int64(length(cu_inputs) / 32)
|
||||
|
||||
outputs = CuArray{ComplexF64}(undef, size(cu_inputs))
|
||||
|
||||
@cuda threads = ts blocks = bs always_inline = true kernel!(cu_inputs, outputs, length(cu_inputs))
|
||||
CUDA.device_synchronize()
|
||||
cu_results += abs2.(outputs)
|
||||
|
||||
println("Done.")
|
||||
i += 1
|
||||
end
|
||||
|
||||
println("Writing results")
|
||||
|
||||
out_ph_moms = getindex.(getindex.(input_momenta, 2), 1)
|
||||
out_el_moms = getindex.(getindex.(input_momenta, 2), 2)
|
||||
|
||||
results = NamedDimsArray{(:omegas, :thetas, :phis)}(Array(cu_results))
|
||||
println("Named results array: $(typeof(results))")
|
||||
|
||||
@save "$(photons)_congruent_photons_grid.jld2" omegas thetas phis results
|
||||
end
|
60
examples/plot_chain.jl
Normal file
60
examples/plot_chain.jl
Normal file
@ -0,0 +1,60 @@
|
||||
using MetagraphOptimization
|
||||
using Plots
|
||||
using Random
|
||||
|
||||
function gen_plot(filepath)
|
||||
name = basename(filepath)
|
||||
name, _ = splitext(name)
|
||||
|
||||
filepath = joinpath(@__DIR__, "../input/", filepath)
|
||||
if !isfile(filepath)
|
||||
println("File ", filepath, " does not exist, skipping")
|
||||
return
|
||||
end
|
||||
|
||||
g = parse_dag(filepath, ABCModel())
|
||||
|
||||
Random.seed!(1)
|
||||
|
||||
println("Random Walking... ")
|
||||
|
||||
x = Vector{Float64}()
|
||||
y = Vector{Float64}()
|
||||
|
||||
for i in 1:30
|
||||
print("\r", i)
|
||||
# push
|
||||
opt = get_operations(g)
|
||||
|
||||
# choose one of fuse/split/reduce
|
||||
option = rand(1:3)
|
||||
if option == 1 && !isempty(opt.nodeFusions)
|
||||
push_operation!(g, rand(collect(opt.nodeFusions)))
|
||||
println("NF")
|
||||
elseif option == 2 && !isempty(opt.nodeReductions)
|
||||
push_operation!(g, rand(collect(opt.nodeReductions)))
|
||||
println("NR")
|
||||
elseif option == 3 && !isempty(opt.nodeSplits)
|
||||
push_operation!(g, rand(collect(opt.nodeSplits)))
|
||||
println("NS")
|
||||
else
|
||||
i = i - 1
|
||||
end
|
||||
|
||||
props = get_properties(g)
|
||||
push!(x, props.data)
|
||||
push!(y, props.computeEffort)
|
||||
end
|
||||
|
||||
println("\rDone.")
|
||||
|
||||
plot([x[1], x[2]], [y[1], y[2]], linestyle = :solid, linewidth = 1, color = :red, legend = false)
|
||||
# Create lines connecting the reference point to each data point
|
||||
for i in 3:length(x)
|
||||
plot!([x[i - 1], x[i]], [y[i - 1], y[i]], linestyle = :solid, linewidth = 1, color = :red)
|
||||
end
|
||||
|
||||
return gui()
|
||||
end
|
||||
|
||||
gen_plot("AB->ABBB.txt")
|
96
examples/plot_star.jl
Normal file
96
examples/plot_star.jl
Normal file
@ -0,0 +1,96 @@
|
||||
using MetagraphOptimization
|
||||
using Plots
|
||||
using Random
|
||||
|
||||
function gen_plot(filepath)
|
||||
name = basename(filepath)
|
||||
name, _ = splitext(name)
|
||||
|
||||
filepath = joinpath(@__DIR__, "../input/", filepath)
|
||||
if !isfile(filepath)
|
||||
println("File ", filepath, " does not exist, skipping")
|
||||
return
|
||||
end
|
||||
|
||||
g = parse_dag(filepath, ABCModel())
|
||||
|
||||
Random.seed!(1)
|
||||
|
||||
println("Random Walking... ")
|
||||
|
||||
for i in 1:30
|
||||
print("\r", i)
|
||||
# push
|
||||
opt = get_operations(g)
|
||||
|
||||
# choose one of fuse/split/reduce
|
||||
option = rand(1:3)
|
||||
if option == 1 && !isempty(opt.nodeFusions)
|
||||
push_operation!(g, rand(collect(opt.nodeFusions)))
|
||||
println("NF")
|
||||
elseif option == 2 && !isempty(opt.nodeReductions)
|
||||
push_operation!(g, rand(collect(opt.nodeReductions)))
|
||||
println("NR")
|
||||
elseif option == 3 && !isempty(opt.nodeSplits)
|
||||
push_operation!(g, rand(collect(opt.nodeSplits)))
|
||||
println("NS")
|
||||
else
|
||||
i = i - 1
|
||||
end
|
||||
end
|
||||
|
||||
println("\rDone.")
|
||||
|
||||
|
||||
|
||||
|
||||
props = get_properties(g)
|
||||
x0 = props.data
|
||||
y0 = props.computeEffort
|
||||
|
||||
x = Vector{Float64}()
|
||||
y = Vector{Float64}()
|
||||
names = Vector{String}()
|
||||
|
||||
opt = get_operations(g)
|
||||
for op in opt.nodeFusions
|
||||
push_operation!(g, op)
|
||||
props = get_properties(g)
|
||||
push!(x, props.data)
|
||||
push!(y, props.computeEffort)
|
||||
pop_operation!(g)
|
||||
|
||||
push!(names, "NF: (" * string(props.data) * ", " * string(props.computeEffort) * ")")
|
||||
end
|
||||
for op in opt.nodeReductions
|
||||
push_operation!(g, op)
|
||||
props = get_properties(g)
|
||||
push!(x, props.data)
|
||||
push!(y, props.computeEffort)
|
||||
pop_operation!(g)
|
||||
|
||||
push!(names, "NR: (" * string(props.data) * ", " * string(props.computeEffort) * ")")
|
||||
end
|
||||
for op in opt.nodeSplits
|
||||
push_operation!(g, op)
|
||||
props = get_properties(g)
|
||||
push!(x, props.data)
|
||||
push!(y, props.computeEffort)
|
||||
pop_operation!(g)
|
||||
|
||||
push!(names, "NS: (" * string(props.data) * ", " * string(props.computeEffort) * ")")
|
||||
end
|
||||
|
||||
plot([x0, x[1]], [y0, y[1]], linestyle = :solid, linewidth = 1, color = :red, legend = false)
|
||||
# Create lines connecting the reference point to each data point
|
||||
for i in 2:length(x)
|
||||
plot!([x0, x[i]], [y0, y[i]], linestyle = :solid, linewidth = 1, color = :red)
|
||||
end
|
||||
#scatter!(x, y, label=names)
|
||||
|
||||
print(names)
|
||||
|
||||
return gui()
|
||||
end
|
||||
|
||||
gen_plot("AB->ABBB.txt")
|
BIN
images/contour_plot_congruent_in_photons.gif
(Stored with Git LFS)
BIN
images/contour_plot_congruent_in_photons.gif
(Stored with Git LFS)
Binary file not shown.
@ -1,5 +1,5 @@
|
||||
# Optimizer Plots
|
||||
|
||||
Plots of FusionOptimizer (deprecated), ReductionOptimizer, SplitOptimizer, RandomWalkOptimizer, and GreedyOptimizer, executed on a system with 32 threads and an A30 GPU.
|
||||
Plots of FusionOptimizer, ReductionOptimizer, SplitOptimizer, RandomWalkOptimizer, and GreedyOptimizer, executed on a system with 32 threads and an A30 GPU.
|
||||
|
||||
Benchmarked using `notebooks/optimizers.ipynb`.
|
||||
|
@ -413,7 +413,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Julia 1.10.4",
|
||||
"display_name": "Julia 1.10.2",
|
||||
"language": "julia",
|
||||
"name": "julia-1.10"
|
||||
},
|
||||
@ -421,7 +421,7 @@
|
||||
"file_extension": ".jl",
|
||||
"mimetype": "application/julia",
|
||||
"name": "julia",
|
||||
"version": "1.10.4"
|
||||
"version": "1.10.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
File diff suppressed because one or more lines are too long
@ -54,6 +54,8 @@
|
||||
"cu_inputs = CuArray(inputs)\n",
|
||||
"optimizer = RandomWalkOptimizer(MersenneTwister(0))# GreedyOptimizer(GlobalMetricEstimator())\n",
|
||||
"\n",
|
||||
"#done: split, reduce, fuse, greedy\n",
|
||||
"\n",
|
||||
"process_str_short = \"qed_k3\"\n",
|
||||
"optim_str = \"Random Walk Optimization\"\n",
|
||||
"optim_str_short=\"random\"\n",
|
||||
|
BIN
results/1_congruent_photons_grid.jld2
(Stored with Git LFS)
BIN
results/1_congruent_photons_grid.jld2
(Stored with Git LFS)
Binary file not shown.
BIN
results/2_congruent_photons_grid.jld2
(Stored with Git LFS)
BIN
results/2_congruent_photons_grid.jld2
(Stored with Git LFS)
Binary file not shown.
BIN
results/3_congruent_photons_grid.jld2
(Stored with Git LFS)
BIN
results/3_congruent_photons_grid.jld2
(Stored with Git LFS)
Binary file not shown.
BIN
results/4_congruent_photons_grid.jld2
(Stored with Git LFS)
BIN
results/4_congruent_photons_grid.jld2
(Stored with Git LFS)
Binary file not shown.
BIN
results/5_congruent_photons_grid.jld2
(Stored with Git LFS)
BIN
results/5_congruent_photons_grid.jld2
(Stored with Git LFS)
Binary file not shown.
@ -6,8 +6,6 @@ A module containing tools to work on DAGs.
|
||||
module MetagraphOptimization
|
||||
|
||||
using QEDbase
|
||||
using QEDcore
|
||||
using QEDprocesses
|
||||
|
||||
# graph types
|
||||
export DAG
|
||||
@ -19,6 +17,7 @@ export AbstractTask
|
||||
export AbstractComputeTask
|
||||
export AbstractDataTask
|
||||
export DataTask
|
||||
export FusedComputeTask
|
||||
export PossibleOperations
|
||||
export GraphProperties
|
||||
|
||||
@ -43,6 +42,7 @@ export is_valid, is_scheduled
|
||||
# graph operation related
|
||||
export Operation
|
||||
export AppliedOperation
|
||||
export NodeFusion
|
||||
export NodeReduction
|
||||
export NodeSplit
|
||||
export push_operation!
|
||||
@ -64,7 +64,8 @@ export ComputeTaskABC_Sum
|
||||
|
||||
# QED model
|
||||
export FeynmanDiagram, FeynmanVertex, FeynmanTie, FeynmanParticle
|
||||
export GenericQEDProcess, QEDModel
|
||||
export PhotonStateful, FermionStateful, AntiFermionStateful
|
||||
export QEDParticle, QEDProcessDescription, QEDProcessInput, QEDModel
|
||||
export ComputeTaskQED_P
|
||||
export ComputeTaskQED_S1
|
||||
export ComputeTaskQED_S2
|
||||
@ -86,7 +87,7 @@ export GlobalMetricEstimator, CDCost
|
||||
|
||||
# optimization
|
||||
export AbstractOptimizer, GreedyOptimizer, RandomWalkOptimizer
|
||||
export ReductionOptimizer, SplitOptimizer
|
||||
export ReductionOptimizer, SplitOptimizer, FusionOptimizer
|
||||
export optimize_step!, optimize!
|
||||
export fixpoint_reached, optimize_to_fixpoint!
|
||||
|
||||
@ -98,6 +99,9 @@ export ==, in, show, isempty, delete!, length
|
||||
|
||||
export bytes_to_human_readable
|
||||
|
||||
# TODO: this is probably not good
|
||||
import QEDprocesses.compute
|
||||
|
||||
import Base.length
|
||||
import Base.show
|
||||
import Base.==
|
||||
@ -110,6 +114,7 @@ import Base.delete!
|
||||
import Base.insert!
|
||||
import Base.collect
|
||||
|
||||
|
||||
include("devices/interface.jl")
|
||||
include("task/type.jl")
|
||||
include("node/type.jl")
|
||||
@ -162,30 +167,28 @@ include("optimization/interface.jl")
|
||||
include("optimization/greedy.jl")
|
||||
include("optimization/random_walk.jl")
|
||||
include("optimization/reduce.jl")
|
||||
include("optimization/fuse.jl")
|
||||
include("optimization/split.jl")
|
||||
|
||||
include("models/interface.jl")
|
||||
include("models/print.jl")
|
||||
|
||||
include("models/physics_models/interface.jl")
|
||||
include("models/abc/types.jl")
|
||||
include("models/abc/particle.jl")
|
||||
include("models/abc/compute.jl")
|
||||
include("models/abc/create.jl")
|
||||
include("models/abc/properties.jl")
|
||||
include("models/abc/parse.jl")
|
||||
include("models/abc/print.jl")
|
||||
|
||||
include("models/physics_models/abc/types.jl")
|
||||
include("models/physics_models/abc/particle.jl")
|
||||
include("models/physics_models/abc/compute.jl")
|
||||
include("models/physics_models/abc/create.jl")
|
||||
include("models/physics_models/abc/properties.jl")
|
||||
include("models/physics_models/abc/parse.jl")
|
||||
include("models/physics_models/abc/print.jl")
|
||||
|
||||
include("models/physics_models/qed/utility.jl")
|
||||
include("models/physics_models/qed/types.jl")
|
||||
include("models/physics_models/qed/particle.jl")
|
||||
include("models/physics_models/qed/diagrams.jl")
|
||||
include("models/physics_models/qed/compute.jl")
|
||||
include("models/physics_models/qed/create.jl")
|
||||
include("models/physics_models/qed/properties.jl")
|
||||
include("models/physics_models/qed/parse.jl")
|
||||
include("models/physics_models/qed/print.jl")
|
||||
include("models/qed/types.jl")
|
||||
include("models/qed/particle.jl")
|
||||
include("models/qed/diagrams.jl")
|
||||
include("models/qed/compute.jl")
|
||||
include("models/qed/create.jl")
|
||||
include("models/qed/properties.jl")
|
||||
include("models/qed/parse.jl")
|
||||
include("models/qed/print.jl")
|
||||
|
||||
include("devices/measure.jl")
|
||||
include("devices/detect.jl")
|
||||
@ -194,7 +197,7 @@ include("devices/impl.jl")
|
||||
include("devices/numa/impl.jl")
|
||||
include("devices/cuda/impl.jl")
|
||||
include("devices/rocm/impl.jl")
|
||||
#include("devices/oneapi/impl.jl")
|
||||
include("devices/oneapi/impl.jl")
|
||||
|
||||
include("scheduler/interface.jl")
|
||||
include("scheduler/greedy.jl")
|
||||
|
@ -1,70 +1,72 @@
|
||||
"""
|
||||
get_compute_function(graph::DAG, instance, machine::Machine)
|
||||
get_compute_function(graph::DAG, process::AbstractProcessDescription, machine::Machine)
|
||||
|
||||
Return a function of signature `compute_<id>(input::input_type(instance))`, which will return the result of the DAG computation on the given input.
|
||||
Return a function of signature `compute_<id>(input::AbstractProcessInput)`, which will return the result of the DAG computation on the given input.
|
||||
"""
|
||||
function get_compute_function(graph::DAG, instance, machine::Machine)
|
||||
tape = gen_tape(graph, instance, machine)
|
||||
function get_compute_function(graph::DAG, process::AbstractProcessDescription, machine::Machine)
|
||||
tape = gen_tape(graph, process, machine)
|
||||
|
||||
initCaches = Expr(:block, tape.initCachesCode...)
|
||||
assignInputs = Expr(:block, tape.inputAssignCode...)
|
||||
assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...)
|
||||
code = Expr(:block, expr_from_fc.(tape.computeCode)...)
|
||||
|
||||
functionId = to_var_name(UUIDs.uuid1(rng[1]))
|
||||
resSym = eval(gen_access_expr(entry_device(tape.machine), tape.outputSymbol))
|
||||
expr = Meta.parse(
|
||||
"function compute_$(functionId)(data_input::$(input_type(instance))) $(initCaches); $(assignInputs); $code; return $resSym; end",
|
||||
"function compute_$(functionId)(data_input::AbstractProcessInput) $(initCaches); $(assignInputs); $code; return $resSym; end",
|
||||
)
|
||||
|
||||
return expr
|
||||
func = eval(expr)
|
||||
|
||||
return func
|
||||
end
|
||||
|
||||
"""
|
||||
get_cuda_kernel(graph::DAG, instance, machine::Machine)
|
||||
get_cuda_kernel(graph::DAG, process::AbstractProcessDescription, machine::Machine)
|
||||
|
||||
Return a function of signature `compute_<id>(input::CuVector, output::CuVector, n::Int64)`, which will return the result of the DAG computation of the input on the given output variable.
|
||||
"""
|
||||
function get_cuda_kernel(graph::DAG, instance, machine::Machine)
|
||||
tape = gen_tape(graph, instance, machine)
|
||||
function get_cuda_kernel(graph::DAG, process::AbstractProcessDescription, machine::Machine)
|
||||
tape = gen_tape(graph, process, machine)
|
||||
|
||||
initCaches = Expr(:block, tape.initCachesCode...)
|
||||
assignInputs = Expr(:block, tape.inputAssignCode...)
|
||||
assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...)
|
||||
code = Expr(:block, expr_from_fc.(tape.computeCode)...)
|
||||
|
||||
functionId = to_var_name(UUIDs.uuid1(rng[1]))
|
||||
resSym = eval(gen_access_expr(entry_device(tape.machine), tape.outputSymbol))
|
||||
expr = Meta.parse(
|
||||
"function compute_$(functionId)(input_vector, output_vector, n::Int64)
|
||||
id = (blockIdx().x - 1) * blockDim().x + threadIdx().x
|
||||
if (id > n)
|
||||
return
|
||||
end
|
||||
@inline data_input = input_vector[id]
|
||||
$(initCaches)
|
||||
$(assignInputs)
|
||||
$code
|
||||
@inline output_vector[id] = $resSym
|
||||
return nothing
|
||||
end"
|
||||
)
|
||||
expr = Meta.parse("function compute_$(functionId)(input_vector, output_vector, n::Int64)
|
||||
id = (blockIdx().x - 1) * blockDim().x + threadIdx().x
|
||||
if (id > n)
|
||||
return
|
||||
end
|
||||
@inline data_input = input_vector[id]
|
||||
$(initCaches)
|
||||
$(assignInputs)
|
||||
$code
|
||||
@inline output_vector[id] = $resSym
|
||||
return nothing
|
||||
end")
|
||||
|
||||
return expr
|
||||
func = eval(expr)
|
||||
|
||||
return func
|
||||
end
|
||||
|
||||
"""
|
||||
execute(graph::DAG, instance, machine::Machine, input)
|
||||
execute(graph::DAG, process::AbstractProcessDescription, machine::Machine, input::AbstractProcessInput)
|
||||
|
||||
Execute the code of the given `graph` on the given input values.
|
||||
Execute the code of the given `graph` on the given input particles.
|
||||
|
||||
This is essentially shorthand for
|
||||
```julia
|
||||
tape = gen_tape(graph, instance, machine)
|
||||
tape = gen_tape(graph, process, machine)
|
||||
return execute_tape(tape, input)
|
||||
```
|
||||
|
||||
See also: [`parse_dag`](@ref), [`parse_process`](@ref), [`gen_process_input`](@ref)
|
||||
"""
|
||||
function execute(graph::DAG, instance, machine::Machine, input)
|
||||
tape = gen_tape(graph, instance, machine)
|
||||
function execute(graph::DAG, process::AbstractProcessDescription, machine::Machine, input::AbstractProcessInput)
|
||||
tape = gen_tape(graph, process, machine)
|
||||
return execute_tape(tape, input)
|
||||
end
|
||||
|
@ -1,11 +1,10 @@
|
||||
# TODO: do this with macros
|
||||
function call_fc(fc::FunctionCall{VectorT, 0}, cache::Dict{Symbol, Any}) where {VectorT <: SVector{1}}
|
||||
cache[fc.return_symbol] = fc.func(cache[fc.arguments[1]])
|
||||
return nothing
|
||||
end
|
||||
|
||||
function call_fc(fc::FunctionCall{VectorT, 1}, cache::Dict{Symbol, Any}) where {VectorT <: SVector{1}}
|
||||
cache[fc.return_symbol] = fc.func(fc.value_arguments[1], cache[fc.arguments[1]])
|
||||
cache[fc.return_symbol] = fc.func(fc.additional_arguments[1], cache[fc.arguments[1]])
|
||||
return nothing
|
||||
end
|
||||
|
||||
@ -15,12 +14,12 @@ function call_fc(fc::FunctionCall{VectorT, 0}, cache::Dict{Symbol, Any}) where {
|
||||
end
|
||||
|
||||
function call_fc(fc::FunctionCall{VectorT, 1}, cache::Dict{Symbol, Any}) where {VectorT <: SVector{2}}
|
||||
cache[fc.return_symbol] = fc.func(fc.value_arguments[1], cache[fc.arguments[1]], cache[fc.arguments[2]])
|
||||
cache[fc.return_symbol] = fc.func(fc.additional_arguments[1], cache[fc.arguments[1]], cache[fc.arguments[2]])
|
||||
return nothing
|
||||
end
|
||||
|
||||
function call_fc(fc::FunctionCall{VectorT, 1}, cache::Dict{Symbol, Any}) where {VectorT}
|
||||
cache[fc.return_symbol] = fc.func(fc.value_arguments[1], getindex.(Ref(cache), fc.arguments)...)
|
||||
cache[fc.return_symbol] = fc.func(fc.additional_arguments[1], getindex.(Ref(cache), fc.arguments)...)
|
||||
return nothing
|
||||
end
|
||||
|
||||
@ -32,7 +31,7 @@ Execute the given [`FunctionCall`](@ref) on the dictionary.
|
||||
Several more specialized versions of this function exist to reduce vector unrolling work for common cases.
|
||||
"""
|
||||
function call_fc(fc::FunctionCall{VectorT, M}, cache::Dict{Symbol, Any}) where {VectorT, M}
|
||||
cache[fc.return_symbol] = fc.func(fc.value_arguments..., getindex.(Ref(cache), fc.arguments)...)
|
||||
cache[fc.return_symbol] = fc.func(fc.additional_arguments..., getindex.(Ref(cache), fc.arguments)...)
|
||||
return nothing
|
||||
end
|
||||
|
||||
@ -48,8 +47,12 @@ end
|
||||
For a given function call, return an expression evaluating it.
|
||||
"""
|
||||
function expr_from_fc(fc::FunctionCall{VectorT, M}) where {VectorT, M}
|
||||
func_call =
|
||||
Expr(:call, Symbol(fc.func), fc.value_arguments..., eval.(gen_access_expr.(Ref(fc.device), fc.arguments))...)
|
||||
func_call = Expr(
|
||||
:call,
|
||||
Symbol(fc.func),
|
||||
fc.additional_arguments...,
|
||||
eval.(gen_access_expr.(Ref(fc.device), fc.arguments))...,
|
||||
)
|
||||
|
||||
expr = :($(eval(gen_access_expr(fc.device, fc.return_symbol))) = $func_call)
|
||||
return expr
|
||||
@ -70,32 +73,51 @@ function gen_cache_init_code(machine::Machine)
|
||||
return initializeCaches
|
||||
end
|
||||
|
||||
"""
|
||||
part_from_x(type::Type, index::Int, x::AbstractProcessInput)
|
||||
|
||||
Return the [`ParticleValue`](@ref) of the given type of particle with the given `index` from the given process input.
|
||||
|
||||
Function is wrapped into a [`FunctionCall`](@ref) in [`gen_input_assignment_code`](@ref).
|
||||
"""
|
||||
part_from_x(type::Type, index::Int, x::AbstractProcessInput) =
|
||||
ParticleValue{type, ComplexF64}(get_particle(x, type, index), one(ComplexF64))
|
||||
|
||||
"""
|
||||
gen_input_assignment_code(
|
||||
inputSymbols::Dict{String, Vector{Symbol}},
|
||||
instance::AbstractProblemInstance,
|
||||
processDescription::AbstractProcessDescription,
|
||||
machine::Machine,
|
||||
problemInputSymbol::Symbol = :data_input,
|
||||
processInputSymbol::Symbol = :input,
|
||||
)
|
||||
|
||||
Return a `Vector{Expr}` doing the input assignments from the given `problemInputSymbol` onto the `inputSymbols`.
|
||||
Return a `Vector{Expr}` doing the input assignments from the given `processInputSymbol` onto the `inputSymbols`.
|
||||
"""
|
||||
function gen_input_assignment_code(
|
||||
inputSymbols::Dict{String, Vector{Symbol}},
|
||||
instance,
|
||||
processDescription::AbstractProcessDescription,
|
||||
machine::Machine,
|
||||
problemInputSymbol::Symbol = :data_input,
|
||||
processInputSymbol::Symbol = :input,
|
||||
)
|
||||
assignInputs = Vector{Expr}()
|
||||
@assert length(inputSymbols) >=
|
||||
sum(values(in_particles(processDescription))) + sum(values(out_particles(processDescription))) "Number of input Symbols is smaller than the number of particles in the process description"
|
||||
|
||||
assignInputs = Vector{FunctionCall}()
|
||||
for (name, symbols) in inputSymbols
|
||||
(type, index) = type_index_from_name(model(processDescription), name)
|
||||
# make a function for this, since we can't use anonymous functions in the FunctionCall
|
||||
|
||||
for symbol in symbols
|
||||
device = entry_device(machine)
|
||||
push!(
|
||||
assignInputs,
|
||||
Meta.parse(
|
||||
"$(eval(gen_access_expr(device, symbol))) = $(input_expr(instance, name, problemInputSymbol))",
|
||||
FunctionCall(
|
||||
# x is the process input
|
||||
part_from_x,
|
||||
SVector{1, Symbol}(processInputSymbol),
|
||||
SVector{2, Any}(type, index),
|
||||
symbol,
|
||||
device,
|
||||
),
|
||||
)
|
||||
end
|
||||
@ -105,14 +127,14 @@ function gen_input_assignment_code(
|
||||
end
|
||||
|
||||
"""
|
||||
gen_tape(graph::DAG, instance::AbstractProblemInstance, machine::Machine, scheduler::AbstractScheduler = GreedyScheduler())
|
||||
gen_tape(graph::DAG, process::AbstractProcessDescription, machine::Machine)
|
||||
|
||||
Generate the code for a given graph. The return value is a [`Tape`](@ref).
|
||||
|
||||
See also: [`execute`](@ref), [`execute_tape`](@ref)
|
||||
"""
|
||||
function gen_tape(graph::DAG, instance, machine::Machine, scheduler::AbstractScheduler = GreedyScheduler())
|
||||
schedule = schedule_dag(scheduler, graph, machine)
|
||||
function gen_tape(graph::DAG, process::AbstractProcessDescription, machine::Machine)
|
||||
schedule = schedule_dag(GreedyScheduler(), graph, machine)
|
||||
|
||||
# get inSymbols
|
||||
inputSyms = Dict{String, Vector{Symbol}}()
|
||||
@ -128,24 +150,23 @@ function gen_tape(graph::DAG, instance, machine::Machine, scheduler::AbstractSch
|
||||
outSym = Symbol(to_var_name(get_exit_node(graph).id))
|
||||
|
||||
initCaches = gen_cache_init_code(machine)
|
||||
assignInputs = gen_input_assignment_code(inputSyms, instance, machine, :data_input)
|
||||
assignInputs = gen_input_assignment_code(inputSyms, process, machine, :input)
|
||||
|
||||
return Tape{input_type(instance)}(initCaches, assignInputs, schedule, inputSyms, outSym, Dict(), instance, machine)
|
||||
return Tape(initCaches, assignInputs, schedule, inputSyms, outSym, Dict(), process, machine)
|
||||
end
|
||||
|
||||
"""
|
||||
execute_tape(tape::Tape, input::Input) where {Input}
|
||||
execute_tape(tape::Tape, input::AbstractProcessInput)
|
||||
|
||||
Execute the given tape with the given input.
|
||||
|
||||
For implementation reasons, this disregards the set [`CacheStrategy`](@ref) of the devices and always uses a dictionary.
|
||||
"""
|
||||
function execute_tape(tape::Tape, input)
|
||||
function execute_tape(tape::Tape, input::AbstractProcessInput)
|
||||
cache = Dict{Symbol, Any}()
|
||||
cache[:data_input] = input
|
||||
cache[:input] = input
|
||||
# simply execute all the code snippets here
|
||||
@assert typeof(input) == input_type(tape.instance)
|
||||
# TODO: `@assert` that input fits the tape.instance
|
||||
# TODO: `@assert` that process input fits the tape.process
|
||||
for expr in tape.initCachesCode
|
||||
@eval $expr
|
||||
end
|
||||
|
@ -1,21 +1,19 @@
|
||||
|
||||
"""
|
||||
Tape{INPUT}
|
||||
Tape
|
||||
|
||||
TODO: update docs
|
||||
- `INPUT` the input type of the problem instance
|
||||
|
||||
- `code::Vector{Expr}`: The julia expression containing the code for the whole graph.
|
||||
- `inputSymbols::Dict{String, Vector{Symbol}}`: A dictionary of symbols mapping the names of the input nodes of the graph to the symbols their inputs should be provided on.
|
||||
- `outputSymbol::Symbol`: The symbol of the final calculated value
|
||||
"""
|
||||
struct Tape{INPUT}
|
||||
struct Tape
|
||||
initCachesCode::Vector{Expr}
|
||||
inputAssignCode::Vector{Expr}
|
||||
inputAssignCode::Vector{FunctionCall}
|
||||
computeCode::Vector{FunctionCall}
|
||||
inputSymbols::Dict{String, Vector{Symbol}}
|
||||
outputSymbol::Symbol
|
||||
cache::Dict{Symbol, Any}
|
||||
instance::Any
|
||||
process::AbstractProcessDescription
|
||||
machine::Machine
|
||||
end
|
||||
|
@ -10,7 +10,8 @@ Representation of a [`DAG`](@ref)'s cost as estimated by the [`GlobalMetricEstim
|
||||
|
||||
|
||||
!!! note
|
||||
Note that the `computeIntensity` doesn't necessarily make sense in the context of only operation costs.
|
||||
Note that the `computeIntensity` doesn't necessarily make sense in the context of only operation costs.
|
||||
For example, for node fusions this will always be 0, since the computeEffort is zero.
|
||||
It will still work as intended when adding/subtracting to/from a `graph_cost` estimate.
|
||||
"""
|
||||
const CDCost = NamedTuple{(:data, :computeEffort, :computeIntensity), Tuple{Float64, Float64, Float64}}
|
||||
@ -54,6 +55,10 @@ function graph_cost(estimator::GlobalMetricEstimator, graph::DAG)
|
||||
)::CDCost
|
||||
end
|
||||
|
||||
function operation_effect(estimator::GlobalMetricEstimator, graph::DAG, operation::NodeFusion)
|
||||
return (data = -data(operation.input[2].task), computeEffort = 0.0, computeIntensity = 0.0)::CDCost
|
||||
end
|
||||
|
||||
function operation_effect(estimator::GlobalMetricEstimator, graph::DAG, operation::NodeReduction)
|
||||
s = length(operation.input) - 1
|
||||
return (
|
||||
|
@ -169,6 +169,66 @@ function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invali
|
||||
return nothing
|
||||
end
|
||||
|
||||
function replace_children!(task::FusedComputeTask, before, after)
|
||||
replacedIn1 = length(findall(x -> x == before, task.t1_inputs))
|
||||
replacedIn2 = length(findall(x -> x == before, task.t2_inputs))
|
||||
|
||||
#@assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
|
||||
|
||||
replace!(task.t1_inputs, before => after)
|
||||
replace!(task.t2_inputs, before => after)
|
||||
|
||||
# recursively descend down the tree, but only in the tasks where we're replacing things
|
||||
if replacedIn1 > 0
|
||||
replace_children!(task.first_task, before, after)
|
||||
end
|
||||
if replacedIn2 > 0
|
||||
replace_children!(task.second_task, before, after)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function replace_children!(task::AbstractTask, before, after)
|
||||
return nothing
|
||||
end
|
||||
|
||||
function update_child!(graph::DAG, n::Node, child_before::Symbol, child_after::Symbol; track = true)
|
||||
# only need to update fused compute tasks
|
||||
if !(typeof(task(n)) <: FusedComputeTask)
|
||||
return nothing
|
||||
end
|
||||
|
||||
taskBefore = copy(task(n))
|
||||
|
||||
#=if !((child_before in task(n).t1_inputs) || (child_before in task(n).t2_inputs))
|
||||
println("------------------ Nothing to replace!! ------------------")
|
||||
child_ids = Vector{String}()
|
||||
for child in children(n)
|
||||
push!(child_ids, "$(child.id)")
|
||||
end
|
||||
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
|
||||
@assert false
|
||||
end=#
|
||||
|
||||
replace_children!(task(n), child_before, child_after)
|
||||
|
||||
#=if !((child_after in task(n).t1_inputs) || (child_after in task(n).t2_inputs))
|
||||
println("------------------ Did not replace anything!! ------------------")
|
||||
child_ids = Vector{String}()
|
||||
for child in children(n)
|
||||
push!(child_ids, "$(child.id)")
|
||||
end
|
||||
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
|
||||
@assert false
|
||||
end=#
|
||||
|
||||
# keep track
|
||||
if (track)
|
||||
push!(graph.diff.updatedChildren, (n, taskBefore))
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
get_snapshot_diff(graph::DAG)
|
||||
|
||||
@ -180,6 +240,31 @@ function get_snapshot_diff(graph::DAG)
|
||||
return swapfield!(graph, :diff, Diff())
|
||||
end
|
||||
|
||||
"""
|
||||
invalidate_caches!(graph::DAG, operation::NodeFusion)
|
||||
|
||||
Invalidate the operation caches for a given [`NodeFusion`](@ref).
|
||||
|
||||
This deletes the operation from the graph's possible operations and from the involved nodes' own operation caches.
|
||||
"""
|
||||
function invalidate_caches!(graph::DAG, operation::NodeFusion)
|
||||
delete!(graph.possibleOperations, operation)
|
||||
|
||||
# delete the operation from all caches of nodes involved in the operation
|
||||
for n in [1, 3]
|
||||
for i in eachindex(operation.input[n].nodeFusions)
|
||||
if operation == operation.input[n].nodeFusions[i]
|
||||
splice!(operation.input[n].nodeFusions, i)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
operation.input[2].nodeFusion = missing
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
invalidate_caches!(graph::DAG, operation::NodeReduction)
|
||||
|
||||
@ -226,6 +311,9 @@ function invalidate_operation_caches!(graph::DAG, node::ComputeTaskNode)
|
||||
if !ismissing(node.nodeSplit)
|
||||
invalidate_caches!(graph, node.nodeSplit)
|
||||
end
|
||||
while !isempty(node.nodeFusions)
|
||||
invalidate_caches!(graph, pop!(node.nodeFusions))
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
@ -241,5 +329,8 @@ function invalidate_operation_caches!(graph::DAG, node::DataTaskNode)
|
||||
if !ismissing(node.nodeSplit)
|
||||
invalidate_caches!(graph, node.nodeSplit)
|
||||
end
|
||||
if !ismissing(node.nodeFusion)
|
||||
invalidate_caches!(graph, node.nodeFusion)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
@ -7,6 +7,7 @@ A struct storing all possible operations on a [`DAG`](@ref).
|
||||
To get the [`PossibleOperations`](@ref) on a [`DAG`](@ref), use [`get_operations`](@ref).
|
||||
"""
|
||||
mutable struct PossibleOperations
|
||||
nodeFusions::Set{NodeFusion}
|
||||
nodeReductions::Set{NodeReduction}
|
||||
nodeSplits::Set{NodeSplit}
|
||||
end
|
||||
@ -51,7 +52,7 @@ end
|
||||
Construct and return an empty [`PossibleOperations`](@ref) object.
|
||||
"""
|
||||
function PossibleOperations()
|
||||
return PossibleOperations(Set{NodeReduction}(), Set{NodeSplit}())
|
||||
return PossibleOperations(Set{NodeFusion}(), Set{NodeReduction}(), Set{NodeSplit}())
|
||||
end
|
||||
|
||||
"""
|
||||
|
@ -40,11 +40,19 @@ function is_valid(graph::DAG)
|
||||
for ns in graph.possibleOperations.nodeSplits
|
||||
@assert is_valid(graph, ns)
|
||||
end
|
||||
for nf in graph.possibleOperations.nodeFusions
|
||||
@assert is_valid(graph, nf)
|
||||
end
|
||||
|
||||
for node in graph.dirtyNodes
|
||||
@assert node in graph "Dirty Node is not part of the graph!"
|
||||
@assert ismissing(node.nodeReduction) "Dirty Node has a NodeReduction!"
|
||||
@assert ismissing(node.nodeSplit) "Dirty Node has a NodeSplit!"
|
||||
if (typeof(node) <: DataTaskNode)
|
||||
@assert ismissing(node.nodeFusion) "Dirty DataTaskNode has a Node Fusion!"
|
||||
elseif (typeof(node) <: ComputeTaskNode)
|
||||
@assert isempty(node.nodeFusions) "Dirty ComputeTaskNode has Node Fusions!"
|
||||
end
|
||||
end
|
||||
|
||||
@assert is_connected(graph) "Graph is not connected!"
|
||||
|
@ -1,17 +1,6 @@
|
||||
using AccurateArithmetic
|
||||
using StaticArrays
|
||||
|
||||
function input_expr(instance::ABCProcessDescription, name::String, psp_symbol::Symbol)
|
||||
(type, index) = type_index_from_name(ABCModel(), name)
|
||||
|
||||
return Meta.parse(
|
||||
"ABCParticleValue(
|
||||
$type(momentum($psp_symbol, $(construction_string(particle_direction(type))), $(construction_string(particle_species(type))), Val($index))),
|
||||
0.0im,
|
||||
)",
|
||||
)
|
||||
end
|
||||
|
||||
"""
|
||||
compute(::ComputeTaskABC_P, data::ABCParticleValue)
|
||||
|
||||
@ -19,7 +8,7 @@ Return the particle and value as is.
|
||||
|
||||
0 FLOP.
|
||||
"""
|
||||
function compute(::ComputeTaskABC_P, data::ABCParticleValue{P})::ABCParticleValue{P} where {P}
|
||||
function compute(::ComputeTaskABC_P, data::ABCParticleValue{P})::ABCParticleValue{P} where {P <: ABCParticle}
|
||||
return data
|
||||
end
|
||||
|
||||
@ -30,7 +19,7 @@ Compute an outer edge. Return the particle value with the same particle and the
|
||||
|
||||
1 FLOP.
|
||||
"""
|
||||
function compute(::ComputeTaskABC_U, data::ABCParticleValue{P})::ABCParticleValue{P} where {P}
|
||||
function compute(::ComputeTaskABC_U, data::ABCParticleValue{P})::ABCParticleValue{P} where {P <: ABCParticle}
|
||||
return ABCParticleValue{P}(data.p, data.v * ABC_outer_edge(data.p))
|
||||
end
|
||||
|
||||
@ -45,7 +34,7 @@ function compute(
|
||||
::ComputeTaskABC_V,
|
||||
data1::ABCParticleValue{P1},
|
||||
data2::ABCParticleValue{P2},
|
||||
)::ABCParticleValue where {P1, P2}
|
||||
)::ABCParticleValue where {P1 <: ABCParticle, P2 <: ABCParticle}
|
||||
p3 = ABC_conserve_momentum(data1.p, data2.p)
|
||||
dataOut = ABCParticleValue{typeof(p3)}(p3, data1.v * ABC_vertex() * data2.v)
|
||||
return dataOut
|
||||
@ -60,7 +49,11 @@ For valid inputs, both input particles should have the same momenta at this poin
|
||||
|
||||
12 FLOP.
|
||||
"""
|
||||
function compute(::ComputeTaskABC_S2, data1::ParticleValue{P}, data2::ParticleValue{P})::Float64 where {P}
|
||||
function compute(
|
||||
::ComputeTaskABC_S2,
|
||||
data1::ParticleValue{P},
|
||||
data2::ParticleValue{P},
|
||||
)::Float64 where {P <: ABCParticle}
|
||||
#=
|
||||
@assert isapprox(abs(data1.p.momentum.E), abs(data2.p.momentum.E), rtol = 0.001, atol = sqrt(eps())) "E: $(data1.p.momentum.E) vs. $(data2.p.momentum.E)"
|
||||
@assert isapprox(data1.p.momentum.px, -data2.p.momentum.px, rtol = 0.001, atol = sqrt(eps())) "px: $(data1.p.momentum.px) vs. $(data2.p.momentum.px)"
|
||||
@ -92,6 +85,12 @@ Linearly many FLOP with growing data.
|
||||
"""
|
||||
function compute(::ComputeTaskABC_Sum, data...)::Float64
|
||||
return sum_kbn([data...])
|
||||
|
||||
#=s = 0.0im
|
||||
for d in data
|
||||
s += d
|
||||
end
|
||||
return s=#
|
||||
end
|
||||
|
||||
function compute(::ComputeTaskABC_Sum, data::AbstractArray)::Float64
|
@ -14,28 +14,34 @@ struct ABCModel <: AbstractPhysicsModel end
|
||||
|
||||
Base type for all particles in the [`ABCModel`](@ref).
|
||||
"""
|
||||
abstract type ABCParticle <: AbstractParticleType end
|
||||
abstract type ABCParticle <: AbstractParticle end
|
||||
|
||||
"""
|
||||
ParticleA <: ABCParticle
|
||||
|
||||
An 'A' particle in the ABC Model.
|
||||
"""
|
||||
struct ParticleA <: ABCParticle end
|
||||
struct ParticleA <: ABCParticle
|
||||
momentum::SFourMomentum
|
||||
end
|
||||
|
||||
"""
|
||||
ParticleB <: ABCParticle
|
||||
|
||||
A 'B' particle in the ABC Model.
|
||||
"""
|
||||
struct ParticleB <: ABCParticle end
|
||||
struct ParticleB <: ABCParticle
|
||||
momentum::SFourMomentum
|
||||
end
|
||||
|
||||
"""
|
||||
ParticleC <: ABCParticle
|
||||
|
||||
A 'C' particle in the ABC Model.
|
||||
"""
|
||||
struct ParticleC <: ABCParticle end
|
||||
struct ParticleC <: ABCParticle
|
||||
momentum::SFourMomentum
|
||||
end
|
||||
|
||||
"""
|
||||
ABCProcessDescription <: AbstractProcessDescription
|
||||
@ -66,7 +72,7 @@ struct ABCProcessInput{N1, N2, N3, N4, N5, N6} <: AbstractProcessInput
|
||||
outC::SVector{N6, ParticleC}
|
||||
end
|
||||
|
||||
ABCParticleValue{ParticleType} = ParticleValue{ParticleType, ComplexF64}
|
||||
ABCParticleValue{ParticleType <: ABCParticle} = ParticleValue{ParticleType, ComplexF64}
|
||||
|
||||
"""
|
||||
mass(t::Type{T}) where {T <: ABCParticle}
|
||||
@ -103,46 +109,39 @@ end
|
||||
Return a Vector of the possible types of particle in the [`ABCModel`](@ref).
|
||||
"""
|
||||
function types(::ABCModel)
|
||||
return [
|
||||
ParticleStateful{Incoming, ParticleA, SFourMomentum},
|
||||
ParticleStateful{Incoming, ParticleB, SFourMomentum},
|
||||
ParticleStateful{Incoming, ParticleC, SFourMomentum},
|
||||
ParticleStateful{Outgoing, ParticleA, SFourMomentum},
|
||||
ParticleStateful{Outgoing, ParticleB, SFourMomentum},
|
||||
ParticleStateful{Outgoing, ParticleC, SFourMomentum},
|
||||
]
|
||||
return [ParticleA, ParticleB, ParticleC]
|
||||
end
|
||||
|
||||
"""
|
||||
square(p::AbstractParticleStateful{Dir, ABCParticle})
|
||||
square(p::ABCParticle)
|
||||
|
||||
Return the square of the particle's momentum as a `Float` value.
|
||||
|
||||
Takes 7 effective FLOP.
|
||||
"""
|
||||
function square(p::AbstractParticleStateful{D, ABCParticle}) where {D}
|
||||
return getMass2(momentum(p))
|
||||
function square(p::ABCParticle)
|
||||
return getMass2(p.momentum)
|
||||
end
|
||||
|
||||
"""
|
||||
ABC_inner_edge(p::AbstractParticleStateful{Dir, ABCParticle})
|
||||
ABC_inner_edge(p::ABCParticle)
|
||||
|
||||
Return the factor of the inner edge with the given (virtual) particle.
|
||||
|
||||
Takes 10 effective FLOP. (3 here + 7 in square(p))
|
||||
"""
|
||||
function ABC_inner_edge(p::AbstractParticleStateful{D, ABCParticle}) where {D}
|
||||
return 1.0 / (square(p) - mass(particle(p))^2)
|
||||
function ABC_inner_edge(p::ABCParticle)
|
||||
return 1.0 / (square(p) - mass(p)^2)
|
||||
end
|
||||
|
||||
"""
|
||||
ABC_outer_edge(p::AbstractParticleStateful{Dir, ABCParticle})
|
||||
ABC_outer_edge(p::ABCParticle)
|
||||
|
||||
Return the factor of the outer edge with the given (real) particle.
|
||||
|
||||
Takes 0 effective FLOP.
|
||||
"""
|
||||
function ABC_outer_edge(::AbstractParticleStateful{D, ABCParticle}) where {D}
|
||||
function ABC_outer_edge(p::ABCParticle)
|
||||
return 1.0
|
||||
end
|
||||
|
||||
@ -180,26 +179,17 @@ model(::ABCProcessDescription) = ABCModel()
|
||||
model(::ABCProcessInput) = ABCModel()
|
||||
|
||||
function type_index_from_name(::ABCModel, name::String)
|
||||
if startswith(name, "Ai")
|
||||
return (ParticleStateful{Incoming, ParticleA, SFourMomentum}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "Ao")
|
||||
return (ParticleStateful{Outgoing, ParticleA, SFourMomentum}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "Bi")
|
||||
return (ParticleStateful{Incoming, ParticleB, SFourMomentum}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "Bo")
|
||||
return (ParticleStateful{Outgoing, ParticleB, SFourMomentum}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "Ci")
|
||||
return (ParticleStateful{Incoming, ParticleC, SFourMomentum}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "Co")
|
||||
return (ParticleStateful{Outgoing, ParticleC, SFourMomentum}, parse(Int, name[3:end]))
|
||||
if startswith(name, "A")
|
||||
return (ParticleA, parse(Int, name[2:end]))
|
||||
elseif startswith(name, "B")
|
||||
return (ParticleB, parse(Int, name[2:end]))
|
||||
elseif startswith(name, "C")
|
||||
return (ParticleC, parse(Int, name[2:end]))
|
||||
else
|
||||
throw("Invalid name for a particle in the ABC model")
|
||||
end
|
||||
end
|
||||
|
||||
function String(::Type{PS}) where {DIR, P <: ABCParticle, PS <: AbstractParticleStateful{DIR, P}}
|
||||
return String(P)
|
||||
end
|
||||
function String(::Type{ParticleA})
|
||||
return "A"
|
||||
end
|
@ -1,46 +1,120 @@
|
||||
import QEDbase.mass
|
||||
import QEDbase.AbstractParticle
|
||||
|
||||
"""
|
||||
AbstractModel
|
||||
AbstractPhysicsModel
|
||||
|
||||
Base type for all models. From this, [`AbstractProblemInstance`](@ref)s can be constructed.
|
||||
|
||||
See also: [`problem_instance`](@ref)
|
||||
Base type for a model, e.g. ABC-Model or QED. This is used to dispatch many functions.
|
||||
"""
|
||||
abstract type AbstractModel end
|
||||
abstract type AbstractPhysicsModel end
|
||||
|
||||
"""
|
||||
problem_instance(::AbstractModel, ::Vararg)
|
||||
ParticleValue{ParticleType <: AbstractParticle}
|
||||
|
||||
Interface function that must be implemented for any implementation of [`AbstractModel`](@ref). This function should return a specific [`AbstractProblemInstance`](@ref) given some parameters.
|
||||
A struct describing a particle during a calculation of a Feynman Diagram, together with the value that's being calculated. `AbstractParticle` is the type from the QEDbase package.
|
||||
|
||||
`sizeof(ParticleValue())` = 48 Byte
|
||||
"""
|
||||
function problem_instance end
|
||||
struct ParticleValue{ParticleType <: AbstractParticle, ValueType}
|
||||
p::ParticleType
|
||||
v::ValueType
|
||||
end
|
||||
|
||||
"""
|
||||
AbstractProblemInstance
|
||||
AbstractProcessDescription
|
||||
|
||||
Base type for problem instances. An object of this type of a corresponding [`AbstractModel`](@ref) should uniquely identify a problem instance of that model.
|
||||
Base type for process descriptions. An object of this type of a corresponding [`AbstractPhysicsModel`](@ref) should uniquely identify a process in that model.
|
||||
|
||||
See also: [`parse_process`](@ref)
|
||||
"""
|
||||
abstract type AbstractProblemInstance end
|
||||
abstract type AbstractProcessDescription end
|
||||
|
||||
"""
|
||||
input_type(problem::AbstractProblemInstance)
|
||||
AbstractProcessInput
|
||||
|
||||
Return the fully specified input type for a specific [`AbstractProblemInstance`](@ref).
|
||||
Base type for process inputs. An object of this type contains the input values (e.g. momenta) of the particles in a process.
|
||||
|
||||
See also: [`gen_process_input`](@ref)
|
||||
"""
|
||||
function input_type end
|
||||
abstract type AbstractProcessInput end
|
||||
|
||||
"""
|
||||
graph(::AbstractProblemInstance)
|
||||
interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: AbstractParticle, T2 <: AbstractParticle}
|
||||
|
||||
Generate the [`DAG`](@ref) for the given [`AbstractProblemInstance`](@ref). Every entry node (see [`get_entry_nodes`](@ref)) to the graph must have a name set. Implement [`input_expr`](@ref) to return a valid expression for each of those names.
|
||||
Interface function that must be implemented for every subtype of [`AbstractParticle`](@ref), returning the result particle type when the two given particles interact.
|
||||
"""
|
||||
function graph end
|
||||
function interaction_result end
|
||||
|
||||
"""
|
||||
input_expr(instance::AbstractProblemInstance, name::String, input_symbol::Symbol)
|
||||
types(::AbstractPhysicsModel)
|
||||
|
||||
For the given [`AbstractProblemInstance`](@ref), the entry node name, and the symbol of the problem input (where a variable of type `input_type(...)` will exist), return an `Expr` that gets that specific input value from the input symbol.
|
||||
Interface function that must be implemented for every subtype of [`AbstractPhysicsModel`](@ref), returning a `Vector` of the available particle types in the model.
|
||||
"""
|
||||
function input_expr end
|
||||
function types end
|
||||
|
||||
"""
|
||||
in_particles(::AbstractProcessDescription)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessDescription`](@ref).
|
||||
Returns a `<: Dict{Type{AbstractParticle}, Int}` object, representing the number of incoming particles for the process per particle type.
|
||||
|
||||
|
||||
in_particles(::AbstractProcessInput)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
|
||||
Returns a `<: Vector{AbstractParticle}` object with the values of all incoming particles for the corresponding `ProcessDescription`.
|
||||
"""
|
||||
function in_particles end
|
||||
|
||||
"""
|
||||
out_particles(::AbstractProcessDescription)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessDescription`](@ref).
|
||||
Returns a `<: Dict{Type{AbstractParticle}, Int}` object, representing the number of outgoing particles for the process per particle type.
|
||||
|
||||
|
||||
out_particles(::AbstractProcessInput)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
|
||||
Returns a `<: Vector{AbstractParticle}` object with the values of all outgoing particles for the corresponding `ProcessDescription`.
|
||||
"""
|
||||
function out_particles end
|
||||
|
||||
"""
|
||||
get_particle(::AbstractProcessInput, t::Type, n::Int)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
|
||||
Returns the `n`th particle of type `t`.
|
||||
"""
|
||||
function get_particle end
|
||||
|
||||
"""
|
||||
parse_process(::AbstractString, ::AbstractPhysicsModel)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractPhysicsModel`](@ref).
|
||||
Returns a `ProcessDescription` object.
|
||||
"""
|
||||
function parse_process end
|
||||
|
||||
"""
|
||||
gen_process_input(::AbstractProcessDescription)
|
||||
|
||||
Interface function that must be implemented for every specific [`AbstractProcessDescription`](@ref).
|
||||
Returns a randomly generated and valid corresponding `ProcessInput`.
|
||||
"""
|
||||
function gen_process_input end
|
||||
|
||||
"""
|
||||
model(::AbstractProcessDescription)
|
||||
model(::AbstarctProcessInput)
|
||||
|
||||
Return the model of this process description or input.
|
||||
"""
|
||||
function model end
|
||||
|
||||
"""
|
||||
type_from_name(model::AbstractModel, name::String)
|
||||
|
||||
For a name of a particle in the given [`AbstractModel`](@ref), return the particle's [`Type`] and index as a tuple. The input string can be expetced to be of the form \"<name><index>\".
|
||||
"""
|
||||
function type_index_from_name end
|
||||
|
@ -1,3 +0,0 @@
|
||||
## Deprecation Warning
|
||||
|
||||
These models are deprecated and should not be used anymore. They will be dropped entirely soon.
|
@ -1,142 +0,0 @@
|
||||
import QEDbase.AbstractParticle
|
||||
|
||||
"""
|
||||
AbstractPhysicsModel
|
||||
|
||||
Base type for a model, e.g. ABC-Model or QED. This is used to dispatch many functions.
|
||||
"""
|
||||
abstract type AbstractPhysicsModel <: AbstractModel end
|
||||
|
||||
"""
|
||||
ParticleValue{ParticleType <: AbstractParticleStateful}
|
||||
|
||||
A struct describing a particle during a calculation of a Feynman Diagram, together with the value that's being calculated. `AbstractParticleStateful` is the type from the QEDbase package.
|
||||
|
||||
`sizeof(ParticleValue())` = 48 Byte
|
||||
"""
|
||||
struct ParticleValue{ParticleType <: AbstractParticleStateful, ValueType}
|
||||
p::ParticleType
|
||||
v::ValueType
|
||||
end
|
||||
|
||||
"""
|
||||
TBW
|
||||
|
||||
particle value + spin/pol info, only used on the external legs (u tasks)
|
||||
"""
|
||||
struct ParticleValueSP{ParticleType <: AbstractParticleStateful, SP <: AbstractSpinOrPolarization, ValueType}
|
||||
p::ParticleType
|
||||
v::ValueType
|
||||
sp::SP
|
||||
end
|
||||
|
||||
"""
|
||||
AbstractProcessDescription <: AbstractProblemInstance
|
||||
|
||||
Base type for particle scattering process descriptions. An object of this type of a corresponding [`AbstractPhysicsModel`](@ref) should uniquely identify a scattering process in that model.
|
||||
|
||||
See also: [`parse_process`](@ref), [`AbstractProblemInstance`](@ref)
|
||||
"""
|
||||
abstract type AbstractProcessDescription end
|
||||
|
||||
|
||||
#TODO: i don't think giving this a base type is a good idea, the input type should just be returned of some function, allowing anything as an input type
|
||||
"""
|
||||
AbstractProcessInput
|
||||
|
||||
Base type for process inputs. An object of this type contains the input values (e.g. momenta) of the particles in a process.
|
||||
|
||||
See also: [`gen_process_input`](@ref)
|
||||
"""
|
||||
abstract type AbstractProcessInput end
|
||||
|
||||
"""
|
||||
interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: AbstractParticle, T2 <: AbstractParticle}
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractParticle`](@ref), returning the result particle type when the two given particles interact.
|
||||
"""
|
||||
function interaction_result end
|
||||
|
||||
"""
|
||||
types(::AbstractPhysicsModel)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractPhysicsModel`](@ref), returning a `Vector` of the available particle types in the model.
|
||||
"""
|
||||
function types end
|
||||
|
||||
"""
|
||||
in_particles(::AbstractProcessDescription)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessDescription`](@ref).
|
||||
Returns a `<: Dict{Type{AbstractParticle}, Int}` object, representing the number of incoming particles for the process per particle type.
|
||||
|
||||
|
||||
in_particles(::AbstractProcessInput)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
|
||||
Returns a `<: Vector{AbstractParticle}` object with the values of all incoming particles for the corresponding `ProcessDescription`.
|
||||
"""
|
||||
function in_particles end
|
||||
|
||||
"""
|
||||
out_particles(::AbstractProcessDescription)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessDescription`](@ref).
|
||||
Returns a `<: Dict{Type{AbstractParticle}, Int}` object, representing the number of outgoing particles for the process per particle type.
|
||||
|
||||
|
||||
out_particles(::AbstractProcessInput)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
|
||||
Returns a `<: Vector{AbstractParticle}` object with the values of all outgoing particles for the corresponding `ProcessDescription`.
|
||||
"""
|
||||
function out_particles end
|
||||
|
||||
"""
|
||||
get_particle(::AbstractProcessInput, t::Type, n::Int)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
|
||||
Returns the `n`th particle of type `t`.
|
||||
"""
|
||||
function get_particle end
|
||||
|
||||
"""
|
||||
parse_process(::AbstractString, ::AbstractPhysicsModel)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractPhysicsModel`](@ref).
|
||||
Returns a `ProcessDescription` object.
|
||||
"""
|
||||
function parse_process end
|
||||
|
||||
"""
|
||||
gen_process_input(::AbstractProcessDescription)
|
||||
|
||||
Interface function that must be implemented for every specific [`AbstractProcessDescription`](@ref).
|
||||
Returns a randomly generated and valid corresponding `ProcessInput`.
|
||||
"""
|
||||
function gen_process_input end
|
||||
|
||||
"""
|
||||
model(::AbstractProcessDescription)
|
||||
model(::AbstractProcessInput)
|
||||
|
||||
Return the model of this process description or input.
|
||||
"""
|
||||
function model end
|
||||
|
||||
"""
|
||||
type_from_name(model::AbstractModel, name::String)
|
||||
|
||||
For a name of a particle in the given [`AbstractModel`](@ref), return the particle's [`Type`] and index as a tuple. The input string can be expetced to be of the form \"<name><index>\".
|
||||
"""
|
||||
function type_index_from_name end
|
||||
|
||||
"""
|
||||
part_from_x(type::Type, index::Int, x::AbstractProcessInput)
|
||||
|
||||
Return the [`ParticleValue`](@ref) of the given type of particle with the given `index` from the given process input.
|
||||
|
||||
Function is wrapped into a [`FunctionCall`](@ref) in [`gen_input_assignment_code`](@ref).
|
||||
"""
|
||||
part_from_x(type::Type, index::Int, x::AbstractProcessInput) =
|
||||
ParticleValue{type, ComplexF64}(get_particle(x, type, index), one(ComplexF64))
|
@ -1,45 +0,0 @@
|
||||
"""
|
||||
parse_process(string::AbstractString, model::QEDModel)
|
||||
|
||||
Parse a string representation of a process, such as "ke->ke" into the corresponding [`QEDProcessDescription`](@ref).
|
||||
"""
|
||||
function parse_process(
|
||||
str::AbstractString,
|
||||
model::QEDModel,
|
||||
inphpol::AbstractDefinitePolarization = PolX(),
|
||||
inelspin::AbstractDefiniteSpin = SpinUp(),
|
||||
outphpol::AbstractDefinitePolarization = PolX(),
|
||||
outelspin::AbstractDefiniteSpin = SpinUp(),
|
||||
)
|
||||
if !(contains(str, "->"))
|
||||
throw("Did not find -> while parsing process \"$str\"")
|
||||
end
|
||||
|
||||
(in_str, out_str) = split(str, "->")
|
||||
|
||||
if (isempty(in_str) || isempty(out_str))
|
||||
throw("Process (\"$str\") input or output part is empty!")
|
||||
end
|
||||
|
||||
in_particles = Vector{AbstractParticleType}()
|
||||
out_particles = Vector{AbstractParticleType}()
|
||||
|
||||
for (particle_vector, s) in ((in_particles, in_str), (out_particles, out_str))
|
||||
for c in s
|
||||
if c == 'e'
|
||||
push!(particle_vector, Electron())
|
||||
elseif c == 'p'
|
||||
push!(particle_vector, Positron())
|
||||
elseif c == 'k'
|
||||
push!(particle_vector, Photon())
|
||||
else
|
||||
throw("Encountered unknown characters in the process \"$str\"")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
in_spin_pols = tuple([is_boson(in_particles[i]) ? inphpol : inelspin for i in eachindex(in_particles)]...)
|
||||
out_spin_pols = tuple([is_boson(out_particles[i]) ? outphpol : outelspin for i in eachindex(out_particles)]...)
|
||||
|
||||
return GenericQEDProcess(tuple(in_particles...), tuple(out_particles...), in_spin_pols, out_spin_pols)
|
||||
end
|
@ -1,305 +0,0 @@
|
||||
using QEDprocesses
|
||||
using StaticArrays
|
||||
|
||||
const e = sqrt(4π / 137)
|
||||
|
||||
QEDbase.is_incoming(::Type{<:ParticleStateful{Incoming}}) = true
|
||||
QEDbase.is_outgoing(::Type{<:ParticleStateful{Outgoing}}) = true
|
||||
QEDbase.is_incoming(::Type{<:ParticleStateful{Outgoing}}) = false
|
||||
QEDbase.is_outgoing(::Type{<:ParticleStateful{Incoming}}) = false
|
||||
|
||||
QEDbase.particle_direction(::Type{<:ParticleStateful{DIR}}) where {DIR <: ParticleDirection} = DIR()
|
||||
QEDbase.particle_species(
|
||||
::Type{<:ParticleStateful{DIR, SPECIES}},
|
||||
) where {DIR <: ParticleDirection, SPECIES <: AbstractParticleType} = SPECIES()
|
||||
|
||||
function spin_or_pol(
|
||||
process::GenericQEDProcess,
|
||||
type::Type{ParticleStateful{DIR, SPECIES, EL}},
|
||||
n::Int,
|
||||
) where {DIR <: ParticleDirection, SPECIES <: AbstractParticleType, EL <: AbstractFourMomentum}
|
||||
i = 0
|
||||
c = n
|
||||
for p in particles(process, DIR())
|
||||
i += 1
|
||||
if p == SPECIES()
|
||||
c -= 1
|
||||
end
|
||||
if c == 0
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if c != 0 || n <= 0
|
||||
throw(InvalidInputError("could not get $n-th spin/pol of $(DIR()) $species, does not exist"))
|
||||
end
|
||||
|
||||
if DIR <: Incoming
|
||||
return process.incoming_spins_pols[i]
|
||||
elseif DIR <: Outgoing
|
||||
return process.outgoing_spins_pols[i]
|
||||
else
|
||||
throw(InvalidInputError("unknown direction $(DIR()) given"))
|
||||
end
|
||||
end
|
||||
|
||||
function input_type(p::GenericQEDProcess)
|
||||
in_t = QEDcore._assemble_tuple_type(incoming_particles(p), Incoming(), SFourMomentum)
|
||||
out_t = QEDcore._assemble_tuple_type(outgoing_particles(p), Outgoing(), SFourMomentum)
|
||||
return PhaseSpacePoint{
|
||||
typeof(p),
|
||||
PerturbativeQED,
|
||||
PhasespaceDefinition{SphericalCoordinateSystem, ElectronRestFrame},
|
||||
Tuple{in_t...},
|
||||
Tuple{out_t...},
|
||||
}
|
||||
end
|
||||
|
||||
ValueType = Union{BiSpinor, AdjointBiSpinor, DiracMatrix, SLorentzVector{Float64}, ComplexF64}
|
||||
|
||||
"""
|
||||
interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: QEDParticle, T2 <: QEDParticle}
|
||||
|
||||
For two given particle types that can interact, return the third.
|
||||
"""
|
||||
function interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: ParticleStateful, T2 <: ParticleStateful}
|
||||
@assert false "Invalid interaction between particles of types $t1 and $t2"
|
||||
end
|
||||
|
||||
interaction_result(
|
||||
::Type{ParticleStateful{Incoming, Electron, EL}},
|
||||
::Type{ParticleStateful{Outgoing, Electron, EL}},
|
||||
) where {EL <: AbstractFourMomentum} = ParticleStateful{Incoming, Photon, EL}
|
||||
interaction_result(
|
||||
::Type{ParticleStateful{Incoming, Electron, EL}},
|
||||
::Type{ParticleStateful{Incoming, Positron, EL}},
|
||||
) where {EL <: AbstractFourMomentum} = ParticleStateful{Incoming, Photon, EL}
|
||||
interaction_result(
|
||||
::Type{ParticleStateful{Incoming, Electron, EL}},
|
||||
::Type{ParticleStateful{DIR, Photon, EL}},
|
||||
) where {EL <: AbstractFourMomentum, DIR <: ParticleDirection} = ParticleStateful{Outgoing, Electron, EL}
|
||||
|
||||
interaction_result(
|
||||
::Type{ParticleStateful{Outgoing, Electron, EL}},
|
||||
::Type{ParticleStateful{Incoming, Electron, EL}},
|
||||
) where {EL <: AbstractFourMomentum} = ParticleStateful{Incoming, Photon, EL}
|
||||
interaction_result(
|
||||
::Type{ParticleStateful{Outgoing, Electron, EL}},
|
||||
::Type{ParticleStateful{Outgoing, Positron, EL}},
|
||||
) where {EL <: AbstractFourMomentum} = ParticleStateful{Incoming, Photon, EL}
|
||||
interaction_result(
|
||||
::Type{ParticleStateful{Outgoing, Electron, EL}},
|
||||
::Type{ParticleStateful{DIR, Photon, EL}},
|
||||
) where {EL <: AbstractFourMomentum, DIR <: ParticleDirection} = ParticleStateful{Incoming, Electron, EL}
|
||||
|
||||
# antifermion mirror
|
||||
interaction_result(
|
||||
::Type{ParticleStateful{Incoming, Positron, EL}},
|
||||
t2::Type{<:ParticleStateful},
|
||||
) where {EL <: AbstractFourMomentum} = interaction_result(ParticleStateful{Outgoing, Electron, EL}, t2)
|
||||
interaction_result(
|
||||
::Type{ParticleStateful{Outgoing, Positron, EL}},
|
||||
t2::Type{<:ParticleStateful},
|
||||
) where {EL <: AbstractFourMomentum} = interaction_result(ParticleStateful{Incoming, Electron, EL}, t2)
|
||||
|
||||
# photon commutativity
|
||||
interaction_result(
|
||||
t1::Type{ParticleStateful{DIR, Photon, EL}},
|
||||
t2::Type{<:ParticleStateful},
|
||||
) where {EL <: AbstractFourMomentum, DIR <: ParticleDirection} = interaction_result(t2, t1)
|
||||
|
||||
# but prevent stack overflow
|
||||
function interaction_result(
|
||||
t1::Type{ParticleStateful{DIR1, Photon, EL}},
|
||||
t2::Type{ParticleStateful{DIR2, Photon, EL}},
|
||||
) where {DIR1 <: ParticleDirection, DIR2 <: ParticleDirection, EL <: AbstractFourMomentum}
|
||||
@assert false "Invalid interaction between particles of types $t1 and $t2"
|
||||
end
|
||||
|
||||
"""
|
||||
propagation_result(t1::Type{T}) where {T <: QEDParticle}
|
||||
|
||||
Return the type of the inverted direction. E.g.
|
||||
"""
|
||||
propagation_result(::Type{ParticleStateful{Incoming, Electron, EL}}) where {EL <: AbstractFourMomentum} =
|
||||
ParticleStateful{Outgoing, Electron, EL}
|
||||
propagation_result(::Type{ParticleStateful{Outgoing, Electron, EL}}) where {EL <: AbstractFourMomentum} =
|
||||
ParticleStateful{Incoming, Electron, EL}
|
||||
propagation_result(::Type{ParticleStateful{Incoming, Positron, EL}}) where {EL <: AbstractFourMomentum} =
|
||||
ParticleStateful{Outgoing, Positron, EL}
|
||||
propagation_result(::Type{ParticleStateful{Outgoing, Positron, EL}}) where {EL <: AbstractFourMomentum} =
|
||||
ParticleStateful{Incoming, Positron, EL}
|
||||
propagation_result(::Type{ParticleStateful{Incoming, Photon, EL}}) where {EL <: AbstractFourMomentum} =
|
||||
ParticleStateful{Outgoing, Photon, EL}
|
||||
propagation_result(::Type{ParticleStateful{Outgoing, Photon, EL}}) where {EL <: AbstractFourMomentum} =
|
||||
ParticleStateful{Incoming, Photon, EL}
|
||||
|
||||
"""
|
||||
types(::QEDModel)
|
||||
|
||||
Return a Vector of the possible types of particle in the [`QEDModel`](@ref).
|
||||
"""
|
||||
function types(::QEDModel)
|
||||
return [
|
||||
ParticleStateful{Incoming, Photon, SFourMomentum},
|
||||
ParticleStateful{Outgoing, Photon, SFourMomentum},
|
||||
ParticleStateful{Incoming, Electron, SFourMomentum},
|
||||
ParticleStateful{Outgoing, Electron, SFourMomentum},
|
||||
ParticleStateful{Incoming, Positron, SFourMomentum},
|
||||
ParticleStateful{Outgoing, Positron, SFourMomentum},
|
||||
]
|
||||
end
|
||||
|
||||
# type piracy?
|
||||
String(::Type{Incoming}) = "Incoming"
|
||||
String(::Type{Outgoing}) = "Outgoing"
|
||||
|
||||
String(::Type{PolX}) = "polx"
|
||||
String(::Type{PolY}) = "poly"
|
||||
|
||||
String(::Type{SpinUp}) = "spinup"
|
||||
String(::Type{SpinDown}) = "spindown"
|
||||
|
||||
String(::Incoming) = "i"
|
||||
String(::Outgoing) = "o"
|
||||
|
||||
function String(::Type{<:ParticleStateful{DIR, Photon}}) where {DIR <: ParticleDirection}
|
||||
return "k"
|
||||
end
|
||||
function String(::Type{<:ParticleStateful{DIR, Electron}}) where {DIR <: ParticleDirection}
|
||||
return "e"
|
||||
end
|
||||
function String(::Type{<:ParticleStateful{DIR, Positron}}) where {DIR <: ParticleDirection}
|
||||
return "p"
|
||||
end
|
||||
|
||||
"""
|
||||
caninteract(T1::Type{<:ParticleStateful}, T2::Type{<:ParticleStateful})
|
||||
|
||||
For two given `ParticleStateful` types, return whether they can interact at a vertex. This is equivalent to `!issame(T1, T2)`.
|
||||
|
||||
See also: [`issame`](@ref) and [`interaction_result`](@ref)
|
||||
"""
|
||||
function caninteract(
|
||||
T1::Type{<:ParticleStateful{D1, S1}},
|
||||
T2::Type{<:ParticleStateful{D2, S2}},
|
||||
) where {D1 <: ParticleDirection, S1 <: AbstractParticleType, D2 <: ParticleDirection, S2 <: AbstractParticleType}
|
||||
if (T1 == T2)
|
||||
return false
|
||||
end
|
||||
if (S1 == Photon && S2 == Photon)
|
||||
return false
|
||||
end
|
||||
|
||||
for (P1, P2) in [(T1, T2), (T2, T1)]
|
||||
if (P1 <: ParticleStateful{Incoming, Electron} && P2 <: ParticleStateful{Outgoing, Positron})
|
||||
return false
|
||||
end
|
||||
if (P1 <: ParticleStateful{Outgoing, Electron} && P2 <: ParticleStateful{Incoming, Positron})
|
||||
return false
|
||||
end
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function type_index_from_name(::QEDModel, name::String)
|
||||
if startswith(name, "ki")
|
||||
return (ParticleStateful{Incoming, Photon, SFourMomentum}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "ko")
|
||||
return (ParticleStateful{Outgoing, Photon, SFourMomentum}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "ei")
|
||||
return (ParticleStateful{Incoming, Electron, SFourMomentum}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "eo")
|
||||
return (ParticleStateful{Outgoing, Electron, SFourMomentum}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "pi")
|
||||
return (ParticleStateful{Incoming, Positron, SFourMomentum}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "po")
|
||||
return (ParticleStateful{Outgoing, Positron, SFourMomentum}, parse(Int, name[3:end]))
|
||||
else
|
||||
throw("Invalid name for a particle in the QED model")
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
issame(T1::Type{<:ParticleStateful}, T2::Type{<:ParticleStateful})
|
||||
|
||||
For two given `ParticleStateful` types, return whether they are equivalent for the purpose of a Feynman Diagram. That means e.g. an `Incoming` `AntiFermion` is the same as an `Outgoing` `Fermion`. This is equivalent to `!caninteract(T1, T2)`.
|
||||
|
||||
See also: [`caninteract`](@ref) and [`interaction_result`](@ref)
|
||||
"""
|
||||
function issame(T1::Type{<:ParticleStateful}, T2::Type{<:ParticleStateful})
|
||||
return !caninteract(T1, T2)
|
||||
end
|
||||
|
||||
"""
|
||||
QED_vertex()
|
||||
|
||||
Return the factor of a vertex in a QED feynman diagram.
|
||||
"""
|
||||
@inline function QED_vertex()::SLorentzVector{DiracMatrix}
|
||||
# Peskin-Schroeder notation
|
||||
return -1im * e * gamma()
|
||||
end
|
||||
|
||||
@inline function QED_inner_edge(p::ParticleStateful)
|
||||
return propagator(particle_species(p), momentum(p))
|
||||
end
|
||||
|
||||
"""
|
||||
QED_conserve_momentum(p1::ParticleStateful, p2::ParticleStateful)
|
||||
|
||||
Calculate and return a new particle from two given interacting ones at a vertex.
|
||||
"""
|
||||
function QED_conserve_momentum(
|
||||
p1::P1,
|
||||
p2::P2,
|
||||
) where {
|
||||
Dir1 <: ParticleDirection,
|
||||
Dir2 <: ParticleDirection,
|
||||
Species1 <: AbstractParticleType,
|
||||
Species2 <: AbstractParticleType,
|
||||
P1 <: ParticleStateful{Dir1, Species1},
|
||||
P2 <: ParticleStateful{Dir2, Species2},
|
||||
}
|
||||
P3 = interaction_result(P1, P2)
|
||||
p1_mom = momentum(p1)
|
||||
if (Dir1 <: Outgoing)
|
||||
p1_mom *= -1
|
||||
end
|
||||
p2_mom = momentum(p2)
|
||||
if (Dir2 <: Outgoing)
|
||||
p2_mom *= -1
|
||||
end
|
||||
|
||||
p3_mom = p1_mom + p2_mom
|
||||
if (particle_direction(P3) isa Incoming)
|
||||
return ParticleStateful(particle_direction(P3), particle_species(P3), -p3_mom)
|
||||
end
|
||||
return ParticleStateful(particle_direction(P3), particle_species(P3), p3_mom)
|
||||
end
|
||||
|
||||
"""
|
||||
model(::AbstractProcessDescription)
|
||||
|
||||
Return the model of this process description.
|
||||
"""
|
||||
model(::GenericQEDProcess) = QEDModel()
|
||||
model(::PhaseSpacePoint) = QEDModel()
|
||||
|
||||
function get_particle(
|
||||
input::PhaseSpacePoint,
|
||||
t::Type{ParticleStateful{DIR, SPECIES}},
|
||||
n::Int,
|
||||
) where {DIR <: ParticleDirection, SPECIES <: AbstractParticleType}
|
||||
i = 0
|
||||
for p in particles(input, DIR())
|
||||
if p isa t
|
||||
i += 1
|
||||
if i == n
|
||||
return p
|
||||
end
|
||||
end
|
||||
end
|
||||
@assert false "Invalid type given"
|
||||
end
|
@ -1,14 +0,0 @@
|
||||
using QEDprocesses
|
||||
|
||||
# add type overload for number_particles function
|
||||
@inline function QEDprocesses.number_particles(
|
||||
proc_def::QEDbase.AbstractProcessDefinition,
|
||||
::Type{PS},
|
||||
) where {
|
||||
DIR <: QEDbase.ParticleDirection,
|
||||
PT <: QEDbase.AbstractParticleType,
|
||||
EL <: AbstractFourMomentum,
|
||||
PS <: ParticleStateful{DIR, PT, EL},
|
||||
}
|
||||
return QEDprocesses.number_particles(proc_def, DIR(), PT())
|
||||
end
|
@ -0,0 +1,10 @@
|
||||
|
||||
"""
|
||||
show(io::IO, particleValue::ParticleValue)
|
||||
|
||||
Pretty print a [`ParticleValue`](@ref), no newlines.
|
||||
"""
|
||||
function show(io::IO, particleValue::ParticleValue)
|
||||
print(io, "($(particleValue.p), value: $(particleValue.v))")
|
||||
return nothing
|
||||
end
|
@ -1,40 +1,23 @@
|
||||
using StaticArrays
|
||||
|
||||
construction_string(::Incoming) = "Incoming()"
|
||||
construction_string(::Outgoing) = "Outgoing()"
|
||||
"""
|
||||
compute(::ComputeTaskQED_P, data::QEDParticleValue)
|
||||
|
||||
construction_string(::Electron) = "Electron()"
|
||||
construction_string(::Positron) = "Positron()"
|
||||
construction_string(::Photon) = "Photon()"
|
||||
|
||||
construction_string(::PolX) = "PolX()"
|
||||
construction_string(::PolY) = "PolY()"
|
||||
construction_string(::SpinUp) = "SpinUp()"
|
||||
construction_string(::SpinDown) = "SpinDown()"
|
||||
|
||||
function input_expr(instance::GenericQEDProcess, name::String, psp_symbol::Symbol)
|
||||
(type, index) = type_index_from_name(QEDModel(), name)
|
||||
|
||||
return Meta.parse(
|
||||
"ParticleValueSP(
|
||||
$type(momentum($psp_symbol, $(construction_string(particle_direction(type))), $(construction_string(particle_species(type))), Val($index))),
|
||||
0.0im,
|
||||
$(construction_string(spin_or_pol(instance, type, index))),
|
||||
)",
|
||||
)
|
||||
Return the particle as is and initialize the Value.
|
||||
"""
|
||||
function compute(::ComputeTaskQED_P, data::QEDParticleValue{P}) where {P <: QEDParticle}
|
||||
# TODO do we actually need this for anything?
|
||||
return ParticleValue{P, DiracMatrix}(data.p, one(DiracMatrix))
|
||||
end
|
||||
|
||||
"""
|
||||
compute(::ComputeTaskQED_U, data::ParticleValueSP)
|
||||
compute(::ComputeTaskQED_U, data::QEDParticleValue)
|
||||
|
||||
Compute an outer edge. Return the particle value with the same particle and the value multiplied by an outer_edge factor.
|
||||
"""
|
||||
function compute(
|
||||
::ComputeTaskQED_U,
|
||||
data::ParticleValueSP{P, SP, V},
|
||||
) where {P <: ParticleStateful, V <: ValueType, SP <: AbstractSpinOrPolarization}
|
||||
function compute(::ComputeTaskQED_U, data::PV) where {P <: QEDParticle, PV <: QEDParticleValue{P}}
|
||||
part::P = data.p
|
||||
state = base_state(particle_species(part), particle_direction(part), momentum(part), SP())
|
||||
state = base_state(particle(part), direction(part), momentum(part), spin_or_pol(part))
|
||||
return ParticleValue{P, typeof(state)}(
|
||||
data.p,
|
||||
state, # will return a SLorentzVector{ComplexF64}, BiSpinor or AdjointBiSpinor
|
||||
@ -42,15 +25,15 @@ function compute(
|
||||
end
|
||||
|
||||
"""
|
||||
compute(::ComputeTaskQED_V, data1::ParticleValue, data2::ParticleValue)
|
||||
compute(::ComputeTaskQED_V, data1::QEDParticleValue, data2::QEDParticleValue)
|
||||
|
||||
Compute a vertex. Preserve momentum and particle types (e + gamma->p etc.) to create resulting particle, multiply values together and times a vertex factor.
|
||||
"""
|
||||
function compute(
|
||||
::ComputeTaskQED_V,
|
||||
data1::ParticleValue{P1, V1},
|
||||
data2::ParticleValue{P2, V2},
|
||||
) where {P1 <: ParticleStateful, P2 <: ParticleStateful, V1 <: ValueType, V2 <: ValueType}
|
||||
data1::PV1,
|
||||
data2::PV2,
|
||||
) where {P1 <: QEDParticle, P2 <: QEDParticle, PV1 <: QEDParticleValue{P1}, PV2 <: QEDParticleValue{P2}}
|
||||
p3 = QED_conserve_momentum(data1.p, data2.p)
|
||||
P3 = interaction_result(P1, P2)
|
||||
state = QED_vertex()
|
||||
@ -70,7 +53,7 @@ function compute(
|
||||
end
|
||||
|
||||
"""
|
||||
compute(::ComputeTaskQED_S2, data1::ParticleValue, data2::ParticleValue)
|
||||
compute(::ComputeTaskQED_S2, data1::QEDParticleValue, data2::QEDParticleValue)
|
||||
|
||||
Compute a final inner edge (2 input particles, no output particle).
|
||||
|
||||
@ -80,19 +63,9 @@ For valid inputs, both input particles should have the same momenta at this poin
|
||||
"""
|
||||
function compute(
|
||||
::ComputeTaskQED_S2,
|
||||
data1::ParticleValue{P1, V1},
|
||||
data2::ParticleValue{P2, V2},
|
||||
) where {
|
||||
D1 <: ParticleDirection,
|
||||
D2 <: ParticleDirection,
|
||||
S1 <: Union{Electron, Positron},
|
||||
S2 <: Union{Electron, Positron},
|
||||
V1 <: ValueType,
|
||||
V2 <: ValueType,
|
||||
EL <: AbstractFourMomentum,
|
||||
P1 <: ParticleStateful{D1, S1, EL},
|
||||
P2 <: ParticleStateful{D2, S2, EL},
|
||||
}
|
||||
data1::ParticleValue{P1},
|
||||
data2::ParticleValue{P2},
|
||||
) where {P1 <: Union{AntiFermionStateful, FermionStateful}, P2 <: Union{AntiFermionStateful, FermionStateful}}
|
||||
#@assert isapprox(data1.p.momentum, data2.p.momentum, rtol = sqrt(eps()), atol = sqrt(eps())) "$(data1.p.momentum) vs. $(data2.p.momentum)"
|
||||
|
||||
inner = QED_inner_edge(propagation_result(P1)(momentum(data1.p)))
|
||||
@ -107,9 +80,9 @@ end
|
||||
|
||||
function compute(
|
||||
::ComputeTaskQED_S2,
|
||||
data1::ParticleValue{ParticleStateful{D1, Photon}, V1},
|
||||
data2::ParticleValue{ParticleStateful{D2, Photon}, V2},
|
||||
) where {D1 <: ParticleDirection, D2 <: ParticleDirection, V1 <: ValueType, V2 <: ValueType}
|
||||
data1::ParticleValue{P1},
|
||||
data2::ParticleValue{P2},
|
||||
) where {P1 <: PhotonStateful, P2 <: PhotonStateful}
|
||||
# TODO: assert that data1 and data2 are opposites
|
||||
inner = QED_inner_edge(data1.p)
|
||||
# inner edge is just a scalar, data1 and data2 are photon states that are just Complex numbers here
|
||||
@ -117,11 +90,11 @@ function compute(
|
||||
end
|
||||
|
||||
"""
|
||||
compute(::ComputeTaskQED_S1, data::ParticleValue)
|
||||
compute(::ComputeTaskQED_S1, data::QEDParticleValue)
|
||||
|
||||
Compute inner edge (1 input particle, 1 output particle).
|
||||
"""
|
||||
function compute(::ComputeTaskQED_S1, data::ParticleValue{P, V}) where {P <: ParticleStateful, V <: ValueType}
|
||||
function compute(::ComputeTaskQED_S1, data::QEDParticleValue{P}) where {P <: QEDParticle}
|
||||
newP = propagation_result(P)
|
||||
new_p = newP(momentum(data.p))
|
||||
# inner edge is just a scalar, can multiply from either side
|
@ -1,58 +1,83 @@
|
||||
|
||||
ComputeTaskQED_Sum() = ComputeTaskQED_Sum(0)
|
||||
|
||||
function _svector_from_type(processDescription::GenericQEDProcess, type, particles)
|
||||
function _svector_from_type(processDescription::QEDProcessDescription, type, particles)
|
||||
if haskey(in_particles(processDescription), type)
|
||||
return SVector{in_particles(processDescription)[type], type}(filter(x -> typeof(x) <: type, particles))
|
||||
end
|
||||
if haskey(out_particles(processDescription), type)
|
||||
return SVector{out_particles(processDescription)[type], type}(filter(x -> typeof(x) <: type, particles))
|
||||
end
|
||||
return SVector{0, type}()
|
||||
end
|
||||
|
||||
"""
|
||||
gen_process_input(processDescription::GenericQEDProcess)
|
||||
gen_process_input(processDescription::QEDProcessDescription)
|
||||
|
||||
Return a ProcessInput of randomly generated [`QEDParticle`](@ref)s from a [`GenericQEDProcess`](@ref). The process description can be created manually or parsed from a string using [`parse_process`](@ref).
|
||||
Return a ProcessInput of randomly generated [`QEDParticle`](@ref)s from a [`QEDProcessDescription`](@ref). The process description can be created manually or parsed from a string using [`parse_process`](@ref).
|
||||
|
||||
Note: This uses RAMBO to create a valid process with conservation of momentum and energy.
|
||||
"""
|
||||
function gen_process_input(processDescription::GenericQEDProcess)
|
||||
function gen_process_input(processDescription::QEDProcessDescription)
|
||||
massSum = 0
|
||||
inputMasses = Vector{Float64}()
|
||||
for particle in incoming_particles(processDescription)
|
||||
massSum += mass(particle)
|
||||
push!(inputMasses, mass(particle))
|
||||
for (particle, n) in processDescription.inParticles
|
||||
for _ in 1:n
|
||||
massSum += mass(particle)
|
||||
push!(inputMasses, mass(particle))
|
||||
end
|
||||
end
|
||||
outputMasses = Vector{Float64}()
|
||||
for particle in outgoing_particles(processDescription)
|
||||
massSum += mass(particle)
|
||||
push!(outputMasses, mass(particle))
|
||||
for (particle, n) in processDescription.outParticles
|
||||
for _ in 1:n
|
||||
massSum += mass(particle)
|
||||
push!(outputMasses, mass(particle))
|
||||
end
|
||||
end
|
||||
|
||||
# add some extra random mass to allow for some momentum
|
||||
massSum += rand(rng[threadid()]) * (length(inputMasses) + length(outputMasses))
|
||||
|
||||
initial_momenta = generate_initial_moms(massSum, inputMasses)
|
||||
final_momenta = generate_physical_massive_moms(rng[threadid()], massSum, outputMasses)
|
||||
|
||||
processInput = PhaseSpacePoint(
|
||||
processDescription,
|
||||
PerturbativeQED(),
|
||||
PhasespaceDefinition(SphericalCoordinateSystem(), ElectronRestFrame()),
|
||||
tuple(initial_momenta...),
|
||||
tuple(final_momenta...),
|
||||
)
|
||||
particles = Vector{QEDParticle}()
|
||||
initialMomenta = generate_initial_moms(massSum, inputMasses)
|
||||
index = 1
|
||||
for (particle, n) in processDescription.inParticles
|
||||
for _ in 1:n
|
||||
mom = initialMomenta[index]
|
||||
push!(particles, particle(mom))
|
||||
index += 1
|
||||
end
|
||||
end
|
||||
|
||||
final_momenta = generate_physical_massive_moms(rng[threadid()], massSum, outputMasses)
|
||||
index = 1
|
||||
for (particle, n) in processDescription.outParticles
|
||||
for _ in 1:n
|
||||
push!(particles, particle(final_momenta[index]))
|
||||
index += 1
|
||||
end
|
||||
end
|
||||
|
||||
inFerms = _svector_from_type(processDescription, FermionStateful{Incoming, SpinUp}, particles)
|
||||
outFerms = _svector_from_type(processDescription, FermionStateful{Outgoing, SpinUp}, particles)
|
||||
inAntiferms = _svector_from_type(processDescription, AntiFermionStateful{Incoming, SpinUp}, particles)
|
||||
outAntiferms = _svector_from_type(processDescription, AntiFermionStateful{Outgoing, SpinUp}, particles)
|
||||
inPhotons = _svector_from_type(processDescription, PhotonStateful{Incoming, PolX}, particles)
|
||||
outPhotons = _svector_from_type(processDescription, PhotonStateful{Outgoing, PolX}, particles)
|
||||
|
||||
processInput =
|
||||
QEDProcessInput(processDescription, inFerms, outFerms, inAntiferms, outAntiferms, inPhotons, outPhotons)
|
||||
|
||||
return processInput
|
||||
end
|
||||
|
||||
"""
|
||||
gen_graph(process_description::GenericQEDProcess)
|
||||
gen_graph(process_description::QEDProcessDescription)
|
||||
|
||||
For a given [`GenericQEDProcess`](@ref), return the [`DAG`](@ref) that computes it.
|
||||
For a given [`QEDProcessDescription`](@ref), return the [`DAG`](@ref) that computes it.
|
||||
"""
|
||||
function gen_graph(process_description::GenericQEDProcess)
|
||||
function gen_graph(process_description::QEDProcessDescription)
|
||||
initial_diagram = FeynmanDiagram(process_description)
|
||||
diagrams = gen_diagrams(initial_diagram)
|
||||
|
||||
@ -63,9 +88,9 @@ function gen_graph(process_description::GenericQEDProcess)
|
||||
|
||||
# TODO: Not all diagram outputs should always be summed at the end, if they differ by fermion exchange they need to be diffed
|
||||
# Should not matter for n-Photon Compton processes though
|
||||
sum_node = insert_node!(graph, make_node(ComputeTaskQED_Sum(0)); track = false, invalidate_cache = false)
|
||||
global_data_out = insert_node!(graph, make_node(DataTask(COMPLEX_SIZE)); track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, sum_node, global_data_out; track = false, invalidate_cache = false)
|
||||
sum_node = insert_node!(graph, make_node(ComputeTaskQED_Sum(0)), track = false, invalidate_cache = false)
|
||||
global_data_out = insert_node!(graph, make_node(DataTask(COMPLEX_SIZE)), track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, sum_node, global_data_out, track = false, invalidate_cache = false)
|
||||
|
||||
# remember the data out nodes for connection
|
||||
dataOutNodes = Dict()
|
||||
@ -74,16 +99,16 @@ function gen_graph(process_description::GenericQEDProcess)
|
||||
# generate data in and U tasks
|
||||
data_in = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE), String(particle));
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE), String(particle)),
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
) # read particle data node
|
||||
compute_u = insert_node!(graph, make_node(ComputeTaskQED_U()); track = false, invalidate_cache = false) # compute U node
|
||||
compute_u = insert_node!(graph, make_node(ComputeTaskQED_U()), track = false, invalidate_cache = false) # compute U node
|
||||
data_out =
|
||||
insert_node!(graph, make_node(DataTask(PARTICLE_VALUE_SIZE)); track = false, invalidate_cache = false) # transfer data out from u (one ABCParticleValue object)
|
||||
insert_node!(graph, make_node(DataTask(PARTICLE_VALUE_SIZE)), track = false, invalidate_cache = false) # transfer data out from u (one ABCParticleValue object)
|
||||
|
||||
insert_edge!(graph, data_in, compute_u; track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, compute_u, data_out; track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, data_in, compute_u, track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, compute_u, data_out, track = false, invalidate_cache = false)
|
||||
|
||||
# remember the data_out node for future edges
|
||||
dataOutNodes[String(particle)] = data_out
|
||||
@ -99,19 +124,19 @@ function gen_graph(process_description::GenericQEDProcess)
|
||||
data_in1 = dataOutNodes[String(vertex.in1)]
|
||||
data_in2 = dataOutNodes[String(vertex.in2)]
|
||||
|
||||
compute_V = insert_node!(graph, make_node(ComputeTaskQED_V()); track = false, invalidate_cache = false) # compute vertex
|
||||
compute_V = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false, invalidate_cache = false) # compute vertex
|
||||
|
||||
insert_edge!(graph, data_in1, compute_V; track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, data_in2, compute_V; track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, data_in1, compute_V, track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, data_in2, compute_V, track = false, invalidate_cache = false)
|
||||
|
||||
data_V_out = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE));
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE)),
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
|
||||
insert_edge!(graph, compute_V, data_V_out; track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, compute_V, data_V_out, track = false, invalidate_cache = false)
|
||||
|
||||
if (vertex.out == tie.in1 || vertex.out == tie.in2)
|
||||
# out particle is part of the tie -> there will be an S2 task with it later, don't make S1 task
|
||||
@ -121,18 +146,18 @@ function gen_graph(process_description::GenericQEDProcess)
|
||||
|
||||
# otherwise, add S1 task
|
||||
compute_S1 =
|
||||
insert_node!(graph, make_node(ComputeTaskQED_S1()); track = false, invalidate_cache = false) # compute propagator
|
||||
insert_node!(graph, make_node(ComputeTaskQED_S1()), track = false, invalidate_cache = false) # compute propagator
|
||||
|
||||
insert_edge!(graph, data_V_out, compute_S1; track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, data_V_out, compute_S1, track = false, invalidate_cache = false)
|
||||
|
||||
data_S1_out = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE));
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE)),
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
|
||||
insert_edge!(graph, compute_S1, data_S1_out; track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, compute_S1, data_S1_out, track = false, invalidate_cache = false)
|
||||
|
||||
# overrides potentially different nodes from previous diagrams, which is intentional
|
||||
dataOutNodes[String(vertex.out)] = data_S1_out
|
||||
@ -143,16 +168,16 @@ function gen_graph(process_description::GenericQEDProcess)
|
||||
data_in1 = dataOutNodes[String(tie.in1)]
|
||||
data_in2 = dataOutNodes[String(tie.in2)]
|
||||
|
||||
compute_S2 = insert_node!(graph, make_node(ComputeTaskQED_S2()); track = false, invalidate_cache = false)
|
||||
compute_S2 = insert_node!(graph, make_node(ComputeTaskQED_S2()), track = false, invalidate_cache = false)
|
||||
|
||||
data_S2 = insert_node!(graph, make_node(DataTask(PARTICLE_VALUE_SIZE)); track = false, invalidate_cache = false)
|
||||
data_S2 = insert_node!(graph, make_node(DataTask(PARTICLE_VALUE_SIZE)), track = false, invalidate_cache = false)
|
||||
|
||||
insert_edge!(graph, data_in1, compute_S2; track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, data_in2, compute_S2; track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, data_in1, compute_S2, track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, data_in2, compute_S2, track = false, invalidate_cache = false)
|
||||
|
||||
insert_edge!(graph, compute_S2, data_S2; track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, compute_S2, data_S2, track = false, invalidate_cache = false)
|
||||
|
||||
insert_edge!(graph, data_S2, sum_node; track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, data_S2, sum_node, track = false, invalidate_cache = false)
|
||||
add_child!(task(sum_node))
|
||||
end
|
||||
|
@ -8,10 +8,10 @@ import Base.show
|
||||
"""
|
||||
FeynmanParticle
|
||||
|
||||
Representation of a particle for use in [`FeynmanDiagram`](@ref)s. Consist of the `ParticleStateful` type and an id.
|
||||
Representation of a particle for use in [`FeynmanDiagram`](@ref)s. Consist of the [`QEDParticle`](@ref) type and an id.
|
||||
"""
|
||||
struct FeynmanParticle
|
||||
particle::Type{<:ParticleStateful}
|
||||
particle::Type{<:QEDParticle}
|
||||
id::Int
|
||||
end
|
||||
|
||||
@ -51,21 +51,31 @@ struct FeynmanDiagram
|
||||
end
|
||||
|
||||
"""
|
||||
FeynmanDiagram(pd::GenericQEDProcess)
|
||||
FeynmanDiagram(pd::QEDProcessDescription)
|
||||
|
||||
Create an initial [`FeynmanDiagram`](@ref) with only its initial particles set and no vertices or ties.
|
||||
|
||||
Use [`gen_diagrams`](@ref) to generate all possible diagrams from this one.
|
||||
"""
|
||||
function FeynmanDiagram(pd::GenericQEDProcess)
|
||||
function FeynmanDiagram(pd::QEDProcessDescription)
|
||||
parts = Vector{FeynmanParticle}()
|
||||
|
||||
ids = Dict{Type, Int64}()
|
||||
for type in types(model(pd))
|
||||
for i in 1:number_particles(pd, type)
|
||||
for (type, n) in pd.inParticles
|
||||
for i in 1:n
|
||||
push!(parts, FeynmanParticle(type, i))
|
||||
end
|
||||
ids[type] = number_particles(pd, type)
|
||||
end
|
||||
for (type, n) in pd.outParticles
|
||||
for i in 1:n
|
||||
push!(parts, FeynmanParticle(type, i))
|
||||
end
|
||||
end
|
||||
ids = Dict{Type, Int64}()
|
||||
for t in types(QEDModel())
|
||||
if (isincoming(t))
|
||||
ids[t] = get(pd.inParticles, t, 0)
|
||||
else
|
||||
ids[t] = get(pd.outParticles, t, 0)
|
||||
end
|
||||
end
|
||||
|
||||
return FeynmanDiagram([], missing, parts, ids)
|
||||
@ -73,7 +83,7 @@ end
|
||||
|
||||
function particle_after_tie(p::FeynmanParticle, t::FeynmanTie)
|
||||
if p == t.in1 || p == t.in2
|
||||
return FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, -1) # placeholder particle and id for tied particles
|
||||
return FeynmanParticle(FermionStateful{Incoming, SpinUp}, -1) # placeholder particle and id for tied particles
|
||||
end
|
||||
return p
|
||||
end
|
||||
@ -104,7 +114,7 @@ end
|
||||
Return a string representation of the [`FeynmanParticle`](@ref) in a format that is readable by [`type_index_from_name`](@ref).
|
||||
"""
|
||||
function String(p::FeynmanParticle)
|
||||
return "$(String(p.particle))$(String(particle_direction(p.particle)))$(p.id)"
|
||||
return "$(String(p.particle))$(String(direction(p.particle)))$(p.id)"
|
||||
end
|
||||
|
||||
function hash(v::FeynmanVertex)
|
||||
@ -152,16 +162,15 @@ function ==(d1::FeynmanDiagram, d2::FeynmanDiagram)
|
||||
)=#
|
||||
end
|
||||
|
||||
function copy(fd::FeynmanDiagram)
|
||||
return FeynmanDiagram(deepcopy(fd.vertices), copy(fd.tie[]), deepcopy(fd.particles), copy(fd.type_ids))
|
||||
end
|
||||
copy(fd::FeynmanDiagram) =
|
||||
FeynmanDiagram(deepcopy(fd.vertices), copy(fd.tie[]), deepcopy(fd.particles), copy(fd.type_ids))
|
||||
|
||||
"""
|
||||
id_for_type(d::FeynmanDiagram, t::Type{<:ParticleStateful})
|
||||
id_for_type(d::FeynmanDiagram, t::Type{<:QEDParticle})
|
||||
|
||||
Return the highest id of any particle of the given type in the diagram + 1.
|
||||
"""
|
||||
function id_for_type(d::FeynmanDiagram, t::Type{<:ParticleStateful})
|
||||
function id_for_type(d::FeynmanDiagram, t::Type{<:QEDParticle})
|
||||
return d.type_ids[t] + 1
|
||||
end
|
||||
|
||||
@ -430,19 +439,18 @@ function remove_duplicates(compare_set::Set{FeynmanDiagram})
|
||||
return result
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
is_compton(fd::FeynmanDiagram)
|
||||
|
||||
Returns true iff the given feynman diagram is an (empty) diagram of a compton process like ke->k^ne
|
||||
"""
|
||||
function is_compton(fd::FeynmanDiagram)
|
||||
return fd.type_ids[ParticleStateful{Incoming, Electron, SFourMomentum}] == 1 &&
|
||||
fd.type_ids[ParticleStateful{Outgoing, Electron, SFourMomentum}] == 1 &&
|
||||
fd.type_ids[ParticleStateful{Incoming, Positron, SFourMomentum}] == 0 &&
|
||||
fd.type_ids[ParticleStateful{Outgoing, Positron, SFourMomentum}] == 0 &&
|
||||
fd.type_ids[ParticleStateful{Incoming, Photon, SFourMomentum}] >= 1 &&
|
||||
fd.type_ids[ParticleStateful{Outgoing, Photon, SFourMomentum}] >= 1
|
||||
return fd.type_ids[FermionStateful{Incoming, SpinUp}] == 1 &&
|
||||
fd.type_ids[FermionStateful{Outgoing, SpinUp}] == 1 &&
|
||||
fd.type_ids[AntiFermionStateful{Incoming, SpinUp}] == 0 &&
|
||||
fd.type_ids[AntiFermionStateful{Outgoing, SpinUp}] == 0 &&
|
||||
fd.type_ids[PhotonStateful{Incoming, PolX}] >= 1 &&
|
||||
fd.type_ids[PhotonStateful{Outgoing, PolX}] >= 1
|
||||
end
|
||||
|
||||
"""
|
||||
@ -452,8 +460,8 @@ Helper function for [`gen_compton_diagrams`](@Ref). Generates a single diagram f
|
||||
"""
|
||||
function gen_compton_diagram_from_order(order::Vector{Int}, inFerm, outFerm, n::Int, m::Int)
|
||||
photons = vcat(
|
||||
[FeynmanParticle(ParticleStateful{Incoming, Photon, SFourMomentum}, i) for i in 1:n],
|
||||
[FeynmanParticle(ParticleStateful{Outgoing, Photon, SFourMomentum}, i) for i in 1:m],
|
||||
[FeynmanParticle(PhotonStateful{Incoming, PolX}, i) for i in 1:n],
|
||||
[FeynmanParticle(PhotonStateful{Outgoing, PolX}, i) for i in 1:m],
|
||||
)
|
||||
|
||||
new_diagram = FeynmanDiagram(
|
||||
@ -461,10 +469,10 @@ function gen_compton_diagram_from_order(order::Vector{Int}, inFerm, outFerm, n::
|
||||
missing,
|
||||
[inFerm, outFerm, photons...],
|
||||
Dict{Type, Int64}(
|
||||
ParticleStateful{Incoming, Electron, SFourMomentum} => 1,
|
||||
ParticleStateful{Outgoing, Electron, SFourMomentum} => 1,
|
||||
ParticleStateful{Incoming, Photon, SFourMomentum} => n,
|
||||
ParticleStateful{Outgoing, Photon, SFourMomentum} => m,
|
||||
FermionStateful{Incoming, SpinUp} => 1,
|
||||
FermionStateful{Outgoing, SpinUp} => 1,
|
||||
PhotonStateful{Incoming, PolX} => n,
|
||||
PhotonStateful{Outgoing, PolX} => m,
|
||||
),
|
||||
)
|
||||
|
||||
@ -476,9 +484,9 @@ function gen_compton_diagram_from_order(order::Vector{Int}, inFerm, outFerm, n::
|
||||
while left_index <= right_index
|
||||
# left side
|
||||
v_left = FeynmanVertex(
|
||||
FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, iterations),
|
||||
FeynmanParticle(FermionStateful{Incoming, SpinUp}, iterations),
|
||||
photons[order[left_index]],
|
||||
FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, iterations + 1),
|
||||
FeynmanParticle(FermionStateful{Incoming, SpinUp}, iterations + 1),
|
||||
)
|
||||
left_index += 1
|
||||
add_vertex!(new_diagram, v_left)
|
||||
@ -489,9 +497,9 @@ function gen_compton_diagram_from_order(order::Vector{Int}, inFerm, outFerm, n::
|
||||
|
||||
# right side
|
||||
v_right = FeynmanVertex(
|
||||
FeynmanParticle(ParticleStateful{Outgoing, Electron, SFourMomentum}, iterations),
|
||||
FeynmanParticle(FermionStateful{Outgoing, SpinUp}, iterations),
|
||||
photons[order[right_index]],
|
||||
FeynmanParticle(ParticleStateful{Outgoing, Electron, SFourMomentum}, iterations + 1),
|
||||
FeynmanParticle(FermionStateful{Outgoing, SpinUp}, iterations + 1),
|
||||
)
|
||||
right_index -= 1
|
||||
add_vertex!(new_diagram, v_right)
|
||||
@ -504,6 +512,7 @@ function gen_compton_diagram_from_order(order::Vector{Int}, inFerm, outFerm, n::
|
||||
return new_diagram
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, outFerm, n::Int, m::Int)
|
||||
|
||||
@ -511,8 +520,8 @@ Helper function for [`gen_compton_diagrams`](@Ref). Generates a single diagram f
|
||||
"""
|
||||
function gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, outFerm, n::Int, m::Int)
|
||||
photons = vcat(
|
||||
[FeynmanParticle(ParticleStateful{Incoming, Photon, SFourMomentum}, i) for i in 1:n],
|
||||
[FeynmanParticle(ParticleStateful{Outgoing, Photon, SFourMomentum}, i) for i in 1:m],
|
||||
[FeynmanParticle(PhotonStateful{Incoming, PolX}, i) for i in 1:n],
|
||||
[FeynmanParticle(PhotonStateful{Outgoing, PolX}, i) for i in 1:m],
|
||||
)
|
||||
|
||||
new_diagram = FeynmanDiagram(
|
||||
@ -520,10 +529,10 @@ function gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, out
|
||||
missing,
|
||||
[inFerm, outFerm, photons...],
|
||||
Dict{Type, Int64}(
|
||||
ParticleStateful{Incoming, Electron, SFourMomentum} => 1,
|
||||
ParticleStateful{Outgoing, Electron, SFourMomentum} => 1,
|
||||
ParticleStateful{Incoming, Photon, SFourMomentum} => n,
|
||||
ParticleStateful{Outgoing, Photon, SFourMomentum} => m,
|
||||
FermionStateful{Incoming, SpinUp} => 1,
|
||||
FermionStateful{Outgoing, SpinUp} => 1,
|
||||
PhotonStateful{Incoming, PolX} => n,
|
||||
PhotonStateful{Outgoing, PolX} => m,
|
||||
),
|
||||
)
|
||||
|
||||
@ -535,9 +544,9 @@ function gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, out
|
||||
while left_index <= right_index
|
||||
# left side
|
||||
v_left = FeynmanVertex(
|
||||
FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, iterations),
|
||||
FeynmanParticle(FermionStateful{Incoming, SpinUp}, iterations),
|
||||
photons[order[left_index]],
|
||||
FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, iterations + 1),
|
||||
FeynmanParticle(FermionStateful{Incoming, SpinUp}, iterations + 1),
|
||||
)
|
||||
left_index += 1
|
||||
add_vertex!(new_diagram, v_left)
|
||||
@ -550,9 +559,9 @@ function gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, out
|
||||
if (iterations == 1)
|
||||
# right side
|
||||
v_right = FeynmanVertex(
|
||||
FeynmanParticle(ParticleStateful{Outgoing, Electron, SFourMomentum}, iterations),
|
||||
FeynmanParticle(FermionStateful{Outgoing, SpinUp}, iterations),
|
||||
photons[order[right_index]],
|
||||
FeynmanParticle(ParticleStateful{Outgoing, Electron, SFourMomentum}, iterations + 1),
|
||||
FeynmanParticle(FermionStateful{Outgoing, SpinUp}, iterations + 1),
|
||||
)
|
||||
right_index -= 1
|
||||
add_vertex!(new_diagram, v_right)
|
||||
@ -567,14 +576,15 @@ function gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, out
|
||||
return new_diagram
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
gen_compton_diagrams(n::Int, m::Int)
|
||||
|
||||
Special case diagram generation for Compton processes, i.e., processes of the form k^ne->k^me
|
||||
"""
|
||||
function gen_compton_diagrams(n::Int, m::Int)
|
||||
inFerm = FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, 1)
|
||||
outFerm = FeynmanParticle(ParticleStateful{Outgoing, Electron, SFourMomentum}, 1)
|
||||
inFerm = FeynmanParticle(FermionStateful{Incoming, SpinUp}, 1)
|
||||
outFerm = FeynmanParticle(FermionStateful{Outgoing, SpinUp}, 1)
|
||||
|
||||
perms = [permutations([i for i in 1:(n + m)])...]
|
||||
|
||||
@ -586,14 +596,15 @@ function gen_compton_diagrams(n::Int, m::Int)
|
||||
return vcat(diagrams...)
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
gen_compton_diagrams_one_side(n::Int, m::Int)
|
||||
|
||||
Special case diagram generation for Compton processes, i.e., processes of the form k^ne->k^me, but generating from one end, yielding larger diagrams
|
||||
"""
|
||||
function gen_compton_diagrams_one_side(n::Int, m::Int)
|
||||
inFerm = FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, 1)
|
||||
outFerm = FeynmanParticle(ParticleStateful{Outgoing, Electron, SFourMomentum}, 1)
|
||||
inFerm = FeynmanParticle(FermionStateful{Incoming, SpinUp}, 1)
|
||||
outFerm = FeynmanParticle(FermionStateful{Outgoing, SpinUp}, 1)
|
||||
|
||||
perms = [permutations([i for i in 1:(n + m)])...]
|
||||
|
||||
@ -612,15 +623,12 @@ From a given feynman diagram in its initial state, e.g. when created through the
|
||||
"""
|
||||
function gen_diagrams(fd::FeynmanDiagram)
|
||||
if is_compton(fd)
|
||||
return gen_compton_diagrams(
|
||||
fd.type_ids[ParticleStateful{Incoming, Photon, SFourMomentum}],
|
||||
fd.type_ids[ParticleStateful{Outgoing, Photon, SFourMomentum}],
|
||||
return gen_compton_diagrams_one_side(
|
||||
fd.type_ids[PhotonStateful{Incoming, PolX}],
|
||||
fd.type_ids[PhotonStateful{Outgoing, PolX}],
|
||||
)
|
||||
end
|
||||
|
||||
throw(error("Unimplemented for non-compton!"))
|
||||
|
||||
#=
|
||||
working = Set{FeynmanDiagram}()
|
||||
results = Set{FeynmanDiagram}()
|
||||
|
||||
@ -659,5 +667,4 @@ function gen_diagrams(fd::FeynmanDiagram)
|
||||
end
|
||||
|
||||
return remove_duplicates(results)
|
||||
=#
|
||||
end
|
44
src/models/qed/parse.jl
Normal file
44
src/models/qed/parse.jl
Normal file
@ -0,0 +1,44 @@
|
||||
|
||||
"""
|
||||
parse_process(string::AbstractString, model::QEDModel)
|
||||
|
||||
Parse a string representation of a process, such as "ke->ke" into the corresponding [`QEDProcessDescription`](@ref).
|
||||
"""
|
||||
function parse_process(str::AbstractString, model::QEDModel)
|
||||
inParticles = Dict{Type, Int}()
|
||||
outParticles = Dict{Type, Int}()
|
||||
|
||||
if !(contains(str, "->"))
|
||||
throw("Did not find -> while parsing process \"$str\"")
|
||||
end
|
||||
|
||||
(inStr, outStr) = split(str, "->")
|
||||
|
||||
if (isempty(inStr) || isempty(outStr))
|
||||
throw("Process (\"$str\") input or output part is empty!")
|
||||
end
|
||||
|
||||
for t in types(model)
|
||||
if (isincoming(t))
|
||||
inCount = count(x -> x == String(t)[1], inStr)
|
||||
|
||||
if inCount != 0
|
||||
inParticles[t] = inCount
|
||||
end
|
||||
end
|
||||
if (isoutgoing(t))
|
||||
outCount = count(x -> x == String(t)[1], outStr)
|
||||
if outCount != 0
|
||||
outParticles[t] = outCount
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if length(inStr) != sum(values(inParticles))
|
||||
throw("Encountered unknown characters in the input part of process \"$str\"")
|
||||
elseif length(outStr) != sum(values(outParticles))
|
||||
throw("Encountered unknown characters in the output part of process \"$str\"")
|
||||
end
|
||||
|
||||
return QEDProcessDescription(inParticles, outParticles)
|
||||
end
|
408
src/models/qed/particle.jl
Normal file
408
src/models/qed/particle.jl
Normal file
@ -0,0 +1,408 @@
|
||||
using QEDprocesses
|
||||
using StaticArrays
|
||||
import QEDbase.mass
|
||||
|
||||
# TODO check
|
||||
const e = sqrt(4π / 137)
|
||||
|
||||
"""
|
||||
QEDModel <: AbstractPhysicsModel
|
||||
|
||||
Singleton definition for identification of the QED-Model.
|
||||
"""
|
||||
struct QEDModel <: AbstractPhysicsModel end
|
||||
|
||||
"""
|
||||
QEDParticle
|
||||
|
||||
Base type for all particles in the [`QEDModel`](@ref).
|
||||
|
||||
Its template parameter specifies the particle's direction.
|
||||
|
||||
The concrete types contain singletons of the types that they are, like `Photon` and `Electron` from QEDbase, and their state descriptions.
|
||||
"""
|
||||
abstract type QEDParticle{Direction <: ParticleDirection} <: AbstractParticle end
|
||||
|
||||
"""
|
||||
QEDProcessDescription <: AbstractProcessDescription
|
||||
|
||||
A description of a process in the QED-Model. Contains the input and output particles.
|
||||
|
||||
See also: [`in_particles`](@ref), [`out_particles`](@ref), [`parse_process`](@ref)
|
||||
"""
|
||||
struct QEDProcessDescription <: AbstractProcessDescription
|
||||
inParticles::Dict{Type{<:QEDParticle{Incoming}}, Int}
|
||||
outParticles::Dict{Type{<:QEDParticle{Outgoing}}, Int}
|
||||
end
|
||||
|
||||
QEDParticleValue{ParticleType <: QEDParticle} = Union{
|
||||
ParticleValue{ParticleType, BiSpinor},
|
||||
ParticleValue{ParticleType, AdjointBiSpinor},
|
||||
ParticleValue{ParticleType, DiracMatrix},
|
||||
ParticleValue{ParticleType, SLorentzVector{Float64}},
|
||||
ParticleValue{ParticleType, ComplexF64},
|
||||
}
|
||||
|
||||
"""
|
||||
PhotonStateful <: QEDParticle
|
||||
|
||||
A photon of the [`QEDModel`](@ref) with its state.
|
||||
"""
|
||||
struct PhotonStateful{Direction <: ParticleDirection, Pol <: AbstractDefinitePolarization} <: QEDParticle{Direction}
|
||||
momentum::SFourMomentum
|
||||
end
|
||||
|
||||
PhotonStateful{Direction}(mom::SFourMomentum) where {Direction <: ParticleDirection} =
|
||||
PhotonStateful{Direction, PolX}(mom)
|
||||
|
||||
PhotonStateful{Dir, Pol}(ph::PhotonStateful) where {Dir, Pol} = PhotonStateful{Dir, Pol}(ph.momentum)
|
||||
|
||||
"""
|
||||
FermionStateful <: QEDParticle
|
||||
|
||||
A fermion of the [`QEDModel`](@ref) with its state.
|
||||
"""
|
||||
struct FermionStateful{Direction <: ParticleDirection, Spin <: AbstractDefiniteSpin} <: QEDParticle{Direction}
|
||||
momentum::SFourMomentum
|
||||
# TODO: mass for electron/muon/tauon representation?
|
||||
end
|
||||
|
||||
FermionStateful{Direction}(mom::SFourMomentum) where {Direction <: ParticleDirection} =
|
||||
FermionStateful{Direction, SpinUp}(mom)
|
||||
|
||||
FermionStateful{Dir, Spin}(f::FermionStateful) where {Dir, Spin} = FermionStateful{Dir, Spin}(f.momentum)
|
||||
|
||||
"""
|
||||
AntiFermionStateful <: QEDParticle
|
||||
|
||||
An anti-fermion of the [`QEDModel`](@ref) with its state.
|
||||
"""
|
||||
struct AntiFermionStateful{Direction <: ParticleDirection, Spin <: AbstractDefiniteSpin} <: QEDParticle{Direction}
|
||||
momentum::SFourMomentum
|
||||
# TODO: mass for electron/muon/tauon representation?
|
||||
end
|
||||
|
||||
AntiFermionStateful{Direction}(mom::SFourMomentum) where {Direction <: ParticleDirection} =
|
||||
AntiFermionStateful{Direction, SpinUp}(mom)
|
||||
|
||||
AntiFermionStateful{Dir, Spin}(f::AntiFermionStateful) where {Dir, Spin} = AntiFermionStateful{Dir, Spin}(f.momentum)
|
||||
|
||||
"""
|
||||
interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: QEDParticle, T2 <: QEDParticle}
|
||||
|
||||
For two given particle types that can interact, return the third.
|
||||
"""
|
||||
function interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: QEDParticle, T2 <: QEDParticle}
|
||||
@assert false "Invalid interaction between particles of types $t1 and $t2"
|
||||
end
|
||||
|
||||
interaction_result(
|
||||
::Type{FermionStateful{Incoming, Spin1}},
|
||||
::Type{FermionStateful{Outgoing, Spin2}},
|
||||
) where {Spin1, Spin2} = PhotonStateful{Incoming, PolX}
|
||||
interaction_result(
|
||||
::Type{FermionStateful{Incoming, Spin1}},
|
||||
::Type{AntiFermionStateful{Incoming, Spin2}},
|
||||
) where {Spin1, Spin2} = PhotonStateful{Incoming, PolX}
|
||||
interaction_result(::Type{FermionStateful{Incoming, Spin1}}, ::Type{<:PhotonStateful}) where {Spin1} =
|
||||
FermionStateful{Outgoing, SpinUp}
|
||||
|
||||
interaction_result(
|
||||
::Type{FermionStateful{Outgoing, Spin1}},
|
||||
::Type{FermionStateful{Incoming, Spin2}},
|
||||
) where {Spin1, Spin2} = PhotonStateful{Incoming, PolX}
|
||||
interaction_result(
|
||||
::Type{FermionStateful{Outgoing, Spin1}},
|
||||
::Type{AntiFermionStateful{Outgoing, Spin2}},
|
||||
) where {Spin1, Spin2} = PhotonStateful{Incoming, PolX}
|
||||
interaction_result(::Type{FermionStateful{Outgoing, Spin1}}, ::Type{<:PhotonStateful}) where {Spin1} =
|
||||
FermionStateful{Incoming, SpinUp}
|
||||
|
||||
# antifermion mirror
|
||||
interaction_result(::Type{AntiFermionStateful{Incoming, Spin}}, t2::Type{<:QEDParticle}) where {Spin} =
|
||||
interaction_result(FermionStateful{Outgoing, Spin}, t2)
|
||||
interaction_result(::Type{AntiFermionStateful{Outgoing, Spin}}, t2::Type{<:QEDParticle}) where {Spin} =
|
||||
interaction_result(FermionStateful{Incoming, Spin}, t2)
|
||||
|
||||
# photon commutativity
|
||||
interaction_result(t1::Type{<:PhotonStateful}, t2::Type{<:QEDParticle}) = interaction_result(t2, t1)
|
||||
|
||||
# but prevent stack overflow
|
||||
function interaction_result(t1::Type{<:PhotonStateful}, t2::Type{<:PhotonStateful})
|
||||
@assert false "Invalid interaction between particles of types $t1 and $t2"
|
||||
end
|
||||
|
||||
"""
|
||||
propagation_result(t1::Type{T}) where {T <: QEDParticle}
|
||||
|
||||
Return the type of the inverted direction. E.g.
|
||||
"""
|
||||
propagation_result(::Type{FermionStateful{Incoming, Spin}}) where {Spin <: AbstractDefiniteSpin} =
|
||||
FermionStateful{Outgoing, Spin}
|
||||
propagation_result(::Type{FermionStateful{Outgoing, Spin}}) where {Spin <: AbstractDefiniteSpin} =
|
||||
FermionStateful{Incoming, Spin}
|
||||
propagation_result(::Type{AntiFermionStateful{Incoming, Spin}}) where {Spin <: AbstractDefiniteSpin} =
|
||||
AntiFermionStateful{Outgoing, Spin}
|
||||
propagation_result(::Type{AntiFermionStateful{Outgoing, Spin}}) where {Spin <: AbstractDefiniteSpin} =
|
||||
AntiFermionStateful{Incoming, Spin}
|
||||
propagation_result(::Type{PhotonStateful{Incoming, Pol}}) where {Pol <: AbstractDefinitePolarization} =
|
||||
PhotonStateful{Outgoing, Pol}
|
||||
propagation_result(::Type{PhotonStateful{Outgoing, Pol}}) where {Pol <: AbstractDefinitePolarization} =
|
||||
PhotonStateful{Incoming, Pol}
|
||||
|
||||
"""
|
||||
types(::QEDModel)
|
||||
|
||||
Return a Vector of the possible types of particle in the [`QEDModel`](@ref).
|
||||
"""
|
||||
function types(::QEDModel)
|
||||
return [
|
||||
PhotonStateful{Incoming, PolX},
|
||||
PhotonStateful{Outgoing, PolX},
|
||||
FermionStateful{Incoming, SpinUp},
|
||||
FermionStateful{Outgoing, SpinUp},
|
||||
AntiFermionStateful{Incoming, SpinUp},
|
||||
AntiFermionStateful{Outgoing, SpinUp},
|
||||
]
|
||||
end
|
||||
|
||||
# type piracy?
|
||||
String(::Type{Incoming}) = "Incoming"
|
||||
String(::Type{Outgoing}) = "Outgoing"
|
||||
|
||||
String(::Type{PolX}) = "polx"
|
||||
String(::Type{PolY}) = "poly"
|
||||
|
||||
String(::Type{SpinUp}) = "spinup"
|
||||
String(::Type{SpinDown}) = "spindown"
|
||||
|
||||
String(::Incoming) = "i"
|
||||
String(::Outgoing) = "o"
|
||||
|
||||
function String(::Type{<:PhotonStateful})
|
||||
return "k"
|
||||
end
|
||||
function String(::Type{<:FermionStateful})
|
||||
return "e"
|
||||
end
|
||||
function String(::Type{<:AntiFermionStateful})
|
||||
return "p"
|
||||
end
|
||||
|
||||
function unique_name(::Type{PhotonStateful{Dir, Pol}}) where {Dir, Pol}
|
||||
return String(PhotonStateful) * String(Dir) * String(Pol)
|
||||
end
|
||||
function unique_name(::Type{FermionStateful{Dir, Spin}}) where {Dir, Spin}
|
||||
return String(FermionStateful) * String(Dir) * String(Spin)
|
||||
end
|
||||
function unique_name(::Type{AntiFermionStateful{Dir, Spin}}) where {Dir, Spin}
|
||||
return String(AntiFermionStateful) * String(Dir) * String(Spin)
|
||||
end
|
||||
|
||||
@inline particle(::PhotonStateful) = Photon()
|
||||
@inline particle(::FermionStateful) = Electron()
|
||||
@inline particle(::AntiFermionStateful) = Positron()
|
||||
|
||||
@inline momentum(p::PhotonStateful)::SFourMomentum = p.momentum
|
||||
@inline momentum(p::FermionStateful)::SFourMomentum = p.momentum
|
||||
@inline momentum(p::AntiFermionStateful)::SFourMomentum = p.momentum
|
||||
|
||||
@inline spin_or_pol(p::PhotonStateful{Dir, Pol}) where {Dir, Pol <: AbstractDefinitePolarization} = Pol()
|
||||
@inline spin_or_pol(p::FermionStateful{Dir, Spin}) where {Dir, Spin <: AbstractDefiniteSpin} = Spin()
|
||||
@inline spin_or_pol(p::AntiFermionStateful{Dir, Spin}) where {Dir, Spin <: AbstractDefiniteSpin} = Spin()
|
||||
|
||||
@inline direction(
|
||||
::Type{P},
|
||||
) where {P <: Union{FermionStateful{Incoming}, AntiFermionStateful{Incoming}, PhotonStateful{Incoming}}} = Incoming()
|
||||
@inline direction(
|
||||
::Type{P},
|
||||
) where {P <: Union{FermionStateful{Outgoing}, AntiFermionStateful{Outgoing}, PhotonStateful{Outgoing}}} = Outgoing()
|
||||
|
||||
@inline direction(
|
||||
::P,
|
||||
) where {P <: Union{FermionStateful{Incoming}, AntiFermionStateful{Incoming}, PhotonStateful{Incoming}}} = Incoming()
|
||||
@inline direction(
|
||||
::P,
|
||||
) where {P <: Union{FermionStateful{Outgoing}, AntiFermionStateful{Outgoing}, PhotonStateful{Outgoing}}} = Outgoing()
|
||||
|
||||
@inline isincoming(::QEDParticle{Incoming}) = true
|
||||
@inline isincoming(::QEDParticle{Outgoing}) = false
|
||||
@inline isoutgoing(::QEDParticle{Incoming}) = false
|
||||
@inline isoutgoing(::QEDParticle{Outgoing}) = true
|
||||
|
||||
@inline isincoming(::Type{<:QEDParticle{Incoming}}) = true
|
||||
@inline isincoming(::Type{<:QEDParticle{Outgoing}}) = false
|
||||
@inline isoutgoing(::Type{<:QEDParticle{Incoming}}) = false
|
||||
@inline isoutgoing(::Type{<:QEDParticle{Outgoing}}) = true
|
||||
|
||||
@inline mass(::Type{<:FermionStateful}) = 1.0
|
||||
@inline mass(::Type{<:AntiFermionStateful}) = 1.0
|
||||
@inline mass(::Type{<:PhotonStateful}) = 0.0
|
||||
|
||||
@inline invert_momentum(p::FermionStateful{Dir, Spin}) where {Dir, Spin} =
|
||||
FermionStateful{Dir, Spin}(-p.momentum, p.spin)
|
||||
@inline invert_momentum(p::AntiFermionStateful{Dir, Spin}) where {Dir, Spin} =
|
||||
AntiFermionStateful{Dir, Spin}(-p.momentum, p.spin)
|
||||
@inline invert_momentum(k::PhotonStateful{Dir, Spin}) where {Dir, Spin} =
|
||||
PhotonStateful{Dir, Spin}(-k.momentum, k.polarization)
|
||||
|
||||
|
||||
"""
|
||||
caninteract(T1::Type{<:QEDParticle}, T2::Type{<:QEDParticle})
|
||||
|
||||
For two given [`QEDParticle`](@ref) types, return whether they can interact at a vertex. This is equivalent to `!issame(T1, T2)`.
|
||||
|
||||
See also: [`issame`](@ref) and [`interaction_result`](@ref)
|
||||
"""
|
||||
function caninteract(T1::Type{<:QEDParticle}, T2::Type{<:QEDParticle})
|
||||
if (T1 == T2)
|
||||
return false
|
||||
end
|
||||
if (T1 <: PhotonStateful && T2 <: PhotonStateful)
|
||||
return false
|
||||
end
|
||||
|
||||
for (P1, P2) in [(T1, T2), (T2, T1)]
|
||||
if (P1 <: FermionStateful{Incoming} && P2 <: AntiFermionStateful{Outgoing})
|
||||
return false
|
||||
end
|
||||
if (P1 <: FermionStateful{Outgoing} && P2 <: AntiFermionStateful{Incoming})
|
||||
return false
|
||||
end
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function type_index_from_name(::QEDModel, name::String)
|
||||
if startswith(name, "ki")
|
||||
return (PhotonStateful{Incoming, PolX}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "ko")
|
||||
return (PhotonStateful{Outgoing, PolX}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "ei")
|
||||
return (FermionStateful{Incoming, SpinUp}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "eo")
|
||||
return (FermionStateful{Outgoing, SpinUp}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "pi")
|
||||
return (AntiFermionStateful{Incoming, SpinUp}, parse(Int, name[3:end]))
|
||||
elseif startswith(name, "po")
|
||||
return (AntiFermionStateful{Outgoing, SpinUp}, parse(Int, name[3:end]))
|
||||
else
|
||||
throw("Invalid name for a particle in the QED model")
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
issame(T1::Type{<:QEDParticle}, T2::Type{<:QEDParticle})
|
||||
|
||||
For two given [`QEDParticle`](@ref) types, return whether they are equivalent for the purpose of a Feynman Diagram. That means e.g. an `Incoming` `AntiFermion` is the same as an `Outgoing` `Fermion`. This is equivalent to `!caninteract(T1, T2)`.
|
||||
|
||||
See also: [`caninteract`](@ref) and [`interaction_result`](@ref)
|
||||
"""
|
||||
function issame(T1::Type{<:QEDParticle}, T2::Type{<:QEDParticle})
|
||||
return !caninteract(T1, T2)
|
||||
end
|
||||
|
||||
"""
|
||||
QED_vertex()
|
||||
|
||||
Return the factor of a vertex in a QED feynman diagram.
|
||||
"""
|
||||
@inline function QED_vertex()::SLorentzVector{DiracMatrix}
|
||||
# Peskin-Schroeder notation
|
||||
return -1im * e * gamma()
|
||||
end
|
||||
|
||||
@inline function QED_inner_edge(p::QEDParticle)
|
||||
return propagator(particle(p), p.momentum)
|
||||
end
|
||||
|
||||
"""
|
||||
QED_conserve_momentum(p1::QEDParticle, p2::QEDParticle)
|
||||
|
||||
Calculate and return a new particle from two given interacting ones at a vertex.
|
||||
"""
|
||||
function QED_conserve_momentum(
|
||||
p1::P1,
|
||||
p2::P2,
|
||||
) where {
|
||||
Dir1 <: ParticleDirection,
|
||||
Dir2 <: ParticleDirection,
|
||||
SpinPol1 <: AbstractSpinOrPolarization,
|
||||
SpinPol2 <: AbstractSpinOrPolarization,
|
||||
P1 <: Union{FermionStateful{Dir1, SpinPol1}, AntiFermionStateful{Dir1, SpinPol1}, PhotonStateful{Dir1, SpinPol1}},
|
||||
P2 <: Union{FermionStateful{Dir2, SpinPol2}, AntiFermionStateful{Dir2, SpinPol2}, PhotonStateful{Dir2, SpinPol2}},
|
||||
}
|
||||
P3 = interaction_result(P1, P2)
|
||||
p1_mom = p1.momentum
|
||||
if (Dir1 <: Outgoing)
|
||||
p1_mom *= -1
|
||||
end
|
||||
p2_mom = p2.momentum
|
||||
if (Dir2 <: Outgoing)
|
||||
p2_mom *= -1
|
||||
end
|
||||
|
||||
p3_mom = p1_mom + p2_mom
|
||||
if (typeof(direction(P3)) <: Incoming)
|
||||
return P3(-p3_mom)
|
||||
end
|
||||
return P3(p3_mom)
|
||||
end
|
||||
|
||||
"""
|
||||
QEDProcessInput <: AbstractProcessInput
|
||||
|
||||
Input for a QED Process. Contains the [`QEDProcessDescription`](@ref) of the process it is an input for, and the values of the in and out particles.
|
||||
|
||||
See also: [`gen_process_input`](@ref)
|
||||
"""
|
||||
struct QEDProcessInput{N1, N2, N3, N4, N5, N6} <: AbstractProcessInput
|
||||
process::QEDProcessDescription
|
||||
inFerms::SVector{N1, FermionStateful{Incoming, SpinUp}}
|
||||
outFerms::SVector{N2, FermionStateful{Outgoing, SpinUp}}
|
||||
inAntiferms::SVector{N3, AntiFermionStateful{Incoming, SpinUp}}
|
||||
outAntiferms::SVector{N4, AntiFermionStateful{Outgoing, SpinUp}}
|
||||
inPhotons::SVector{N5, PhotonStateful{Incoming, PolX}}
|
||||
outPhotons::SVector{N6, PhotonStateful{Outgoing, PolX}}
|
||||
end
|
||||
|
||||
"""
|
||||
model(::AbstractProcessDescription)
|
||||
|
||||
Return the model of this process description.
|
||||
"""
|
||||
model(::QEDProcessDescription) = QEDModel()
|
||||
model(::QEDProcessInput) = QEDModel()
|
||||
|
||||
function copy(process::QEDProcessDescription)
|
||||
return QEDProcessDescription(copy(process.inParticles), copy(process.outParticles))
|
||||
end
|
||||
|
||||
==(p1::QEDProcessDescription, p2::QEDProcessDescription) =
|
||||
p1.inParticles == p2.inParticles && p1.outParticles == p2.outParticles
|
||||
|
||||
function in_particles(process::QEDProcessDescription)
|
||||
return process.inParticles
|
||||
end
|
||||
|
||||
function out_particles(process::QEDProcessDescription)
|
||||
return process.outParticles
|
||||
end
|
||||
|
||||
function get_particle(input::QEDProcessInput, t::Type{Particle}, n::Int)::Particle where {Particle}
|
||||
if (t <: FermionStateful{Incoming})
|
||||
return input.inFerms[n]
|
||||
elseif (t <: FermionStateful{Outgoing})
|
||||
return input.outFerms[n]
|
||||
elseif (t <: AntiFermionStateful{Incoming})
|
||||
return input.inAntiferms[n]
|
||||
elseif (t <: AntiFermionStateful{Outgoing})
|
||||
return input.outAntiferms[n]
|
||||
elseif (t <: PhotonStateful{Incoming})
|
||||
return input.inPhotons[n]
|
||||
elseif (t <: PhotonStateful{Outgoing})
|
||||
return input.outPhotons[n]
|
||||
end
|
||||
@assert false "Invalid type given"
|
||||
end
|
@ -1,9 +1,8 @@
|
||||
|
||||
#=
|
||||
"""
|
||||
show(io::IO, process::GenericQEDProcess)
|
||||
show(io::IO, process::QEDProcessDescription)
|
||||
|
||||
Pretty print an [`GenericQEDProcess`](@ref) (no newlines).
|
||||
Pretty print an [`QEDProcessDescription`](@ref) (no newlines).
|
||||
|
||||
```jldoctest
|
||||
julia> using MetagraphOptimization
|
||||
@ -15,7 +14,7 @@ julia> print(parse_process("kk->ep", QEDModel()))
|
||||
QED Process: 'kk->ep'
|
||||
```
|
||||
"""
|
||||
function show(io::IO, process::GenericQEDProcess)
|
||||
function show(io::IO, process::QEDProcessDescription)
|
||||
# types() gives the types in order (QED) instead of random like keys() would
|
||||
print(io, "QED Process: \'")
|
||||
for type in types(QEDModel())
|
||||
@ -35,7 +34,7 @@ end
|
||||
|
||||
|
||||
"""
|
||||
String(process::GenericQEDProcess)
|
||||
String(process::QEDProcessDescription)
|
||||
|
||||
Create a short string suitable as a filename or similar, describing the given process.
|
||||
|
||||
@ -65,9 +64,7 @@ function String(process::QEDProcessDescription)
|
||||
end
|
||||
return str
|
||||
end
|
||||
=#
|
||||
|
||||
#=
|
||||
"""
|
||||
show(io::IO, processInput::QEDProcessInput)
|
||||
|
||||
@ -95,9 +92,7 @@ function show(io::IO, processInput::QEDProcessInput)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
=#
|
||||
|
||||
#=
|
||||
"""
|
||||
show(io::IO, particle::T) where {T <: QEDParticle}
|
||||
|
||||
@ -107,14 +102,13 @@ function show(io::IO, particle::T) where {T <: QEDParticle}
|
||||
print(io, "$(String(typeof(particle))): $(particle.momentum)")
|
||||
return nothing
|
||||
end
|
||||
=#
|
||||
|
||||
"""
|
||||
show(io::IO, particle::FeynmanParticle)
|
||||
|
||||
Pretty print a [`FeynmanParticle`](@ref) (no newlines).
|
||||
"""
|
||||
show(io::IO, p::FeynmanParticle) = print(io, "$(String(p.particle))_$(String(particle_direction(p.particle)))_$(p.id)")
|
||||
show(io::IO, p::FeynmanParticle) = print(io, "$(String(p.particle))_$(String(direction(p.particle)))_$(p.id)")
|
||||
|
||||
"""
|
||||
show(io::IO, particle::FeynmanVertex)
|
@ -1,10 +1,3 @@
|
||||
"""
|
||||
QEDModel <: AbstractPhysicsModel
|
||||
|
||||
Singleton definition for identification of the QED-Model.
|
||||
"""
|
||||
struct QEDModel <: AbstractPhysicsModel end
|
||||
|
||||
"""
|
||||
ComputeTaskQED_S1 <: AbstractComputeTask
|
||||
|
@ -1,6 +1,6 @@
|
||||
|
||||
DataTaskNode(t::AbstractDataTask, name = "") =
|
||||
DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, name)
|
||||
DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, missing, name)
|
||||
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
|
||||
t, # task
|
||||
Vector{Node}(), # parents
|
||||
@ -8,6 +8,7 @@ ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
|
||||
UUIDs.uuid1(rng[threadid()]), # id
|
||||
missing, # node reduction
|
||||
missing, # node split
|
||||
Vector{NodeFusion}(), # node fusions
|
||||
missing, # device
|
||||
)
|
||||
|
||||
|
@ -30,6 +30,7 @@ Any node that transfers data and does no computation.
|
||||
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
|
||||
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.\\
|
||||
`.name`: The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\
|
||||
"""
|
||||
mutable struct DataTaskNode{TaskType <: AbstractDataTask} <: Node
|
||||
@ -50,6 +51,9 @@ mutable struct DataTaskNode{TaskType <: AbstractDataTask} <: Node
|
||||
# the NodeSplit involving this node, if it exists
|
||||
nodeSplit::Union{Operation, Missing}
|
||||
|
||||
# the node fusion involving this node, if it exists
|
||||
nodeFusion::Union{Operation, Missing}
|
||||
|
||||
# for input nodes we need a name for the node to distinguish between them
|
||||
name::String
|
||||
end
|
||||
@ -66,6 +70,7 @@ Any node that computes a result from inputs using an [`AbstractComputeTask`](@re
|
||||
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
|
||||
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeFusions`: A vector of this node's [`NodeFusion`](@ref)s. For a `ComputeTaskNode` there can be any number of these, unlike the [`DataTaskNode`](@ref)s.\\
|
||||
`.device`: The Device this node has been scheduled on by a [`Scheduler`](@ref).
|
||||
"""
|
||||
mutable struct ComputeTaskNode{TaskType <: AbstractComputeTask} <: Node
|
||||
@ -77,6 +82,9 @@ mutable struct ComputeTaskNode{TaskType <: AbstractComputeTask} <: Node
|
||||
nodeReduction::Union{Operation, Missing}
|
||||
nodeSplit::Union{Operation, Missing}
|
||||
|
||||
# for ComputeTasks there can be multiple fusions, unlike the DataTasks
|
||||
nodeFusions::Vector{<:Operation}
|
||||
|
||||
# the device this node is assigned to execute on
|
||||
device::Union{AbstractDevice, Missing}
|
||||
end
|
||||
|
@ -29,6 +29,17 @@ function is_valid_node(graph::DAG, node::Node)
|
||||
@assert is_valid(graph, node.nodeSplit)
|
||||
end=#
|
||||
|
||||
if !(typeof(task(node)) <: FusedComputeTask)
|
||||
# the remaining checks are only necessary for fused compute tasks
|
||||
return true
|
||||
end
|
||||
|
||||
# every child must be in some input of the task
|
||||
for child in node.children
|
||||
str = Symbol(to_var_name(child.id))
|
||||
@assert (str in task(node).t1_inputs) || (str in task(node).t2_inputs) "$str was not in any of the tasks' inputs\nt1_inputs: $(task(node).t1_inputs)\nt2_inputs: $(task(node).t2_inputs)"
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
@ -42,6 +53,9 @@ This also calls [`is_valid_node(graph::DAG, node::Node)`](@ref).
|
||||
function is_valid(graph::DAG, node::ComputeTaskNode)
|
||||
@assert is_valid_node(graph, node)
|
||||
|
||||
#=for nf in node.nodeFusions
|
||||
@assert is_valid(graph, nf)
|
||||
end=#
|
||||
return true
|
||||
end
|
||||
|
||||
@ -55,5 +69,8 @@ This also calls [`is_valid_node(graph::DAG, node::Node)`](@ref).
|
||||
function is_valid(graph::DAG, node::DataTaskNode)
|
||||
@assert is_valid_node(graph, node)
|
||||
|
||||
#=if !ismissing(node.nodeFusion)
|
||||
@assert is_valid(graph, node.nodeFusion)
|
||||
end=#
|
||||
return true
|
||||
end
|
||||
|
@ -26,6 +26,21 @@ function apply_operation!(graph::DAG, operation::Operation)
|
||||
return error("Unknown operation type!")
|
||||
end
|
||||
|
||||
"""
|
||||
apply_operation!(graph::DAG, operation::NodeFusion)
|
||||
|
||||
Apply the given [`NodeFusion`](@ref) to the graph. Generic wrapper around [`node_fusion!`](@ref).
|
||||
|
||||
Return an [`AppliedNodeFusion`](@ref) object generated from the graph's [`Diff`](@ref).
|
||||
"""
|
||||
function apply_operation!(graph::DAG, operation::NodeFusion)
|
||||
diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3])
|
||||
|
||||
graph.properties += GraphProperties(diff)
|
||||
|
||||
return AppliedNodeFusion(operation, diff)
|
||||
end
|
||||
|
||||
"""
|
||||
apply_operation!(graph::DAG, operation::NodeReduction)
|
||||
|
||||
@ -65,10 +80,20 @@ function revert_operation!(graph::DAG, operation::AppliedOperation)
|
||||
return error("Unknown operation type!")
|
||||
end
|
||||
|
||||
"""
|
||||
revert_operation!(graph::DAG, operation::AppliedNodeFusion)
|
||||
|
||||
Revert the applied node fusion on the graph. Return the original [`NodeFusion`](@ref) operation.
|
||||
"""
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeFusion)
|
||||
revert_diff!(graph, operation.diff)
|
||||
return operation.operation
|
||||
end
|
||||
|
||||
"""
|
||||
revert_operation!(graph::DAG, operation::AppliedNodeReduction)
|
||||
|
||||
Revert the applied node reduction on the graph. Return the original [`NodeReduction`](@ref) operation.
|
||||
Revert the applied node fusion on the graph. Return the original [`NodeReduction`](@ref) operation.
|
||||
"""
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeReduction)
|
||||
revert_diff!(graph, operation.diff)
|
||||
@ -78,7 +103,7 @@ end
|
||||
"""
|
||||
revert_operation!(graph::DAG, operation::AppliedNodeSplit)
|
||||
|
||||
Revert the applied node split on the graph. Return the original [`NodeSplit`](@ref) operation.
|
||||
Revert the applied node fusion on the graph. Return the original [`NodeSplit`](@ref) operation.
|
||||
"""
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeSplit)
|
||||
revert_diff!(graph, operation.diff)
|
||||
@ -107,11 +132,88 @@ function revert_diff!(graph::DAG, diff::Diff)
|
||||
insert_edge!(graph, edge.edge[1], edge.edge[2], track = false)
|
||||
end
|
||||
|
||||
for (node, t) in diff.updatedChildren
|
||||
# node must be fused compute task at this point
|
||||
@assert typeof(task(node)) <: FusedComputeTask
|
||||
|
||||
node.task = t
|
||||
end
|
||||
|
||||
graph.properties -= GraphProperties(diff)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
|
||||
Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph.
|
||||
|
||||
For details see [`NodeFusion`](@ref).
|
||||
"""
|
||||
function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
@assert is_valid_node_fusion_input(graph, n1, n2, n3)
|
||||
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
# save children and parents
|
||||
n1Children = copy(children(n1))
|
||||
n3Parents = copy(parents(n3))
|
||||
|
||||
n1Task = copy(task(n1))
|
||||
n3Task = copy(task(n3))
|
||||
|
||||
# assemble the input node vectors of n1 and n3 to save into the FusedComputeTask
|
||||
n1Inputs = Vector{Symbol}()
|
||||
for child in n1Children
|
||||
push!(n1Inputs, Symbol(to_var_name(child.id)))
|
||||
end
|
||||
|
||||
# remove the edges and nodes that will be replaced by the fused node
|
||||
remove_edge!(graph, n1, n2)
|
||||
remove_edge!(graph, n2, n3)
|
||||
remove_node!(graph, n1)
|
||||
remove_node!(graph, n2)
|
||||
|
||||
# get n3's children now so it automatically excludes n2
|
||||
n3Children = copy(children(n3))
|
||||
|
||||
n3Inputs = Vector{Symbol}()
|
||||
for child in n3Children
|
||||
push!(n3Inputs, Symbol(to_var_name(child.id)))
|
||||
end
|
||||
|
||||
remove_node!(graph, n3)
|
||||
|
||||
# create new node with the fused compute task
|
||||
newNode = ComputeTaskNode(FusedComputeTask(n1Task, n3Task, n1Inputs, Symbol(to_var_name(n2.id)), n3Inputs))
|
||||
insert_node!(graph, newNode)
|
||||
|
||||
for child in n1Children
|
||||
remove_edge!(graph, child, n1)
|
||||
insert_edge!(graph, child, newNode)
|
||||
end
|
||||
|
||||
for child in n3Children
|
||||
remove_edge!(graph, child, n3)
|
||||
if !(child in n1Children)
|
||||
insert_edge!(graph, child, newNode)
|
||||
end
|
||||
end
|
||||
|
||||
for parent in n3Parents
|
||||
remove_edge!(graph, n3, parent)
|
||||
insert_edge!(graph, newNode, parent)
|
||||
|
||||
# important! update the parent node's child names in case they are fused compute tasks
|
||||
# needed for compute generation so the fused compute task can correctly match inputs to its component tasks
|
||||
update_child!(graph, parent, Symbol(to_var_name(n3.id)), Symbol(to_var_name(newNode.id)))
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
end
|
||||
|
||||
"""
|
||||
node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
|
||||
@ -163,6 +265,7 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
|
||||
# this has to be done for all parents, even the ones of n1 because they can be duplicate
|
||||
prevChild = newParentsChildNames[parent]
|
||||
update_child!(graph, parent, prevChild, Symbol(to_var_name(n1.id)))
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
@ -204,6 +307,8 @@ function node_split!(
|
||||
for child in n1Children
|
||||
insert_edge!(graph, child, nCopy)
|
||||
end
|
||||
|
||||
update_child!(graph, parent, Symbol(to_var_name(n1.id)), Symbol(to_var_name(nCopy.id)))
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
|
@ -1,5 +1,60 @@
|
||||
# These are functions for "cleaning" nodes, i.e. regenerating the possible operations for a node
|
||||
|
||||
"""
|
||||
find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
|
||||
Find node fusions involving the given data node. The function pushes the found [`NodeFusion`](@ref) (if any) everywhere it needs to be and returns nothing.
|
||||
|
||||
Does nothing if the node already has a node fusion set. Since it's a data node, only one node fusion can be possible with it.
|
||||
"""
|
||||
function find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
# if there is already a fusion here, skip to avoid duplicates
|
||||
if !ismissing(node.nodeFusion)
|
||||
return nothing
|
||||
end
|
||||
|
||||
if length(parents(node)) != 1 || length(children(node)) != 1
|
||||
return nothing
|
||||
end
|
||||
|
||||
child_node = first(children(node))
|
||||
parent_node = first(parents(node))
|
||||
|
||||
if !(child_node in graph) || !(parent_node in graph)
|
||||
error("Parents/Children that are not in the graph!!!")
|
||||
end
|
||||
|
||||
if length(parents(child_node)) != 1
|
||||
return nothing
|
||||
end
|
||||
|
||||
nf = NodeFusion((child_node, node, parent_node))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(child_node.nodeFusions, nf)
|
||||
node.nodeFusion = nf
|
||||
push!(parent_node.nodeFusions, nf)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
find_fusions!(graph::DAG, node::ComputeTaskNode)
|
||||
|
||||
Find node fusions involving the given compute node. The function pushes the found [`NodeFusion`](@ref)s (if any) everywhere they need to be and returns nothing.
|
||||
"""
|
||||
function find_fusions!(graph::DAG, node::ComputeTaskNode)
|
||||
# just find fusions in neighbouring DataTaskNodes
|
||||
for child in children(node)
|
||||
find_fusions!(graph, child)
|
||||
end
|
||||
|
||||
for parent in parents(node)
|
||||
find_fusions!(graph, parent)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
find_reductions!(graph::DAG, node::Node)
|
||||
|
||||
@ -66,7 +121,7 @@ end
|
||||
"""
|
||||
clean_node!(graph::DAG, node::Node)
|
||||
|
||||
Sort this node's parent and child sets, then find reductions and splits involving it. Needs to be called after the node was changed in some way.
|
||||
Sort this node's parent and child sets, then find fusions, reductions and splits involving it. Needs to be called after the node was changed in some way.
|
||||
"""
|
||||
function clean_node!(
|
||||
graph::DAG,
|
||||
@ -74,6 +129,7 @@ function clean_node!(
|
||||
) where {TaskType <: AbstractTask}
|
||||
sort_node!(node)
|
||||
|
||||
find_fusions!(graph, node)
|
||||
find_reductions!(graph, node)
|
||||
find_splits!(graph, node)
|
||||
|
||||
|
@ -2,6 +2,26 @@
|
||||
|
||||
using Base.Threads
|
||||
|
||||
"""
|
||||
insert_operation!(nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock})
|
||||
|
||||
Insert the given node fusion into its input nodes' operation caches. For the compute nodes, locking via the given `locks` is employed to have safe multi-threading. For a large set of nodes, contention on the locks should be very small.
|
||||
"""
|
||||
function insert_operation!(nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock})
|
||||
n1 = nf.input[1]
|
||||
n2 = nf.input[2]
|
||||
n3 = nf.input[3]
|
||||
|
||||
lock(locks[n1]) do
|
||||
return push!(nf.input[1].nodeFusions, nf)
|
||||
end
|
||||
n2.nodeFusion = nf
|
||||
lock(locks[n3]) do
|
||||
return push!(nf.input[3].nodeFusions, nf)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
insert_operation!(nf::NodeReduction)
|
||||
|
||||
@ -52,6 +72,41 @@ function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Ve
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
nf_insertion!(graph::DAG, operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}})
|
||||
|
||||
Insert the node fusions into the graph and the nodes' caches. Employs multithreading for speedup.
|
||||
"""
|
||||
function nf_insertion!(graph::DAG, operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}})
|
||||
total_len = 0
|
||||
for vec in nodeFusions
|
||||
total_len += length(vec)
|
||||
end
|
||||
sizehint!(operations.nodeFusions, total_len)
|
||||
|
||||
t = @task for vec in nodeFusions
|
||||
union!(operations.nodeFusions, Set(vec))
|
||||
end
|
||||
schedule(t)
|
||||
|
||||
locks = Dict{ComputeTaskNode, SpinLock}()
|
||||
for n in graph.nodes
|
||||
if (typeof(n) <: ComputeTaskNode)
|
||||
locks[n] = SpinLock()
|
||||
end
|
||||
end
|
||||
|
||||
@threads for vec in nodeFusions
|
||||
for op in vec
|
||||
insert_operation!(op, locks)
|
||||
end
|
||||
end
|
||||
|
||||
wait(t)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplits}})
|
||||
|
||||
@ -88,6 +143,7 @@ Generate all possible operations on the graph. Used initially when the graph is
|
||||
Safely inserts all the found operations into the graph and its nodes.
|
||||
"""
|
||||
function generate_operations(graph::DAG)
|
||||
generatedFusions = [Vector{NodeFusion}() for _ in 1:nthreads()]
|
||||
generatedReductions = [Vector{NodeReduction}() for _ in 1:nthreads()]
|
||||
generatedSplits = [Vector{NodeSplit}() for _ in 1:nthreads()]
|
||||
|
||||
@ -143,6 +199,31 @@ function generate_operations(graph::DAG)
|
||||
# remove duplicates
|
||||
nr_task = @spawn nr_insertion!(graph.possibleOperations, generatedReductions)
|
||||
|
||||
# --- find possible node fusions ---
|
||||
@threads for node in nodeArray
|
||||
if (typeof(node) <: DataTaskNode)
|
||||
if length(parents(node)) != 1
|
||||
# data node can only have a single parent
|
||||
continue
|
||||
end
|
||||
parent_node = first(parents(node))
|
||||
|
||||
if length(children(node)) != 1
|
||||
# this node is an entry node or has multiple children which should not be possible
|
||||
continue
|
||||
end
|
||||
child_node = first(children(node))
|
||||
if (length(parents(child_node)) != 1)
|
||||
continue
|
||||
end
|
||||
|
||||
push!(generatedFusions[threadid()], NodeFusion((child_node, node, parent_node)))
|
||||
end
|
||||
end
|
||||
|
||||
# launch thread for node fusion insertion
|
||||
nf_task = @spawn nf_insertion!(graph, graph.possibleOperations, generatedFusions)
|
||||
|
||||
# find possible node splits
|
||||
@threads for node in nodeArray
|
||||
if (can_split(node))
|
||||
@ -156,6 +237,7 @@ function generate_operations(graph::DAG)
|
||||
empty!(graph.dirtyNodes)
|
||||
|
||||
wait(nr_task)
|
||||
wait(nf_task)
|
||||
wait(ns_task)
|
||||
|
||||
return nothing
|
||||
|
@ -2,7 +2,8 @@ import Base.iterate
|
||||
|
||||
const _POSSIBLE_OPERATIONS_FIELDS = fieldnames(PossibleOperations)
|
||||
|
||||
_POIteratorStateType = NamedTuple{(:result, :state), Tuple{Union{NodeReduction, NodeSplit}, Tuple{Symbol, Int64}}}
|
||||
_POIteratorStateType =
|
||||
NamedTuple{(:result, :state), Tuple{Union{NodeFusion, NodeReduction, NodeSplit}, Tuple{Symbol, Int64}}}
|
||||
|
||||
@inline function iterate(possibleOperations::PossibleOperations)::Union{Nothing, _POIteratorStateType}
|
||||
for fieldname in _POSSIBLE_OPERATIONS_FIELDS
|
||||
|
@ -4,6 +4,11 @@
|
||||
Print a string representation of the set of possible operations to io.
|
||||
"""
|
||||
function show(io::IO, ops::PossibleOperations)
|
||||
print(io, length(ops.nodeFusions))
|
||||
println(io, " Node Fusions: ")
|
||||
for nf in ops.nodeFusions
|
||||
println(io, " - ", nf)
|
||||
end
|
||||
print(io, length(ops.nodeReductions))
|
||||
println(io, " Node Reductions: ")
|
||||
for nr in ops.nodeReductions
|
||||
@ -37,3 +42,17 @@ function show(io::IO, op::NodeSplit)
|
||||
print(io, "NS: ")
|
||||
return print(io, task(op.input))
|
||||
end
|
||||
|
||||
"""
|
||||
show(io::IO, op::NodeFusion)
|
||||
|
||||
Print a string representation of the node fusion to io.
|
||||
"""
|
||||
function show(io::IO, op::NodeFusion)
|
||||
print(io, "NF: ")
|
||||
print(io, task(op.input[1]))
|
||||
print(io, "->")
|
||||
print(io, task(op.input[2]))
|
||||
print(io, "->")
|
||||
return print(io, task(op.input[3]))
|
||||
end
|
||||
|
@ -20,6 +20,45 @@ See also: [`revert_operation!`](@ref).
|
||||
"""
|
||||
abstract type AppliedOperation end
|
||||
|
||||
"""
|
||||
NodeFusion <: Operation
|
||||
|
||||
The NodeFusion operation. Represents the fusing of a chain of compute node -> data node -> compute node.
|
||||
|
||||
After the node fusion is applied, the graph has 2 fewer nodes and edges, and a new [`FusedComputeTask`](@ref) with the two input compute nodes as parts.
|
||||
|
||||
# Requirements for successful application
|
||||
|
||||
A chain of (n1, n2, n3) can be fused if:
|
||||
- All nodes are in the graph.
|
||||
- (n1, n2) is an edge in the graph.
|
||||
- (n2, n3) is an edge in the graph.
|
||||
- n2 has exactly one parent (n3) and exactly one child (n1).
|
||||
- n1 has exactly one parent (n2).
|
||||
|
||||
[`is_valid_node_fusion_input`](@ref) can be used to `@assert` these requirements.
|
||||
|
||||
See also: [`can_fuse`](@ref)
|
||||
"""
|
||||
struct NodeFusion{TaskType1 <: AbstractComputeTask, TaskType2 <: AbstractDataTask, TaskType3 <: AbstractComputeTask} <:
|
||||
Operation
|
||||
input::Tuple{ComputeTaskNode{TaskType1}, DataTaskNode{TaskType2}, ComputeTaskNode{TaskType3}}
|
||||
end
|
||||
|
||||
"""
|
||||
AppliedNodeFusion <: AppliedOperation
|
||||
|
||||
The applied version of the [`NodeFusion`](@ref).
|
||||
"""
|
||||
struct AppliedNodeFusion{
|
||||
TaskType1 <: AbstractComputeTask,
|
||||
TaskType2 <: AbstractDataTask,
|
||||
TaskType3 <: AbstractComputeTask,
|
||||
} <: AppliedOperation
|
||||
operation::NodeFusion{TaskType1, TaskType2, TaskType3}
|
||||
diff::Diff
|
||||
end
|
||||
|
||||
"""
|
||||
NodeReduction <: Operation
|
||||
|
||||
|
@ -4,7 +4,7 @@
|
||||
Return whether `operations` is empty, i.e. all of its fields are empty.
|
||||
"""
|
||||
function isempty(operations::PossibleOperations)
|
||||
return isempty(operations.nodeReductions) && isempty(operations.nodeSplits)
|
||||
return isempty(operations.nodeFusions) && isempty(operations.nodeReductions) && isempty(operations.nodeSplits)
|
||||
end
|
||||
|
||||
"""
|
||||
@ -13,7 +13,21 @@ end
|
||||
Return a named tuple with the number of each of the operation types as a named tuple. The fields are named the same as the [`PossibleOperations`](@ref)'.
|
||||
"""
|
||||
function length(operations::PossibleOperations)
|
||||
return (nodeReductions = length(operations.nodeReductions), nodeSplits = length(operations.nodeSplits))
|
||||
return (
|
||||
nodeFusions = length(operations.nodeFusions),
|
||||
nodeReductions = length(operations.nodeReductions),
|
||||
nodeSplits = length(operations.nodeSplits),
|
||||
)
|
||||
end
|
||||
|
||||
"""
|
||||
delete!(operations::PossibleOperations, op::NodeFusion)
|
||||
|
||||
Delete the given node fusion from the possible operations.
|
||||
"""
|
||||
function delete!(operations::PossibleOperations, op::NodeFusion)
|
||||
delete!(operations.nodeFusions, op)
|
||||
return operations
|
||||
end
|
||||
|
||||
"""
|
||||
@ -36,6 +50,24 @@ function delete!(operations::PossibleOperations, op::NodeSplit)
|
||||
return operations
|
||||
end
|
||||
|
||||
"""
|
||||
can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
|
||||
Return whether the given nodes can be fused. See [`NodeFusion`](@ref) for the requirements.
|
||||
"""
|
||||
function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
if !is_child(n1, n2) || !is_child(n2, n3)
|
||||
# the checks are redundant but maybe a good sanity check
|
||||
return false
|
||||
end
|
||||
|
||||
if length(parents(n2)) != 1 || length(children(n2)) != 1 || length(parents(n1)) != 1
|
||||
return false
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
can_reduce(n1::Node, n2::Node)
|
||||
|
||||
@ -104,6 +136,23 @@ function ==(op1::Operation, op2::Operation)
|
||||
return false
|
||||
end
|
||||
|
||||
"""
|
||||
==(op1::NodeFusion, op2::NodeFusion)
|
||||
|
||||
Equality comparison between two node fusions. Two node fusions are considered equal if they have the same inputs.
|
||||
"""
|
||||
function ==(
|
||||
op1::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2},
|
||||
op2::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2},
|
||||
) where {
|
||||
ComputeTaskType1 <: AbstractComputeTask,
|
||||
DataTaskType <: AbstractDataTask,
|
||||
ComputeTaskType2 <: AbstractComputeTask,
|
||||
}
|
||||
# there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same
|
||||
return op1.input[2] == op2.input[2]
|
||||
end
|
||||
|
||||
"""
|
||||
==(op1::NodeReduction, op2::NodeReduction)
|
||||
|
||||
|
@ -2,6 +2,43 @@
|
||||
# should be called with @assert
|
||||
# the functions throw their own errors though, to still have helpful error messages
|
||||
|
||||
"""
|
||||
is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
|
||||
Assert for a gven node fusion input whether the nodes can be fused. For the requirements of a node fusion see [`NodeFusion`](@ref).
|
||||
|
||||
Intended for use with `@assert` or `@test`.
|
||||
"""
|
||||
function is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
if !(n1 in graph) || !(n2 in graph) || !(n3 in graph)
|
||||
throw(AssertionError("[Node Fusion] The given nodes are not part of the given graph"))
|
||||
end
|
||||
|
||||
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
|
||||
throw(
|
||||
AssertionError(
|
||||
"[Node Fusion] The given nodes are not connected by edges which is required for node fusion",
|
||||
),
|
||||
)
|
||||
end
|
||||
|
||||
if length(n2.parents) > 1
|
||||
throw(AssertionError("[Node Fusion] The given data node has more than one parent"))
|
||||
end
|
||||
if length(n2.children) > 1
|
||||
throw(AssertionError("[Node Fusion] The given data node has more than one child"))
|
||||
end
|
||||
if length(n1.parents) > 1
|
||||
throw(AssertionError("[Node Fusion] The given n1 has more than one parent"))
|
||||
end
|
||||
|
||||
@assert is_valid(graph, n1)
|
||||
@assert is_valid(graph, n2)
|
||||
@assert is_valid(graph, n3)
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node})
|
||||
|
||||
@ -94,3 +131,16 @@ function is_valid(graph::DAG, ns::NodeSplit)
|
||||
#@assert ns in graph.possibleOperations.nodeSplits "NodeSplit is not part of the graph's possible operations!"
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
is_valid(graph::DAG, nr::NodeFusion)
|
||||
|
||||
Assert for a given [`NodeFusion`](@ref) whether it is a valid operation in the graph.
|
||||
|
||||
Intended for use with `@assert` or `@test`.
|
||||
"""
|
||||
function is_valid(graph::DAG, nf::NodeFusion)
|
||||
@assert is_valid_node_fusion_input(graph, nf.input[1], nf.input[2], nf.input[3])
|
||||
#@assert nf in graph.possibleOperations.nodeFusions "NodeFusion is not part of the graph's possible operations!"
|
||||
return true
|
||||
end
|
||||
|
36
src/optimization/fuse.jl
Normal file
36
src/optimization/fuse.jl
Normal file
@ -0,0 +1,36 @@
|
||||
"""
|
||||
FusionOptimizer
|
||||
|
||||
An optimizer that simply applies an available [`NodeFusion`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeFusion`](@ref)s in the graph.
|
||||
|
||||
See also: [`SplitOptimizer`](@ref), [`ReductionOptimizer`](@ref)
|
||||
"""
|
||||
struct FusionOptimizer <: AbstractOptimizer end
|
||||
|
||||
function optimize_step!(optimizer::FusionOptimizer, graph::DAG)
|
||||
# generate all options
|
||||
operations = get_operations(graph)
|
||||
if fixpoint_reached(optimizer, graph)
|
||||
return false
|
||||
end
|
||||
|
||||
push_operation!(graph, first(operations.nodeFusions))
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function fixpoint_reached(optimizer::FusionOptimizer, graph::DAG)
|
||||
operations = get_operations(graph)
|
||||
return isempty(operations.nodeFusions)
|
||||
end
|
||||
|
||||
function optimize_to_fixpoint!(optimizer::FusionOptimizer, graph::DAG)
|
||||
while !fixpoint_reached(optimizer, graph)
|
||||
optimize_step!(optimizer, graph)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function String(::FusionOptimizer)
|
||||
return "fusion_optimizer"
|
||||
end
|
@ -26,12 +26,16 @@ function optimize_step!(optimizer::RandomWalkOptimizer, graph::DAG)
|
||||
if rand(r, Bool)
|
||||
# push
|
||||
|
||||
# choose one of split/reduce
|
||||
option = rand(r, 1:2)
|
||||
if option == 1 && !isempty(operations.nodeReductions)
|
||||
# choose one of fuse/split/reduce
|
||||
# TODO refactor fusions so they actually work
|
||||
option = rand(r, 2:3)
|
||||
if option == 1 && !isempty(operations.nodeFusions)
|
||||
push_operation!(graph, rand(r, collect(operations.nodeFusions)))
|
||||
return true
|
||||
elseif option == 2 && !isempty(operations.nodeReductions)
|
||||
push_operation!(graph, rand(r, collect(operations.nodeReductions)))
|
||||
return true
|
||||
elseif option == 2 && !isempty(operations.nodeSplits)
|
||||
elseif option == 3 && !isempty(operations.nodeSplits)
|
||||
push_operation!(graph, rand(r, collect(operations.nodeSplits)))
|
||||
return true
|
||||
end
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
An optimizer that simply applies an available [`NodeReduction`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeReduction`](@ref)s in the graph.
|
||||
|
||||
See also: [`SplitOptimizer`](@ref)
|
||||
See also: [`FusionOptimizer`](@ref), [`SplitOptimizer`](@ref)
|
||||
"""
|
||||
struct ReductionOptimizer <: AbstractOptimizer end
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
An optimizer that simply applies an available [`NodeSplit`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeSplit`](@ref)s in the graph.
|
||||
|
||||
See also: [`ReductionOptimizer`](@ref)
|
||||
See also: [`FusionOptimizer`](@ref), [`ReductionOptimizer`](@ref)
|
||||
"""
|
||||
struct SplitOptimizer <: AbstractOptimizer end
|
||||
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
A greedy implementation of a scheduler, creating a topological ordering of nodes and naively balancing them onto the different devices.
|
||||
"""
|
||||
struct GreedyScheduler <: AbstractScheduler end
|
||||
struct GreedyScheduler end
|
||||
|
||||
function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine)
|
||||
nodeQueue = PriorityQueue{Node, Int}()
|
||||
|
@ -1,10 +1,10 @@
|
||||
|
||||
"""
|
||||
AbstractScheduler
|
||||
Scheduler
|
||||
|
||||
Abstract base type for scheduler implementations. The scheduler is used to assign each node to a device and create a topological ordering of tasks.
|
||||
"""
|
||||
abstract type AbstractScheduler end
|
||||
abstract type Scheduler end
|
||||
|
||||
"""
|
||||
schedule_dag(::Scheduler, ::DAG, ::Machine)
|
||||
|
@ -5,11 +5,10 @@ using StaticArrays
|
||||
|
||||
Type representing a function call with `N` parameters. Contains the function to call, argument symbols, the return symbol and the device to execute on.
|
||||
"""
|
||||
struct FunctionCall{VectorType <: AbstractVector, N}
|
||||
struct FunctionCall{VectorType <: AbstractVector, M}
|
||||
func::Function
|
||||
# TODO: this should be a tuple
|
||||
value_arguments::SVector{N, Any} # value arguments for the function call, will be prepended to the other arguments
|
||||
arguments::VectorType # symbols of the inputs to the function call
|
||||
arguments::VectorType
|
||||
additional_arguments::SVector{M, Any} # additional arguments (as values) for the function call, will be prepended to the other arguments
|
||||
return_symbol::Symbol
|
||||
device::AbstractDevice
|
||||
end
|
||||
|
@ -1,20 +1,38 @@
|
||||
using StaticArrays
|
||||
|
||||
"""
|
||||
compute(t::FusedComputeTask, data)
|
||||
|
||||
Compute a [`FusedComputeTask`](@ref). This simply asserts false and should not be called. Fused Compute Tasks generate their expressions directly through the other tasks instead.
|
||||
"""
|
||||
function compute(t::FusedComputeTask, data...)
|
||||
inter = compute(t.first_task)
|
||||
return compute(t.second_task, inter, data2...)
|
||||
end
|
||||
|
||||
"""
|
||||
get_function_call(n::Node)
|
||||
get_function_call(t::AbstractTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol)
|
||||
|
||||
For a node or a task together with necessary information, return a vector of [`FunctionCall`](@ref)s for the computation of the node or task.
|
||||
|
||||
For ordinary compute or data tasks the vector will contain exactly one element.
|
||||
For ordinary compute or data tasks the vector will contain exactly one element, for a [`FusedComputeTask`](@ref) there can be any number of tasks greater 1.
|
||||
"""
|
||||
function get_function_call(t::FusedComputeTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol)
|
||||
# sort out the symbols to the correct tasks
|
||||
return [
|
||||
get_function_call(t.first_task, device, t.t1_inputs, t.t1_output)...,
|
||||
get_function_call(t.second_task, device, [t.t2_inputs..., t.t1_output], outSymbol)...,
|
||||
]
|
||||
end
|
||||
|
||||
function get_function_call(
|
||||
t::CompTask,
|
||||
device::AbstractDevice,
|
||||
inSymbols::AbstractVector,
|
||||
outSymbol::Symbol,
|
||||
) where {CompTask <: AbstractComputeTask}
|
||||
return [FunctionCall(compute, SVector{1, Any}(t), inSymbols, outSymbol, device)]
|
||||
return [FunctionCall(compute, inSymbols, SVector{1, Any}(t), outSymbol, device)]
|
||||
end
|
||||
|
||||
function get_function_call(node::ComputeTaskNode)
|
||||
@ -46,8 +64,8 @@ function get_function_call(node::DataTaskNode)
|
||||
return [
|
||||
FunctionCall(
|
||||
unpack_identity,
|
||||
SVector{0, Any}(),
|
||||
SVector{1, Symbol}(Symbol(to_var_name(first(children(node)).id))),
|
||||
SVector{0, Any}(),
|
||||
Symbol(to_var_name(node.id)),
|
||||
first(children(node)).device,
|
||||
),
|
||||
@ -59,8 +77,8 @@ function get_init_function_call(node::DataTaskNode, device::AbstractDevice)
|
||||
|
||||
return FunctionCall(
|
||||
unpack_identity,
|
||||
SVector{0, Any}(),
|
||||
SVector{1, Symbol}(Symbol("$(to_var_name(node.id))_in")),
|
||||
SVector{0, Any}(),
|
||||
Symbol(to_var_name(node.id)),
|
||||
device,
|
||||
)
|
||||
|
@ -11,3 +11,22 @@ copy(t::AbstractDataTask) = error("Need to implement copying for your data tasks
|
||||
Return a copy of the given compute task.
|
||||
"""
|
||||
copy(t::AbstractComputeTask) = typeof(t)()
|
||||
|
||||
"""
|
||||
copy(t::FusedComputeTask)
|
||||
|
||||
Return a copy of th egiven [`FusedComputeTask`](@ref).
|
||||
"""
|
||||
function copy(t::FusedComputeTask)
|
||||
return FusedComputeTask(copy(t.first_task), copy(t.second_task), copy(t.t1_inputs), t.t1_output, copy(t.t2_inputs))
|
||||
end
|
||||
|
||||
function FusedComputeTask(
|
||||
T1::Type{<:AbstractComputeTask},
|
||||
T2::Type{<:AbstractComputeTask},
|
||||
t1_inputs::Vector{String},
|
||||
t1_output::String,
|
||||
t2_inputs::Vector{String},
|
||||
)
|
||||
return FusedComputeTask(T1(), T2(), t1_inputs, t1_output, t2_inputs)
|
||||
end
|
||||
|
@ -3,21 +3,28 @@
|
||||
|
||||
Fallback implementation of the compute function of a compute task, throwing an error.
|
||||
"""
|
||||
function compute end
|
||||
function compute(t::AbstractTask, data...)
|
||||
return error("Need to implement compute()")
|
||||
end
|
||||
|
||||
"""
|
||||
compute_effort(t::AbstractTask)
|
||||
|
||||
Fallback implementation of the compute effort of a task, throwing an error.
|
||||
"""
|
||||
function compute_effort end
|
||||
function compute_effort(t::AbstractTask)::Float64
|
||||
# default implementation using compute
|
||||
return error("Need to implement compute_effort()")
|
||||
end
|
||||
|
||||
"""
|
||||
data(t::AbstractTask)
|
||||
|
||||
Fallback implementation of the data of a task, throwing an error.
|
||||
"""
|
||||
function data end
|
||||
function data(t::AbstractTask)::Float64
|
||||
return error("Need to implement data()")
|
||||
end
|
||||
|
||||
"""
|
||||
compute_effort(t::AbstractDataTask)
|
||||
@ -47,9 +54,34 @@ Return the number of children of a data task (always 1).
|
||||
"""
|
||||
children(::DataTask) = 1
|
||||
|
||||
"""
|
||||
children(t::FusedComputeTask)
|
||||
|
||||
Return the number of children of a FusedComputeTask.
|
||||
"""
|
||||
function children(t::FusedComputeTask)
|
||||
return length(union(Set(t.t1_inputs), Set(t.t2_inputs)))
|
||||
end
|
||||
|
||||
"""
|
||||
data(t::AbstractComputeTask)
|
||||
|
||||
Return the data of a compute task, always zero, regardless of the specific task.
|
||||
"""
|
||||
data(t::AbstractComputeTask)::Float64 = 0.0
|
||||
|
||||
"""
|
||||
compute_effort(t::FusedComputeTask)
|
||||
|
||||
Return the compute effort of a fused compute task.
|
||||
"""
|
||||
function compute_effort(t::FusedComputeTask)::Float64
|
||||
return compute_effort(t.first_task) + compute_effort(t.second_task)
|
||||
end
|
||||
|
||||
"""
|
||||
get_types(::FusedComputeTask{T1, T2})
|
||||
|
||||
Return a tuple of a the fused compute task's components' types.
|
||||
"""
|
||||
get_types(t::FusedComputeTask) = (typeof(t.first_task), typeof(t.second_task))
|
||||
|
@ -27,3 +27,21 @@ Task representing a specific data transfer.
|
||||
struct DataTask <: AbstractDataTask
|
||||
data::Float64
|
||||
end
|
||||
|
||||
"""
|
||||
FusedComputeTask{T1 <: AbstractComputeTask, T2 <: AbstractComputeTask} <: AbstractComputeTask
|
||||
|
||||
A fused compute task made up of the computation of first `T1` and then `T2`.
|
||||
|
||||
Also see: [`get_types`](@ref).
|
||||
"""
|
||||
struct FusedComputeTask <: AbstractComputeTask
|
||||
first_task::AbstractComputeTask
|
||||
second_task::AbstractComputeTask
|
||||
# the names of the inputs for T1
|
||||
t1_inputs::Vector{Symbol}
|
||||
# output name of T1
|
||||
t1_output::Symbol
|
||||
# t2_inputs doesn't include the output of t1, that's implicit
|
||||
t2_inputs::Vector{Symbol}
|
||||
end
|
||||
|
@ -1,6 +1,3 @@
|
||||
using Roots
|
||||
using ForwardDiff
|
||||
|
||||
"""
|
||||
noop()
|
||||
|
||||
@ -74,6 +71,9 @@ function mem(graph::DAG)
|
||||
size += sizeof(graph.operationsToApply)
|
||||
|
||||
size += sizeof(graph.possibleOperations)
|
||||
for op in graph.possibleOperations.nodeFusions
|
||||
size += mem(op)
|
||||
end
|
||||
for op in graph.possibleOperations.nodeReductions
|
||||
size += mem(op)
|
||||
end
|
||||
@ -249,6 +249,7 @@ end
|
||||
|
||||
first_derivative(func) = x -> ForwardDiff.derivative(func, float(x))
|
||||
|
||||
|
||||
function generate_physical_massive_moms(rng, ss, masses; x0 = 0.1)
|
||||
n = length(masses)
|
||||
massless_moms = generate_physical_massless_moms(rng, ss, n)
|
||||
|
10
test/Project.toml
Normal file
10
test/Project.toml
Normal file
@ -0,0 +1,10 @@
|
||||
[deps]
|
||||
AccurateArithmetic = "22286c92-06ac-501d-9306-4abd417d9753"
|
||||
QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
|
||||
QEDprocesses = "46de9c38-1bb3-4547-a1ec-da24d767fdad"
|
||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
|
||||
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
|
||||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
@ -3,15 +3,48 @@ using Random
|
||||
|
||||
RNG = Random.MersenneTwister(321)
|
||||
|
||||
function test_known_graph(name::String, n)
|
||||
function test_known_graph(name::String, n, fusion_test = true)
|
||||
@testset "Test $name Graph ($n)" begin
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "$name.txt"), ABCModel())
|
||||
props = get_properties(graph)
|
||||
|
||||
if (fusion_test)
|
||||
test_node_fusion(graph)
|
||||
end
|
||||
test_random_walk(RNG, graph, n)
|
||||
end
|
||||
end
|
||||
|
||||
function test_node_fusion(g::DAG)
|
||||
@testset "Test Node Fusion" begin
|
||||
props = get_properties(g)
|
||||
|
||||
options = get_operations(g)
|
||||
|
||||
nodes_number = length(g.nodes)
|
||||
data = props.data
|
||||
compute_effort = props.computeEffort
|
||||
|
||||
while !isempty(options.nodeFusions)
|
||||
fusion = first(options.nodeFusions)
|
||||
|
||||
@test typeof(fusion) <: NodeFusion
|
||||
|
||||
push_operation!(g, fusion)
|
||||
|
||||
props = get_properties(g)
|
||||
@test props.data < data
|
||||
@test props.computeEffort == compute_effort
|
||||
|
||||
nodes_number = length(g.nodes)
|
||||
data = props.data
|
||||
compute_effort = props.computeEffort
|
||||
|
||||
options = get_operations(g)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function test_random_walk(RNG, g::DAG, n::Int64)
|
||||
@testset "Test Random Walk ($n)" begin
|
||||
# the purpose here is to do "random" operations and reverse them again and validate that the graph stays the same and doesn't diverge
|
||||
@ -27,11 +60,13 @@ function test_random_walk(RNG, g::DAG, n::Int64)
|
||||
# push
|
||||
opt = get_operations(g)
|
||||
|
||||
# choose one of split/reduce
|
||||
option = rand(RNG, 1:2)
|
||||
if option == 1 && !isempty(opt.nodeReductions)
|
||||
# choose one of fuse/split/reduce
|
||||
option = rand(RNG, 1:3)
|
||||
if option == 1 && !isempty(opt.nodeFusions)
|
||||
push_operation!(g, rand(RNG, collect(opt.nodeFusions)))
|
||||
elseif option == 2 && !isempty(opt.nodeReductions)
|
||||
push_operation!(g, rand(RNG, collect(opt.nodeReductions)))
|
||||
elseif option == 2 && !isempty(opt.nodeSplits)
|
||||
elseif option == 3 && !isempty(opt.nodeSplits)
|
||||
push_operation!(g, rand(RNG, collect(opt.nodeSplits)))
|
||||
else
|
||||
i = i - 1
|
||||
@ -56,4 +91,4 @@ end
|
||||
|
||||
test_known_graph("AB->AB", 10000)
|
||||
test_known_graph("AB->ABBB", 10000)
|
||||
test_known_graph("AB->ABBBBB", 1000)
|
||||
test_known_graph("AB->ABBBBB", 1000, false)
|
||||
|
@ -61,14 +61,17 @@ insert_edge!(graph, CD, C1C, track = false)
|
||||
|
||||
opt = get_operations(graph)
|
||||
|
||||
@test length(opt) == (nodeReductions = 1, nodeSplits = 1)
|
||||
@test length(opt) == (nodeFusions = 6, nodeReductions = 1, nodeSplits = 1)
|
||||
|
||||
#println("Initial State:\n", opt)
|
||||
|
||||
nr = first(opt.nodeReductions)
|
||||
@test Set(nr.input) == Set([B1C_1, B1C_2])
|
||||
push_operation!(graph, nr)
|
||||
opt = get_operations(graph)
|
||||
|
||||
@test length(opt) == (nodeReductions = 1, nodeSplits = 1)
|
||||
@test length(opt) == (nodeFusions = 4, nodeReductions = 1, nodeSplits = 1)
|
||||
#println("After 1 Node Reduction:\n", opt)
|
||||
|
||||
nr = first(opt.nodeReductions)
|
||||
@test Set(nr.input) == Set([B1D_1, B1D_2])
|
||||
@ -77,16 +80,19 @@ opt = get_operations(graph)
|
||||
|
||||
@test is_valid(graph)
|
||||
|
||||
@test length(opt) == (nodeReductions = 0, nodeSplits = 1)
|
||||
@test length(opt) == (nodeFusions = 4, nodeReductions = 0, nodeSplits = 1)
|
||||
#println("After 2 Node Reductions:\n", opt)
|
||||
|
||||
pop_operation!(graph)
|
||||
|
||||
opt = get_operations(graph)
|
||||
@test length(opt) == (nodeReductions = 1, nodeSplits = 1)
|
||||
@test length(opt) == (nodeFusions = 4, nodeReductions = 1, nodeSplits = 1)
|
||||
#println("After reverting the second Node Reduction:\n", opt)
|
||||
|
||||
reset_graph!(graph)
|
||||
|
||||
opt = get_operations(graph)
|
||||
@test length(opt) == (nodeReductions = 1, nodeSplits = 1)
|
||||
@test length(opt) == (nodeFusions = 6, nodeReductions = 1, nodeSplits = 1)
|
||||
#println("After reverting to the initial state:\n", opt)
|
||||
|
||||
@test is_valid(graph)
|
||||
|
@ -1,5 +1,5 @@
|
||||
using SafeTestsets
|
||||
#=
|
||||
|
||||
@safetestset "Utility Unit Tests " begin
|
||||
include("unit_tests_utility.jl")
|
||||
end
|
||||
@ -30,7 +30,6 @@ end
|
||||
@safetestset "Graph Unit Tests " begin
|
||||
include("unit_tests_graph.jl")
|
||||
end
|
||||
=#
|
||||
@safetestset "Execution Unit Tests " begin
|
||||
include("unit_tests_execution.jl")
|
||||
end
|
||||
|
@ -1,5 +1,5 @@
|
||||
using MetagraphOptimization
|
||||
using QEDcore
|
||||
using QEDbase
|
||||
|
||||
import MetagraphOptimization.interaction_result
|
||||
|
||||
|
@ -1,5 +1,16 @@
|
||||
using MetagraphOptimization
|
||||
|
||||
function test_op_specific(estimator, graph, nf::NodeFusion)
|
||||
estimate = operation_effect(estimator, graph, nf)
|
||||
data_reduce = data(nf.input[2].task)
|
||||
|
||||
@test isapprox(estimate.data, -data_reduce)
|
||||
@test isapprox(estimate.computeEffort, 0; atol = eps(Float64))
|
||||
@test isapprox(estimate.computeIntensity, 0; atol = eps(Float64))
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function test_op_specific(estimator, graph, nr::NodeReduction)
|
||||
estimate = operation_effect(estimator, graph, nr)
|
||||
|
||||
@ -63,9 +74,13 @@ end
|
||||
|
||||
@testset "Operation Cost" begin
|
||||
ops = get_operations(graph)
|
||||
nfs = copy(ops.nodeFusions)
|
||||
nrs = copy(ops.nodeReductions)
|
||||
nss = copy(ops.nodeSplits)
|
||||
|
||||
for nf in nfs
|
||||
test_op(estimator, graph, nf)
|
||||
end
|
||||
for nr in nrs
|
||||
test_op(estimator, graph, nr)
|
||||
end
|
||||
|
@ -1,5 +1,5 @@
|
||||
using MetagraphOptimization
|
||||
using QEDcore
|
||||
using QEDbase
|
||||
using AccurateArithmetic
|
||||
using Random
|
||||
using UUIDs
|
||||
@ -63,14 +63,8 @@ machine = Machine(
|
||||
)
|
||||
|
||||
process_2_2 = ABCProcessDescription(
|
||||
Dict{Type, Int64}(
|
||||
ParticleStateful{Incoming, ParticleA, SFourMomentum} => 1,
|
||||
ParticleStateful{Incoming, ParticleB, SFourMomentum} => 1,
|
||||
),
|
||||
Dict{Type, Int64}(
|
||||
ParticleStateful{Outgoing, ParticleA, SFourMomentum} => 1,
|
||||
ParticleStateful{Outgoing, ParticleB, SFourMomentum} => 1,
|
||||
),
|
||||
Dict{Type, Int64}(ParticleA => 1, ParticleB => 1),
|
||||
Dict{Type, Int64}(ParticleA => 1, ParticleB => 1),
|
||||
)
|
||||
|
||||
particles_2_2 = ABCProcessInput(
|
||||
@ -112,14 +106,8 @@ end
|
||||
end
|
||||
|
||||
process_2_4 = ABCProcessDescription(
|
||||
Dict{Type, Int64}(
|
||||
ParticleStateful{Incoming, ParticleA, SFourMomentum} => 1,
|
||||
ParticleStateful{Incoming, ParticleB, SFourMomentum} => 1,
|
||||
),
|
||||
Dict{Type, Int64}(
|
||||
ParticleStateful{Outgoing, ParticleA, SFourMomentum} => 1,
|
||||
ParticleStateful{Outgoing, ParticleB, SFourMomentum} => 3,
|
||||
),
|
||||
Dict{Type, Int64}(ParticleA => 1, ParticleB => 1),
|
||||
Dict{Type, Int64}(ParticleA => 1, ParticleB => 3),
|
||||
)
|
||||
particles_2_4 = gen_process_input(process_2_4)
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"), ABCModel())
|
||||
@ -148,6 +136,105 @@ TODO: fix precision(?) issues
|
||||
end
|
||||
=#
|
||||
|
||||
@testset "AB->AB large sum fusion" begin
|
||||
for _ in 1:20
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
|
||||
|
||||
# push a fusion with the sum node
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[3].task, ComputeTaskABC_Sum)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
# push two more fusions with the fused node
|
||||
for _ in 1:15
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[3].task, FusedComputeTask)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# try execute
|
||||
@test is_valid(graph)
|
||||
expected_result = ground_truth_graph_result(particles_2_2)
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
@testset "AB->AB large sum fusion" begin
|
||||
for _ in 1:20
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
|
||||
|
||||
# push a fusion with the sum node
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[3].task, ComputeTaskABC_Sum)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
# push two more fusions with the fused node
|
||||
for _ in 1:15
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[3].task, FusedComputeTask)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# try execute
|
||||
@test is_valid(graph)
|
||||
expected_result = ground_truth_graph_result(particles_2_2)
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
|
||||
end
|
||||
end
|
||||
|
||||
@testset "AB->AB fusion edge case" begin
|
||||
for _ in 1:20
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
|
||||
|
||||
# push two fusions with ComputeTaskABC_V
|
||||
for _ in 1:2
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[1].task, ComputeTaskABC_V)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# push fusions until the end
|
||||
cont = true
|
||||
while cont
|
||||
cont = false
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[1].task, FusedComputeTask)
|
||||
push_operation!(graph, fusion)
|
||||
cont = true
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# try execute
|
||||
@test is_valid(graph)
|
||||
expected_result = ground_truth_graph_result(particles_2_2)
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
|
||||
end
|
||||
end
|
||||
|
||||
@testset "$(process) after random walk" for process in ["ke->ke", "ke->kke", "ke->kkke"]
|
||||
process = parse_process("ke->kkke", QEDModel())
|
||||
inputs = [gen_process_input(process) for _ in 1:100]
|
||||
|
@ -13,7 +13,7 @@ graph = MetagraphOptimization.DAG()
|
||||
@test length(graph.operationsToApply) == 0
|
||||
@test length(graph.dirtyNodes) == 0
|
||||
@test length(graph.diff) == (addedNodes = 0, removedNodes = 0, addedEdges = 0, removedEdges = 0)
|
||||
@test length(get_operations(graph)) == (nodeReductions = 0, nodeSplits = 0)
|
||||
@test length(get_operations(graph)) == (nodeFusions = 0, nodeReductions = 0, nodeSplits = 0)
|
||||
|
||||
# s to output (exit node)
|
||||
d_exit = insert_node!(graph, make_node(DataTask(10)), track = false)
|
||||
@ -133,10 +133,13 @@ insert_edge!(graph, s0, d_exit, track = false)
|
||||
@test length(siblings(s0)) == 1
|
||||
|
||||
operations = get_operations(graph)
|
||||
@test length(operations) == (nodeReductions = 0, nodeSplits = 0)
|
||||
@test length(operations) == (nodeFusions = 10, nodeReductions = 0, nodeSplits = 0)
|
||||
@test length(graph.dirtyNodes) == 0
|
||||
|
||||
@test sum(length(operations)) == 10
|
||||
|
||||
@test operations == get_operations(graph)
|
||||
nf = first(operations.nodeFusions)
|
||||
|
||||
properties = get_properties(graph)
|
||||
@test properties.computeEffort == 28
|
||||
@ -145,19 +148,54 @@ properties = get_properties(graph)
|
||||
@test properties.noNodes == 26
|
||||
@test properties.noEdges == 25
|
||||
|
||||
push_operation!(graph, nf)
|
||||
# **does not immediately apply the operation**
|
||||
|
||||
@test length(graph.nodes) == 26
|
||||
@test length(graph.appliedOperations) == 0
|
||||
@test length(graph.operationsToApply) == 1
|
||||
@test first(graph.operationsToApply) == nf
|
||||
@test length(graph.dirtyNodes) == 0
|
||||
@test length(graph.diff) == (addedNodes = 0, removedNodes = 0, addedEdges = 0, removedEdges = 0)
|
||||
|
||||
# this applies pending operations
|
||||
properties = get_properties(graph)
|
||||
|
||||
@test length(graph.nodes) == 24
|
||||
@test length(graph.appliedOperations) == 1
|
||||
@test length(graph.operationsToApply) == 0
|
||||
@test length(graph.dirtyNodes) != 0
|
||||
@test properties.noNodes == 24
|
||||
@test properties.noEdges == 23
|
||||
@test properties.computeEffort == 28
|
||||
@test properties.data < 62
|
||||
@test properties.computeIntensity > 28 / 62
|
||||
|
||||
operations = get_operations(graph)
|
||||
@test length(graph.dirtyNodes) == 0
|
||||
|
||||
@test length(operations) == (nodeReductions = 0, nodeSplits = 0)
|
||||
@test length(operations) == (nodeFusions = 9, nodeReductions = 0, nodeSplits = 0)
|
||||
@test !isempty(operations)
|
||||
|
||||
possibleNF = 9
|
||||
while !isempty(operations.nodeFusions)
|
||||
push_operation!(graph, first(operations.nodeFusions))
|
||||
global operations = get_operations(graph)
|
||||
global possibleNF = possibleNF - 1
|
||||
@test length(operations) == (nodeFusions = possibleNF, nodeReductions = 0, nodeSplits = 0)
|
||||
end
|
||||
|
||||
@test isempty(operations)
|
||||
|
||||
@test length(operations) == (nodeFusions = 0, nodeReductions = 0, nodeSplits = 0)
|
||||
@test length(graph.dirtyNodes) == 0
|
||||
@test length(graph.nodes) == 26
|
||||
@test length(graph.appliedOperations) == 0
|
||||
@test length(graph.nodes) == 6
|
||||
@test length(graph.appliedOperations) == 10
|
||||
@test length(graph.operationsToApply) == 0
|
||||
|
||||
reset_graph!(graph)
|
||||
|
||||
@test length(graph.dirtyNodes) == 0
|
||||
@test length(graph.dirtyNodes) == 26
|
||||
@test length(graph.nodes) == 26
|
||||
@test length(graph.appliedOperations) == 0
|
||||
@test length(graph.operationsToApply) == 0
|
||||
@ -170,6 +208,6 @@ properties = get_properties(graph)
|
||||
@test properties.computeIntensity ≈ 28 / 62
|
||||
|
||||
operations = get_operations(graph)
|
||||
@test length(operations) == (nodeReductions = 0, nodeSplits = 0)
|
||||
@test length(operations) == (nodeFusions = 10, nodeReductions = 0, nodeSplits = 0)
|
||||
|
||||
@test is_valid(graph)
|
||||
|
@ -6,7 +6,8 @@ RNG = Random.MersenneTwister(0)
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"), ABCModel())
|
||||
|
||||
# create the optimizers
|
||||
FIXPOINT_OPTIMIZERS = [GreedyOptimizer(GlobalMetricEstimator()), ReductionOptimizer(), SplitOptimizer()]
|
||||
FIXPOINT_OPTIMIZERS =
|
||||
[GreedyOptimizer(GlobalMetricEstimator()), ReductionOptimizer(), SplitOptimizer(), FusionOptimizer()]
|
||||
NO_FIXPOINT_OPTIMIZERS = [RandomWalkOptimizer(RNG)]
|
||||
|
||||
@testset "Optimizer $optimizer" for optimizer in vcat(NO_FIXPOINT_OPTIMIZERS, FIXPOINT_OPTIMIZERS)
|
||||
|
@ -1,8 +1,7 @@
|
||||
using MetagraphOptimization
|
||||
|
||||
using QEDcore
|
||||
|
||||
import MetagraphOptimization.gen_diagrams
|
||||
import MetagraphOptimization.isincoming
|
||||
import MetagraphOptimization.types
|
||||
|
||||
|
||||
@ -10,12 +9,25 @@ model = QEDModel()
|
||||
compton = ("Compton Scattering", parse_process("ke->ke", model), 2)
|
||||
compton_3 = ("3-Photon Compton Scattering", parse_process("kkke->ke", QEDModel()), 24)
|
||||
compton_4 = ("4-Photon Compton Scattering", parse_process("kkkke->ke", QEDModel()), 120)
|
||||
bhabha = ("Bhabha Scattering", parse_process("ep->ep", model), 2)
|
||||
moller = ("Møller Scattering", parse_process("ee->ee", model), 2)
|
||||
pair_production = ("Pair production", parse_process("kk->ep", model), 2)
|
||||
pair_annihilation = ("Pair annihilation", parse_process("ep->kk", model), 2)
|
||||
trident = ("Trident", parse_process("ke->epe", model), 8)
|
||||
|
||||
@testset "Known Processes" begin
|
||||
@testset "$name" for (name, process, n) in [compton, compton_3, compton_4]
|
||||
@testset "$name" for (name, process, n) in
|
||||
[compton, bhabha, moller, pair_production, pair_annihilation, trident, compton_3, compton_4]
|
||||
initial_diagram = FeynmanDiagram(process)
|
||||
n_particles = number_incoming_particles(process) + number_outgoing_particles(process)
|
||||
|
||||
n_particles = 0
|
||||
for type in types(model)
|
||||
if (isincoming(type))
|
||||
n_particles += get(process.inParticles, type, 0)
|
||||
else
|
||||
n_particles += get(process.outParticles, type, 0)
|
||||
end
|
||||
end
|
||||
@test n_particles == length(initial_diagram.particles)
|
||||
@test ismissing(initial_diagram.tie[])
|
||||
@test isempty(initial_diagram.vertices)
|
||||
|
@ -1,6 +1,5 @@
|
||||
using MetagraphOptimization
|
||||
using QEDbase
|
||||
using QEDcore
|
||||
using QEDprocesses
|
||||
using StatsBase # for countmap
|
||||
using Random
|
||||
@ -10,6 +9,8 @@ import MetagraphOptimization.caninteract
|
||||
import MetagraphOptimization.issame
|
||||
import MetagraphOptimization.interaction_result
|
||||
import MetagraphOptimization.propagation_result
|
||||
import MetagraphOptimization.direction
|
||||
import MetagraphOptimization.spin_or_pol
|
||||
import MetagraphOptimization.QED_vertex
|
||||
|
||||
def_momentum = SFourMomentum(1.0, 0.0, 0.0, 0.0)
|
||||
@ -17,32 +18,32 @@ def_momentum = SFourMomentum(1.0, 0.0, 0.0, 0.0)
|
||||
RNG = Random.MersenneTwister(0)
|
||||
|
||||
testparticleTypes = [
|
||||
ParticleStateful{Incoming, Photon, SFourMomentum},
|
||||
ParticleStateful{Outgoing, Photon, SFourMomentum},
|
||||
ParticleStateful{Incoming, Electron, SFourMomentum},
|
||||
ParticleStateful{Outgoing, Electron, SFourMomentum},
|
||||
ParticleStateful{Incoming, Positron, SFourMomentum},
|
||||
ParticleStateful{Outgoing, Positron, SFourMomentum},
|
||||
PhotonStateful{Incoming, PolX},
|
||||
PhotonStateful{Outgoing, PolX},
|
||||
FermionStateful{Incoming, SpinUp},
|
||||
FermionStateful{Outgoing, SpinUp},
|
||||
AntiFermionStateful{Incoming, SpinUp},
|
||||
AntiFermionStateful{Outgoing, SpinUp},
|
||||
]
|
||||
|
||||
testparticleTypesPropagated = [
|
||||
ParticleStateful{Outgoing, Photon, SFourMomentum},
|
||||
ParticleStateful{Incoming, Photon, SFourMomentum},
|
||||
ParticleStateful{Outgoing, Electron, SFourMomentum},
|
||||
ParticleStateful{Incoming, Electron, SFourMomentum},
|
||||
ParticleStateful{Outgoing, Positron, SFourMomentum},
|
||||
ParticleStateful{Incoming, Positron, SFourMomentum},
|
||||
PhotonStateful{Outgoing, PolX},
|
||||
PhotonStateful{Incoming, PolX},
|
||||
FermionStateful{Outgoing, SpinUp},
|
||||
FermionStateful{Incoming, SpinUp},
|
||||
AntiFermionStateful{Outgoing, SpinUp},
|
||||
AntiFermionStateful{Incoming, SpinUp},
|
||||
]
|
||||
|
||||
function compton_groundtruth(input::PhaseSpacePoint)
|
||||
function compton_groundtruth(input::QEDProcessInput)
|
||||
# p1k1 -> p2k2
|
||||
# formula: −(ie)^2 (u(p2) slashed(ε1) S(p2 − k1) slashed(ε2) u(p1) + u(p2) slashed(ε2) S(p1 + k1) slashed(ε1) u(p1))
|
||||
|
||||
p1 = momentum(psp, Incoming(), 2)
|
||||
p2 = momentum(psp, Outgoing(), 2)
|
||||
p1 = input.inFerms[1]
|
||||
p2 = input.outFerms[1]
|
||||
|
||||
k1 = momentum(psp, Incoming(), 1)
|
||||
k2 = momentum(psp, Outgoing(), 1)
|
||||
k1 = input.inPhotons[1]
|
||||
k2 = input.outPhotons[1]
|
||||
|
||||
u_p1 = base_state(Electron(), Incoming(), p1.momentum, spin_or_pol(p1))
|
||||
u_p2 = base_state(Electron(), Outgoing(), p2.momentum, spin_or_pol(p2))
|
||||
@ -56,8 +57,8 @@ function compton_groundtruth(input::PhaseSpacePoint)
|
||||
virt2_mom = p1.momentum + k1.momentum
|
||||
@test isapprox(p2.momentum + k2.momentum, virt2_mom)
|
||||
|
||||
s_p2_k1 = QEDbase.propagator(Electron(), virt1_mom)
|
||||
s_p1_k1 = QEDbase.propagator(Electron(), virt2_mom)
|
||||
s_p2_k1 = propagator(Electron(), virt1_mom)
|
||||
s_p1_k1 = propagator(Electron(), virt2_mom)
|
||||
|
||||
diagram1 = u_p2 * (eps_1 * QED_vertex()) * s_p2_k1 * (eps_2 * QED_vertex()) * u_p1
|
||||
diagram2 = u_p2 * (eps_2 * QED_vertex()) * s_p1_k1 * (eps_1 * QED_vertex()) * u_p1
|
||||
@ -65,6 +66,7 @@ function compton_groundtruth(input::PhaseSpacePoint)
|
||||
return diagram1 + diagram2
|
||||
end
|
||||
|
||||
|
||||
@testset "Interaction Result" begin
|
||||
import MetagraphOptimization.QED_conserve_momentum
|
||||
|
||||
@ -86,8 +88,8 @@ end
|
||||
@test issame(typeof(resultParticle), interaction_result(p1, p2))
|
||||
|
||||
totalMom = zero(SFourMomentum)
|
||||
for (p, mom) in [(p1, momentum(testParticle1)), (p2, momentum(testParticle2)), (p3, momentum(resultParticle))]
|
||||
if (typeof(particle_direction(p)) <: Incoming)
|
||||
for (p, mom) in [(p1, testParticle1.momentum), (p2, testParticle2.momentum), (p3, resultParticle.momentum)]
|
||||
if (typeof(direction(p)) <: Incoming)
|
||||
totalMom += mom
|
||||
else
|
||||
totalMom -= mom
|
||||
@ -101,31 +103,54 @@ end
|
||||
@testset "Propagation Result" begin
|
||||
for (p, propResult) in zip(testparticleTypes, testparticleTypesPropagated)
|
||||
@test issame(propagation_result(p), propResult)
|
||||
@test particle_direction(propagation_result(p)(def_momentum)) != particle_direction(p(def_momentum))
|
||||
@test direction(propagation_result(p)(def_momentum)) != direction(p(def_momentum))
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Parse Process" begin
|
||||
@testset "Order invariance" begin
|
||||
@test parse_process("ke->ke", QEDModel()) == parse_process("ek->ke", QEDModel())
|
||||
@test parse_process("ke->ke", QEDModel()) == parse_process("ek->ek", QEDModel())
|
||||
@test parse_process("ke->ke", QEDModel()) == parse_process("ke->ek", QEDModel())
|
||||
|
||||
@test parse_process("kkke->eep", QEDModel()) == parse_process("kkek->epe", QEDModel())
|
||||
end
|
||||
|
||||
@testset "Known processes" begin
|
||||
proc = parse_process("ke->ke", QEDModel())
|
||||
@test incoming_particles(proc) == (Photon(), Electron())
|
||||
@test outgoing_particles(proc) == (Photon(), Electron())
|
||||
compton_process = QEDProcessDescription(
|
||||
Dict{Type, Int}(PhotonStateful{Incoming, PolX} => 1, FermionStateful{Incoming, SpinUp} => 1),
|
||||
Dict{Type, Int}(PhotonStateful{Outgoing, PolX} => 1, FermionStateful{Outgoing, SpinUp} => 1),
|
||||
)
|
||||
|
||||
proc = parse_process("kp->kp", QEDModel())
|
||||
@test incoming_particles(proc) == (Photon(), Positron())
|
||||
@test outgoing_particles(proc) == (Photon(), Positron())
|
||||
@test parse_process("ke->ke", QEDModel()) == compton_process
|
||||
|
||||
proc = parse_process("ke->eep", QEDModel())
|
||||
@test incoming_particles(proc) == (Photon(), Electron())
|
||||
@test outgoing_particles(proc) == (Electron(), Electron(), Positron())
|
||||
positron_compton_process = QEDProcessDescription(
|
||||
Dict{Type, Int}(PhotonStateful{Incoming, PolX} => 1, AntiFermionStateful{Incoming, SpinUp} => 1),
|
||||
Dict{Type, Int}(PhotonStateful{Outgoing, PolX} => 1, AntiFermionStateful{Outgoing, SpinUp} => 1),
|
||||
)
|
||||
|
||||
proc = parse_process("kk->pe", QEDModel())
|
||||
@test incoming_particles(proc) == (Photon(), Photon())
|
||||
@test outgoing_particles(proc) == (Positron(), Electron())
|
||||
@test parse_process("kp->kp", QEDModel()) == positron_compton_process
|
||||
|
||||
proc = parse_process("pe->kk", QEDModel())
|
||||
@test incoming_particles(proc) == (Positron(), Electron())
|
||||
@test outgoing_particles(proc) == (Photon(), Photon())
|
||||
trident_process = QEDProcessDescription(
|
||||
Dict{Type, Int}(PhotonStateful{Incoming, PolX} => 1, FermionStateful{Incoming, SpinUp} => 1),
|
||||
Dict{Type, Int}(FermionStateful{Outgoing, SpinUp} => 2, AntiFermionStateful{Outgoing, SpinUp} => 1),
|
||||
)
|
||||
|
||||
@test parse_process("ke->eep", QEDModel()) == trident_process
|
||||
|
||||
pair_production_process = QEDProcessDescription(
|
||||
Dict{Type, Int}(PhotonStateful{Incoming, PolX} => 2),
|
||||
Dict{Type, Int}(FermionStateful{Outgoing, SpinUp} => 1, AntiFermionStateful{Outgoing, SpinUp} => 1),
|
||||
)
|
||||
|
||||
@test parse_process("kk->pe", QEDModel()) == pair_production_process
|
||||
|
||||
pair_annihilation_process = QEDProcessDescription(
|
||||
Dict{Type, Int}(FermionStateful{Incoming, SpinUp} => 1, AntiFermionStateful{Incoming, SpinUp} => 1),
|
||||
Dict{Type, Int}(PhotonStateful{Outgoing, PolX} => 2),
|
||||
)
|
||||
|
||||
@test parse_process("pe->kk", QEDModel()) == pair_annihilation_process
|
||||
end
|
||||
end
|
||||
|
||||
@ -136,12 +161,30 @@ end
|
||||
|
||||
for i in 1:100
|
||||
input = gen_process_input(process)
|
||||
@test isapprox(sum(momenta(input, Incoming())), sum(momenta(input, Outgoing())); atol = sqrt(eps()))
|
||||
@test length(input.inFerms) == get(process.inParticles, FermionStateful{Incoming, SpinUp}, 0)
|
||||
@test length(input.inAntiferms) == get(process.inParticles, AntiFermionStateful{Incoming, SpinUp}, 0)
|
||||
@test length(input.inPhotons) == get(process.inParticles, PhotonStateful{Incoming, PolX}, 0)
|
||||
@test length(input.outFerms) == get(process.outParticles, FermionStateful{Outgoing, SpinUp}, 0)
|
||||
@test length(input.outAntiferms) == get(process.outParticles, AntiFermionStateful{Outgoing, SpinUp}, 0)
|
||||
@test length(input.outPhotons) == get(process.outParticles, PhotonStateful{Outgoing, PolX}, 0)
|
||||
|
||||
@test isapprox(
|
||||
sum([
|
||||
getfield.(input.inFerms, :momentum)...,
|
||||
getfield.(input.inAntiferms, :momentum)...,
|
||||
getfield.(input.inPhotons, :momentum)...,
|
||||
]),
|
||||
sum([
|
||||
getfield.(input.outFerms, :momentum)...,
|
||||
getfield.(input.outAntiferms, :momentum)...,
|
||||
getfield.(input.outPhotons, :momentum)...,
|
||||
]);
|
||||
atol = sqrt(eps()),
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
#=
|
||||
@testset "Compton" begin
|
||||
import MetagraphOptimization.insert_node!
|
||||
import MetagraphOptimization.insert_edge!
|
||||
@ -168,97 +211,97 @@ end
|
||||
graph = DAG()
|
||||
|
||||
# s to output (exit node)
|
||||
d_exit = insert_node!(graph, make_node(DataTask(16)); track=false)
|
||||
d_exit = insert_node!(graph, make_node(DataTask(16)), track = false)
|
||||
|
||||
sum_node = insert_node!(graph, make_node(ComputeTaskQED_Sum(2)); track=false)
|
||||
sum_node = insert_node!(graph, make_node(ComputeTaskQED_Sum(2)), track = false)
|
||||
|
||||
d_s0_sum = insert_node!(graph, make_node(DataTask(16)); track=false)
|
||||
d_s1_sum = insert_node!(graph, make_node(DataTask(16)); track=false)
|
||||
d_s0_sum = insert_node!(graph, make_node(DataTask(16)), track = false)
|
||||
d_s1_sum = insert_node!(graph, make_node(DataTask(16)), track = false)
|
||||
|
||||
# final s compute
|
||||
s0 = insert_node!(graph, make_node(ComputeTaskQED_S2()); track=false)
|
||||
s1 = insert_node!(graph, make_node(ComputeTaskQED_S2()); track=false)
|
||||
s0 = insert_node!(graph, make_node(ComputeTaskQED_S2()), track = false)
|
||||
s1 = insert_node!(graph, make_node(ComputeTaskQED_S2()), track = false)
|
||||
|
||||
# data from v0 and v1 to s0
|
||||
d_v0_s0 = insert_node!(graph, make_node(DataTask(96)); track=false)
|
||||
d_v1_s0 = insert_node!(graph, make_node(DataTask(96)); track=false)
|
||||
d_v2_s1 = insert_node!(graph, make_node(DataTask(96)); track=false)
|
||||
d_v3_s1 = insert_node!(graph, make_node(DataTask(96)); track=false)
|
||||
d_v0_s0 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
d_v1_s0 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
d_v2_s1 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
d_v3_s1 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
|
||||
# v0 and v1 compute
|
||||
v0 = insert_node!(graph, make_node(ComputeTaskQED_V()); track=false)
|
||||
v1 = insert_node!(graph, make_node(ComputeTaskQED_V()); track=false)
|
||||
v2 = insert_node!(graph, make_node(ComputeTaskQED_V()); track=false)
|
||||
v3 = insert_node!(graph, make_node(ComputeTaskQED_V()); track=false)
|
||||
v0 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
|
||||
v1 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
|
||||
v2 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
|
||||
v3 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
|
||||
|
||||
# data from uPhIn, uPhOut, uElIn, uElOut to v0 and v1
|
||||
d_uPhIn_v0 = insert_node!(graph, make_node(DataTask(96)); track=false)
|
||||
d_uElIn_v0 = insert_node!(graph, make_node(DataTask(96)); track=false)
|
||||
d_uPhOut_v1 = insert_node!(graph, make_node(DataTask(96)); track=false)
|
||||
d_uElOut_v1 = insert_node!(graph, make_node(DataTask(96)); track=false)
|
||||
d_uPhIn_v0 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
d_uElIn_v0 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
d_uPhOut_v1 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
d_uElOut_v1 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
|
||||
# data from uPhIn, uPhOut, uElIn, uElOut to v2 and v3
|
||||
d_uPhOut_v2 = insert_node!(graph, make_node(DataTask(96)); track=false)
|
||||
d_uElIn_v2 = insert_node!(graph, make_node(DataTask(96)); track=false)
|
||||
d_uPhIn_v3 = insert_node!(graph, make_node(DataTask(96)); track=false)
|
||||
d_uElOut_v3 = insert_node!(graph, make_node(DataTask(96)); track=false)
|
||||
d_uPhOut_v2 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
d_uElIn_v2 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
d_uPhIn_v3 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
d_uElOut_v3 = insert_node!(graph, make_node(DataTask(96)), track = false)
|
||||
|
||||
# uPhIn, uPhOut, uElIn and uElOut computes
|
||||
uPhIn = insert_node!(graph, make_node(ComputeTaskQED_U()); track=false)
|
||||
uPhOut = insert_node!(graph, make_node(ComputeTaskQED_U()); track=false)
|
||||
uElIn = insert_node!(graph, make_node(ComputeTaskQED_U()); track=false)
|
||||
uElOut = insert_node!(graph, make_node(ComputeTaskQED_U()); track=false)
|
||||
uPhIn = insert_node!(graph, make_node(ComputeTaskQED_U()), track = false)
|
||||
uPhOut = insert_node!(graph, make_node(ComputeTaskQED_U()), track = false)
|
||||
uElIn = insert_node!(graph, make_node(ComputeTaskQED_U()), track = false)
|
||||
uElOut = insert_node!(graph, make_node(ComputeTaskQED_U()), track = false)
|
||||
|
||||
# data into U
|
||||
d_uPhIn = insert_node!(graph, make_node(DataTask(16), "ki1"); track=false)
|
||||
d_uPhOut = insert_node!(graph, make_node(DataTask(16), "ko1"); track=false)
|
||||
d_uElIn = insert_node!(graph, make_node(DataTask(16), "ei1"); track=false)
|
||||
d_uElOut = insert_node!(graph, make_node(DataTask(16), "eo1"); track=false)
|
||||
d_uPhIn = insert_node!(graph, make_node(DataTask(16), "ki1"), track = false)
|
||||
d_uPhOut = insert_node!(graph, make_node(DataTask(16), "ko1"), track = false)
|
||||
d_uElIn = insert_node!(graph, make_node(DataTask(16), "ei1"), track = false)
|
||||
d_uElOut = insert_node!(graph, make_node(DataTask(16), "eo1"), track = false)
|
||||
|
||||
# now for all the edges
|
||||
insert_edge!(graph, d_uPhIn, uPhIn; track=false)
|
||||
insert_edge!(graph, d_uPhOut, uPhOut; track=false)
|
||||
insert_edge!(graph, d_uElIn, uElIn; track=false)
|
||||
insert_edge!(graph, d_uElOut, uElOut; track=false)
|
||||
insert_edge!(graph, d_uPhIn, uPhIn, track = false)
|
||||
insert_edge!(graph, d_uPhOut, uPhOut, track = false)
|
||||
insert_edge!(graph, d_uElIn, uElIn, track = false)
|
||||
insert_edge!(graph, d_uElOut, uElOut, track = false)
|
||||
|
||||
insert_edge!(graph, uPhIn, d_uPhIn_v0; track=false)
|
||||
insert_edge!(graph, uPhOut, d_uPhOut_v1; track=false)
|
||||
insert_edge!(graph, uElIn, d_uElIn_v0; track=false)
|
||||
insert_edge!(graph, uElOut, d_uElOut_v1; track=false)
|
||||
insert_edge!(graph, uPhIn, d_uPhIn_v0, track = false)
|
||||
insert_edge!(graph, uPhOut, d_uPhOut_v1, track = false)
|
||||
insert_edge!(graph, uElIn, d_uElIn_v0, track = false)
|
||||
insert_edge!(graph, uElOut, d_uElOut_v1, track = false)
|
||||
|
||||
insert_edge!(graph, uPhIn, d_uPhIn_v3; track=false)
|
||||
insert_edge!(graph, uPhOut, d_uPhOut_v2; track=false)
|
||||
insert_edge!(graph, uElIn, d_uElIn_v2; track=false)
|
||||
insert_edge!(graph, uElOut, d_uElOut_v3; track=false)
|
||||
insert_edge!(graph, uPhIn, d_uPhIn_v3, track = false)
|
||||
insert_edge!(graph, uPhOut, d_uPhOut_v2, track = false)
|
||||
insert_edge!(graph, uElIn, d_uElIn_v2, track = false)
|
||||
insert_edge!(graph, uElOut, d_uElOut_v3, track = false)
|
||||
|
||||
insert_edge!(graph, d_uPhIn_v0, v0; track=false)
|
||||
insert_edge!(graph, d_uPhOut_v1, v1; track=false)
|
||||
insert_edge!(graph, d_uElIn_v0, v0; track=false)
|
||||
insert_edge!(graph, d_uElOut_v1, v1; track=false)
|
||||
insert_edge!(graph, d_uPhIn_v0, v0, track = false)
|
||||
insert_edge!(graph, d_uPhOut_v1, v1, track = false)
|
||||
insert_edge!(graph, d_uElIn_v0, v0, track = false)
|
||||
insert_edge!(graph, d_uElOut_v1, v1, track = false)
|
||||
|
||||
insert_edge!(graph, d_uPhIn_v3, v3; track=false)
|
||||
insert_edge!(graph, d_uPhOut_v2, v2; track=false)
|
||||
insert_edge!(graph, d_uElIn_v2, v2; track=false)
|
||||
insert_edge!(graph, d_uElOut_v3, v3; track=false)
|
||||
insert_edge!(graph, d_uPhIn_v3, v3, track = false)
|
||||
insert_edge!(graph, d_uPhOut_v2, v2, track = false)
|
||||
insert_edge!(graph, d_uElIn_v2, v2, track = false)
|
||||
insert_edge!(graph, d_uElOut_v3, v3, track = false)
|
||||
|
||||
insert_edge!(graph, v0, d_v0_s0; track=false)
|
||||
insert_edge!(graph, v1, d_v1_s0; track=false)
|
||||
insert_edge!(graph, v2, d_v2_s1; track=false)
|
||||
insert_edge!(graph, v3, d_v3_s1; track=false)
|
||||
insert_edge!(graph, v0, d_v0_s0, track = false)
|
||||
insert_edge!(graph, v1, d_v1_s0, track = false)
|
||||
insert_edge!(graph, v2, d_v2_s1, track = false)
|
||||
insert_edge!(graph, v3, d_v3_s1, track = false)
|
||||
|
||||
insert_edge!(graph, d_v0_s0, s0; track=false)
|
||||
insert_edge!(graph, d_v1_s0, s0; track=false)
|
||||
insert_edge!(graph, d_v0_s0, s0, track = false)
|
||||
insert_edge!(graph, d_v1_s0, s0, track = false)
|
||||
|
||||
insert_edge!(graph, d_v2_s1, s1; track=false)
|
||||
insert_edge!(graph, d_v3_s1, s1; track=false)
|
||||
insert_edge!(graph, d_v2_s1, s1, track = false)
|
||||
insert_edge!(graph, d_v3_s1, s1, track = false)
|
||||
|
||||
insert_edge!(graph, s0, d_s0_sum; track=false)
|
||||
insert_edge!(graph, s1, d_s1_sum; track=false)
|
||||
insert_edge!(graph, s0, d_s0_sum, track = false)
|
||||
insert_edge!(graph, s1, d_s1_sum, track = false)
|
||||
|
||||
insert_edge!(graph, d_s0_sum, sum_node; track=false)
|
||||
insert_edge!(graph, d_s1_sum, sum_node; track=false)
|
||||
insert_edge!(graph, d_s0_sum, sum_node, track = false)
|
||||
insert_edge!(graph, d_s1_sum, sum_node, track = false)
|
||||
|
||||
insert_edge!(graph, sum_node, d_exit; track=false)
|
||||
insert_edge!(graph, sum_node, d_exit, track = false)
|
||||
|
||||
input = [gen_process_input(process) for _ in 1:1000]
|
||||
|
||||
@ -271,12 +314,9 @@ end
|
||||
@test isapprox(compton_function.(input), compton_groundtruth.(input))
|
||||
end
|
||||
|
||||
@testset "Equal results after optimization" for optimizer in [
|
||||
ReductionOptimizer(), RandomWalkOptimizer(MersenneTwister(0))
|
||||
]
|
||||
@testset "Process $proc_str" for proc_str in [
|
||||
"ke->ke", "kp->kp", "kk->ep", "ep->kk", "ke->kke", "ke->kkke"
|
||||
]
|
||||
@testset "Equal results after optimization" for optimizer in
|
||||
[ReductionOptimizer(), RandomWalkOptimizer(MersenneTwister(0))]
|
||||
@testset "Process $proc_str" for proc_str in ["ke->ke", "kp->kp", "kk->ep", "ep->kk", "ke->kke", "ke->kkke"]
|
||||
model = QEDModel()
|
||||
process = parse_process(proc_str, model)
|
||||
machine = Machine(
|
||||
@ -307,4 +347,3 @@ end
|
||||
@test isapprox(compute_function.(input), reduced_compute_function.(input))
|
||||
end
|
||||
end
|
||||
=#
|
Loading…
x
Reference in New Issue
Block a user