Compare commits

...

13 Commits

Author SHA1 Message Date
994d4d7cee
Start adapting ABCModel implementation to new interfaces
Some checks failed
MetagraphOptimization_CI / test (push) Failing after 7m6s
MetagraphOptimization_CI / docs (push) Failing after 7m15s
2024-08-19 18:37:55 +02:00
0ce98e29ef
Reenable tests and fix a lot
Some checks failed
MetagraphOptimization_CI / test (push) Failing after 6m58s
MetagraphOptimization_CI / docs (push) Failing after 7m6s
2024-08-19 17:45:40 +02:00
d553fe8ffc
Reenable ABC Model
Some checks failed
MetagraphOptimization_CI / test (push) Failing after 6m46s
MetagraphOptimization_CI / docs (push) Successful in 7m37s
2024-08-19 14:54:34 +02:00
e9bd1f2939
Still remove NodeFusion
Some checks failed
MetagraphOptimization_CI / test (push) Failing after 7m23s
MetagraphOptimization_CI / docs (push) Successful in 7m57s
2024-08-19 14:02:46 +02:00
97ccb3f3fb
Remove occurrences of Fusion/Fuse
Some checks failed
MetagraphOptimization_CI / docs (push) Failing after 6m56s
MetagraphOptimization_CI / test (push) Failing after 7m25s
2024-08-13 17:57:16 +02:00
8a5e49429b Don't eval in generated function return
Some checks failed
MetagraphOptimization_CI / test (push) Failing after 6m58s
MetagraphOptimization_CI / docs (push) Successful in 7m27s
2024-08-09 12:21:49 +02:00
5be7ca99e7 Add results and evaluation
Some checks failed
MetagraphOptimization_CI / test (push) Failing after 1m33s
MetagraphOptimization_CI / docs (push) Failing after 1m33s
2024-07-10 14:21:26 +02:00
1ae39a8caa Congruent in photons example (#12)
Some checks failed
MetagraphOptimization_CI / docs (push) Failing after 1m50s
MetagraphOptimization_CI / test (push) Failing after 1m53s
Now targeting the correct branches

Co-authored-by: Rubydragon <anton.reinhard@proton.me>
Reviewed-on: #12
2024-07-10 14:17:39 +02:00
b5d92b729c Get compute function working
Some checks failed
MetagraphOptimization_CI / docs (push) Failing after 1m32s
MetagraphOptimization_CI / test (push) Failing after 1m49s
2024-07-04 15:31:22 +02:00
6a9a7b41f1 rework a lot of the QED model to use QEDcore/base/processes 2024-07-03 20:24:53 +02:00
a1581182ca WIP 2024-07-02 10:50:30 +02:00
1b4ba285c3 WIP refactor
Some checks failed
MetagraphOptimization_CI / test (push) Failing after 7m59s
MetagraphOptimization_CI / docs (push) Failing after 8m6s
2024-06-24 23:31:30 +02:00
2921882fd4 EOD
Some checks failed
MetagraphOptimization_CI / docs (push) Failing after 8m32s
MetagraphOptimization_CI / test (push) Failing after 11m24s
2024-05-24 19:20:59 +02:00
80 changed files with 1386 additions and 2114 deletions

4
.gitattributes vendored
View File

@ -1,3 +1,5 @@
input/AB->ABBBBBBBBB.txt filter=lfs diff=lfs merge=lfs -text
input/AB->ABBBBBBB.txt filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs
*.gif filter=lfs diff=lfs merge=lfs
*.jld2 filter=lfs diff=lfs merge=lfs

View File

@ -14,16 +14,22 @@ JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
NumaAllocators = "21436f30-1b4a-4f08-87af-e26101bb5379"
QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
QEDcore = "35dc0263-cb5f-4c33-a114-1d7f54ab753e"
QEDprocesses = "46de9c38-1bb3-4547-a1ec-da24d767fdad"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
[extras]
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
QEDcore = "35dc0263-cb5f-4c33-a114-1d7f54ab753e"
QEDprocesses = "46de9c38-1bb3-4547-a1ec-da24d767fdad"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[targets]
test = ["Test"]
test = ["SafeTestsets", "Test", "QEDbase", "QEDcore", "QEDprocesses"]

View File

@ -5,7 +5,7 @@
## Package Features
- Read a DAG from a file
- Analyze its properties
- Mute the graph using the operations NodeFusion, NodeReduction and NodeSplit
- Mute the graph using the operations NodeReduction and NodeSplit
## Coming Soon:
- Add Code Generation from finished DAG

194
examples/congruent_in_ph.jl Normal file
View File

@ -0,0 +1,194 @@
ENV["UCX_ERROR_SIGNALS"] = "SIGILL,SIGBUS,SIGFPE"
using MetagraphOptimization
using QEDbase
using QEDcore
using QEDprocesses
using Random
using UUIDs
using CUDA
using NamedDims
using CSV
using JLD2
using FlexiMaps
RNG = Random.MersenneTwister(123)
function mock_machine()
return Machine(
[
MetagraphOptimization.NumaNode(
0,
1,
MetagraphOptimization.default_strategy(MetagraphOptimization.NumaNode),
-1.0,
UUIDs.uuid1(),
),
],
[-1.0;;],
)
end
function congruent_input_momenta(processDescription::GenericQEDProcess, omega::Number)
# generate an input sample for given e + nk -> e' + k' process, where the nk are equal
inputMasses = Vector{Float64}()
for particle in incoming_particles(processDescription)
push!(inputMasses, mass(particle))
end
outputMasses = Vector{Float64}()
for particle in outgoing_particles(processDescription)
push!(outputMasses, mass(particle))
end
initial_momenta = [
i == length(inputMasses) ? SFourMomentum(1, 0, 0, 0) : SFourMomentum(omega, 0, 0, omega) for
i in 1:length(inputMasses)
]
ss = sqrt(sum(initial_momenta) * sum(initial_momenta))
final_momenta = MetagraphOptimization.generate_physical_massive_moms(RNG, ss, outputMasses)
return (tuple(initial_momenta...), tuple(final_momenta...))
end
# theta ∈ [0, 2π] and phi ∈ [0, 2π]
function congruent_input_momenta_scenario_2(
processDescription::GenericQEDProcess,
omega::Number,
theta::Number,
phi::Number,
)
# -------------
# same as above
# generate an input sample for given e + nk -> e' + k' process, where the nk are equal
inputMasses = Vector{Float64}()
for particle in incoming_particles(processDescription)
push!(inputMasses, mass(particle))
end
outputMasses = Vector{Float64}()
for particle in outgoing_particles(processDescription)
push!(outputMasses, mass(particle))
end
initial_momenta = [
i == length(inputMasses) ? SFourMomentum(1, 0, 0, 0) : SFourMomentum(omega, 0, 0, omega) for
i in 1:length(inputMasses)
]
ss = sqrt(sum(initial_momenta) * sum(initial_momenta))
# up to here
# ----------
# now calculate the final_momenta from omega, cos_theta and phi
n = number_particles(processDescription, Incoming(), Photon())
cos_theta = cos(theta)
omega_prime = (n * omega) / (1 + n * omega * (1 - cos_theta))
k_prime =
omega_prime * SFourMomentum(1, sqrt(1 - cos_theta^2) * cos(phi), sqrt(1 - cos_theta^2) * sin(phi), cos_theta)
p_prime = sum(initial_momenta) - k_prime
final_momenta = (k_prime, p_prime)
return (tuple(initial_momenta...), tuple(final_momenta...))
end
function build_psp(processDescription::GenericQEDProcess, momenta)
return PhaseSpacePoint(
processDescription,
PerturbativeQED(),
PhasespaceDefinition(SphericalCoordinateSystem(), ElectronRestFrame()),
momenta[1],
momenta[2],
)
end
# hack to fix stacksize for threading
with_stacksize(f, n) = fetch(schedule(Task(f, n)))
# scenario 2
N = 1024 # thetas
M = 1024 # phis
K = 64 # omegas
thetas = collect(LinRange(0, 2π, N))
phis = collect(LinRange(0, 2π, M))
omegas = collect(maprange(log, 2e-2, 2e-7, K))
for photons in 1:5
# temp process to generate momenta
println("Generating $(K*N*M) inputs for $photons photons (Scenario 2 grid walk)...")
temp_process = parse_process("k"^photons * "e->ke", QEDModel(), PolX(), SpinUp(), PolX(), SpinUp())
input_momenta =
Array{typeof(congruent_input_momenta_scenario_2(temp_process, omegas[1], thetas[1], phis[1]))}(undef, (K, N, M))
Threads.@threads for k in 1:K
Threads.@threads for i in 1:N
Threads.@threads for j in 1:M
input_momenta[k, i, j] = congruent_input_momenta_scenario_2(temp_process, omegas[k], thetas[i], phis[j])
end
end
end
cu_results = CuArray{Float64}(undef, size(input_momenta))
fill!(cu_results, 0.0)
i = 1
for (in_pol, in_spin, out_pol, out_spin) in
Iterators.product([PolX(), PolY()], [SpinUp(), SpinDown()], [PolX(), PolY()], [SpinUp(), SpinDown()])
print(
"[$i/16] Calculating for spin/pol config: $in_pol, $in_spin -> $out_pol, $out_spin... Preparing inputs... ",
)
process = parse_process("k"^photons * "e->ke", QEDModel(), in_pol, in_spin, out_pol, out_spin)
inputs = Array{typeof(build_psp(process, input_momenta[1, 1, 1]))}(undef, (K, N, M))
#println("input_momenta: $input_momenta")
Threads.@threads for k in 1:K
Threads.@threads for i in 1:N
Threads.@threads for j in 1:M
inputs[k, i, j] = build_psp(process, input_momenta[k, i, j])
end
end
end
cu_inputs = CuArray(inputs)
print("Preparing graph... ")
graph = gen_graph(process)
optimize_to_fixpoint!(ReductionOptimizer(), graph)
print("Preparing function... ")
kernel! = get_cuda_kernel(graph, process, mock_machine())
#func = get_compute_function(graph, process, mock_machine())
print("Calculating... ")
ts = 32
bs = Int64(length(cu_inputs) / 32)
outputs = CuArray{ComplexF64}(undef, size(cu_inputs))
@cuda threads = ts blocks = bs always_inline = true kernel!(cu_inputs, outputs, length(cu_inputs))
CUDA.device_synchronize()
cu_results += abs2.(outputs)
println("Done.")
i += 1
end
println("Writing results")
out_ph_moms = getindex.(getindex.(input_momenta, 2), 1)
out_el_moms = getindex.(getindex.(input_momenta, 2), 2)
results = NamedDimsArray{(:omegas, :thetas, :phis)}(Array(cu_results))
println("Named results array: $(typeof(results))")
@save "$(photons)_congruent_photons_grid.jld2" omegas thetas phis results
end

View File

@ -1,60 +0,0 @@
using MetagraphOptimization
using Plots
using Random
function gen_plot(filepath)
name = basename(filepath)
name, _ = splitext(name)
filepath = joinpath(@__DIR__, "../input/", filepath)
if !isfile(filepath)
println("File ", filepath, " does not exist, skipping")
return
end
g = parse_dag(filepath, ABCModel())
Random.seed!(1)
println("Random Walking... ")
x = Vector{Float64}()
y = Vector{Float64}()
for i in 1:30
print("\r", i)
# push
opt = get_operations(g)
# choose one of fuse/split/reduce
option = rand(1:3)
if option == 1 && !isempty(opt.nodeFusions)
push_operation!(g, rand(collect(opt.nodeFusions)))
println("NF")
elseif option == 2 && !isempty(opt.nodeReductions)
push_operation!(g, rand(collect(opt.nodeReductions)))
println("NR")
elseif option == 3 && !isempty(opt.nodeSplits)
push_operation!(g, rand(collect(opt.nodeSplits)))
println("NS")
else
i = i - 1
end
props = get_properties(g)
push!(x, props.data)
push!(y, props.computeEffort)
end
println("\rDone.")
plot([x[1], x[2]], [y[1], y[2]], linestyle = :solid, linewidth = 1, color = :red, legend = false)
# Create lines connecting the reference point to each data point
for i in 3:length(x)
plot!([x[i - 1], x[i]], [y[i - 1], y[i]], linestyle = :solid, linewidth = 1, color = :red)
end
return gui()
end
gen_plot("AB->ABBB.txt")

View File

@ -1,96 +0,0 @@
using MetagraphOptimization
using Plots
using Random
function gen_plot(filepath)
name = basename(filepath)
name, _ = splitext(name)
filepath = joinpath(@__DIR__, "../input/", filepath)
if !isfile(filepath)
println("File ", filepath, " does not exist, skipping")
return
end
g = parse_dag(filepath, ABCModel())
Random.seed!(1)
println("Random Walking... ")
for i in 1:30
print("\r", i)
# push
opt = get_operations(g)
# choose one of fuse/split/reduce
option = rand(1:3)
if option == 1 && !isempty(opt.nodeFusions)
push_operation!(g, rand(collect(opt.nodeFusions)))
println("NF")
elseif option == 2 && !isempty(opt.nodeReductions)
push_operation!(g, rand(collect(opt.nodeReductions)))
println("NR")
elseif option == 3 && !isempty(opt.nodeSplits)
push_operation!(g, rand(collect(opt.nodeSplits)))
println("NS")
else
i = i - 1
end
end
println("\rDone.")
props = get_properties(g)
x0 = props.data
y0 = props.computeEffort
x = Vector{Float64}()
y = Vector{Float64}()
names = Vector{String}()
opt = get_operations(g)
for op in opt.nodeFusions
push_operation!(g, op)
props = get_properties(g)
push!(x, props.data)
push!(y, props.computeEffort)
pop_operation!(g)
push!(names, "NF: (" * string(props.data) * ", " * string(props.computeEffort) * ")")
end
for op in opt.nodeReductions
push_operation!(g, op)
props = get_properties(g)
push!(x, props.data)
push!(y, props.computeEffort)
pop_operation!(g)
push!(names, "NR: (" * string(props.data) * ", " * string(props.computeEffort) * ")")
end
for op in opt.nodeSplits
push_operation!(g, op)
props = get_properties(g)
push!(x, props.data)
push!(y, props.computeEffort)
pop_operation!(g)
push!(names, "NS: (" * string(props.data) * ", " * string(props.computeEffort) * ")")
end
plot([x0, x[1]], [y0, y[1]], linestyle = :solid, linewidth = 1, color = :red, legend = false)
# Create lines connecting the reference point to each data point
for i in 2:length(x)
plot!([x0, x[i]], [y0, y[i]], linestyle = :solid, linewidth = 1, color = :red)
end
#scatter!(x, y, label=names)
print(names)
return gui()
end
gen_plot("AB->ABBB.txt")

BIN
images/contour_plot_congruent_in_photons.gif (Stored with Git LFS) Normal file

Binary file not shown.

View File

@ -1,5 +1,5 @@
# Optimizer Plots
Plots of FusionOptimizer, ReductionOptimizer, SplitOptimizer, RandomWalkOptimizer, and GreedyOptimizer, executed on a system with 32 threads and an A30 GPU.
Plots of FusionOptimizer (deprecated), ReductionOptimizer, SplitOptimizer, RandomWalkOptimizer, and GreedyOptimizer, executed on a system with 32 threads and an A30 GPU.
Benchmarked using `notebooks/optimizers.ipynb`.

View File

@ -413,7 +413,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.10.2",
"display_name": "Julia 1.10.4",
"language": "julia",
"name": "julia-1.10"
},
@ -421,7 +421,7 @@
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.10.2"
"version": "1.10.4"
}
},
"nbformat": 4,

File diff suppressed because one or more lines are too long

View File

@ -54,8 +54,6 @@
"cu_inputs = CuArray(inputs)\n",
"optimizer = RandomWalkOptimizer(MersenneTwister(0))# GreedyOptimizer(GlobalMetricEstimator())\n",
"\n",
"#done: split, reduce, fuse, greedy\n",
"\n",
"process_str_short = \"qed_k3\"\n",
"optim_str = \"Random Walk Optimization\"\n",
"optim_str_short=\"random\"\n",

BIN
results/1_congruent_photons_grid.jld2 (Stored with Git LFS) Normal file

Binary file not shown.

BIN
results/2_congruent_photons_grid.jld2 (Stored with Git LFS) Normal file

Binary file not shown.

BIN
results/3_congruent_photons_grid.jld2 (Stored with Git LFS) Normal file

Binary file not shown.

BIN
results/4_congruent_photons_grid.jld2 (Stored with Git LFS) Normal file

Binary file not shown.

BIN
results/5_congruent_photons_grid.jld2 (Stored with Git LFS) Normal file

Binary file not shown.

View File

@ -6,6 +6,8 @@ A module containing tools to work on DAGs.
module MetagraphOptimization
using QEDbase
using QEDcore
using QEDprocesses
# graph types
export DAG
@ -17,7 +19,6 @@ export AbstractTask
export AbstractComputeTask
export AbstractDataTask
export DataTask
export FusedComputeTask
export PossibleOperations
export GraphProperties
@ -42,7 +43,6 @@ export is_valid, is_scheduled
# graph operation related
export Operation
export AppliedOperation
export NodeFusion
export NodeReduction
export NodeSplit
export push_operation!
@ -64,8 +64,7 @@ export ComputeTaskABC_Sum
# QED model
export FeynmanDiagram, FeynmanVertex, FeynmanTie, FeynmanParticle
export PhotonStateful, FermionStateful, AntiFermionStateful
export QEDParticle, QEDProcessDescription, QEDProcessInput, QEDModel
export GenericQEDProcess, QEDModel
export ComputeTaskQED_P
export ComputeTaskQED_S1
export ComputeTaskQED_S2
@ -87,7 +86,7 @@ export GlobalMetricEstimator, CDCost
# optimization
export AbstractOptimizer, GreedyOptimizer, RandomWalkOptimizer
export ReductionOptimizer, SplitOptimizer, FusionOptimizer
export ReductionOptimizer, SplitOptimizer
export optimize_step!, optimize!
export fixpoint_reached, optimize_to_fixpoint!
@ -99,9 +98,6 @@ export ==, in, show, isempty, delete!, length
export bytes_to_human_readable
# TODO: this is probably not good
import QEDprocesses.compute
import Base.length
import Base.show
import Base.==
@ -114,7 +110,6 @@ import Base.delete!
import Base.insert!
import Base.collect
include("devices/interface.jl")
include("task/type.jl")
include("node/type.jl")
@ -167,28 +162,30 @@ include("optimization/interface.jl")
include("optimization/greedy.jl")
include("optimization/random_walk.jl")
include("optimization/reduce.jl")
include("optimization/fuse.jl")
include("optimization/split.jl")
include("models/interface.jl")
include("models/print.jl")
include("models/abc/types.jl")
include("models/abc/particle.jl")
include("models/abc/compute.jl")
include("models/abc/create.jl")
include("models/abc/properties.jl")
include("models/abc/parse.jl")
include("models/abc/print.jl")
include("models/physics_models/interface.jl")
include("models/qed/types.jl")
include("models/qed/particle.jl")
include("models/qed/diagrams.jl")
include("models/qed/compute.jl")
include("models/qed/create.jl")
include("models/qed/properties.jl")
include("models/qed/parse.jl")
include("models/qed/print.jl")
include("models/physics_models/abc/types.jl")
include("models/physics_models/abc/particle.jl")
include("models/physics_models/abc/compute.jl")
include("models/physics_models/abc/create.jl")
include("models/physics_models/abc/properties.jl")
include("models/physics_models/abc/parse.jl")
include("models/physics_models/abc/print.jl")
include("models/physics_models/qed/utility.jl")
include("models/physics_models/qed/types.jl")
include("models/physics_models/qed/particle.jl")
include("models/physics_models/qed/diagrams.jl")
include("models/physics_models/qed/compute.jl")
include("models/physics_models/qed/create.jl")
include("models/physics_models/qed/properties.jl")
include("models/physics_models/qed/parse.jl")
include("models/physics_models/qed/print.jl")
include("devices/measure.jl")
include("devices/detect.jl")
@ -197,7 +194,7 @@ include("devices/impl.jl")
include("devices/numa/impl.jl")
include("devices/cuda/impl.jl")
include("devices/rocm/impl.jl")
include("devices/oneapi/impl.jl")
#include("devices/oneapi/impl.jl")
include("scheduler/interface.jl")
include("scheduler/greedy.jl")

View File

@ -1,72 +1,70 @@
"""
get_compute_function(graph::DAG, process::AbstractProcessDescription, machine::Machine)
get_compute_function(graph::DAG, instance, machine::Machine)
Return a function of signature `compute_<id>(input::AbstractProcessInput)`, which will return the result of the DAG computation on the given input.
Return a function of signature `compute_<id>(input::input_type(instance))`, which will return the result of the DAG computation on the given input.
"""
function get_compute_function(graph::DAG, process::AbstractProcessDescription, machine::Machine)
tape = gen_tape(graph, process, machine)
function get_compute_function(graph::DAG, instance, machine::Machine)
tape = gen_tape(graph, instance, machine)
initCaches = Expr(:block, tape.initCachesCode...)
assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...)
assignInputs = Expr(:block, tape.inputAssignCode...)
code = Expr(:block, expr_from_fc.(tape.computeCode)...)
functionId = to_var_name(UUIDs.uuid1(rng[1]))
resSym = eval(gen_access_expr(entry_device(tape.machine), tape.outputSymbol))
expr = Meta.parse(
"function compute_$(functionId)(data_input::AbstractProcessInput) $(initCaches); $(assignInputs); $code; return $resSym; end",
"function compute_$(functionId)(data_input::$(input_type(instance))) $(initCaches); $(assignInputs); $code; return $resSym; end",
)
func = eval(expr)
return func
return expr
end
"""
get_cuda_kernel(graph::DAG, process::AbstractProcessDescription, machine::Machine)
get_cuda_kernel(graph::DAG, instance, machine::Machine)
Return a function of signature `compute_<id>(input::CuVector, output::CuVector, n::Int64)`, which will return the result of the DAG computation of the input on the given output variable.
"""
function get_cuda_kernel(graph::DAG, process::AbstractProcessDescription, machine::Machine)
tape = gen_tape(graph, process, machine)
function get_cuda_kernel(graph::DAG, instance, machine::Machine)
tape = gen_tape(graph, instance, machine)
initCaches = Expr(:block, tape.initCachesCode...)
assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...)
assignInputs = Expr(:block, tape.inputAssignCode...)
code = Expr(:block, expr_from_fc.(tape.computeCode)...)
functionId = to_var_name(UUIDs.uuid1(rng[1]))
resSym = eval(gen_access_expr(entry_device(tape.machine), tape.outputSymbol))
expr = Meta.parse("function compute_$(functionId)(input_vector, output_vector, n::Int64)
id = (blockIdx().x - 1) * blockDim().x + threadIdx().x
if (id > n)
return
end
@inline data_input = input_vector[id]
$(initCaches)
$(assignInputs)
$code
@inline output_vector[id] = $resSym
return nothing
end")
expr = Meta.parse(
"function compute_$(functionId)(input_vector, output_vector, n::Int64)
id = (blockIdx().x - 1) * blockDim().x + threadIdx().x
if (id > n)
return
end
@inline data_input = input_vector[id]
$(initCaches)
$(assignInputs)
$code
@inline output_vector[id] = $resSym
return nothing
end"
)
func = eval(expr)
return func
return expr
end
"""
execute(graph::DAG, process::AbstractProcessDescription, machine::Machine, input::AbstractProcessInput)
execute(graph::DAG, instance, machine::Machine, input)
Execute the code of the given `graph` on the given input particles.
Execute the code of the given `graph` on the given input values.
This is essentially shorthand for
```julia
tape = gen_tape(graph, process, machine)
tape = gen_tape(graph, instance, machine)
return execute_tape(tape, input)
```
See also: [`parse_dag`](@ref), [`parse_process`](@ref), [`gen_process_input`](@ref)
"""
function execute(graph::DAG, process::AbstractProcessDescription, machine::Machine, input::AbstractProcessInput)
tape = gen_tape(graph, process, machine)
function execute(graph::DAG, instance, machine::Machine, input)
tape = gen_tape(graph, instance, machine)
return execute_tape(tape, input)
end

View File

@ -1,10 +1,11 @@
# TODO: do this with macros
function call_fc(fc::FunctionCall{VectorT, 0}, cache::Dict{Symbol, Any}) where {VectorT <: SVector{1}}
cache[fc.return_symbol] = fc.func(cache[fc.arguments[1]])
return nothing
end
function call_fc(fc::FunctionCall{VectorT, 1}, cache::Dict{Symbol, Any}) where {VectorT <: SVector{1}}
cache[fc.return_symbol] = fc.func(fc.additional_arguments[1], cache[fc.arguments[1]])
cache[fc.return_symbol] = fc.func(fc.value_arguments[1], cache[fc.arguments[1]])
return nothing
end
@ -14,12 +15,12 @@ function call_fc(fc::FunctionCall{VectorT, 0}, cache::Dict{Symbol, Any}) where {
end
function call_fc(fc::FunctionCall{VectorT, 1}, cache::Dict{Symbol, Any}) where {VectorT <: SVector{2}}
cache[fc.return_symbol] = fc.func(fc.additional_arguments[1], cache[fc.arguments[1]], cache[fc.arguments[2]])
cache[fc.return_symbol] = fc.func(fc.value_arguments[1], cache[fc.arguments[1]], cache[fc.arguments[2]])
return nothing
end
function call_fc(fc::FunctionCall{VectorT, 1}, cache::Dict{Symbol, Any}) where {VectorT}
cache[fc.return_symbol] = fc.func(fc.additional_arguments[1], getindex.(Ref(cache), fc.arguments)...)
cache[fc.return_symbol] = fc.func(fc.value_arguments[1], getindex.(Ref(cache), fc.arguments)...)
return nothing
end
@ -31,7 +32,7 @@ Execute the given [`FunctionCall`](@ref) on the dictionary.
Several more specialized versions of this function exist to reduce vector unrolling work for common cases.
"""
function call_fc(fc::FunctionCall{VectorT, M}, cache::Dict{Symbol, Any}) where {VectorT, M}
cache[fc.return_symbol] = fc.func(fc.additional_arguments..., getindex.(Ref(cache), fc.arguments)...)
cache[fc.return_symbol] = fc.func(fc.value_arguments..., getindex.(Ref(cache), fc.arguments)...)
return nothing
end
@ -47,12 +48,8 @@ end
For a given function call, return an expression evaluating it.
"""
function expr_from_fc(fc::FunctionCall{VectorT, M}) where {VectorT, M}
func_call = Expr(
:call,
Symbol(fc.func),
fc.additional_arguments...,
eval.(gen_access_expr.(Ref(fc.device), fc.arguments))...,
)
func_call =
Expr(:call, Symbol(fc.func), fc.value_arguments..., eval.(gen_access_expr.(Ref(fc.device), fc.arguments))...)
expr = :($(eval(gen_access_expr(fc.device, fc.return_symbol))) = $func_call)
return expr
@ -73,51 +70,32 @@ function gen_cache_init_code(machine::Machine)
return initializeCaches
end
"""
part_from_x(type::Type, index::Int, x::AbstractProcessInput)
Return the [`ParticleValue`](@ref) of the given type of particle with the given `index` from the given process input.
Function is wrapped into a [`FunctionCall`](@ref) in [`gen_input_assignment_code`](@ref).
"""
part_from_x(type::Type, index::Int, x::AbstractProcessInput) =
ParticleValue{type, ComplexF64}(get_particle(x, type, index), one(ComplexF64))
"""
gen_input_assignment_code(
inputSymbols::Dict{String, Vector{Symbol}},
processDescription::AbstractProcessDescription,
instance::AbstractProblemInstance,
machine::Machine,
processInputSymbol::Symbol = :input,
problemInputSymbol::Symbol = :data_input,
)
Return a `Vector{Expr}` doing the input assignments from the given `processInputSymbol` onto the `inputSymbols`.
Return a `Vector{Expr}` doing the input assignments from the given `problemInputSymbol` onto the `inputSymbols`.
"""
function gen_input_assignment_code(
inputSymbols::Dict{String, Vector{Symbol}},
processDescription::AbstractProcessDescription,
instance,
machine::Machine,
processInputSymbol::Symbol = :input,
problemInputSymbol::Symbol = :data_input,
)
@assert length(inputSymbols) >=
sum(values(in_particles(processDescription))) + sum(values(out_particles(processDescription))) "Number of input Symbols is smaller than the number of particles in the process description"
assignInputs = Vector{FunctionCall}()
assignInputs = Vector{Expr}()
for (name, symbols) in inputSymbols
(type, index) = type_index_from_name(model(processDescription), name)
# make a function for this, since we can't use anonymous functions in the FunctionCall
for symbol in symbols
device = entry_device(machine)
push!(
assignInputs,
FunctionCall(
# x is the process input
part_from_x,
SVector{1, Symbol}(processInputSymbol),
SVector{2, Any}(type, index),
symbol,
device,
Meta.parse(
"$(eval(gen_access_expr(device, symbol))) = $(input_expr(instance, name, problemInputSymbol))",
),
)
end
@ -127,14 +105,14 @@ function gen_input_assignment_code(
end
"""
gen_tape(graph::DAG, process::AbstractProcessDescription, machine::Machine)
gen_tape(graph::DAG, instance::AbstractProblemInstance, machine::Machine, scheduler::AbstractScheduler = GreedyScheduler())
Generate the code for a given graph. The return value is a [`Tape`](@ref).
See also: [`execute`](@ref), [`execute_tape`](@ref)
"""
function gen_tape(graph::DAG, process::AbstractProcessDescription, machine::Machine)
schedule = schedule_dag(GreedyScheduler(), graph, machine)
function gen_tape(graph::DAG, instance, machine::Machine, scheduler::AbstractScheduler = GreedyScheduler())
schedule = schedule_dag(scheduler, graph, machine)
# get inSymbols
inputSyms = Dict{String, Vector{Symbol}}()
@ -150,23 +128,24 @@ function gen_tape(graph::DAG, process::AbstractProcessDescription, machine::Mach
outSym = Symbol(to_var_name(get_exit_node(graph).id))
initCaches = gen_cache_init_code(machine)
assignInputs = gen_input_assignment_code(inputSyms, process, machine, :input)
assignInputs = gen_input_assignment_code(inputSyms, instance, machine, :data_input)
return Tape(initCaches, assignInputs, schedule, inputSyms, outSym, Dict(), process, machine)
return Tape{input_type(instance)}(initCaches, assignInputs, schedule, inputSyms, outSym, Dict(), instance, machine)
end
"""
execute_tape(tape::Tape, input::AbstractProcessInput)
execute_tape(tape::Tape, input::Input) where {Input}
Execute the given tape with the given input.
For implementation reasons, this disregards the set [`CacheStrategy`](@ref) of the devices and always uses a dictionary.
"""
function execute_tape(tape::Tape, input::AbstractProcessInput)
function execute_tape(tape::Tape, input)
cache = Dict{Symbol, Any}()
cache[:input] = input
cache[:data_input] = input
# simply execute all the code snippets here
# TODO: `@assert` that process input fits the tape.process
@assert typeof(input) == input_type(tape.instance)
# TODO: `@assert` that input fits the tape.instance
for expr in tape.initCachesCode
@eval $expr
end

View File

@ -1,19 +1,21 @@
"""
Tape
Tape{INPUT}
TODO: update docs
- `INPUT` the input type of the problem instance
- `code::Vector{Expr}`: The julia expression containing the code for the whole graph.
- `inputSymbols::Dict{String, Vector{Symbol}}`: A dictionary of symbols mapping the names of the input nodes of the graph to the symbols their inputs should be provided on.
- `outputSymbol::Symbol`: The symbol of the final calculated value
"""
struct Tape
struct Tape{INPUT}
initCachesCode::Vector{Expr}
inputAssignCode::Vector{FunctionCall}
inputAssignCode::Vector{Expr}
computeCode::Vector{FunctionCall}
inputSymbols::Dict{String, Vector{Symbol}}
outputSymbol::Symbol
cache::Dict{Symbol, Any}
process::AbstractProcessDescription
instance::Any
machine::Machine
end

View File

@ -10,8 +10,7 @@ Representation of a [`DAG`](@ref)'s cost as estimated by the [`GlobalMetricEstim
!!! note
Note that the `computeIntensity` doesn't necessarily make sense in the context of only operation costs.
For example, for node fusions this will always be 0, since the computeEffort is zero.
Note that the `computeIntensity` doesn't necessarily make sense in the context of only operation costs.
It will still work as intended when adding/subtracting to/from a `graph_cost` estimate.
"""
const CDCost = NamedTuple{(:data, :computeEffort, :computeIntensity), Tuple{Float64, Float64, Float64}}
@ -55,10 +54,6 @@ function graph_cost(estimator::GlobalMetricEstimator, graph::DAG)
)::CDCost
end
function operation_effect(estimator::GlobalMetricEstimator, graph::DAG, operation::NodeFusion)
return (data = -data(operation.input[2].task), computeEffort = 0.0, computeIntensity = 0.0)::CDCost
end
function operation_effect(estimator::GlobalMetricEstimator, graph::DAG, operation::NodeReduction)
s = length(operation.input) - 1
return (

View File

@ -169,66 +169,6 @@ function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invali
return nothing
end
function replace_children!(task::FusedComputeTask, before, after)
replacedIn1 = length(findall(x -> x == before, task.t1_inputs))
replacedIn2 = length(findall(x -> x == before, task.t2_inputs))
#@assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
replace!(task.t1_inputs, before => after)
replace!(task.t2_inputs, before => after)
# recursively descend down the tree, but only in the tasks where we're replacing things
if replacedIn1 > 0
replace_children!(task.first_task, before, after)
end
if replacedIn2 > 0
replace_children!(task.second_task, before, after)
end
return nothing
end
function replace_children!(task::AbstractTask, before, after)
return nothing
end
function update_child!(graph::DAG, n::Node, child_before::Symbol, child_after::Symbol; track = true)
# only need to update fused compute tasks
if !(typeof(task(n)) <: FusedComputeTask)
return nothing
end
taskBefore = copy(task(n))
#=if !((child_before in task(n).t1_inputs) || (child_before in task(n).t2_inputs))
println("------------------ Nothing to replace!! ------------------")
child_ids = Vector{String}()
for child in children(n)
push!(child_ids, "$(child.id)")
end
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
@assert false
end=#
replace_children!(task(n), child_before, child_after)
#=if !((child_after in task(n).t1_inputs) || (child_after in task(n).t2_inputs))
println("------------------ Did not replace anything!! ------------------")
child_ids = Vector{String}()
for child in children(n)
push!(child_ids, "$(child.id)")
end
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
@assert false
end=#
# keep track
if (track)
push!(graph.diff.updatedChildren, (n, taskBefore))
end
end
"""
get_snapshot_diff(graph::DAG)
@ -240,31 +180,6 @@ function get_snapshot_diff(graph::DAG)
return swapfield!(graph, :diff, Diff())
end
"""
invalidate_caches!(graph::DAG, operation::NodeFusion)
Invalidate the operation caches for a given [`NodeFusion`](@ref).
This deletes the operation from the graph's possible operations and from the involved nodes' own operation caches.
"""
function invalidate_caches!(graph::DAG, operation::NodeFusion)
delete!(graph.possibleOperations, operation)
# delete the operation from all caches of nodes involved in the operation
for n in [1, 3]
for i in eachindex(operation.input[n].nodeFusions)
if operation == operation.input[n].nodeFusions[i]
splice!(operation.input[n].nodeFusions, i)
break
end
end
end
operation.input[2].nodeFusion = missing
return nothing
end
"""
invalidate_caches!(graph::DAG, operation::NodeReduction)
@ -311,9 +226,6 @@ function invalidate_operation_caches!(graph::DAG, node::ComputeTaskNode)
if !ismissing(node.nodeSplit)
invalidate_caches!(graph, node.nodeSplit)
end
while !isempty(node.nodeFusions)
invalidate_caches!(graph, pop!(node.nodeFusions))
end
return nothing
end
@ -329,8 +241,5 @@ function invalidate_operation_caches!(graph::DAG, node::DataTaskNode)
if !ismissing(node.nodeSplit)
invalidate_caches!(graph, node.nodeSplit)
end
if !ismissing(node.nodeFusion)
invalidate_caches!(graph, node.nodeFusion)
end
return nothing
end

View File

@ -7,7 +7,6 @@ A struct storing all possible operations on a [`DAG`](@ref).
To get the [`PossibleOperations`](@ref) on a [`DAG`](@ref), use [`get_operations`](@ref).
"""
mutable struct PossibleOperations
nodeFusions::Set{NodeFusion}
nodeReductions::Set{NodeReduction}
nodeSplits::Set{NodeSplit}
end
@ -52,7 +51,7 @@ end
Construct and return an empty [`PossibleOperations`](@ref) object.
"""
function PossibleOperations()
return PossibleOperations(Set{NodeFusion}(), Set{NodeReduction}(), Set{NodeSplit}())
return PossibleOperations(Set{NodeReduction}(), Set{NodeSplit}())
end
"""

View File

@ -40,19 +40,11 @@ function is_valid(graph::DAG)
for ns in graph.possibleOperations.nodeSplits
@assert is_valid(graph, ns)
end
for nf in graph.possibleOperations.nodeFusions
@assert is_valid(graph, nf)
end
for node in graph.dirtyNodes
@assert node in graph "Dirty Node is not part of the graph!"
@assert ismissing(node.nodeReduction) "Dirty Node has a NodeReduction!"
@assert ismissing(node.nodeSplit) "Dirty Node has a NodeSplit!"
if (typeof(node) <: DataTaskNode)
@assert ismissing(node.nodeFusion) "Dirty DataTaskNode has a Node Fusion!"
elseif (typeof(node) <: ComputeTaskNode)
@assert isempty(node.nodeFusions) "Dirty ComputeTaskNode has Node Fusions!"
end
end
@assert is_connected(graph) "Graph is not connected!"

View File

@ -1,120 +1,46 @@
import QEDbase.mass
import QEDbase.AbstractParticle
"""
AbstractPhysicsModel
AbstractModel
Base type for a model, e.g. ABC-Model or QED. This is used to dispatch many functions.
Base type for all models. From this, [`AbstractProblemInstance`](@ref)s can be constructed.
See also: [`problem_instance`](@ref)
"""
abstract type AbstractPhysicsModel end
abstract type AbstractModel end
"""
ParticleValue{ParticleType <: AbstractParticle}
problem_instance(::AbstractModel, ::Vararg)
A struct describing a particle during a calculation of a Feynman Diagram, together with the value that's being calculated. `AbstractParticle` is the type from the QEDbase package.
`sizeof(ParticleValue())` = 48 Byte
Interface function that must be implemented for any implementation of [`AbstractModel`](@ref). This function should return a specific [`AbstractProblemInstance`](@ref) given some parameters.
"""
struct ParticleValue{ParticleType <: AbstractParticle, ValueType}
p::ParticleType
v::ValueType
end
function problem_instance end
"""
AbstractProcessDescription
AbstractProblemInstance
Base type for process descriptions. An object of this type of a corresponding [`AbstractPhysicsModel`](@ref) should uniquely identify a process in that model.
Base type for problem instances. An object of this type of a corresponding [`AbstractModel`](@ref) should uniquely identify a problem instance of that model.
See also: [`parse_process`](@ref)
"""
abstract type AbstractProcessDescription end
abstract type AbstractProblemInstance end
"""
AbstractProcessInput
input_type(problem::AbstractProblemInstance)
Base type for process inputs. An object of this type contains the input values (e.g. momenta) of the particles in a process.
See also: [`gen_process_input`](@ref)
Return the fully specified input type for a specific [`AbstractProblemInstance`](@ref).
"""
abstract type AbstractProcessInput end
function input_type end
"""
interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: AbstractParticle, T2 <: AbstractParticle}
graph(::AbstractProblemInstance)
Interface function that must be implemented for every subtype of [`AbstractParticle`](@ref), returning the result particle type when the two given particles interact.
Generate the [`DAG`](@ref) for the given [`AbstractProblemInstance`](@ref). Every entry node (see [`get_entry_nodes`](@ref)) to the graph must have a name set. Implement [`input_expr`](@ref) to return a valid expression for each of those names.
"""
function interaction_result end
function graph end
"""
types(::AbstractPhysicsModel)
input_expr(instance::AbstractProblemInstance, name::String, input_symbol::Symbol)
Interface function that must be implemented for every subtype of [`AbstractPhysicsModel`](@ref), returning a `Vector` of the available particle types in the model.
For the given [`AbstractProblemInstance`](@ref), the entry node name, and the symbol of the problem input (where a variable of type `input_type(...)` will exist), return an `Expr` that gets that specific input value from the input symbol.
"""
function types end
"""
in_particles(::AbstractProcessDescription)
Interface function that must be implemented for every subtype of [`AbstractProcessDescription`](@ref).
Returns a `<: Dict{Type{AbstractParticle}, Int}` object, representing the number of incoming particles for the process per particle type.
in_particles(::AbstractProcessInput)
Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
Returns a `<: Vector{AbstractParticle}` object with the values of all incoming particles for the corresponding `ProcessDescription`.
"""
function in_particles end
"""
out_particles(::AbstractProcessDescription)
Interface function that must be implemented for every subtype of [`AbstractProcessDescription`](@ref).
Returns a `<: Dict{Type{AbstractParticle}, Int}` object, representing the number of outgoing particles for the process per particle type.
out_particles(::AbstractProcessInput)
Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
Returns a `<: Vector{AbstractParticle}` object with the values of all outgoing particles for the corresponding `ProcessDescription`.
"""
function out_particles end
"""
get_particle(::AbstractProcessInput, t::Type, n::Int)
Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
Returns the `n`th particle of type `t`.
"""
function get_particle end
"""
parse_process(::AbstractString, ::AbstractPhysicsModel)
Interface function that must be implemented for every subtype of [`AbstractPhysicsModel`](@ref).
Returns a `ProcessDescription` object.
"""
function parse_process end
"""
gen_process_input(::AbstractProcessDescription)
Interface function that must be implemented for every specific [`AbstractProcessDescription`](@ref).
Returns a randomly generated and valid corresponding `ProcessInput`.
"""
function gen_process_input end
"""
model(::AbstractProcessDescription)
model(::AbstarctProcessInput)
Return the model of this process description or input.
"""
function model end
"""
type_from_name(model::AbstractModel, name::String)
For a name of a particle in the given [`AbstractModel`](@ref), return the particle's [`Type`] and index as a tuple. The input string can be expetced to be of the form \"<name><index>\".
"""
function type_index_from_name end
function input_expr end

View File

@ -0,0 +1,3 @@
## Deprecation Warning
These models are deprecated and should not be used anymore. They will be dropped entirely soon.

View File

@ -1,6 +1,17 @@
using AccurateArithmetic
using StaticArrays
function input_expr(instance::ABCProcessDescription, name::String, psp_symbol::Symbol)
(type, index) = type_index_from_name(ABCModel(), name)
return Meta.parse(
"ABCParticleValue(
$type(momentum($psp_symbol, $(construction_string(particle_direction(type))), $(construction_string(particle_species(type))), Val($index))),
0.0im,
)",
)
end
"""
compute(::ComputeTaskABC_P, data::ABCParticleValue)
@ -8,7 +19,7 @@ Return the particle and value as is.
0 FLOP.
"""
function compute(::ComputeTaskABC_P, data::ABCParticleValue{P})::ABCParticleValue{P} where {P <: ABCParticle}
function compute(::ComputeTaskABC_P, data::ABCParticleValue{P})::ABCParticleValue{P} where {P}
return data
end
@ -19,7 +30,7 @@ Compute an outer edge. Return the particle value with the same particle and the
1 FLOP.
"""
function compute(::ComputeTaskABC_U, data::ABCParticleValue{P})::ABCParticleValue{P} where {P <: ABCParticle}
function compute(::ComputeTaskABC_U, data::ABCParticleValue{P})::ABCParticleValue{P} where {P}
return ABCParticleValue{P}(data.p, data.v * ABC_outer_edge(data.p))
end
@ -34,7 +45,7 @@ function compute(
::ComputeTaskABC_V,
data1::ABCParticleValue{P1},
data2::ABCParticleValue{P2},
)::ABCParticleValue where {P1 <: ABCParticle, P2 <: ABCParticle}
)::ABCParticleValue where {P1, P2}
p3 = ABC_conserve_momentum(data1.p, data2.p)
dataOut = ABCParticleValue{typeof(p3)}(p3, data1.v * ABC_vertex() * data2.v)
return dataOut
@ -49,11 +60,7 @@ For valid inputs, both input particles should have the same momenta at this poin
12 FLOP.
"""
function compute(
::ComputeTaskABC_S2,
data1::ParticleValue{P},
data2::ParticleValue{P},
)::Float64 where {P <: ABCParticle}
function compute(::ComputeTaskABC_S2, data1::ParticleValue{P}, data2::ParticleValue{P})::Float64 where {P}
#=
@assert isapprox(abs(data1.p.momentum.E), abs(data2.p.momentum.E), rtol = 0.001, atol = sqrt(eps())) "E: $(data1.p.momentum.E) vs. $(data2.p.momentum.E)"
@assert isapprox(data1.p.momentum.px, -data2.p.momentum.px, rtol = 0.001, atol = sqrt(eps())) "px: $(data1.p.momentum.px) vs. $(data2.p.momentum.px)"
@ -85,12 +92,6 @@ Linearly many FLOP with growing data.
"""
function compute(::ComputeTaskABC_Sum, data...)::Float64
return sum_kbn([data...])
#=s = 0.0im
for d in data
s += d
end
return s=#
end
function compute(::ComputeTaskABC_Sum, data::AbstractArray)::Float64

View File

@ -14,34 +14,28 @@ struct ABCModel <: AbstractPhysicsModel end
Base type for all particles in the [`ABCModel`](@ref).
"""
abstract type ABCParticle <: AbstractParticle end
abstract type ABCParticle <: AbstractParticleType end
"""
ParticleA <: ABCParticle
An 'A' particle in the ABC Model.
"""
struct ParticleA <: ABCParticle
momentum::SFourMomentum
end
struct ParticleA <: ABCParticle end
"""
ParticleB <: ABCParticle
A 'B' particle in the ABC Model.
"""
struct ParticleB <: ABCParticle
momentum::SFourMomentum
end
struct ParticleB <: ABCParticle end
"""
ParticleC <: ABCParticle
A 'C' particle in the ABC Model.
"""
struct ParticleC <: ABCParticle
momentum::SFourMomentum
end
struct ParticleC <: ABCParticle end
"""
ABCProcessDescription <: AbstractProcessDescription
@ -72,7 +66,7 @@ struct ABCProcessInput{N1, N2, N3, N4, N5, N6} <: AbstractProcessInput
outC::SVector{N6, ParticleC}
end
ABCParticleValue{ParticleType <: ABCParticle} = ParticleValue{ParticleType, ComplexF64}
ABCParticleValue{ParticleType} = ParticleValue{ParticleType, ComplexF64}
"""
mass(t::Type{T}) where {T <: ABCParticle}
@ -109,39 +103,46 @@ end
Return a Vector of the possible types of particle in the [`ABCModel`](@ref).
"""
function types(::ABCModel)
return [ParticleA, ParticleB, ParticleC]
return [
ParticleStateful{Incoming, ParticleA, SFourMomentum},
ParticleStateful{Incoming, ParticleB, SFourMomentum},
ParticleStateful{Incoming, ParticleC, SFourMomentum},
ParticleStateful{Outgoing, ParticleA, SFourMomentum},
ParticleStateful{Outgoing, ParticleB, SFourMomentum},
ParticleStateful{Outgoing, ParticleC, SFourMomentum},
]
end
"""
square(p::ABCParticle)
square(p::AbstractParticleStateful{Dir, ABCParticle})
Return the square of the particle's momentum as a `Float` value.
Takes 7 effective FLOP.
"""
function square(p::ABCParticle)
return getMass2(p.momentum)
function square(p::AbstractParticleStateful{D, ABCParticle}) where {D}
return getMass2(momentum(p))
end
"""
ABC_inner_edge(p::ABCParticle)
ABC_inner_edge(p::AbstractParticleStateful{Dir, ABCParticle})
Return the factor of the inner edge with the given (virtual) particle.
Takes 10 effective FLOP. (3 here + 7 in square(p))
"""
function ABC_inner_edge(p::ABCParticle)
return 1.0 / (square(p) - mass(p)^2)
function ABC_inner_edge(p::AbstractParticleStateful{D, ABCParticle}) where {D}
return 1.0 / (square(p) - mass(particle(p))^2)
end
"""
ABC_outer_edge(p::ABCParticle)
ABC_outer_edge(p::AbstractParticleStateful{Dir, ABCParticle})
Return the factor of the outer edge with the given (real) particle.
Takes 0 effective FLOP.
"""
function ABC_outer_edge(p::ABCParticle)
function ABC_outer_edge(::AbstractParticleStateful{D, ABCParticle}) where {D}
return 1.0
end
@ -179,17 +180,26 @@ model(::ABCProcessDescription) = ABCModel()
model(::ABCProcessInput) = ABCModel()
function type_index_from_name(::ABCModel, name::String)
if startswith(name, "A")
return (ParticleA, parse(Int, name[2:end]))
elseif startswith(name, "B")
return (ParticleB, parse(Int, name[2:end]))
elseif startswith(name, "C")
return (ParticleC, parse(Int, name[2:end]))
if startswith(name, "Ai")
return (ParticleStateful{Incoming, ParticleA, SFourMomentum}, parse(Int, name[3:end]))
elseif startswith(name, "Ao")
return (ParticleStateful{Outgoing, ParticleA, SFourMomentum}, parse(Int, name[3:end]))
elseif startswith(name, "Bi")
return (ParticleStateful{Incoming, ParticleB, SFourMomentum}, parse(Int, name[3:end]))
elseif startswith(name, "Bo")
return (ParticleStateful{Outgoing, ParticleB, SFourMomentum}, parse(Int, name[3:end]))
elseif startswith(name, "Ci")
return (ParticleStateful{Incoming, ParticleC, SFourMomentum}, parse(Int, name[3:end]))
elseif startswith(name, "Co")
return (ParticleStateful{Outgoing, ParticleC, SFourMomentum}, parse(Int, name[3:end]))
else
throw("Invalid name for a particle in the ABC model")
end
end
function String(::Type{PS}) where {DIR, P <: ABCParticle, PS <: AbstractParticleStateful{DIR, P}}
return String(P)
end
function String(::Type{ParticleA})
return "A"
end

View File

@ -0,0 +1,142 @@
import QEDbase.AbstractParticle
"""
AbstractPhysicsModel
Base type for a model, e.g. ABC-Model or QED. This is used to dispatch many functions.
"""
abstract type AbstractPhysicsModel <: AbstractModel end
"""
ParticleValue{ParticleType <: AbstractParticleStateful}
A struct describing a particle during a calculation of a Feynman Diagram, together with the value that's being calculated. `AbstractParticleStateful` is the type from the QEDbase package.
`sizeof(ParticleValue())` = 48 Byte
"""
struct ParticleValue{ParticleType <: AbstractParticleStateful, ValueType}
p::ParticleType
v::ValueType
end
"""
TBW
particle value + spin/pol info, only used on the external legs (u tasks)
"""
struct ParticleValueSP{ParticleType <: AbstractParticleStateful, SP <: AbstractSpinOrPolarization, ValueType}
p::ParticleType
v::ValueType
sp::SP
end
"""
AbstractProcessDescription <: AbstractProblemInstance
Base type for particle scattering process descriptions. An object of this type of a corresponding [`AbstractPhysicsModel`](@ref) should uniquely identify a scattering process in that model.
See also: [`parse_process`](@ref), [`AbstractProblemInstance`](@ref)
"""
abstract type AbstractProcessDescription end
#TODO: i don't think giving this a base type is a good idea, the input type should just be returned of some function, allowing anything as an input type
"""
AbstractProcessInput
Base type for process inputs. An object of this type contains the input values (e.g. momenta) of the particles in a process.
See also: [`gen_process_input`](@ref)
"""
abstract type AbstractProcessInput end
"""
interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: AbstractParticle, T2 <: AbstractParticle}
Interface function that must be implemented for every subtype of [`AbstractParticle`](@ref), returning the result particle type when the two given particles interact.
"""
function interaction_result end
"""
types(::AbstractPhysicsModel)
Interface function that must be implemented for every subtype of [`AbstractPhysicsModel`](@ref), returning a `Vector` of the available particle types in the model.
"""
function types end
"""
in_particles(::AbstractProcessDescription)
Interface function that must be implemented for every subtype of [`AbstractProcessDescription`](@ref).
Returns a `<: Dict{Type{AbstractParticle}, Int}` object, representing the number of incoming particles for the process per particle type.
in_particles(::AbstractProcessInput)
Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
Returns a `<: Vector{AbstractParticle}` object with the values of all incoming particles for the corresponding `ProcessDescription`.
"""
function in_particles end
"""
out_particles(::AbstractProcessDescription)
Interface function that must be implemented for every subtype of [`AbstractProcessDescription`](@ref).
Returns a `<: Dict{Type{AbstractParticle}, Int}` object, representing the number of outgoing particles for the process per particle type.
out_particles(::AbstractProcessInput)
Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
Returns a `<: Vector{AbstractParticle}` object with the values of all outgoing particles for the corresponding `ProcessDescription`.
"""
function out_particles end
"""
get_particle(::AbstractProcessInput, t::Type, n::Int)
Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
Returns the `n`th particle of type `t`.
"""
function get_particle end
"""
parse_process(::AbstractString, ::AbstractPhysicsModel)
Interface function that must be implemented for every subtype of [`AbstractPhysicsModel`](@ref).
Returns a `ProcessDescription` object.
"""
function parse_process end
"""
gen_process_input(::AbstractProcessDescription)
Interface function that must be implemented for every specific [`AbstractProcessDescription`](@ref).
Returns a randomly generated and valid corresponding `ProcessInput`.
"""
function gen_process_input end
"""
model(::AbstractProcessDescription)
model(::AbstractProcessInput)
Return the model of this process description or input.
"""
function model end
"""
type_from_name(model::AbstractModel, name::String)
For a name of a particle in the given [`AbstractModel`](@ref), return the particle's [`Type`] and index as a tuple. The input string can be expetced to be of the form \"<name><index>\".
"""
function type_index_from_name end
"""
part_from_x(type::Type, index::Int, x::AbstractProcessInput)
Return the [`ParticleValue`](@ref) of the given type of particle with the given `index` from the given process input.
Function is wrapped into a [`FunctionCall`](@ref) in [`gen_input_assignment_code`](@ref).
"""
part_from_x(type::Type, index::Int, x::AbstractProcessInput) =
ParticleValue{type, ComplexF64}(get_particle(x, type, index), one(ComplexF64))

View File

@ -1,23 +1,40 @@
using StaticArrays
"""
compute(::ComputeTaskQED_P, data::QEDParticleValue)
construction_string(::Incoming) = "Incoming()"
construction_string(::Outgoing) = "Outgoing()"
Return the particle as is and initialize the Value.
"""
function compute(::ComputeTaskQED_P, data::QEDParticleValue{P}) where {P <: QEDParticle}
# TODO do we actually need this for anything?
return ParticleValue{P, DiracMatrix}(data.p, one(DiracMatrix))
construction_string(::Electron) = "Electron()"
construction_string(::Positron) = "Positron()"
construction_string(::Photon) = "Photon()"
construction_string(::PolX) = "PolX()"
construction_string(::PolY) = "PolY()"
construction_string(::SpinUp) = "SpinUp()"
construction_string(::SpinDown) = "SpinDown()"
function input_expr(instance::GenericQEDProcess, name::String, psp_symbol::Symbol)
(type, index) = type_index_from_name(QEDModel(), name)
return Meta.parse(
"ParticleValueSP(
$type(momentum($psp_symbol, $(construction_string(particle_direction(type))), $(construction_string(particle_species(type))), Val($index))),
0.0im,
$(construction_string(spin_or_pol(instance, type, index))),
)",
)
end
"""
compute(::ComputeTaskQED_U, data::QEDParticleValue)
compute(::ComputeTaskQED_U, data::ParticleValueSP)
Compute an outer edge. Return the particle value with the same particle and the value multiplied by an outer_edge factor.
"""
function compute(::ComputeTaskQED_U, data::PV) where {P <: QEDParticle, PV <: QEDParticleValue{P}}
function compute(
::ComputeTaskQED_U,
data::ParticleValueSP{P, SP, V},
) where {P <: ParticleStateful, V <: ValueType, SP <: AbstractSpinOrPolarization}
part::P = data.p
state = base_state(particle(part), direction(part), momentum(part), spin_or_pol(part))
state = base_state(particle_species(part), particle_direction(part), momentum(part), SP())
return ParticleValue{P, typeof(state)}(
data.p,
state, # will return a SLorentzVector{ComplexF64}, BiSpinor or AdjointBiSpinor
@ -25,15 +42,15 @@ function compute(::ComputeTaskQED_U, data::PV) where {P <: QEDParticle, PV <: QE
end
"""
compute(::ComputeTaskQED_V, data1::QEDParticleValue, data2::QEDParticleValue)
compute(::ComputeTaskQED_V, data1::ParticleValue, data2::ParticleValue)
Compute a vertex. Preserve momentum and particle types (e + gamma->p etc.) to create resulting particle, multiply values together and times a vertex factor.
"""
function compute(
::ComputeTaskQED_V,
data1::PV1,
data2::PV2,
) where {P1 <: QEDParticle, P2 <: QEDParticle, PV1 <: QEDParticleValue{P1}, PV2 <: QEDParticleValue{P2}}
data1::ParticleValue{P1, V1},
data2::ParticleValue{P2, V2},
) where {P1 <: ParticleStateful, P2 <: ParticleStateful, V1 <: ValueType, V2 <: ValueType}
p3 = QED_conserve_momentum(data1.p, data2.p)
P3 = interaction_result(P1, P2)
state = QED_vertex()
@ -53,7 +70,7 @@ function compute(
end
"""
compute(::ComputeTaskQED_S2, data1::QEDParticleValue, data2::QEDParticleValue)
compute(::ComputeTaskQED_S2, data1::ParticleValue, data2::ParticleValue)
Compute a final inner edge (2 input particles, no output particle).
@ -63,9 +80,19 @@ For valid inputs, both input particles should have the same momenta at this poin
"""
function compute(
::ComputeTaskQED_S2,
data1::ParticleValue{P1},
data2::ParticleValue{P2},
) where {P1 <: Union{AntiFermionStateful, FermionStateful}, P2 <: Union{AntiFermionStateful, FermionStateful}}
data1::ParticleValue{P1, V1},
data2::ParticleValue{P2, V2},
) where {
D1 <: ParticleDirection,
D2 <: ParticleDirection,
S1 <: Union{Electron, Positron},
S2 <: Union{Electron, Positron},
V1 <: ValueType,
V2 <: ValueType,
EL <: AbstractFourMomentum,
P1 <: ParticleStateful{D1, S1, EL},
P2 <: ParticleStateful{D2, S2, EL},
}
#@assert isapprox(data1.p.momentum, data2.p.momentum, rtol = sqrt(eps()), atol = sqrt(eps())) "$(data1.p.momentum) vs. $(data2.p.momentum)"
inner = QED_inner_edge(propagation_result(P1)(momentum(data1.p)))
@ -80,9 +107,9 @@ end
function compute(
::ComputeTaskQED_S2,
data1::ParticleValue{P1},
data2::ParticleValue{P2},
) where {P1 <: PhotonStateful, P2 <: PhotonStateful}
data1::ParticleValue{ParticleStateful{D1, Photon}, V1},
data2::ParticleValue{ParticleStateful{D2, Photon}, V2},
) where {D1 <: ParticleDirection, D2 <: ParticleDirection, V1 <: ValueType, V2 <: ValueType}
# TODO: assert that data1 and data2 are opposites
inner = QED_inner_edge(data1.p)
# inner edge is just a scalar, data1 and data2 are photon states that are just Complex numbers here
@ -90,11 +117,11 @@ function compute(
end
"""
compute(::ComputeTaskQED_S1, data::QEDParticleValue)
compute(::ComputeTaskQED_S1, data::ParticleValue)
Compute inner edge (1 input particle, 1 output particle).
"""
function compute(::ComputeTaskQED_S1, data::QEDParticleValue{P}) where {P <: QEDParticle}
function compute(::ComputeTaskQED_S1, data::ParticleValue{P, V}) where {P <: ParticleStateful, V <: ValueType}
newP = propagation_result(P)
new_p = newP(momentum(data.p))
# inner edge is just a scalar, can multiply from either side

View File

@ -1,83 +1,58 @@
ComputeTaskQED_Sum() = ComputeTaskQED_Sum(0)
function _svector_from_type(processDescription::QEDProcessDescription, type, particles)
function _svector_from_type(processDescription::GenericQEDProcess, type, particles)
if haskey(in_particles(processDescription), type)
return SVector{in_particles(processDescription)[type], type}(filter(x -> typeof(x) <: type, particles))
end
if haskey(out_particles(processDescription), type)
return SVector{out_particles(processDescription)[type], type}(filter(x -> typeof(x) <: type, particles))
end
return SVector{0, type}()
end
"""
gen_process_input(processDescription::QEDProcessDescription)
gen_process_input(processDescription::GenericQEDProcess)
Return a ProcessInput of randomly generated [`QEDParticle`](@ref)s from a [`QEDProcessDescription`](@ref). The process description can be created manually or parsed from a string using [`parse_process`](@ref).
Return a ProcessInput of randomly generated [`QEDParticle`](@ref)s from a [`GenericQEDProcess`](@ref). The process description can be created manually or parsed from a string using [`parse_process`](@ref).
Note: This uses RAMBO to create a valid process with conservation of momentum and energy.
"""
function gen_process_input(processDescription::QEDProcessDescription)
function gen_process_input(processDescription::GenericQEDProcess)
massSum = 0
inputMasses = Vector{Float64}()
for (particle, n) in processDescription.inParticles
for _ in 1:n
massSum += mass(particle)
push!(inputMasses, mass(particle))
end
for particle in incoming_particles(processDescription)
massSum += mass(particle)
push!(inputMasses, mass(particle))
end
outputMasses = Vector{Float64}()
for (particle, n) in processDescription.outParticles
for _ in 1:n
massSum += mass(particle)
push!(outputMasses, mass(particle))
end
for particle in outgoing_particles(processDescription)
massSum += mass(particle)
push!(outputMasses, mass(particle))
end
# add some extra random mass to allow for some momentum
massSum += rand(rng[threadid()]) * (length(inputMasses) + length(outputMasses))
particles = Vector{QEDParticle}()
initialMomenta = generate_initial_moms(massSum, inputMasses)
index = 1
for (particle, n) in processDescription.inParticles
for _ in 1:n
mom = initialMomenta[index]
push!(particles, particle(mom))
index += 1
end
end
initial_momenta = generate_initial_moms(massSum, inputMasses)
final_momenta = generate_physical_massive_moms(rng[threadid()], massSum, outputMasses)
index = 1
for (particle, n) in processDescription.outParticles
for _ in 1:n
push!(particles, particle(final_momenta[index]))
index += 1
end
end
inFerms = _svector_from_type(processDescription, FermionStateful{Incoming, SpinUp}, particles)
outFerms = _svector_from_type(processDescription, FermionStateful{Outgoing, SpinUp}, particles)
inAntiferms = _svector_from_type(processDescription, AntiFermionStateful{Incoming, SpinUp}, particles)
outAntiferms = _svector_from_type(processDescription, AntiFermionStateful{Outgoing, SpinUp}, particles)
inPhotons = _svector_from_type(processDescription, PhotonStateful{Incoming, PolX}, particles)
outPhotons = _svector_from_type(processDescription, PhotonStateful{Outgoing, PolX}, particles)
processInput =
QEDProcessInput(processDescription, inFerms, outFerms, inAntiferms, outAntiferms, inPhotons, outPhotons)
processInput = PhaseSpacePoint(
processDescription,
PerturbativeQED(),
PhasespaceDefinition(SphericalCoordinateSystem(), ElectronRestFrame()),
tuple(initial_momenta...),
tuple(final_momenta...),
)
return processInput
end
"""
gen_graph(process_description::QEDProcessDescription)
gen_graph(process_description::GenericQEDProcess)
For a given [`QEDProcessDescription`](@ref), return the [`DAG`](@ref) that computes it.
For a given [`GenericQEDProcess`](@ref), return the [`DAG`](@ref) that computes it.
"""
function gen_graph(process_description::QEDProcessDescription)
function gen_graph(process_description::GenericQEDProcess)
initial_diagram = FeynmanDiagram(process_description)
diagrams = gen_diagrams(initial_diagram)
@ -88,9 +63,9 @@ function gen_graph(process_description::QEDProcessDescription)
# TODO: Not all diagram outputs should always be summed at the end, if they differ by fermion exchange they need to be diffed
# Should not matter for n-Photon Compton processes though
sum_node = insert_node!(graph, make_node(ComputeTaskQED_Sum(0)), track = false, invalidate_cache = false)
global_data_out = insert_node!(graph, make_node(DataTask(COMPLEX_SIZE)), track = false, invalidate_cache = false)
insert_edge!(graph, sum_node, global_data_out, track = false, invalidate_cache = false)
sum_node = insert_node!(graph, make_node(ComputeTaskQED_Sum(0)); track = false, invalidate_cache = false)
global_data_out = insert_node!(graph, make_node(DataTask(COMPLEX_SIZE)); track = false, invalidate_cache = false)
insert_edge!(graph, sum_node, global_data_out; track = false, invalidate_cache = false)
# remember the data out nodes for connection
dataOutNodes = Dict()
@ -99,16 +74,16 @@ function gen_graph(process_description::QEDProcessDescription)
# generate data in and U tasks
data_in = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE), String(particle)),
make_node(DataTask(PARTICLE_VALUE_SIZE), String(particle));
track = false,
invalidate_cache = false,
) # read particle data node
compute_u = insert_node!(graph, make_node(ComputeTaskQED_U()), track = false, invalidate_cache = false) # compute U node
compute_u = insert_node!(graph, make_node(ComputeTaskQED_U()); track = false, invalidate_cache = false) # compute U node
data_out =
insert_node!(graph, make_node(DataTask(PARTICLE_VALUE_SIZE)), track = false, invalidate_cache = false) # transfer data out from u (one ABCParticleValue object)
insert_node!(graph, make_node(DataTask(PARTICLE_VALUE_SIZE)); track = false, invalidate_cache = false) # transfer data out from u (one ABCParticleValue object)
insert_edge!(graph, data_in, compute_u, track = false, invalidate_cache = false)
insert_edge!(graph, compute_u, data_out, track = false, invalidate_cache = false)
insert_edge!(graph, data_in, compute_u; track = false, invalidate_cache = false)
insert_edge!(graph, compute_u, data_out; track = false, invalidate_cache = false)
# remember the data_out node for future edges
dataOutNodes[String(particle)] = data_out
@ -124,19 +99,19 @@ function gen_graph(process_description::QEDProcessDescription)
data_in1 = dataOutNodes[String(vertex.in1)]
data_in2 = dataOutNodes[String(vertex.in2)]
compute_V = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false, invalidate_cache = false) # compute vertex
compute_V = insert_node!(graph, make_node(ComputeTaskQED_V()); track = false, invalidate_cache = false) # compute vertex
insert_edge!(graph, data_in1, compute_V, track = false, invalidate_cache = false)
insert_edge!(graph, data_in2, compute_V, track = false, invalidate_cache = false)
insert_edge!(graph, data_in1, compute_V; track = false, invalidate_cache = false)
insert_edge!(graph, data_in2, compute_V; track = false, invalidate_cache = false)
data_V_out = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE)),
make_node(DataTask(PARTICLE_VALUE_SIZE));
track = false,
invalidate_cache = false,
)
insert_edge!(graph, compute_V, data_V_out, track = false, invalidate_cache = false)
insert_edge!(graph, compute_V, data_V_out; track = false, invalidate_cache = false)
if (vertex.out == tie.in1 || vertex.out == tie.in2)
# out particle is part of the tie -> there will be an S2 task with it later, don't make S1 task
@ -146,18 +121,18 @@ function gen_graph(process_description::QEDProcessDescription)
# otherwise, add S1 task
compute_S1 =
insert_node!(graph, make_node(ComputeTaskQED_S1()), track = false, invalidate_cache = false) # compute propagator
insert_node!(graph, make_node(ComputeTaskQED_S1()); track = false, invalidate_cache = false) # compute propagator
insert_edge!(graph, data_V_out, compute_S1, track = false, invalidate_cache = false)
insert_edge!(graph, data_V_out, compute_S1; track = false, invalidate_cache = false)
data_S1_out = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE)),
make_node(DataTask(PARTICLE_VALUE_SIZE));
track = false,
invalidate_cache = false,
)
insert_edge!(graph, compute_S1, data_S1_out, track = false, invalidate_cache = false)
insert_edge!(graph, compute_S1, data_S1_out; track = false, invalidate_cache = false)
# overrides potentially different nodes from previous diagrams, which is intentional
dataOutNodes[String(vertex.out)] = data_S1_out
@ -168,16 +143,16 @@ function gen_graph(process_description::QEDProcessDescription)
data_in1 = dataOutNodes[String(tie.in1)]
data_in2 = dataOutNodes[String(tie.in2)]
compute_S2 = insert_node!(graph, make_node(ComputeTaskQED_S2()), track = false, invalidate_cache = false)
compute_S2 = insert_node!(graph, make_node(ComputeTaskQED_S2()); track = false, invalidate_cache = false)
data_S2 = insert_node!(graph, make_node(DataTask(PARTICLE_VALUE_SIZE)), track = false, invalidate_cache = false)
data_S2 = insert_node!(graph, make_node(DataTask(PARTICLE_VALUE_SIZE)); track = false, invalidate_cache = false)
insert_edge!(graph, data_in1, compute_S2, track = false, invalidate_cache = false)
insert_edge!(graph, data_in2, compute_S2, track = false, invalidate_cache = false)
insert_edge!(graph, data_in1, compute_S2; track = false, invalidate_cache = false)
insert_edge!(graph, data_in2, compute_S2; track = false, invalidate_cache = false)
insert_edge!(graph, compute_S2, data_S2, track = false, invalidate_cache = false)
insert_edge!(graph, compute_S2, data_S2; track = false, invalidate_cache = false)
insert_edge!(graph, data_S2, sum_node, track = false, invalidate_cache = false)
insert_edge!(graph, data_S2, sum_node; track = false, invalidate_cache = false)
add_child!(task(sum_node))
end

View File

@ -8,10 +8,10 @@ import Base.show
"""
FeynmanParticle
Representation of a particle for use in [`FeynmanDiagram`](@ref)s. Consist of the [`QEDParticle`](@ref) type and an id.
Representation of a particle for use in [`FeynmanDiagram`](@ref)s. Consist of the `ParticleStateful` type and an id.
"""
struct FeynmanParticle
particle::Type{<:QEDParticle}
particle::Type{<:ParticleStateful}
id::Int
end
@ -51,31 +51,21 @@ struct FeynmanDiagram
end
"""
FeynmanDiagram(pd::QEDProcessDescription)
FeynmanDiagram(pd::GenericQEDProcess)
Create an initial [`FeynmanDiagram`](@ref) with only its initial particles set and no vertices or ties.
Use [`gen_diagrams`](@ref) to generate all possible diagrams from this one.
"""
function FeynmanDiagram(pd::QEDProcessDescription)
function FeynmanDiagram(pd::GenericQEDProcess)
parts = Vector{FeynmanParticle}()
for (type, n) in pd.inParticles
for i in 1:n
push!(parts, FeynmanParticle(type, i))
end
end
for (type, n) in pd.outParticles
for i in 1:n
push!(parts, FeynmanParticle(type, i))
end
end
ids = Dict{Type, Int64}()
for t in types(QEDModel())
if (isincoming(t))
ids[t] = get(pd.inParticles, t, 0)
else
ids[t] = get(pd.outParticles, t, 0)
for type in types(model(pd))
for i in 1:number_particles(pd, type)
push!(parts, FeynmanParticle(type, i))
end
ids[type] = number_particles(pd, type)
end
return FeynmanDiagram([], missing, parts, ids)
@ -83,7 +73,7 @@ end
function particle_after_tie(p::FeynmanParticle, t::FeynmanTie)
if p == t.in1 || p == t.in2
return FeynmanParticle(FermionStateful{Incoming, SpinUp}, -1) # placeholder particle and id for tied particles
return FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, -1) # placeholder particle and id for tied particles
end
return p
end
@ -114,7 +104,7 @@ end
Return a string representation of the [`FeynmanParticle`](@ref) in a format that is readable by [`type_index_from_name`](@ref).
"""
function String(p::FeynmanParticle)
return "$(String(p.particle))$(String(direction(p.particle)))$(p.id)"
return "$(String(p.particle))$(String(particle_direction(p.particle)))$(p.id)"
end
function hash(v::FeynmanVertex)
@ -162,15 +152,16 @@ function ==(d1::FeynmanDiagram, d2::FeynmanDiagram)
)=#
end
copy(fd::FeynmanDiagram) =
FeynmanDiagram(deepcopy(fd.vertices), copy(fd.tie[]), deepcopy(fd.particles), copy(fd.type_ids))
function copy(fd::FeynmanDiagram)
return FeynmanDiagram(deepcopy(fd.vertices), copy(fd.tie[]), deepcopy(fd.particles), copy(fd.type_ids))
end
"""
id_for_type(d::FeynmanDiagram, t::Type{<:QEDParticle})
id_for_type(d::FeynmanDiagram, t::Type{<:ParticleStateful})
Return the highest id of any particle of the given type in the diagram + 1.
"""
function id_for_type(d::FeynmanDiagram, t::Type{<:QEDParticle})
function id_for_type(d::FeynmanDiagram, t::Type{<:ParticleStateful})
return d.type_ids[t] + 1
end
@ -439,18 +430,19 @@ function remove_duplicates(compare_set::Set{FeynmanDiagram})
return result
end
"""
is_compton(fd::FeynmanDiagram)
Returns true iff the given feynman diagram is an (empty) diagram of a compton process like ke->k^ne
"""
function is_compton(fd::FeynmanDiagram)
return fd.type_ids[FermionStateful{Incoming, SpinUp}] == 1 &&
fd.type_ids[FermionStateful{Outgoing, SpinUp}] == 1 &&
fd.type_ids[AntiFermionStateful{Incoming, SpinUp}] == 0 &&
fd.type_ids[AntiFermionStateful{Outgoing, SpinUp}] == 0 &&
fd.type_ids[PhotonStateful{Incoming, PolX}] >= 1 &&
fd.type_ids[PhotonStateful{Outgoing, PolX}] >= 1
return fd.type_ids[ParticleStateful{Incoming, Electron, SFourMomentum}] == 1 &&
fd.type_ids[ParticleStateful{Outgoing, Electron, SFourMomentum}] == 1 &&
fd.type_ids[ParticleStateful{Incoming, Positron, SFourMomentum}] == 0 &&
fd.type_ids[ParticleStateful{Outgoing, Positron, SFourMomentum}] == 0 &&
fd.type_ids[ParticleStateful{Incoming, Photon, SFourMomentum}] >= 1 &&
fd.type_ids[ParticleStateful{Outgoing, Photon, SFourMomentum}] >= 1
end
"""
@ -460,8 +452,8 @@ Helper function for [`gen_compton_diagrams`](@Ref). Generates a single diagram f
"""
function gen_compton_diagram_from_order(order::Vector{Int}, inFerm, outFerm, n::Int, m::Int)
photons = vcat(
[FeynmanParticle(PhotonStateful{Incoming, PolX}, i) for i in 1:n],
[FeynmanParticle(PhotonStateful{Outgoing, PolX}, i) for i in 1:m],
[FeynmanParticle(ParticleStateful{Incoming, Photon, SFourMomentum}, i) for i in 1:n],
[FeynmanParticle(ParticleStateful{Outgoing, Photon, SFourMomentum}, i) for i in 1:m],
)
new_diagram = FeynmanDiagram(
@ -469,10 +461,10 @@ function gen_compton_diagram_from_order(order::Vector{Int}, inFerm, outFerm, n::
missing,
[inFerm, outFerm, photons...],
Dict{Type, Int64}(
FermionStateful{Incoming, SpinUp} => 1,
FermionStateful{Outgoing, SpinUp} => 1,
PhotonStateful{Incoming, PolX} => n,
PhotonStateful{Outgoing, PolX} => m,
ParticleStateful{Incoming, Electron, SFourMomentum} => 1,
ParticleStateful{Outgoing, Electron, SFourMomentum} => 1,
ParticleStateful{Incoming, Photon, SFourMomentum} => n,
ParticleStateful{Outgoing, Photon, SFourMomentum} => m,
),
)
@ -484,9 +476,9 @@ function gen_compton_diagram_from_order(order::Vector{Int}, inFerm, outFerm, n::
while left_index <= right_index
# left side
v_left = FeynmanVertex(
FeynmanParticle(FermionStateful{Incoming, SpinUp}, iterations),
FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, iterations),
photons[order[left_index]],
FeynmanParticle(FermionStateful{Incoming, SpinUp}, iterations + 1),
FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, iterations + 1),
)
left_index += 1
add_vertex!(new_diagram, v_left)
@ -497,9 +489,9 @@ function gen_compton_diagram_from_order(order::Vector{Int}, inFerm, outFerm, n::
# right side
v_right = FeynmanVertex(
FeynmanParticle(FermionStateful{Outgoing, SpinUp}, iterations),
FeynmanParticle(ParticleStateful{Outgoing, Electron, SFourMomentum}, iterations),
photons[order[right_index]],
FeynmanParticle(FermionStateful{Outgoing, SpinUp}, iterations + 1),
FeynmanParticle(ParticleStateful{Outgoing, Electron, SFourMomentum}, iterations + 1),
)
right_index -= 1
add_vertex!(new_diagram, v_right)
@ -512,7 +504,6 @@ function gen_compton_diagram_from_order(order::Vector{Int}, inFerm, outFerm, n::
return new_diagram
end
"""
gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, outFerm, n::Int, m::Int)
@ -520,8 +511,8 @@ Helper function for [`gen_compton_diagrams`](@Ref). Generates a single diagram f
"""
function gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, outFerm, n::Int, m::Int)
photons = vcat(
[FeynmanParticle(PhotonStateful{Incoming, PolX}, i) for i in 1:n],
[FeynmanParticle(PhotonStateful{Outgoing, PolX}, i) for i in 1:m],
[FeynmanParticle(ParticleStateful{Incoming, Photon, SFourMomentum}, i) for i in 1:n],
[FeynmanParticle(ParticleStateful{Outgoing, Photon, SFourMomentum}, i) for i in 1:m],
)
new_diagram = FeynmanDiagram(
@ -529,10 +520,10 @@ function gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, out
missing,
[inFerm, outFerm, photons...],
Dict{Type, Int64}(
FermionStateful{Incoming, SpinUp} => 1,
FermionStateful{Outgoing, SpinUp} => 1,
PhotonStateful{Incoming, PolX} => n,
PhotonStateful{Outgoing, PolX} => m,
ParticleStateful{Incoming, Electron, SFourMomentum} => 1,
ParticleStateful{Outgoing, Electron, SFourMomentum} => 1,
ParticleStateful{Incoming, Photon, SFourMomentum} => n,
ParticleStateful{Outgoing, Photon, SFourMomentum} => m,
),
)
@ -544,9 +535,9 @@ function gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, out
while left_index <= right_index
# left side
v_left = FeynmanVertex(
FeynmanParticle(FermionStateful{Incoming, SpinUp}, iterations),
FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, iterations),
photons[order[left_index]],
FeynmanParticle(FermionStateful{Incoming, SpinUp}, iterations + 1),
FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, iterations + 1),
)
left_index += 1
add_vertex!(new_diagram, v_left)
@ -559,9 +550,9 @@ function gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, out
if (iterations == 1)
# right side
v_right = FeynmanVertex(
FeynmanParticle(FermionStateful{Outgoing, SpinUp}, iterations),
FeynmanParticle(ParticleStateful{Outgoing, Electron, SFourMomentum}, iterations),
photons[order[right_index]],
FeynmanParticle(FermionStateful{Outgoing, SpinUp}, iterations + 1),
FeynmanParticle(ParticleStateful{Outgoing, Electron, SFourMomentum}, iterations + 1),
)
right_index -= 1
add_vertex!(new_diagram, v_right)
@ -576,15 +567,14 @@ function gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, out
return new_diagram
end
"""
gen_compton_diagrams(n::Int, m::Int)
Special case diagram generation for Compton processes, i.e., processes of the form k^ne->k^me
"""
function gen_compton_diagrams(n::Int, m::Int)
inFerm = FeynmanParticle(FermionStateful{Incoming, SpinUp}, 1)
outFerm = FeynmanParticle(FermionStateful{Outgoing, SpinUp}, 1)
inFerm = FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, 1)
outFerm = FeynmanParticle(ParticleStateful{Outgoing, Electron, SFourMomentum}, 1)
perms = [permutations([i for i in 1:(n + m)])...]
@ -596,15 +586,14 @@ function gen_compton_diagrams(n::Int, m::Int)
return vcat(diagrams...)
end
"""
gen_compton_diagrams_one_side(n::Int, m::Int)
Special case diagram generation for Compton processes, i.e., processes of the form k^ne->k^me, but generating from one end, yielding larger diagrams
"""
function gen_compton_diagrams_one_side(n::Int, m::Int)
inFerm = FeynmanParticle(FermionStateful{Incoming, SpinUp}, 1)
outFerm = FeynmanParticle(FermionStateful{Outgoing, SpinUp}, 1)
inFerm = FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, 1)
outFerm = FeynmanParticle(ParticleStateful{Outgoing, Electron, SFourMomentum}, 1)
perms = [permutations([i for i in 1:(n + m)])...]
@ -623,12 +612,15 @@ From a given feynman diagram in its initial state, e.g. when created through the
"""
function gen_diagrams(fd::FeynmanDiagram)
if is_compton(fd)
return gen_compton_diagrams_one_side(
fd.type_ids[PhotonStateful{Incoming, PolX}],
fd.type_ids[PhotonStateful{Outgoing, PolX}],
return gen_compton_diagrams(
fd.type_ids[ParticleStateful{Incoming, Photon, SFourMomentum}],
fd.type_ids[ParticleStateful{Outgoing, Photon, SFourMomentum}],
)
end
throw(error("Unimplemented for non-compton!"))
#=
working = Set{FeynmanDiagram}()
results = Set{FeynmanDiagram}()
@ -667,4 +659,5 @@ function gen_diagrams(fd::FeynmanDiagram)
end
return remove_duplicates(results)
=#
end

View File

@ -0,0 +1,45 @@
"""
parse_process(string::AbstractString, model::QEDModel)
Parse a string representation of a process, such as "ke->ke" into the corresponding [`QEDProcessDescription`](@ref).
"""
function parse_process(
str::AbstractString,
model::QEDModel,
inphpol::AbstractDefinitePolarization = PolX(),
inelspin::AbstractDefiniteSpin = SpinUp(),
outphpol::AbstractDefinitePolarization = PolX(),
outelspin::AbstractDefiniteSpin = SpinUp(),
)
if !(contains(str, "->"))
throw("Did not find -> while parsing process \"$str\"")
end
(in_str, out_str) = split(str, "->")
if (isempty(in_str) || isempty(out_str))
throw("Process (\"$str\") input or output part is empty!")
end
in_particles = Vector{AbstractParticleType}()
out_particles = Vector{AbstractParticleType}()
for (particle_vector, s) in ((in_particles, in_str), (out_particles, out_str))
for c in s
if c == 'e'
push!(particle_vector, Electron())
elseif c == 'p'
push!(particle_vector, Positron())
elseif c == 'k'
push!(particle_vector, Photon())
else
throw("Encountered unknown characters in the process \"$str\"")
end
end
end
in_spin_pols = tuple([is_boson(in_particles[i]) ? inphpol : inelspin for i in eachindex(in_particles)]...)
out_spin_pols = tuple([is_boson(out_particles[i]) ? outphpol : outelspin for i in eachindex(out_particles)]...)
return GenericQEDProcess(tuple(in_particles...), tuple(out_particles...), in_spin_pols, out_spin_pols)
end

View File

@ -0,0 +1,305 @@
using QEDprocesses
using StaticArrays
const e = sqrt(4π / 137)
QEDbase.is_incoming(::Type{<:ParticleStateful{Incoming}}) = true
QEDbase.is_outgoing(::Type{<:ParticleStateful{Outgoing}}) = true
QEDbase.is_incoming(::Type{<:ParticleStateful{Outgoing}}) = false
QEDbase.is_outgoing(::Type{<:ParticleStateful{Incoming}}) = false
QEDbase.particle_direction(::Type{<:ParticleStateful{DIR}}) where {DIR <: ParticleDirection} = DIR()
QEDbase.particle_species(
::Type{<:ParticleStateful{DIR, SPECIES}},
) where {DIR <: ParticleDirection, SPECIES <: AbstractParticleType} = SPECIES()
function spin_or_pol(
process::GenericQEDProcess,
type::Type{ParticleStateful{DIR, SPECIES, EL}},
n::Int,
) where {DIR <: ParticleDirection, SPECIES <: AbstractParticleType, EL <: AbstractFourMomentum}
i = 0
c = n
for p in particles(process, DIR())
i += 1
if p == SPECIES()
c -= 1
end
if c == 0
break
end
end
if c != 0 || n <= 0
throw(InvalidInputError("could not get $n-th spin/pol of $(DIR()) $species, does not exist"))
end
if DIR <: Incoming
return process.incoming_spins_pols[i]
elseif DIR <: Outgoing
return process.outgoing_spins_pols[i]
else
throw(InvalidInputError("unknown direction $(DIR()) given"))
end
end
function input_type(p::GenericQEDProcess)
in_t = QEDcore._assemble_tuple_type(incoming_particles(p), Incoming(), SFourMomentum)
out_t = QEDcore._assemble_tuple_type(outgoing_particles(p), Outgoing(), SFourMomentum)
return PhaseSpacePoint{
typeof(p),
PerturbativeQED,
PhasespaceDefinition{SphericalCoordinateSystem, ElectronRestFrame},
Tuple{in_t...},
Tuple{out_t...},
}
end
ValueType = Union{BiSpinor, AdjointBiSpinor, DiracMatrix, SLorentzVector{Float64}, ComplexF64}
"""
interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: QEDParticle, T2 <: QEDParticle}
For two given particle types that can interact, return the third.
"""
function interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: ParticleStateful, T2 <: ParticleStateful}
@assert false "Invalid interaction between particles of types $t1 and $t2"
end
interaction_result(
::Type{ParticleStateful{Incoming, Electron, EL}},
::Type{ParticleStateful{Outgoing, Electron, EL}},
) where {EL <: AbstractFourMomentum} = ParticleStateful{Incoming, Photon, EL}
interaction_result(
::Type{ParticleStateful{Incoming, Electron, EL}},
::Type{ParticleStateful{Incoming, Positron, EL}},
) where {EL <: AbstractFourMomentum} = ParticleStateful{Incoming, Photon, EL}
interaction_result(
::Type{ParticleStateful{Incoming, Electron, EL}},
::Type{ParticleStateful{DIR, Photon, EL}},
) where {EL <: AbstractFourMomentum, DIR <: ParticleDirection} = ParticleStateful{Outgoing, Electron, EL}
interaction_result(
::Type{ParticleStateful{Outgoing, Electron, EL}},
::Type{ParticleStateful{Incoming, Electron, EL}},
) where {EL <: AbstractFourMomentum} = ParticleStateful{Incoming, Photon, EL}
interaction_result(
::Type{ParticleStateful{Outgoing, Electron, EL}},
::Type{ParticleStateful{Outgoing, Positron, EL}},
) where {EL <: AbstractFourMomentum} = ParticleStateful{Incoming, Photon, EL}
interaction_result(
::Type{ParticleStateful{Outgoing, Electron, EL}},
::Type{ParticleStateful{DIR, Photon, EL}},
) where {EL <: AbstractFourMomentum, DIR <: ParticleDirection} = ParticleStateful{Incoming, Electron, EL}
# antifermion mirror
interaction_result(
::Type{ParticleStateful{Incoming, Positron, EL}},
t2::Type{<:ParticleStateful},
) where {EL <: AbstractFourMomentum} = interaction_result(ParticleStateful{Outgoing, Electron, EL}, t2)
interaction_result(
::Type{ParticleStateful{Outgoing, Positron, EL}},
t2::Type{<:ParticleStateful},
) where {EL <: AbstractFourMomentum} = interaction_result(ParticleStateful{Incoming, Electron, EL}, t2)
# photon commutativity
interaction_result(
t1::Type{ParticleStateful{DIR, Photon, EL}},
t2::Type{<:ParticleStateful},
) where {EL <: AbstractFourMomentum, DIR <: ParticleDirection} = interaction_result(t2, t1)
# but prevent stack overflow
function interaction_result(
t1::Type{ParticleStateful{DIR1, Photon, EL}},
t2::Type{ParticleStateful{DIR2, Photon, EL}},
) where {DIR1 <: ParticleDirection, DIR2 <: ParticleDirection, EL <: AbstractFourMomentum}
@assert false "Invalid interaction between particles of types $t1 and $t2"
end
"""
propagation_result(t1::Type{T}) where {T <: QEDParticle}
Return the type of the inverted direction. E.g.
"""
propagation_result(::Type{ParticleStateful{Incoming, Electron, EL}}) where {EL <: AbstractFourMomentum} =
ParticleStateful{Outgoing, Electron, EL}
propagation_result(::Type{ParticleStateful{Outgoing, Electron, EL}}) where {EL <: AbstractFourMomentum} =
ParticleStateful{Incoming, Electron, EL}
propagation_result(::Type{ParticleStateful{Incoming, Positron, EL}}) where {EL <: AbstractFourMomentum} =
ParticleStateful{Outgoing, Positron, EL}
propagation_result(::Type{ParticleStateful{Outgoing, Positron, EL}}) where {EL <: AbstractFourMomentum} =
ParticleStateful{Incoming, Positron, EL}
propagation_result(::Type{ParticleStateful{Incoming, Photon, EL}}) where {EL <: AbstractFourMomentum} =
ParticleStateful{Outgoing, Photon, EL}
propagation_result(::Type{ParticleStateful{Outgoing, Photon, EL}}) where {EL <: AbstractFourMomentum} =
ParticleStateful{Incoming, Photon, EL}
"""
types(::QEDModel)
Return a Vector of the possible types of particle in the [`QEDModel`](@ref).
"""
function types(::QEDModel)
return [
ParticleStateful{Incoming, Photon, SFourMomentum},
ParticleStateful{Outgoing, Photon, SFourMomentum},
ParticleStateful{Incoming, Electron, SFourMomentum},
ParticleStateful{Outgoing, Electron, SFourMomentum},
ParticleStateful{Incoming, Positron, SFourMomentum},
ParticleStateful{Outgoing, Positron, SFourMomentum},
]
end
# type piracy?
String(::Type{Incoming}) = "Incoming"
String(::Type{Outgoing}) = "Outgoing"
String(::Type{PolX}) = "polx"
String(::Type{PolY}) = "poly"
String(::Type{SpinUp}) = "spinup"
String(::Type{SpinDown}) = "spindown"
String(::Incoming) = "i"
String(::Outgoing) = "o"
function String(::Type{<:ParticleStateful{DIR, Photon}}) where {DIR <: ParticleDirection}
return "k"
end
function String(::Type{<:ParticleStateful{DIR, Electron}}) where {DIR <: ParticleDirection}
return "e"
end
function String(::Type{<:ParticleStateful{DIR, Positron}}) where {DIR <: ParticleDirection}
return "p"
end
"""
caninteract(T1::Type{<:ParticleStateful}, T2::Type{<:ParticleStateful})
For two given `ParticleStateful` types, return whether they can interact at a vertex. This is equivalent to `!issame(T1, T2)`.
See also: [`issame`](@ref) and [`interaction_result`](@ref)
"""
function caninteract(
T1::Type{<:ParticleStateful{D1, S1}},
T2::Type{<:ParticleStateful{D2, S2}},
) where {D1 <: ParticleDirection, S1 <: AbstractParticleType, D2 <: ParticleDirection, S2 <: AbstractParticleType}
if (T1 == T2)
return false
end
if (S1 == Photon && S2 == Photon)
return false
end
for (P1, P2) in [(T1, T2), (T2, T1)]
if (P1 <: ParticleStateful{Incoming, Electron} && P2 <: ParticleStateful{Outgoing, Positron})
return false
end
if (P1 <: ParticleStateful{Outgoing, Electron} && P2 <: ParticleStateful{Incoming, Positron})
return false
end
end
return true
end
function type_index_from_name(::QEDModel, name::String)
if startswith(name, "ki")
return (ParticleStateful{Incoming, Photon, SFourMomentum}, parse(Int, name[3:end]))
elseif startswith(name, "ko")
return (ParticleStateful{Outgoing, Photon, SFourMomentum}, parse(Int, name[3:end]))
elseif startswith(name, "ei")
return (ParticleStateful{Incoming, Electron, SFourMomentum}, parse(Int, name[3:end]))
elseif startswith(name, "eo")
return (ParticleStateful{Outgoing, Electron, SFourMomentum}, parse(Int, name[3:end]))
elseif startswith(name, "pi")
return (ParticleStateful{Incoming, Positron, SFourMomentum}, parse(Int, name[3:end]))
elseif startswith(name, "po")
return (ParticleStateful{Outgoing, Positron, SFourMomentum}, parse(Int, name[3:end]))
else
throw("Invalid name for a particle in the QED model")
end
end
"""
issame(T1::Type{<:ParticleStateful}, T2::Type{<:ParticleStateful})
For two given `ParticleStateful` types, return whether they are equivalent for the purpose of a Feynman Diagram. That means e.g. an `Incoming` `AntiFermion` is the same as an `Outgoing` `Fermion`. This is equivalent to `!caninteract(T1, T2)`.
See also: [`caninteract`](@ref) and [`interaction_result`](@ref)
"""
function issame(T1::Type{<:ParticleStateful}, T2::Type{<:ParticleStateful})
return !caninteract(T1, T2)
end
"""
QED_vertex()
Return the factor of a vertex in a QED feynman diagram.
"""
@inline function QED_vertex()::SLorentzVector{DiracMatrix}
# Peskin-Schroeder notation
return -1im * e * gamma()
end
@inline function QED_inner_edge(p::ParticleStateful)
return propagator(particle_species(p), momentum(p))
end
"""
QED_conserve_momentum(p1::ParticleStateful, p2::ParticleStateful)
Calculate and return a new particle from two given interacting ones at a vertex.
"""
function QED_conserve_momentum(
p1::P1,
p2::P2,
) where {
Dir1 <: ParticleDirection,
Dir2 <: ParticleDirection,
Species1 <: AbstractParticleType,
Species2 <: AbstractParticleType,
P1 <: ParticleStateful{Dir1, Species1},
P2 <: ParticleStateful{Dir2, Species2},
}
P3 = interaction_result(P1, P2)
p1_mom = momentum(p1)
if (Dir1 <: Outgoing)
p1_mom *= -1
end
p2_mom = momentum(p2)
if (Dir2 <: Outgoing)
p2_mom *= -1
end
p3_mom = p1_mom + p2_mom
if (particle_direction(P3) isa Incoming)
return ParticleStateful(particle_direction(P3), particle_species(P3), -p3_mom)
end
return ParticleStateful(particle_direction(P3), particle_species(P3), p3_mom)
end
"""
model(::AbstractProcessDescription)
Return the model of this process description.
"""
model(::GenericQEDProcess) = QEDModel()
model(::PhaseSpacePoint) = QEDModel()
function get_particle(
input::PhaseSpacePoint,
t::Type{ParticleStateful{DIR, SPECIES}},
n::Int,
) where {DIR <: ParticleDirection, SPECIES <: AbstractParticleType}
i = 0
for p in particles(input, DIR())
if p isa t
i += 1
if i == n
return p
end
end
end
@assert false "Invalid type given"
end

View File

@ -1,8 +1,9 @@
#=
"""
show(io::IO, process::QEDProcessDescription)
show(io::IO, process::GenericQEDProcess)
Pretty print an [`QEDProcessDescription`](@ref) (no newlines).
Pretty print an [`GenericQEDProcess`](@ref) (no newlines).
```jldoctest
julia> using MetagraphOptimization
@ -14,7 +15,7 @@ julia> print(parse_process("kk->ep", QEDModel()))
QED Process: 'kk->ep'
```
"""
function show(io::IO, process::QEDProcessDescription)
function show(io::IO, process::GenericQEDProcess)
# types() gives the types in order (QED) instead of random like keys() would
print(io, "QED Process: \'")
for type in types(QEDModel())
@ -34,7 +35,7 @@ end
"""
String(process::QEDProcessDescription)
String(process::GenericQEDProcess)
Create a short string suitable as a filename or similar, describing the given process.
@ -64,7 +65,9 @@ function String(process::QEDProcessDescription)
end
return str
end
=#
#=
"""
show(io::IO, processInput::QEDProcessInput)
@ -92,7 +95,9 @@ function show(io::IO, processInput::QEDProcessInput)
end
return nothing
end
=#
#=
"""
show(io::IO, particle::T) where {T <: QEDParticle}
@ -102,13 +107,14 @@ function show(io::IO, particle::T) where {T <: QEDParticle}
print(io, "$(String(typeof(particle))): $(particle.momentum)")
return nothing
end
=#
"""
show(io::IO, particle::FeynmanParticle)
Pretty print a [`FeynmanParticle`](@ref) (no newlines).
"""
show(io::IO, p::FeynmanParticle) = print(io, "$(String(p.particle))_$(String(direction(p.particle)))_$(p.id)")
show(io::IO, p::FeynmanParticle) = print(io, "$(String(p.particle))_$(String(particle_direction(p.particle)))_$(p.id)")
"""
show(io::IO, particle::FeynmanVertex)

View File

@ -1,3 +1,10 @@
"""
QEDModel <: AbstractPhysicsModel
Singleton definition for identification of the QED-Model.
"""
struct QEDModel <: AbstractPhysicsModel end
"""
ComputeTaskQED_S1 <: AbstractComputeTask

View File

@ -0,0 +1,14 @@
using QEDprocesses
# add type overload for number_particles function
@inline function QEDprocesses.number_particles(
proc_def::QEDbase.AbstractProcessDefinition,
::Type{PS},
) where {
DIR <: QEDbase.ParticleDirection,
PT <: QEDbase.AbstractParticleType,
EL <: AbstractFourMomentum,
PS <: ParticleStateful{DIR, PT, EL},
}
return QEDprocesses.number_particles(proc_def, DIR(), PT())
end

View File

@ -1,10 +0,0 @@
"""
show(io::IO, particleValue::ParticleValue)
Pretty print a [`ParticleValue`](@ref), no newlines.
"""
function show(io::IO, particleValue::ParticleValue)
print(io, "($(particleValue.p), value: $(particleValue.v))")
return nothing
end

View File

@ -1,44 +0,0 @@
"""
parse_process(string::AbstractString, model::QEDModel)
Parse a string representation of a process, such as "ke->ke" into the corresponding [`QEDProcessDescription`](@ref).
"""
function parse_process(str::AbstractString, model::QEDModel)
inParticles = Dict{Type, Int}()
outParticles = Dict{Type, Int}()
if !(contains(str, "->"))
throw("Did not find -> while parsing process \"$str\"")
end
(inStr, outStr) = split(str, "->")
if (isempty(inStr) || isempty(outStr))
throw("Process (\"$str\") input or output part is empty!")
end
for t in types(model)
if (isincoming(t))
inCount = count(x -> x == String(t)[1], inStr)
if inCount != 0
inParticles[t] = inCount
end
end
if (isoutgoing(t))
outCount = count(x -> x == String(t)[1], outStr)
if outCount != 0
outParticles[t] = outCount
end
end
end
if length(inStr) != sum(values(inParticles))
throw("Encountered unknown characters in the input part of process \"$str\"")
elseif length(outStr) != sum(values(outParticles))
throw("Encountered unknown characters in the output part of process \"$str\"")
end
return QEDProcessDescription(inParticles, outParticles)
end

View File

@ -1,408 +0,0 @@
using QEDprocesses
using StaticArrays
import QEDbase.mass
# TODO check
const e = sqrt(4π / 137)
"""
QEDModel <: AbstractPhysicsModel
Singleton definition for identification of the QED-Model.
"""
struct QEDModel <: AbstractPhysicsModel end
"""
QEDParticle
Base type for all particles in the [`QEDModel`](@ref).
Its template parameter specifies the particle's direction.
The concrete types contain singletons of the types that they are, like `Photon` and `Electron` from QEDbase, and their state descriptions.
"""
abstract type QEDParticle{Direction <: ParticleDirection} <: AbstractParticle end
"""
QEDProcessDescription <: AbstractProcessDescription
A description of a process in the QED-Model. Contains the input and output particles.
See also: [`in_particles`](@ref), [`out_particles`](@ref), [`parse_process`](@ref)
"""
struct QEDProcessDescription <: AbstractProcessDescription
inParticles::Dict{Type{<:QEDParticle{Incoming}}, Int}
outParticles::Dict{Type{<:QEDParticle{Outgoing}}, Int}
end
QEDParticleValue{ParticleType <: QEDParticle} = Union{
ParticleValue{ParticleType, BiSpinor},
ParticleValue{ParticleType, AdjointBiSpinor},
ParticleValue{ParticleType, DiracMatrix},
ParticleValue{ParticleType, SLorentzVector{Float64}},
ParticleValue{ParticleType, ComplexF64},
}
"""
PhotonStateful <: QEDParticle
A photon of the [`QEDModel`](@ref) with its state.
"""
struct PhotonStateful{Direction <: ParticleDirection, Pol <: AbstractDefinitePolarization} <: QEDParticle{Direction}
momentum::SFourMomentum
end
PhotonStateful{Direction}(mom::SFourMomentum) where {Direction <: ParticleDirection} =
PhotonStateful{Direction, PolX}(mom)
PhotonStateful{Dir, Pol}(ph::PhotonStateful) where {Dir, Pol} = PhotonStateful{Dir, Pol}(ph.momentum)
"""
FermionStateful <: QEDParticle
A fermion of the [`QEDModel`](@ref) with its state.
"""
struct FermionStateful{Direction <: ParticleDirection, Spin <: AbstractDefiniteSpin} <: QEDParticle{Direction}
momentum::SFourMomentum
# TODO: mass for electron/muon/tauon representation?
end
FermionStateful{Direction}(mom::SFourMomentum) where {Direction <: ParticleDirection} =
FermionStateful{Direction, SpinUp}(mom)
FermionStateful{Dir, Spin}(f::FermionStateful) where {Dir, Spin} = FermionStateful{Dir, Spin}(f.momentum)
"""
AntiFermionStateful <: QEDParticle
An anti-fermion of the [`QEDModel`](@ref) with its state.
"""
struct AntiFermionStateful{Direction <: ParticleDirection, Spin <: AbstractDefiniteSpin} <: QEDParticle{Direction}
momentum::SFourMomentum
# TODO: mass for electron/muon/tauon representation?
end
AntiFermionStateful{Direction}(mom::SFourMomentum) where {Direction <: ParticleDirection} =
AntiFermionStateful{Direction, SpinUp}(mom)
AntiFermionStateful{Dir, Spin}(f::AntiFermionStateful) where {Dir, Spin} = AntiFermionStateful{Dir, Spin}(f.momentum)
"""
interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: QEDParticle, T2 <: QEDParticle}
For two given particle types that can interact, return the third.
"""
function interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: QEDParticle, T2 <: QEDParticle}
@assert false "Invalid interaction between particles of types $t1 and $t2"
end
interaction_result(
::Type{FermionStateful{Incoming, Spin1}},
::Type{FermionStateful{Outgoing, Spin2}},
) where {Spin1, Spin2} = PhotonStateful{Incoming, PolX}
interaction_result(
::Type{FermionStateful{Incoming, Spin1}},
::Type{AntiFermionStateful{Incoming, Spin2}},
) where {Spin1, Spin2} = PhotonStateful{Incoming, PolX}
interaction_result(::Type{FermionStateful{Incoming, Spin1}}, ::Type{<:PhotonStateful}) where {Spin1} =
FermionStateful{Outgoing, SpinUp}
interaction_result(
::Type{FermionStateful{Outgoing, Spin1}},
::Type{FermionStateful{Incoming, Spin2}},
) where {Spin1, Spin2} = PhotonStateful{Incoming, PolX}
interaction_result(
::Type{FermionStateful{Outgoing, Spin1}},
::Type{AntiFermionStateful{Outgoing, Spin2}},
) where {Spin1, Spin2} = PhotonStateful{Incoming, PolX}
interaction_result(::Type{FermionStateful{Outgoing, Spin1}}, ::Type{<:PhotonStateful}) where {Spin1} =
FermionStateful{Incoming, SpinUp}
# antifermion mirror
interaction_result(::Type{AntiFermionStateful{Incoming, Spin}}, t2::Type{<:QEDParticle}) where {Spin} =
interaction_result(FermionStateful{Outgoing, Spin}, t2)
interaction_result(::Type{AntiFermionStateful{Outgoing, Spin}}, t2::Type{<:QEDParticle}) where {Spin} =
interaction_result(FermionStateful{Incoming, Spin}, t2)
# photon commutativity
interaction_result(t1::Type{<:PhotonStateful}, t2::Type{<:QEDParticle}) = interaction_result(t2, t1)
# but prevent stack overflow
function interaction_result(t1::Type{<:PhotonStateful}, t2::Type{<:PhotonStateful})
@assert false "Invalid interaction between particles of types $t1 and $t2"
end
"""
propagation_result(t1::Type{T}) where {T <: QEDParticle}
Return the type of the inverted direction. E.g.
"""
propagation_result(::Type{FermionStateful{Incoming, Spin}}) where {Spin <: AbstractDefiniteSpin} =
FermionStateful{Outgoing, Spin}
propagation_result(::Type{FermionStateful{Outgoing, Spin}}) where {Spin <: AbstractDefiniteSpin} =
FermionStateful{Incoming, Spin}
propagation_result(::Type{AntiFermionStateful{Incoming, Spin}}) where {Spin <: AbstractDefiniteSpin} =
AntiFermionStateful{Outgoing, Spin}
propagation_result(::Type{AntiFermionStateful{Outgoing, Spin}}) where {Spin <: AbstractDefiniteSpin} =
AntiFermionStateful{Incoming, Spin}
propagation_result(::Type{PhotonStateful{Incoming, Pol}}) where {Pol <: AbstractDefinitePolarization} =
PhotonStateful{Outgoing, Pol}
propagation_result(::Type{PhotonStateful{Outgoing, Pol}}) where {Pol <: AbstractDefinitePolarization} =
PhotonStateful{Incoming, Pol}
"""
types(::QEDModel)
Return a Vector of the possible types of particle in the [`QEDModel`](@ref).
"""
function types(::QEDModel)
return [
PhotonStateful{Incoming, PolX},
PhotonStateful{Outgoing, PolX},
FermionStateful{Incoming, SpinUp},
FermionStateful{Outgoing, SpinUp},
AntiFermionStateful{Incoming, SpinUp},
AntiFermionStateful{Outgoing, SpinUp},
]
end
# type piracy?
String(::Type{Incoming}) = "Incoming"
String(::Type{Outgoing}) = "Outgoing"
String(::Type{PolX}) = "polx"
String(::Type{PolY}) = "poly"
String(::Type{SpinUp}) = "spinup"
String(::Type{SpinDown}) = "spindown"
String(::Incoming) = "i"
String(::Outgoing) = "o"
function String(::Type{<:PhotonStateful})
return "k"
end
function String(::Type{<:FermionStateful})
return "e"
end
function String(::Type{<:AntiFermionStateful})
return "p"
end
function unique_name(::Type{PhotonStateful{Dir, Pol}}) where {Dir, Pol}
return String(PhotonStateful) * String(Dir) * String(Pol)
end
function unique_name(::Type{FermionStateful{Dir, Spin}}) where {Dir, Spin}
return String(FermionStateful) * String(Dir) * String(Spin)
end
function unique_name(::Type{AntiFermionStateful{Dir, Spin}}) where {Dir, Spin}
return String(AntiFermionStateful) * String(Dir) * String(Spin)
end
@inline particle(::PhotonStateful) = Photon()
@inline particle(::FermionStateful) = Electron()
@inline particle(::AntiFermionStateful) = Positron()
@inline momentum(p::PhotonStateful)::SFourMomentum = p.momentum
@inline momentum(p::FermionStateful)::SFourMomentum = p.momentum
@inline momentum(p::AntiFermionStateful)::SFourMomentum = p.momentum
@inline spin_or_pol(p::PhotonStateful{Dir, Pol}) where {Dir, Pol <: AbstractDefinitePolarization} = Pol()
@inline spin_or_pol(p::FermionStateful{Dir, Spin}) where {Dir, Spin <: AbstractDefiniteSpin} = Spin()
@inline spin_or_pol(p::AntiFermionStateful{Dir, Spin}) where {Dir, Spin <: AbstractDefiniteSpin} = Spin()
@inline direction(
::Type{P},
) where {P <: Union{FermionStateful{Incoming}, AntiFermionStateful{Incoming}, PhotonStateful{Incoming}}} = Incoming()
@inline direction(
::Type{P},
) where {P <: Union{FermionStateful{Outgoing}, AntiFermionStateful{Outgoing}, PhotonStateful{Outgoing}}} = Outgoing()
@inline direction(
::P,
) where {P <: Union{FermionStateful{Incoming}, AntiFermionStateful{Incoming}, PhotonStateful{Incoming}}} = Incoming()
@inline direction(
::P,
) where {P <: Union{FermionStateful{Outgoing}, AntiFermionStateful{Outgoing}, PhotonStateful{Outgoing}}} = Outgoing()
@inline isincoming(::QEDParticle{Incoming}) = true
@inline isincoming(::QEDParticle{Outgoing}) = false
@inline isoutgoing(::QEDParticle{Incoming}) = false
@inline isoutgoing(::QEDParticle{Outgoing}) = true
@inline isincoming(::Type{<:QEDParticle{Incoming}}) = true
@inline isincoming(::Type{<:QEDParticle{Outgoing}}) = false
@inline isoutgoing(::Type{<:QEDParticle{Incoming}}) = false
@inline isoutgoing(::Type{<:QEDParticle{Outgoing}}) = true
@inline mass(::Type{<:FermionStateful}) = 1.0
@inline mass(::Type{<:AntiFermionStateful}) = 1.0
@inline mass(::Type{<:PhotonStateful}) = 0.0
@inline invert_momentum(p::FermionStateful{Dir, Spin}) where {Dir, Spin} =
FermionStateful{Dir, Spin}(-p.momentum, p.spin)
@inline invert_momentum(p::AntiFermionStateful{Dir, Spin}) where {Dir, Spin} =
AntiFermionStateful{Dir, Spin}(-p.momentum, p.spin)
@inline invert_momentum(k::PhotonStateful{Dir, Spin}) where {Dir, Spin} =
PhotonStateful{Dir, Spin}(-k.momentum, k.polarization)
"""
caninteract(T1::Type{<:QEDParticle}, T2::Type{<:QEDParticle})
For two given [`QEDParticle`](@ref) types, return whether they can interact at a vertex. This is equivalent to `!issame(T1, T2)`.
See also: [`issame`](@ref) and [`interaction_result`](@ref)
"""
function caninteract(T1::Type{<:QEDParticle}, T2::Type{<:QEDParticle})
if (T1 == T2)
return false
end
if (T1 <: PhotonStateful && T2 <: PhotonStateful)
return false
end
for (P1, P2) in [(T1, T2), (T2, T1)]
if (P1 <: FermionStateful{Incoming} && P2 <: AntiFermionStateful{Outgoing})
return false
end
if (P1 <: FermionStateful{Outgoing} && P2 <: AntiFermionStateful{Incoming})
return false
end
end
return true
end
function type_index_from_name(::QEDModel, name::String)
if startswith(name, "ki")
return (PhotonStateful{Incoming, PolX}, parse(Int, name[3:end]))
elseif startswith(name, "ko")
return (PhotonStateful{Outgoing, PolX}, parse(Int, name[3:end]))
elseif startswith(name, "ei")
return (FermionStateful{Incoming, SpinUp}, parse(Int, name[3:end]))
elseif startswith(name, "eo")
return (FermionStateful{Outgoing, SpinUp}, parse(Int, name[3:end]))
elseif startswith(name, "pi")
return (AntiFermionStateful{Incoming, SpinUp}, parse(Int, name[3:end]))
elseif startswith(name, "po")
return (AntiFermionStateful{Outgoing, SpinUp}, parse(Int, name[3:end]))
else
throw("Invalid name for a particle in the QED model")
end
end
"""
issame(T1::Type{<:QEDParticle}, T2::Type{<:QEDParticle})
For two given [`QEDParticle`](@ref) types, return whether they are equivalent for the purpose of a Feynman Diagram. That means e.g. an `Incoming` `AntiFermion` is the same as an `Outgoing` `Fermion`. This is equivalent to `!caninteract(T1, T2)`.
See also: [`caninteract`](@ref) and [`interaction_result`](@ref)
"""
function issame(T1::Type{<:QEDParticle}, T2::Type{<:QEDParticle})
return !caninteract(T1, T2)
end
"""
QED_vertex()
Return the factor of a vertex in a QED feynman diagram.
"""
@inline function QED_vertex()::SLorentzVector{DiracMatrix}
# Peskin-Schroeder notation
return -1im * e * gamma()
end
@inline function QED_inner_edge(p::QEDParticle)
return propagator(particle(p), p.momentum)
end
"""
QED_conserve_momentum(p1::QEDParticle, p2::QEDParticle)
Calculate and return a new particle from two given interacting ones at a vertex.
"""
function QED_conserve_momentum(
p1::P1,
p2::P2,
) where {
Dir1 <: ParticleDirection,
Dir2 <: ParticleDirection,
SpinPol1 <: AbstractSpinOrPolarization,
SpinPol2 <: AbstractSpinOrPolarization,
P1 <: Union{FermionStateful{Dir1, SpinPol1}, AntiFermionStateful{Dir1, SpinPol1}, PhotonStateful{Dir1, SpinPol1}},
P2 <: Union{FermionStateful{Dir2, SpinPol2}, AntiFermionStateful{Dir2, SpinPol2}, PhotonStateful{Dir2, SpinPol2}},
}
P3 = interaction_result(P1, P2)
p1_mom = p1.momentum
if (Dir1 <: Outgoing)
p1_mom *= -1
end
p2_mom = p2.momentum
if (Dir2 <: Outgoing)
p2_mom *= -1
end
p3_mom = p1_mom + p2_mom
if (typeof(direction(P3)) <: Incoming)
return P3(-p3_mom)
end
return P3(p3_mom)
end
"""
QEDProcessInput <: AbstractProcessInput
Input for a QED Process. Contains the [`QEDProcessDescription`](@ref) of the process it is an input for, and the values of the in and out particles.
See also: [`gen_process_input`](@ref)
"""
struct QEDProcessInput{N1, N2, N3, N4, N5, N6} <: AbstractProcessInput
process::QEDProcessDescription
inFerms::SVector{N1, FermionStateful{Incoming, SpinUp}}
outFerms::SVector{N2, FermionStateful{Outgoing, SpinUp}}
inAntiferms::SVector{N3, AntiFermionStateful{Incoming, SpinUp}}
outAntiferms::SVector{N4, AntiFermionStateful{Outgoing, SpinUp}}
inPhotons::SVector{N5, PhotonStateful{Incoming, PolX}}
outPhotons::SVector{N6, PhotonStateful{Outgoing, PolX}}
end
"""
model(::AbstractProcessDescription)
Return the model of this process description.
"""
model(::QEDProcessDescription) = QEDModel()
model(::QEDProcessInput) = QEDModel()
function copy(process::QEDProcessDescription)
return QEDProcessDescription(copy(process.inParticles), copy(process.outParticles))
end
==(p1::QEDProcessDescription, p2::QEDProcessDescription) =
p1.inParticles == p2.inParticles && p1.outParticles == p2.outParticles
function in_particles(process::QEDProcessDescription)
return process.inParticles
end
function out_particles(process::QEDProcessDescription)
return process.outParticles
end
function get_particle(input::QEDProcessInput, t::Type{Particle}, n::Int)::Particle where {Particle}
if (t <: FermionStateful{Incoming})
return input.inFerms[n]
elseif (t <: FermionStateful{Outgoing})
return input.outFerms[n]
elseif (t <: AntiFermionStateful{Incoming})
return input.inAntiferms[n]
elseif (t <: AntiFermionStateful{Outgoing})
return input.outAntiferms[n]
elseif (t <: PhotonStateful{Incoming})
return input.inPhotons[n]
elseif (t <: PhotonStateful{Outgoing})
return input.outPhotons[n]
end
@assert false "Invalid type given"
end

View File

@ -1,6 +1,6 @@
DataTaskNode(t::AbstractDataTask, name = "") =
DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, missing, name)
DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, name)
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
t, # task
Vector{Node}(), # parents
@ -8,7 +8,6 @@ ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
UUIDs.uuid1(rng[threadid()]), # id
missing, # node reduction
missing, # node split
Vector{NodeFusion}(), # node fusions
missing, # device
)

View File

@ -30,7 +30,6 @@ Any node that transfers data and does no computation.
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.\\
`.name`: The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\
"""
mutable struct DataTaskNode{TaskType <: AbstractDataTask} <: Node
@ -51,9 +50,6 @@ mutable struct DataTaskNode{TaskType <: AbstractDataTask} <: Node
# the NodeSplit involving this node, if it exists
nodeSplit::Union{Operation, Missing}
# the node fusion involving this node, if it exists
nodeFusion::Union{Operation, Missing}
# for input nodes we need a name for the node to distinguish between them
name::String
end
@ -70,7 +66,6 @@ Any node that computes a result from inputs using an [`AbstractComputeTask`](@re
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeFusions`: A vector of this node's [`NodeFusion`](@ref)s. For a `ComputeTaskNode` there can be any number of these, unlike the [`DataTaskNode`](@ref)s.\\
`.device`: The Device this node has been scheduled on by a [`Scheduler`](@ref).
"""
mutable struct ComputeTaskNode{TaskType <: AbstractComputeTask} <: Node
@ -82,9 +77,6 @@ mutable struct ComputeTaskNode{TaskType <: AbstractComputeTask} <: Node
nodeReduction::Union{Operation, Missing}
nodeSplit::Union{Operation, Missing}
# for ComputeTasks there can be multiple fusions, unlike the DataTasks
nodeFusions::Vector{<:Operation}
# the device this node is assigned to execute on
device::Union{AbstractDevice, Missing}
end

View File

@ -29,17 +29,6 @@ function is_valid_node(graph::DAG, node::Node)
@assert is_valid(graph, node.nodeSplit)
end=#
if !(typeof(task(node)) <: FusedComputeTask)
# the remaining checks are only necessary for fused compute tasks
return true
end
# every child must be in some input of the task
for child in node.children
str = Symbol(to_var_name(child.id))
@assert (str in task(node).t1_inputs) || (str in task(node).t2_inputs) "$str was not in any of the tasks' inputs\nt1_inputs: $(task(node).t1_inputs)\nt2_inputs: $(task(node).t2_inputs)"
end
return true
end
@ -53,9 +42,6 @@ This also calls [`is_valid_node(graph::DAG, node::Node)`](@ref).
function is_valid(graph::DAG, node::ComputeTaskNode)
@assert is_valid_node(graph, node)
#=for nf in node.nodeFusions
@assert is_valid(graph, nf)
end=#
return true
end
@ -69,8 +55,5 @@ This also calls [`is_valid_node(graph::DAG, node::Node)`](@ref).
function is_valid(graph::DAG, node::DataTaskNode)
@assert is_valid_node(graph, node)
#=if !ismissing(node.nodeFusion)
@assert is_valid(graph, node.nodeFusion)
end=#
return true
end

View File

@ -26,21 +26,6 @@ function apply_operation!(graph::DAG, operation::Operation)
return error("Unknown operation type!")
end
"""
apply_operation!(graph::DAG, operation::NodeFusion)
Apply the given [`NodeFusion`](@ref) to the graph. Generic wrapper around [`node_fusion!`](@ref).
Return an [`AppliedNodeFusion`](@ref) object generated from the graph's [`Diff`](@ref).
"""
function apply_operation!(graph::DAG, operation::NodeFusion)
diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3])
graph.properties += GraphProperties(diff)
return AppliedNodeFusion(operation, diff)
end
"""
apply_operation!(graph::DAG, operation::NodeReduction)
@ -80,20 +65,10 @@ function revert_operation!(graph::DAG, operation::AppliedOperation)
return error("Unknown operation type!")
end
"""
revert_operation!(graph::DAG, operation::AppliedNodeFusion)
Revert the applied node fusion on the graph. Return the original [`NodeFusion`](@ref) operation.
"""
function revert_operation!(graph::DAG, operation::AppliedNodeFusion)
revert_diff!(graph, operation.diff)
return operation.operation
end
"""
revert_operation!(graph::DAG, operation::AppliedNodeReduction)
Revert the applied node fusion on the graph. Return the original [`NodeReduction`](@ref) operation.
Revert the applied node reduction on the graph. Return the original [`NodeReduction`](@ref) operation.
"""
function revert_operation!(graph::DAG, operation::AppliedNodeReduction)
revert_diff!(graph, operation.diff)
@ -103,7 +78,7 @@ end
"""
revert_operation!(graph::DAG, operation::AppliedNodeSplit)
Revert the applied node fusion on the graph. Return the original [`NodeSplit`](@ref) operation.
Revert the applied node split on the graph. Return the original [`NodeSplit`](@ref) operation.
"""
function revert_operation!(graph::DAG, operation::AppliedNodeSplit)
revert_diff!(graph, operation.diff)
@ -132,88 +107,11 @@ function revert_diff!(graph::DAG, diff::Diff)
insert_edge!(graph, edge.edge[1], edge.edge[2], track = false)
end
for (node, t) in diff.updatedChildren
# node must be fused compute task at this point
@assert typeof(task(node)) <: FusedComputeTask
node.task = t
end
graph.properties -= GraphProperties(diff)
return nothing
end
"""
node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph.
For details see [`NodeFusion`](@ref).
"""
function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
@assert is_valid_node_fusion_input(graph, n1, n2, n3)
# clear snapshot
get_snapshot_diff(graph)
# save children and parents
n1Children = copy(children(n1))
n3Parents = copy(parents(n3))
n1Task = copy(task(n1))
n3Task = copy(task(n3))
# assemble the input node vectors of n1 and n3 to save into the FusedComputeTask
n1Inputs = Vector{Symbol}()
for child in n1Children
push!(n1Inputs, Symbol(to_var_name(child.id)))
end
# remove the edges and nodes that will be replaced by the fused node
remove_edge!(graph, n1, n2)
remove_edge!(graph, n2, n3)
remove_node!(graph, n1)
remove_node!(graph, n2)
# get n3's children now so it automatically excludes n2
n3Children = copy(children(n3))
n3Inputs = Vector{Symbol}()
for child in n3Children
push!(n3Inputs, Symbol(to_var_name(child.id)))
end
remove_node!(graph, n3)
# create new node with the fused compute task
newNode = ComputeTaskNode(FusedComputeTask(n1Task, n3Task, n1Inputs, Symbol(to_var_name(n2.id)), n3Inputs))
insert_node!(graph, newNode)
for child in n1Children
remove_edge!(graph, child, n1)
insert_edge!(graph, child, newNode)
end
for child in n3Children
remove_edge!(graph, child, n3)
if !(child in n1Children)
insert_edge!(graph, child, newNode)
end
end
for parent in n3Parents
remove_edge!(graph, n3, parent)
insert_edge!(graph, newNode, parent)
# important! update the parent node's child names in case they are fused compute tasks
# needed for compute generation so the fused compute task can correctly match inputs to its component tasks
update_child!(graph, parent, Symbol(to_var_name(n3.id)), Symbol(to_var_name(newNode.id)))
end
return get_snapshot_diff(graph)
end
"""
node_reduction!(graph::DAG, nodes::Vector{Node})
@ -265,7 +163,6 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
# this has to be done for all parents, even the ones of n1 because they can be duplicate
prevChild = newParentsChildNames[parent]
update_child!(graph, parent, prevChild, Symbol(to_var_name(n1.id)))
end
return get_snapshot_diff(graph)
@ -307,8 +204,6 @@ function node_split!(
for child in n1Children
insert_edge!(graph, child, nCopy)
end
update_child!(graph, parent, Symbol(to_var_name(n1.id)), Symbol(to_var_name(nCopy.id)))
end
return get_snapshot_diff(graph)

View File

@ -1,60 +1,5 @@
# These are functions for "cleaning" nodes, i.e. regenerating the possible operations for a node
"""
find_fusions!(graph::DAG, node::DataTaskNode)
Find node fusions involving the given data node. The function pushes the found [`NodeFusion`](@ref) (if any) everywhere it needs to be and returns nothing.
Does nothing if the node already has a node fusion set. Since it's a data node, only one node fusion can be possible with it.
"""
function find_fusions!(graph::DAG, node::DataTaskNode)
# if there is already a fusion here, skip to avoid duplicates
if !ismissing(node.nodeFusion)
return nothing
end
if length(parents(node)) != 1 || length(children(node)) != 1
return nothing
end
child_node = first(children(node))
parent_node = first(parents(node))
if !(child_node in graph) || !(parent_node in graph)
error("Parents/Children that are not in the graph!!!")
end
if length(parents(child_node)) != 1
return nothing
end
nf = NodeFusion((child_node, node, parent_node))
push!(graph.possibleOperations.nodeFusions, nf)
push!(child_node.nodeFusions, nf)
node.nodeFusion = nf
push!(parent_node.nodeFusions, nf)
return nothing
end
"""
find_fusions!(graph::DAG, node::ComputeTaskNode)
Find node fusions involving the given compute node. The function pushes the found [`NodeFusion`](@ref)s (if any) everywhere they need to be and returns nothing.
"""
function find_fusions!(graph::DAG, node::ComputeTaskNode)
# just find fusions in neighbouring DataTaskNodes
for child in children(node)
find_fusions!(graph, child)
end
for parent in parents(node)
find_fusions!(graph, parent)
end
return nothing
end
"""
find_reductions!(graph::DAG, node::Node)
@ -121,7 +66,7 @@ end
"""
clean_node!(graph::DAG, node::Node)
Sort this node's parent and child sets, then find fusions, reductions and splits involving it. Needs to be called after the node was changed in some way.
Sort this node's parent and child sets, then find reductions and splits involving it. Needs to be called after the node was changed in some way.
"""
function clean_node!(
graph::DAG,
@ -129,7 +74,6 @@ function clean_node!(
) where {TaskType <: AbstractTask}
sort_node!(node)
find_fusions!(graph, node)
find_reductions!(graph, node)
find_splits!(graph, node)

View File

@ -2,26 +2,6 @@
using Base.Threads
"""
insert_operation!(nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock})
Insert the given node fusion into its input nodes' operation caches. For the compute nodes, locking via the given `locks` is employed to have safe multi-threading. For a large set of nodes, contention on the locks should be very small.
"""
function insert_operation!(nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock})
n1 = nf.input[1]
n2 = nf.input[2]
n3 = nf.input[3]
lock(locks[n1]) do
return push!(nf.input[1].nodeFusions, nf)
end
n2.nodeFusion = nf
lock(locks[n3]) do
return push!(nf.input[3].nodeFusions, nf)
end
return nothing
end
"""
insert_operation!(nf::NodeReduction)
@ -72,41 +52,6 @@ function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Ve
return nothing
end
"""
nf_insertion!(graph::DAG, operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}})
Insert the node fusions into the graph and the nodes' caches. Employs multithreading for speedup.
"""
function nf_insertion!(graph::DAG, operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}})
total_len = 0
for vec in nodeFusions
total_len += length(vec)
end
sizehint!(operations.nodeFusions, total_len)
t = @task for vec in nodeFusions
union!(operations.nodeFusions, Set(vec))
end
schedule(t)
locks = Dict{ComputeTaskNode, SpinLock}()
for n in graph.nodes
if (typeof(n) <: ComputeTaskNode)
locks[n] = SpinLock()
end
end
@threads for vec in nodeFusions
for op in vec
insert_operation!(op, locks)
end
end
wait(t)
return nothing
end
"""
ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplits}})
@ -143,7 +88,6 @@ Generate all possible operations on the graph. Used initially when the graph is
Safely inserts all the found operations into the graph and its nodes.
"""
function generate_operations(graph::DAG)
generatedFusions = [Vector{NodeFusion}() for _ in 1:nthreads()]
generatedReductions = [Vector{NodeReduction}() for _ in 1:nthreads()]
generatedSplits = [Vector{NodeSplit}() for _ in 1:nthreads()]
@ -199,31 +143,6 @@ function generate_operations(graph::DAG)
# remove duplicates
nr_task = @spawn nr_insertion!(graph.possibleOperations, generatedReductions)
# --- find possible node fusions ---
@threads for node in nodeArray
if (typeof(node) <: DataTaskNode)
if length(parents(node)) != 1
# data node can only have a single parent
continue
end
parent_node = first(parents(node))
if length(children(node)) != 1
# this node is an entry node or has multiple children which should not be possible
continue
end
child_node = first(children(node))
if (length(parents(child_node)) != 1)
continue
end
push!(generatedFusions[threadid()], NodeFusion((child_node, node, parent_node)))
end
end
# launch thread for node fusion insertion
nf_task = @spawn nf_insertion!(graph, graph.possibleOperations, generatedFusions)
# find possible node splits
@threads for node in nodeArray
if (can_split(node))
@ -237,7 +156,6 @@ function generate_operations(graph::DAG)
empty!(graph.dirtyNodes)
wait(nr_task)
wait(nf_task)
wait(ns_task)
return nothing

View File

@ -2,8 +2,7 @@ import Base.iterate
const _POSSIBLE_OPERATIONS_FIELDS = fieldnames(PossibleOperations)
_POIteratorStateType =
NamedTuple{(:result, :state), Tuple{Union{NodeFusion, NodeReduction, NodeSplit}, Tuple{Symbol, Int64}}}
_POIteratorStateType = NamedTuple{(:result, :state), Tuple{Union{NodeReduction, NodeSplit}, Tuple{Symbol, Int64}}}
@inline function iterate(possibleOperations::PossibleOperations)::Union{Nothing, _POIteratorStateType}
for fieldname in _POSSIBLE_OPERATIONS_FIELDS

View File

@ -4,11 +4,6 @@
Print a string representation of the set of possible operations to io.
"""
function show(io::IO, ops::PossibleOperations)
print(io, length(ops.nodeFusions))
println(io, " Node Fusions: ")
for nf in ops.nodeFusions
println(io, " - ", nf)
end
print(io, length(ops.nodeReductions))
println(io, " Node Reductions: ")
for nr in ops.nodeReductions
@ -42,17 +37,3 @@ function show(io::IO, op::NodeSplit)
print(io, "NS: ")
return print(io, task(op.input))
end
"""
show(io::IO, op::NodeFusion)
Print a string representation of the node fusion to io.
"""
function show(io::IO, op::NodeFusion)
print(io, "NF: ")
print(io, task(op.input[1]))
print(io, "->")
print(io, task(op.input[2]))
print(io, "->")
return print(io, task(op.input[3]))
end

View File

@ -20,45 +20,6 @@ See also: [`revert_operation!`](@ref).
"""
abstract type AppliedOperation end
"""
NodeFusion <: Operation
The NodeFusion operation. Represents the fusing of a chain of compute node -> data node -> compute node.
After the node fusion is applied, the graph has 2 fewer nodes and edges, and a new [`FusedComputeTask`](@ref) with the two input compute nodes as parts.
# Requirements for successful application
A chain of (n1, n2, n3) can be fused if:
- All nodes are in the graph.
- (n1, n2) is an edge in the graph.
- (n2, n3) is an edge in the graph.
- n2 has exactly one parent (n3) and exactly one child (n1).
- n1 has exactly one parent (n2).
[`is_valid_node_fusion_input`](@ref) can be used to `@assert` these requirements.
See also: [`can_fuse`](@ref)
"""
struct NodeFusion{TaskType1 <: AbstractComputeTask, TaskType2 <: AbstractDataTask, TaskType3 <: AbstractComputeTask} <:
Operation
input::Tuple{ComputeTaskNode{TaskType1}, DataTaskNode{TaskType2}, ComputeTaskNode{TaskType3}}
end
"""
AppliedNodeFusion <: AppliedOperation
The applied version of the [`NodeFusion`](@ref).
"""
struct AppliedNodeFusion{
TaskType1 <: AbstractComputeTask,
TaskType2 <: AbstractDataTask,
TaskType3 <: AbstractComputeTask,
} <: AppliedOperation
operation::NodeFusion{TaskType1, TaskType2, TaskType3}
diff::Diff
end
"""
NodeReduction <: Operation

View File

@ -4,7 +4,7 @@
Return whether `operations` is empty, i.e. all of its fields are empty.
"""
function isempty(operations::PossibleOperations)
return isempty(operations.nodeFusions) && isempty(operations.nodeReductions) && isempty(operations.nodeSplits)
return isempty(operations.nodeReductions) && isempty(operations.nodeSplits)
end
"""
@ -13,21 +13,7 @@ end
Return a named tuple with the number of each of the operation types as a named tuple. The fields are named the same as the [`PossibleOperations`](@ref)'.
"""
function length(operations::PossibleOperations)
return (
nodeFusions = length(operations.nodeFusions),
nodeReductions = length(operations.nodeReductions),
nodeSplits = length(operations.nodeSplits),
)
end
"""
delete!(operations::PossibleOperations, op::NodeFusion)
Delete the given node fusion from the possible operations.
"""
function delete!(operations::PossibleOperations, op::NodeFusion)
delete!(operations.nodeFusions, op)
return operations
return (nodeReductions = length(operations.nodeReductions), nodeSplits = length(operations.nodeSplits))
end
"""
@ -50,24 +36,6 @@ function delete!(operations::PossibleOperations, op::NodeSplit)
return operations
end
"""
can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
Return whether the given nodes can be fused. See [`NodeFusion`](@ref) for the requirements.
"""
function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
if !is_child(n1, n2) || !is_child(n2, n3)
# the checks are redundant but maybe a good sanity check
return false
end
if length(parents(n2)) != 1 || length(children(n2)) != 1 || length(parents(n1)) != 1
return false
end
return true
end
"""
can_reduce(n1::Node, n2::Node)
@ -136,23 +104,6 @@ function ==(op1::Operation, op2::Operation)
return false
end
"""
==(op1::NodeFusion, op2::NodeFusion)
Equality comparison between two node fusions. Two node fusions are considered equal if they have the same inputs.
"""
function ==(
op1::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2},
op2::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2},
) where {
ComputeTaskType1 <: AbstractComputeTask,
DataTaskType <: AbstractDataTask,
ComputeTaskType2 <: AbstractComputeTask,
}
# there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same
return op1.input[2] == op2.input[2]
end
"""
==(op1::NodeReduction, op2::NodeReduction)

View File

@ -2,43 +2,6 @@
# should be called with @assert
# the functions throw their own errors though, to still have helpful error messages
"""
is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
Assert for a gven node fusion input whether the nodes can be fused. For the requirements of a node fusion see [`NodeFusion`](@ref).
Intended for use with `@assert` or `@test`.
"""
function is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
if !(n1 in graph) || !(n2 in graph) || !(n3 in graph)
throw(AssertionError("[Node Fusion] The given nodes are not part of the given graph"))
end
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
throw(
AssertionError(
"[Node Fusion] The given nodes are not connected by edges which is required for node fusion",
),
)
end
if length(n2.parents) > 1
throw(AssertionError("[Node Fusion] The given data node has more than one parent"))
end
if length(n2.children) > 1
throw(AssertionError("[Node Fusion] The given data node has more than one child"))
end
if length(n1.parents) > 1
throw(AssertionError("[Node Fusion] The given n1 has more than one parent"))
end
@assert is_valid(graph, n1)
@assert is_valid(graph, n2)
@assert is_valid(graph, n3)
return true
end
"""
is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node})
@ -131,16 +94,3 @@ function is_valid(graph::DAG, ns::NodeSplit)
#@assert ns in graph.possibleOperations.nodeSplits "NodeSplit is not part of the graph's possible operations!"
return true
end
"""
is_valid(graph::DAG, nr::NodeFusion)
Assert for a given [`NodeFusion`](@ref) whether it is a valid operation in the graph.
Intended for use with `@assert` or `@test`.
"""
function is_valid(graph::DAG, nf::NodeFusion)
@assert is_valid_node_fusion_input(graph, nf.input[1], nf.input[2], nf.input[3])
#@assert nf in graph.possibleOperations.nodeFusions "NodeFusion is not part of the graph's possible operations!"
return true
end

View File

@ -1,36 +0,0 @@
"""
FusionOptimizer
An optimizer that simply applies an available [`NodeFusion`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeFusion`](@ref)s in the graph.
See also: [`SplitOptimizer`](@ref), [`ReductionOptimizer`](@ref)
"""
struct FusionOptimizer <: AbstractOptimizer end
function optimize_step!(optimizer::FusionOptimizer, graph::DAG)
# generate all options
operations = get_operations(graph)
if fixpoint_reached(optimizer, graph)
return false
end
push_operation!(graph, first(operations.nodeFusions))
return true
end
function fixpoint_reached(optimizer::FusionOptimizer, graph::DAG)
operations = get_operations(graph)
return isempty(operations.nodeFusions)
end
function optimize_to_fixpoint!(optimizer::FusionOptimizer, graph::DAG)
while !fixpoint_reached(optimizer, graph)
optimize_step!(optimizer, graph)
end
return nothing
end
function String(::FusionOptimizer)
return "fusion_optimizer"
end

View File

@ -26,16 +26,12 @@ function optimize_step!(optimizer::RandomWalkOptimizer, graph::DAG)
if rand(r, Bool)
# push
# choose one of fuse/split/reduce
# TODO refactor fusions so they actually work
option = rand(r, 2:3)
if option == 1 && !isempty(operations.nodeFusions)
push_operation!(graph, rand(r, collect(operations.nodeFusions)))
return true
elseif option == 2 && !isempty(operations.nodeReductions)
# choose one of split/reduce
option = rand(r, 1:2)
if option == 1 && !isempty(operations.nodeReductions)
push_operation!(graph, rand(r, collect(operations.nodeReductions)))
return true
elseif option == 3 && !isempty(operations.nodeSplits)
elseif option == 2 && !isempty(operations.nodeSplits)
push_operation!(graph, rand(r, collect(operations.nodeSplits)))
return true
end

View File

@ -3,7 +3,7 @@
An optimizer that simply applies an available [`NodeReduction`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeReduction`](@ref)s in the graph.
See also: [`FusionOptimizer`](@ref), [`SplitOptimizer`](@ref)
See also: [`SplitOptimizer`](@ref)
"""
struct ReductionOptimizer <: AbstractOptimizer end

View File

@ -3,7 +3,7 @@
An optimizer that simply applies an available [`NodeSplit`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeSplit`](@ref)s in the graph.
See also: [`FusionOptimizer`](@ref), [`ReductionOptimizer`](@ref)
See also: [`ReductionOptimizer`](@ref)
"""
struct SplitOptimizer <: AbstractOptimizer end

View File

@ -4,7 +4,7 @@
A greedy implementation of a scheduler, creating a topological ordering of nodes and naively balancing them onto the different devices.
"""
struct GreedyScheduler end
struct GreedyScheduler <: AbstractScheduler end
function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine)
nodeQueue = PriorityQueue{Node, Int}()

View File

@ -1,10 +1,10 @@
"""
Scheduler
AbstractScheduler
Abstract base type for scheduler implementations. The scheduler is used to assign each node to a device and create a topological ordering of tasks.
"""
abstract type Scheduler end
abstract type AbstractScheduler end
"""
schedule_dag(::Scheduler, ::DAG, ::Machine)

View File

@ -5,10 +5,11 @@ using StaticArrays
Type representing a function call with `N` parameters. Contains the function to call, argument symbols, the return symbol and the device to execute on.
"""
struct FunctionCall{VectorType <: AbstractVector, M}
struct FunctionCall{VectorType <: AbstractVector, N}
func::Function
arguments::VectorType
additional_arguments::SVector{M, Any} # additional arguments (as values) for the function call, will be prepended to the other arguments
# TODO: this should be a tuple
value_arguments::SVector{N, Any} # value arguments for the function call, will be prepended to the other arguments
arguments::VectorType # symbols of the inputs to the function call
return_symbol::Symbol
device::AbstractDevice
end

View File

@ -1,38 +1,20 @@
using StaticArrays
"""
compute(t::FusedComputeTask, data)
Compute a [`FusedComputeTask`](@ref). This simply asserts false and should not be called. Fused Compute Tasks generate their expressions directly through the other tasks instead.
"""
function compute(t::FusedComputeTask, data...)
inter = compute(t.first_task)
return compute(t.second_task, inter, data2...)
end
"""
get_function_call(n::Node)
get_function_call(t::AbstractTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol)
For a node or a task together with necessary information, return a vector of [`FunctionCall`](@ref)s for the computation of the node or task.
For ordinary compute or data tasks the vector will contain exactly one element, for a [`FusedComputeTask`](@ref) there can be any number of tasks greater 1.
For ordinary compute or data tasks the vector will contain exactly one element.
"""
function get_function_call(t::FusedComputeTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol)
# sort out the symbols to the correct tasks
return [
get_function_call(t.first_task, device, t.t1_inputs, t.t1_output)...,
get_function_call(t.second_task, device, [t.t2_inputs..., t.t1_output], outSymbol)...,
]
end
function get_function_call(
t::CompTask,
device::AbstractDevice,
inSymbols::AbstractVector,
outSymbol::Symbol,
) where {CompTask <: AbstractComputeTask}
return [FunctionCall(compute, inSymbols, SVector{1, Any}(t), outSymbol, device)]
return [FunctionCall(compute, SVector{1, Any}(t), inSymbols, outSymbol, device)]
end
function get_function_call(node::ComputeTaskNode)
@ -64,8 +46,8 @@ function get_function_call(node::DataTaskNode)
return [
FunctionCall(
unpack_identity,
SVector{1, Symbol}(Symbol(to_var_name(first(children(node)).id))),
SVector{0, Any}(),
SVector{1, Symbol}(Symbol(to_var_name(first(children(node)).id))),
Symbol(to_var_name(node.id)),
first(children(node)).device,
),
@ -77,8 +59,8 @@ function get_init_function_call(node::DataTaskNode, device::AbstractDevice)
return FunctionCall(
unpack_identity,
SVector{1, Symbol}(Symbol("$(to_var_name(node.id))_in")),
SVector{0, Any}(),
SVector{1, Symbol}(Symbol("$(to_var_name(node.id))_in")),
Symbol(to_var_name(node.id)),
device,
)

View File

@ -11,22 +11,3 @@ copy(t::AbstractDataTask) = error("Need to implement copying for your data tasks
Return a copy of the given compute task.
"""
copy(t::AbstractComputeTask) = typeof(t)()
"""
copy(t::FusedComputeTask)
Return a copy of th egiven [`FusedComputeTask`](@ref).
"""
function copy(t::FusedComputeTask)
return FusedComputeTask(copy(t.first_task), copy(t.second_task), copy(t.t1_inputs), t.t1_output, copy(t.t2_inputs))
end
function FusedComputeTask(
T1::Type{<:AbstractComputeTask},
T2::Type{<:AbstractComputeTask},
t1_inputs::Vector{String},
t1_output::String,
t2_inputs::Vector{String},
)
return FusedComputeTask(T1(), T2(), t1_inputs, t1_output, t2_inputs)
end

View File

@ -3,28 +3,21 @@
Fallback implementation of the compute function of a compute task, throwing an error.
"""
function compute(t::AbstractTask, data...)
return error("Need to implement compute()")
end
function compute end
"""
compute_effort(t::AbstractTask)
Fallback implementation of the compute effort of a task, throwing an error.
"""
function compute_effort(t::AbstractTask)::Float64
# default implementation using compute
return error("Need to implement compute_effort()")
end
function compute_effort end
"""
data(t::AbstractTask)
Fallback implementation of the data of a task, throwing an error.
"""
function data(t::AbstractTask)::Float64
return error("Need to implement data()")
end
function data end
"""
compute_effort(t::AbstractDataTask)
@ -54,34 +47,9 @@ Return the number of children of a data task (always 1).
"""
children(::DataTask) = 1
"""
children(t::FusedComputeTask)
Return the number of children of a FusedComputeTask.
"""
function children(t::FusedComputeTask)
return length(union(Set(t.t1_inputs), Set(t.t2_inputs)))
end
"""
data(t::AbstractComputeTask)
Return the data of a compute task, always zero, regardless of the specific task.
"""
data(t::AbstractComputeTask)::Float64 = 0.0
"""
compute_effort(t::FusedComputeTask)
Return the compute effort of a fused compute task.
"""
function compute_effort(t::FusedComputeTask)::Float64
return compute_effort(t.first_task) + compute_effort(t.second_task)
end
"""
get_types(::FusedComputeTask{T1, T2})
Return a tuple of a the fused compute task's components' types.
"""
get_types(t::FusedComputeTask) = (typeof(t.first_task), typeof(t.second_task))

View File

@ -27,21 +27,3 @@ Task representing a specific data transfer.
struct DataTask <: AbstractDataTask
data::Float64
end
"""
FusedComputeTask{T1 <: AbstractComputeTask, T2 <: AbstractComputeTask} <: AbstractComputeTask
A fused compute task made up of the computation of first `T1` and then `T2`.
Also see: [`get_types`](@ref).
"""
struct FusedComputeTask <: AbstractComputeTask
first_task::AbstractComputeTask
second_task::AbstractComputeTask
# the names of the inputs for T1
t1_inputs::Vector{Symbol}
# output name of T1
t1_output::Symbol
# t2_inputs doesn't include the output of t1, that's implicit
t2_inputs::Vector{Symbol}
end

View File

@ -1,3 +1,6 @@
using Roots
using ForwardDiff
"""
noop()
@ -71,9 +74,6 @@ function mem(graph::DAG)
size += sizeof(graph.operationsToApply)
size += sizeof(graph.possibleOperations)
for op in graph.possibleOperations.nodeFusions
size += mem(op)
end
for op in graph.possibleOperations.nodeReductions
size += mem(op)
end
@ -249,7 +249,6 @@ end
first_derivative(func) = x -> ForwardDiff.derivative(func, float(x))
function generate_physical_massive_moms(rng, ss, masses; x0 = 0.1)
n = length(masses)
massless_moms = generate_physical_massless_moms(rng, ss, n)

View File

@ -1,10 +0,0 @@
[deps]
AccurateArithmetic = "22286c92-06ac-501d-9306-4abd417d9753"
QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
QEDprocesses = "46de9c38-1bb3-4547-a1ec-da24d767fdad"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

View File

@ -3,48 +3,15 @@ using Random
RNG = Random.MersenneTwister(321)
function test_known_graph(name::String, n, fusion_test = true)
function test_known_graph(name::String, n)
@testset "Test $name Graph ($n)" begin
graph = parse_dag(joinpath(@__DIR__, "..", "input", "$name.txt"), ABCModel())
props = get_properties(graph)
if (fusion_test)
test_node_fusion(graph)
end
test_random_walk(RNG, graph, n)
end
end
function test_node_fusion(g::DAG)
@testset "Test Node Fusion" begin
props = get_properties(g)
options = get_operations(g)
nodes_number = length(g.nodes)
data = props.data
compute_effort = props.computeEffort
while !isempty(options.nodeFusions)
fusion = first(options.nodeFusions)
@test typeof(fusion) <: NodeFusion
push_operation!(g, fusion)
props = get_properties(g)
@test props.data < data
@test props.computeEffort == compute_effort
nodes_number = length(g.nodes)
data = props.data
compute_effort = props.computeEffort
options = get_operations(g)
end
end
end
function test_random_walk(RNG, g::DAG, n::Int64)
@testset "Test Random Walk ($n)" begin
# the purpose here is to do "random" operations and reverse them again and validate that the graph stays the same and doesn't diverge
@ -60,13 +27,11 @@ function test_random_walk(RNG, g::DAG, n::Int64)
# push
opt = get_operations(g)
# choose one of fuse/split/reduce
option = rand(RNG, 1:3)
if option == 1 && !isempty(opt.nodeFusions)
push_operation!(g, rand(RNG, collect(opt.nodeFusions)))
elseif option == 2 && !isempty(opt.nodeReductions)
# choose one of split/reduce
option = rand(RNG, 1:2)
if option == 1 && !isempty(opt.nodeReductions)
push_operation!(g, rand(RNG, collect(opt.nodeReductions)))
elseif option == 3 && !isempty(opt.nodeSplits)
elseif option == 2 && !isempty(opt.nodeSplits)
push_operation!(g, rand(RNG, collect(opt.nodeSplits)))
else
i = i - 1
@ -91,4 +56,4 @@ end
test_known_graph("AB->AB", 10000)
test_known_graph("AB->ABBB", 10000)
test_known_graph("AB->ABBBBB", 1000, false)
test_known_graph("AB->ABBBBB", 1000)

View File

@ -61,17 +61,14 @@ insert_edge!(graph, CD, C1C, track = false)
opt = get_operations(graph)
@test length(opt) == (nodeFusions = 6, nodeReductions = 1, nodeSplits = 1)
#println("Initial State:\n", opt)
@test length(opt) == (nodeReductions = 1, nodeSplits = 1)
nr = first(opt.nodeReductions)
@test Set(nr.input) == Set([B1C_1, B1C_2])
push_operation!(graph, nr)
opt = get_operations(graph)
@test length(opt) == (nodeFusions = 4, nodeReductions = 1, nodeSplits = 1)
#println("After 1 Node Reduction:\n", opt)
@test length(opt) == (nodeReductions = 1, nodeSplits = 1)
nr = first(opt.nodeReductions)
@test Set(nr.input) == Set([B1D_1, B1D_2])
@ -80,19 +77,16 @@ opt = get_operations(graph)
@test is_valid(graph)
@test length(opt) == (nodeFusions = 4, nodeReductions = 0, nodeSplits = 1)
#println("After 2 Node Reductions:\n", opt)
@test length(opt) == (nodeReductions = 0, nodeSplits = 1)
pop_operation!(graph)
opt = get_operations(graph)
@test length(opt) == (nodeFusions = 4, nodeReductions = 1, nodeSplits = 1)
#println("After reverting the second Node Reduction:\n", opt)
@test length(opt) == (nodeReductions = 1, nodeSplits = 1)
reset_graph!(graph)
opt = get_operations(graph)
@test length(opt) == (nodeFusions = 6, nodeReductions = 1, nodeSplits = 1)
#println("After reverting to the initial state:\n", opt)
@test length(opt) == (nodeReductions = 1, nodeSplits = 1)
@test is_valid(graph)

View File

@ -1,5 +1,5 @@
using SafeTestsets
#=
@safetestset "Utility Unit Tests " begin
include("unit_tests_utility.jl")
end
@ -30,6 +30,7 @@ end
@safetestset "Graph Unit Tests " begin
include("unit_tests_graph.jl")
end
=#
@safetestset "Execution Unit Tests " begin
include("unit_tests_execution.jl")
end

View File

@ -1,5 +1,5 @@
using MetagraphOptimization
using QEDbase
using QEDcore
import MetagraphOptimization.interaction_result

View File

@ -1,16 +1,5 @@
using MetagraphOptimization
function test_op_specific(estimator, graph, nf::NodeFusion)
estimate = operation_effect(estimator, graph, nf)
data_reduce = data(nf.input[2].task)
@test isapprox(estimate.data, -data_reduce)
@test isapprox(estimate.computeEffort, 0; atol = eps(Float64))
@test isapprox(estimate.computeIntensity, 0; atol = eps(Float64))
return nothing
end
function test_op_specific(estimator, graph, nr::NodeReduction)
estimate = operation_effect(estimator, graph, nr)
@ -74,13 +63,9 @@ end
@testset "Operation Cost" begin
ops = get_operations(graph)
nfs = copy(ops.nodeFusions)
nrs = copy(ops.nodeReductions)
nss = copy(ops.nodeSplits)
for nf in nfs
test_op(estimator, graph, nf)
end
for nr in nrs
test_op(estimator, graph, nr)
end

View File

@ -1,5 +1,5 @@
using MetagraphOptimization
using QEDbase
using QEDcore
using AccurateArithmetic
using Random
using UUIDs
@ -63,8 +63,14 @@ machine = Machine(
)
process_2_2 = ABCProcessDescription(
Dict{Type, Int64}(ParticleA => 1, ParticleB => 1),
Dict{Type, Int64}(ParticleA => 1, ParticleB => 1),
Dict{Type, Int64}(
ParticleStateful{Incoming, ParticleA, SFourMomentum} => 1,
ParticleStateful{Incoming, ParticleB, SFourMomentum} => 1,
),
Dict{Type, Int64}(
ParticleStateful{Outgoing, ParticleA, SFourMomentum} => 1,
ParticleStateful{Outgoing, ParticleB, SFourMomentum} => 1,
),
)
particles_2_2 = ABCProcessInput(
@ -106,8 +112,14 @@ end
end
process_2_4 = ABCProcessDescription(
Dict{Type, Int64}(ParticleA => 1, ParticleB => 1),
Dict{Type, Int64}(ParticleA => 1, ParticleB => 3),
Dict{Type, Int64}(
ParticleStateful{Incoming, ParticleA, SFourMomentum} => 1,
ParticleStateful{Incoming, ParticleB, SFourMomentum} => 1,
),
Dict{Type, Int64}(
ParticleStateful{Outgoing, ParticleA, SFourMomentum} => 1,
ParticleStateful{Outgoing, ParticleB, SFourMomentum} => 3,
),
)
particles_2_4 = gen_process_input(process_2_4)
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"), ABCModel())
@ -136,105 +148,6 @@ TODO: fix precision(?) issues
end
=#
@testset "AB->AB large sum fusion" begin
for _ in 1:20
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
# push a fusion with the sum node
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, ComputeTaskABC_Sum)
push_operation!(graph, fusion)
break
end
end
# push two more fusions with the fused node
for _ in 1:15
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, FusedComputeTask)
push_operation!(graph, fusion)
break
end
end
end
# try execute
@test is_valid(graph)
expected_result = ground_truth_graph_result(particles_2_2)
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
end
end
@testset "AB->AB large sum fusion" begin
for _ in 1:20
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
# push a fusion with the sum node
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, ComputeTaskABC_Sum)
push_operation!(graph, fusion)
break
end
end
# push two more fusions with the fused node
for _ in 1:15
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, FusedComputeTask)
push_operation!(graph, fusion)
break
end
end
end
# try execute
@test is_valid(graph)
expected_result = ground_truth_graph_result(particles_2_2)
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
end
end
@testset "AB->AB fusion edge case" begin
for _ in 1:20
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
# push two fusions with ComputeTaskABC_V
for _ in 1:2
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[1].task, ComputeTaskABC_V)
push_operation!(graph, fusion)
break
end
end
end
# push fusions until the end
cont = true
while cont
cont = false
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[1].task, FusedComputeTask)
push_operation!(graph, fusion)
cont = true
break
end
end
end
# try execute
@test is_valid(graph)
expected_result = ground_truth_graph_result(particles_2_2)
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
end
end
@testset "$(process) after random walk" for process in ["ke->ke", "ke->kke", "ke->kkke"]
process = parse_process("ke->kkke", QEDModel())
inputs = [gen_process_input(process) for _ in 1:100]

View File

@ -13,7 +13,7 @@ graph = MetagraphOptimization.DAG()
@test length(graph.operationsToApply) == 0
@test length(graph.dirtyNodes) == 0
@test length(graph.diff) == (addedNodes = 0, removedNodes = 0, addedEdges = 0, removedEdges = 0)
@test length(get_operations(graph)) == (nodeFusions = 0, nodeReductions = 0, nodeSplits = 0)
@test length(get_operations(graph)) == (nodeReductions = 0, nodeSplits = 0)
# s to output (exit node)
d_exit = insert_node!(graph, make_node(DataTask(10)), track = false)
@ -133,13 +133,10 @@ insert_edge!(graph, s0, d_exit, track = false)
@test length(siblings(s0)) == 1
operations = get_operations(graph)
@test length(operations) == (nodeFusions = 10, nodeReductions = 0, nodeSplits = 0)
@test length(operations) == (nodeReductions = 0, nodeSplits = 0)
@test length(graph.dirtyNodes) == 0
@test sum(length(operations)) == 10
@test operations == get_operations(graph)
nf = first(operations.nodeFusions)
properties = get_properties(graph)
@test properties.computeEffort == 28
@ -148,54 +145,19 @@ properties = get_properties(graph)
@test properties.noNodes == 26
@test properties.noEdges == 25
push_operation!(graph, nf)
# **does not immediately apply the operation**
@test length(graph.nodes) == 26
@test length(graph.appliedOperations) == 0
@test length(graph.operationsToApply) == 1
@test first(graph.operationsToApply) == nf
@test length(graph.dirtyNodes) == 0
@test length(graph.diff) == (addedNodes = 0, removedNodes = 0, addedEdges = 0, removedEdges = 0)
# this applies pending operations
properties = get_properties(graph)
@test length(graph.nodes) == 24
@test length(graph.appliedOperations) == 1
@test length(graph.operationsToApply) == 0
@test length(graph.dirtyNodes) != 0
@test properties.noNodes == 24
@test properties.noEdges == 23
@test properties.computeEffort == 28
@test properties.data < 62
@test properties.computeIntensity > 28 / 62
operations = get_operations(graph)
@test length(graph.dirtyNodes) == 0
@test length(operations) == (nodeFusions = 9, nodeReductions = 0, nodeSplits = 0)
@test !isempty(operations)
possibleNF = 9
while !isempty(operations.nodeFusions)
push_operation!(graph, first(operations.nodeFusions))
global operations = get_operations(graph)
global possibleNF = possibleNF - 1
@test length(operations) == (nodeFusions = possibleNF, nodeReductions = 0, nodeSplits = 0)
end
@test length(operations) == (nodeReductions = 0, nodeSplits = 0)
@test isempty(operations)
@test length(operations) == (nodeFusions = 0, nodeReductions = 0, nodeSplits = 0)
@test length(graph.dirtyNodes) == 0
@test length(graph.nodes) == 6
@test length(graph.appliedOperations) == 10
@test length(graph.nodes) == 26
@test length(graph.appliedOperations) == 0
@test length(graph.operationsToApply) == 0
reset_graph!(graph)
@test length(graph.dirtyNodes) == 26
@test length(graph.dirtyNodes) == 0
@test length(graph.nodes) == 26
@test length(graph.appliedOperations) == 0
@test length(graph.operationsToApply) == 0
@ -208,6 +170,6 @@ properties = get_properties(graph)
@test properties.computeIntensity 28 / 62
operations = get_operations(graph)
@test length(operations) == (nodeFusions = 10, nodeReductions = 0, nodeSplits = 0)
@test length(operations) == (nodeReductions = 0, nodeSplits = 0)
@test is_valid(graph)

View File

@ -6,8 +6,7 @@ RNG = Random.MersenneTwister(0)
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"), ABCModel())
# create the optimizers
FIXPOINT_OPTIMIZERS =
[GreedyOptimizer(GlobalMetricEstimator()), ReductionOptimizer(), SplitOptimizer(), FusionOptimizer()]
FIXPOINT_OPTIMIZERS = [GreedyOptimizer(GlobalMetricEstimator()), ReductionOptimizer(), SplitOptimizer()]
NO_FIXPOINT_OPTIMIZERS = [RandomWalkOptimizer(RNG)]
@testset "Optimizer $optimizer" for optimizer in vcat(NO_FIXPOINT_OPTIMIZERS, FIXPOINT_OPTIMIZERS)

View File

@ -1,7 +1,8 @@
using MetagraphOptimization
using QEDcore
import MetagraphOptimization.gen_diagrams
import MetagraphOptimization.isincoming
import MetagraphOptimization.types
@ -9,25 +10,12 @@ model = QEDModel()
compton = ("Compton Scattering", parse_process("ke->ke", model), 2)
compton_3 = ("3-Photon Compton Scattering", parse_process("kkke->ke", QEDModel()), 24)
compton_4 = ("4-Photon Compton Scattering", parse_process("kkkke->ke", QEDModel()), 120)
bhabha = ("Bhabha Scattering", parse_process("ep->ep", model), 2)
moller = ("Møller Scattering", parse_process("ee->ee", model), 2)
pair_production = ("Pair production", parse_process("kk->ep", model), 2)
pair_annihilation = ("Pair annihilation", parse_process("ep->kk", model), 2)
trident = ("Trident", parse_process("ke->epe", model), 8)
@testset "Known Processes" begin
@testset "$name" for (name, process, n) in
[compton, bhabha, moller, pair_production, pair_annihilation, trident, compton_3, compton_4]
@testset "$name" for (name, process, n) in [compton, compton_3, compton_4]
initial_diagram = FeynmanDiagram(process)
n_particles = number_incoming_particles(process) + number_outgoing_particles(process)
n_particles = 0
for type in types(model)
if (isincoming(type))
n_particles += get(process.inParticles, type, 0)
else
n_particles += get(process.outParticles, type, 0)
end
end
@test n_particles == length(initial_diagram.particles)
@test ismissing(initial_diagram.tie[])
@test isempty(initial_diagram.vertices)

View File

@ -1,5 +1,6 @@
using MetagraphOptimization
using QEDbase
using QEDcore
using QEDprocesses
using StatsBase # for countmap
using Random
@ -9,8 +10,6 @@ import MetagraphOptimization.caninteract
import MetagraphOptimization.issame
import MetagraphOptimization.interaction_result
import MetagraphOptimization.propagation_result
import MetagraphOptimization.direction
import MetagraphOptimization.spin_or_pol
import MetagraphOptimization.QED_vertex
def_momentum = SFourMomentum(1.0, 0.0, 0.0, 0.0)
@ -18,32 +17,32 @@ def_momentum = SFourMomentum(1.0, 0.0, 0.0, 0.0)
RNG = Random.MersenneTwister(0)
testparticleTypes = [
PhotonStateful{Incoming, PolX},
PhotonStateful{Outgoing, PolX},
FermionStateful{Incoming, SpinUp},
FermionStateful{Outgoing, SpinUp},
AntiFermionStateful{Incoming, SpinUp},
AntiFermionStateful{Outgoing, SpinUp},
ParticleStateful{Incoming, Photon, SFourMomentum},
ParticleStateful{Outgoing, Photon, SFourMomentum},
ParticleStateful{Incoming, Electron, SFourMomentum},
ParticleStateful{Outgoing, Electron, SFourMomentum},
ParticleStateful{Incoming, Positron, SFourMomentum},
ParticleStateful{Outgoing, Positron, SFourMomentum},
]
testparticleTypesPropagated = [
PhotonStateful{Outgoing, PolX},
PhotonStateful{Incoming, PolX},
FermionStateful{Outgoing, SpinUp},
FermionStateful{Incoming, SpinUp},
AntiFermionStateful{Outgoing, SpinUp},
AntiFermionStateful{Incoming, SpinUp},
ParticleStateful{Outgoing, Photon, SFourMomentum},
ParticleStateful{Incoming, Photon, SFourMomentum},
ParticleStateful{Outgoing, Electron, SFourMomentum},
ParticleStateful{Incoming, Electron, SFourMomentum},
ParticleStateful{Outgoing, Positron, SFourMomentum},
ParticleStateful{Incoming, Positron, SFourMomentum},
]
function compton_groundtruth(input::QEDProcessInput)
function compton_groundtruth(input::PhaseSpacePoint)
# p1k1 -> p2k2
# formula: (ie)^2 (u(p2) slashed(ε1) S(p2 k1) slashed(ε2) u(p1) + u(p2) slashed(ε2) S(p1 + k1) slashed(ε1) u(p1))
p1 = input.inFerms[1]
p2 = input.outFerms[1]
p1 = momentum(psp, Incoming(), 2)
p2 = momentum(psp, Outgoing(), 2)
k1 = input.inPhotons[1]
k2 = input.outPhotons[1]
k1 = momentum(psp, Incoming(), 1)
k2 = momentum(psp, Outgoing(), 1)
u_p1 = base_state(Electron(), Incoming(), p1.momentum, spin_or_pol(p1))
u_p2 = base_state(Electron(), Outgoing(), p2.momentum, spin_or_pol(p2))
@ -57,8 +56,8 @@ function compton_groundtruth(input::QEDProcessInput)
virt2_mom = p1.momentum + k1.momentum
@test isapprox(p2.momentum + k2.momentum, virt2_mom)
s_p2_k1 = propagator(Electron(), virt1_mom)
s_p1_k1 = propagator(Electron(), virt2_mom)
s_p2_k1 = QEDbase.propagator(Electron(), virt1_mom)
s_p1_k1 = QEDbase.propagator(Electron(), virt2_mom)
diagram1 = u_p2 * (eps_1 * QED_vertex()) * s_p2_k1 * (eps_2 * QED_vertex()) * u_p1
diagram2 = u_p2 * (eps_2 * QED_vertex()) * s_p1_k1 * (eps_1 * QED_vertex()) * u_p1
@ -66,7 +65,6 @@ function compton_groundtruth(input::QEDProcessInput)
return diagram1 + diagram2
end
@testset "Interaction Result" begin
import MetagraphOptimization.QED_conserve_momentum
@ -88,8 +86,8 @@ end
@test issame(typeof(resultParticle), interaction_result(p1, p2))
totalMom = zero(SFourMomentum)
for (p, mom) in [(p1, testParticle1.momentum), (p2, testParticle2.momentum), (p3, resultParticle.momentum)]
if (typeof(direction(p)) <: Incoming)
for (p, mom) in [(p1, momentum(testParticle1)), (p2, momentum(testParticle2)), (p3, momentum(resultParticle))]
if (typeof(particle_direction(p)) <: Incoming)
totalMom += mom
else
totalMom -= mom
@ -103,54 +101,31 @@ end
@testset "Propagation Result" begin
for (p, propResult) in zip(testparticleTypes, testparticleTypesPropagated)
@test issame(propagation_result(p), propResult)
@test direction(propagation_result(p)(def_momentum)) != direction(p(def_momentum))
@test particle_direction(propagation_result(p)(def_momentum)) != particle_direction(p(def_momentum))
end
end
@testset "Parse Process" begin
@testset "Order invariance" begin
@test parse_process("ke->ke", QEDModel()) == parse_process("ek->ke", QEDModel())
@test parse_process("ke->ke", QEDModel()) == parse_process("ek->ek", QEDModel())
@test parse_process("ke->ke", QEDModel()) == parse_process("ke->ek", QEDModel())
@test parse_process("kkke->eep", QEDModel()) == parse_process("kkek->epe", QEDModel())
end
@testset "Known processes" begin
compton_process = QEDProcessDescription(
Dict{Type, Int}(PhotonStateful{Incoming, PolX} => 1, FermionStateful{Incoming, SpinUp} => 1),
Dict{Type, Int}(PhotonStateful{Outgoing, PolX} => 1, FermionStateful{Outgoing, SpinUp} => 1),
)
proc = parse_process("ke->ke", QEDModel())
@test incoming_particles(proc) == (Photon(), Electron())
@test outgoing_particles(proc) == (Photon(), Electron())
@test parse_process("ke->ke", QEDModel()) == compton_process
proc = parse_process("kp->kp", QEDModel())
@test incoming_particles(proc) == (Photon(), Positron())
@test outgoing_particles(proc) == (Photon(), Positron())
positron_compton_process = QEDProcessDescription(
Dict{Type, Int}(PhotonStateful{Incoming, PolX} => 1, AntiFermionStateful{Incoming, SpinUp} => 1),
Dict{Type, Int}(PhotonStateful{Outgoing, PolX} => 1, AntiFermionStateful{Outgoing, SpinUp} => 1),
)
proc = parse_process("ke->eep", QEDModel())
@test incoming_particles(proc) == (Photon(), Electron())
@test outgoing_particles(proc) == (Electron(), Electron(), Positron())
@test parse_process("kp->kp", QEDModel()) == positron_compton_process
proc = parse_process("kk->pe", QEDModel())
@test incoming_particles(proc) == (Photon(), Photon())
@test outgoing_particles(proc) == (Positron(), Electron())
trident_process = QEDProcessDescription(
Dict{Type, Int}(PhotonStateful{Incoming, PolX} => 1, FermionStateful{Incoming, SpinUp} => 1),
Dict{Type, Int}(FermionStateful{Outgoing, SpinUp} => 2, AntiFermionStateful{Outgoing, SpinUp} => 1),
)
@test parse_process("ke->eep", QEDModel()) == trident_process
pair_production_process = QEDProcessDescription(
Dict{Type, Int}(PhotonStateful{Incoming, PolX} => 2),
Dict{Type, Int}(FermionStateful{Outgoing, SpinUp} => 1, AntiFermionStateful{Outgoing, SpinUp} => 1),
)
@test parse_process("kk->pe", QEDModel()) == pair_production_process
pair_annihilation_process = QEDProcessDescription(
Dict{Type, Int}(FermionStateful{Incoming, SpinUp} => 1, AntiFermionStateful{Incoming, SpinUp} => 1),
Dict{Type, Int}(PhotonStateful{Outgoing, PolX} => 2),
)
@test parse_process("pe->kk", QEDModel()) == pair_annihilation_process
proc = parse_process("pe->kk", QEDModel())
@test incoming_particles(proc) == (Positron(), Electron())
@test outgoing_particles(proc) == (Photon(), Photon())
end
end
@ -161,30 +136,12 @@ end
for i in 1:100
input = gen_process_input(process)
@test length(input.inFerms) == get(process.inParticles, FermionStateful{Incoming, SpinUp}, 0)
@test length(input.inAntiferms) == get(process.inParticles, AntiFermionStateful{Incoming, SpinUp}, 0)
@test length(input.inPhotons) == get(process.inParticles, PhotonStateful{Incoming, PolX}, 0)
@test length(input.outFerms) == get(process.outParticles, FermionStateful{Outgoing, SpinUp}, 0)
@test length(input.outAntiferms) == get(process.outParticles, AntiFermionStateful{Outgoing, SpinUp}, 0)
@test length(input.outPhotons) == get(process.outParticles, PhotonStateful{Outgoing, PolX}, 0)
@test isapprox(
sum([
getfield.(input.inFerms, :momentum)...,
getfield.(input.inAntiferms, :momentum)...,
getfield.(input.inPhotons, :momentum)...,
]),
sum([
getfield.(input.outFerms, :momentum)...,
getfield.(input.outAntiferms, :momentum)...,
getfield.(input.outPhotons, :momentum)...,
]);
atol = sqrt(eps()),
)
@test isapprox(sum(momenta(input, Incoming())), sum(momenta(input, Outgoing())); atol = sqrt(eps()))
end
end
end
#=
@testset "Compton" begin
import MetagraphOptimization.insert_node!
import MetagraphOptimization.insert_edge!
@ -211,97 +168,97 @@ end
graph = DAG()
# s to output (exit node)
d_exit = insert_node!(graph, make_node(DataTask(16)), track = false)
d_exit = insert_node!(graph, make_node(DataTask(16)); track=false)
sum_node = insert_node!(graph, make_node(ComputeTaskQED_Sum(2)), track = false)
sum_node = insert_node!(graph, make_node(ComputeTaskQED_Sum(2)); track=false)
d_s0_sum = insert_node!(graph, make_node(DataTask(16)), track = false)
d_s1_sum = insert_node!(graph, make_node(DataTask(16)), track = false)
d_s0_sum = insert_node!(graph, make_node(DataTask(16)); track=false)
d_s1_sum = insert_node!(graph, make_node(DataTask(16)); track=false)
# final s compute
s0 = insert_node!(graph, make_node(ComputeTaskQED_S2()), track = false)
s1 = insert_node!(graph, make_node(ComputeTaskQED_S2()), track = false)
s0 = insert_node!(graph, make_node(ComputeTaskQED_S2()); track=false)
s1 = insert_node!(graph, make_node(ComputeTaskQED_S2()); track=false)
# data from v0 and v1 to s0
d_v0_s0 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_v1_s0 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_v2_s1 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_v3_s1 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_v0_s0 = insert_node!(graph, make_node(DataTask(96)); track=false)
d_v1_s0 = insert_node!(graph, make_node(DataTask(96)); track=false)
d_v2_s1 = insert_node!(graph, make_node(DataTask(96)); track=false)
d_v3_s1 = insert_node!(graph, make_node(DataTask(96)); track=false)
# v0 and v1 compute
v0 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
v1 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
v2 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
v3 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
v0 = insert_node!(graph, make_node(ComputeTaskQED_V()); track=false)
v1 = insert_node!(graph, make_node(ComputeTaskQED_V()); track=false)
v2 = insert_node!(graph, make_node(ComputeTaskQED_V()); track=false)
v3 = insert_node!(graph, make_node(ComputeTaskQED_V()); track=false)
# data from uPhIn, uPhOut, uElIn, uElOut to v0 and v1
d_uPhIn_v0 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uElIn_v0 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uPhOut_v1 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uElOut_v1 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uPhIn_v0 = insert_node!(graph, make_node(DataTask(96)); track=false)
d_uElIn_v0 = insert_node!(graph, make_node(DataTask(96)); track=false)
d_uPhOut_v1 = insert_node!(graph, make_node(DataTask(96)); track=false)
d_uElOut_v1 = insert_node!(graph, make_node(DataTask(96)); track=false)
# data from uPhIn, uPhOut, uElIn, uElOut to v2 and v3
d_uPhOut_v2 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uElIn_v2 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uPhIn_v3 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uElOut_v3 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uPhOut_v2 = insert_node!(graph, make_node(DataTask(96)); track=false)
d_uElIn_v2 = insert_node!(graph, make_node(DataTask(96)); track=false)
d_uPhIn_v3 = insert_node!(graph, make_node(DataTask(96)); track=false)
d_uElOut_v3 = insert_node!(graph, make_node(DataTask(96)); track=false)
# uPhIn, uPhOut, uElIn and uElOut computes
uPhIn = insert_node!(graph, make_node(ComputeTaskQED_U()), track = false)
uPhOut = insert_node!(graph, make_node(ComputeTaskQED_U()), track = false)
uElIn = insert_node!(graph, make_node(ComputeTaskQED_U()), track = false)
uElOut = insert_node!(graph, make_node(ComputeTaskQED_U()), track = false)
uPhIn = insert_node!(graph, make_node(ComputeTaskQED_U()); track=false)
uPhOut = insert_node!(graph, make_node(ComputeTaskQED_U()); track=false)
uElIn = insert_node!(graph, make_node(ComputeTaskQED_U()); track=false)
uElOut = insert_node!(graph, make_node(ComputeTaskQED_U()); track=false)
# data into U
d_uPhIn = insert_node!(graph, make_node(DataTask(16), "ki1"), track = false)
d_uPhOut = insert_node!(graph, make_node(DataTask(16), "ko1"), track = false)
d_uElIn = insert_node!(graph, make_node(DataTask(16), "ei1"), track = false)
d_uElOut = insert_node!(graph, make_node(DataTask(16), "eo1"), track = false)
d_uPhIn = insert_node!(graph, make_node(DataTask(16), "ki1"); track=false)
d_uPhOut = insert_node!(graph, make_node(DataTask(16), "ko1"); track=false)
d_uElIn = insert_node!(graph, make_node(DataTask(16), "ei1"); track=false)
d_uElOut = insert_node!(graph, make_node(DataTask(16), "eo1"); track=false)
# now for all the edges
insert_edge!(graph, d_uPhIn, uPhIn, track = false)
insert_edge!(graph, d_uPhOut, uPhOut, track = false)
insert_edge!(graph, d_uElIn, uElIn, track = false)
insert_edge!(graph, d_uElOut, uElOut, track = false)
insert_edge!(graph, d_uPhIn, uPhIn; track=false)
insert_edge!(graph, d_uPhOut, uPhOut; track=false)
insert_edge!(graph, d_uElIn, uElIn; track=false)
insert_edge!(graph, d_uElOut, uElOut; track=false)
insert_edge!(graph, uPhIn, d_uPhIn_v0, track = false)
insert_edge!(graph, uPhOut, d_uPhOut_v1, track = false)
insert_edge!(graph, uElIn, d_uElIn_v0, track = false)
insert_edge!(graph, uElOut, d_uElOut_v1, track = false)
insert_edge!(graph, uPhIn, d_uPhIn_v0; track=false)
insert_edge!(graph, uPhOut, d_uPhOut_v1; track=false)
insert_edge!(graph, uElIn, d_uElIn_v0; track=false)
insert_edge!(graph, uElOut, d_uElOut_v1; track=false)
insert_edge!(graph, uPhIn, d_uPhIn_v3, track = false)
insert_edge!(graph, uPhOut, d_uPhOut_v2, track = false)
insert_edge!(graph, uElIn, d_uElIn_v2, track = false)
insert_edge!(graph, uElOut, d_uElOut_v3, track = false)
insert_edge!(graph, uPhIn, d_uPhIn_v3; track=false)
insert_edge!(graph, uPhOut, d_uPhOut_v2; track=false)
insert_edge!(graph, uElIn, d_uElIn_v2; track=false)
insert_edge!(graph, uElOut, d_uElOut_v3; track=false)
insert_edge!(graph, d_uPhIn_v0, v0, track = false)
insert_edge!(graph, d_uPhOut_v1, v1, track = false)
insert_edge!(graph, d_uElIn_v0, v0, track = false)
insert_edge!(graph, d_uElOut_v1, v1, track = false)
insert_edge!(graph, d_uPhIn_v0, v0; track=false)
insert_edge!(graph, d_uPhOut_v1, v1; track=false)
insert_edge!(graph, d_uElIn_v0, v0; track=false)
insert_edge!(graph, d_uElOut_v1, v1; track=false)
insert_edge!(graph, d_uPhIn_v3, v3, track = false)
insert_edge!(graph, d_uPhOut_v2, v2, track = false)
insert_edge!(graph, d_uElIn_v2, v2, track = false)
insert_edge!(graph, d_uElOut_v3, v3, track = false)
insert_edge!(graph, d_uPhIn_v3, v3; track=false)
insert_edge!(graph, d_uPhOut_v2, v2; track=false)
insert_edge!(graph, d_uElIn_v2, v2; track=false)
insert_edge!(graph, d_uElOut_v3, v3; track=false)
insert_edge!(graph, v0, d_v0_s0, track = false)
insert_edge!(graph, v1, d_v1_s0, track = false)
insert_edge!(graph, v2, d_v2_s1, track = false)
insert_edge!(graph, v3, d_v3_s1, track = false)
insert_edge!(graph, v0, d_v0_s0; track=false)
insert_edge!(graph, v1, d_v1_s0; track=false)
insert_edge!(graph, v2, d_v2_s1; track=false)
insert_edge!(graph, v3, d_v3_s1; track=false)
insert_edge!(graph, d_v0_s0, s0, track = false)
insert_edge!(graph, d_v1_s0, s0, track = false)
insert_edge!(graph, d_v0_s0, s0; track=false)
insert_edge!(graph, d_v1_s0, s0; track=false)
insert_edge!(graph, d_v2_s1, s1, track = false)
insert_edge!(graph, d_v3_s1, s1, track = false)
insert_edge!(graph, d_v2_s1, s1; track=false)
insert_edge!(graph, d_v3_s1, s1; track=false)
insert_edge!(graph, s0, d_s0_sum, track = false)
insert_edge!(graph, s1, d_s1_sum, track = false)
insert_edge!(graph, s0, d_s0_sum; track=false)
insert_edge!(graph, s1, d_s1_sum; track=false)
insert_edge!(graph, d_s0_sum, sum_node, track = false)
insert_edge!(graph, d_s1_sum, sum_node, track = false)
insert_edge!(graph, d_s0_sum, sum_node; track=false)
insert_edge!(graph, d_s1_sum, sum_node; track=false)
insert_edge!(graph, sum_node, d_exit, track = false)
insert_edge!(graph, sum_node, d_exit; track=false)
input = [gen_process_input(process) for _ in 1:1000]
@ -314,9 +271,12 @@ end
@test isapprox(compton_function.(input), compton_groundtruth.(input))
end
@testset "Equal results after optimization" for optimizer in
[ReductionOptimizer(), RandomWalkOptimizer(MersenneTwister(0))]
@testset "Process $proc_str" for proc_str in ["ke->ke", "kp->kp", "kk->ep", "ep->kk", "ke->kke", "ke->kkke"]
@testset "Equal results after optimization" for optimizer in [
ReductionOptimizer(), RandomWalkOptimizer(MersenneTwister(0))
]
@testset "Process $proc_str" for proc_str in [
"ke->ke", "kp->kp", "kk->ep", "ep->kk", "ke->kke", "ke->kkke"
]
model = QEDModel()
process = parse_process(proc_str, model)
machine = Machine(
@ -347,3 +307,4 @@ end
@test isapprox(compute_function.(input), reduced_compute_function.(input))
end
end
=#