8 Commits

Author SHA1 Message Date
994d4d7cee Start adapting ABCModel implementation to new interfaces
Some checks failed
MetagraphOptimization_CI / test (push) Failing after 7m6s
MetagraphOptimization_CI / docs (push) Failing after 7m15s
2024-08-19 18:37:55 +02:00
0ce98e29ef Reenable tests and fix a lot
Some checks failed
MetagraphOptimization_CI / test (push) Failing after 6m58s
MetagraphOptimization_CI / docs (push) Failing after 7m6s
2024-08-19 17:45:40 +02:00
d553fe8ffc Reenable ABC Model
Some checks failed
MetagraphOptimization_CI / test (push) Failing after 6m46s
MetagraphOptimization_CI / docs (push) Successful in 7m37s
2024-08-19 14:54:34 +02:00
e9bd1f2939 Still remove NodeFusion
Some checks failed
MetagraphOptimization_CI / test (push) Failing after 7m23s
MetagraphOptimization_CI / docs (push) Successful in 7m57s
2024-08-19 14:02:46 +02:00
97ccb3f3fb Remove occurrences of Fusion/Fuse
Some checks failed
MetagraphOptimization_CI / docs (push) Failing after 6m56s
MetagraphOptimization_CI / test (push) Failing after 7m25s
2024-08-13 17:57:16 +02:00
8a5e49429b Don't eval in generated function return
Some checks failed
MetagraphOptimization_CI / test (push) Failing after 6m58s
MetagraphOptimization_CI / docs (push) Successful in 7m27s
2024-08-09 12:21:49 +02:00
5be7ca99e7 Add results and evaluation
Some checks failed
MetagraphOptimization_CI / test (push) Failing after 1m33s
MetagraphOptimization_CI / docs (push) Failing after 1m33s
2024-07-10 14:21:26 +02:00
1ae39a8caa Congruent in photons example (#12)
Some checks failed
MetagraphOptimization_CI / docs (push) Failing after 1m50s
MetagraphOptimization_CI / test (push) Failing after 1m53s
Now targeting the correct branches

Co-authored-by: Rubydragon <anton.reinhard@proton.me>
Reviewed-on: #12
2024-07-10 14:17:39 +02:00
64 changed files with 687 additions and 1501 deletions

4
.gitattributes vendored
View File

@ -1,3 +1,5 @@
input/AB->ABBBBBBBBB.txt filter=lfs diff=lfs merge=lfs -text
input/AB->ABBBBBBB.txt filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs
*.gif filter=lfs diff=lfs merge=lfs
*.jld2 filter=lfs diff=lfs merge=lfs

View File

@ -5,7 +5,7 @@
## Package Features
- Read a DAG from a file
- Analyze its properties
- Mute the graph using the operations NodeFusion, NodeReduction and NodeSplit
- Mute the graph using the operations NodeReduction and NodeSplit
## Coming Soon:
- Add Code Generation from finished DAG

194
examples/congruent_in_ph.jl Normal file
View File

@ -0,0 +1,194 @@
ENV["UCX_ERROR_SIGNALS"] = "SIGILL,SIGBUS,SIGFPE"
using MetagraphOptimization
using QEDbase
using QEDcore
using QEDprocesses
using Random
using UUIDs
using CUDA
using NamedDims
using CSV
using JLD2
using FlexiMaps
RNG = Random.MersenneTwister(123)
function mock_machine()
return Machine(
[
MetagraphOptimization.NumaNode(
0,
1,
MetagraphOptimization.default_strategy(MetagraphOptimization.NumaNode),
-1.0,
UUIDs.uuid1(),
),
],
[-1.0;;],
)
end
function congruent_input_momenta(processDescription::GenericQEDProcess, omega::Number)
# generate an input sample for given e + nk -> e' + k' process, where the nk are equal
inputMasses = Vector{Float64}()
for particle in incoming_particles(processDescription)
push!(inputMasses, mass(particle))
end
outputMasses = Vector{Float64}()
for particle in outgoing_particles(processDescription)
push!(outputMasses, mass(particle))
end
initial_momenta = [
i == length(inputMasses) ? SFourMomentum(1, 0, 0, 0) : SFourMomentum(omega, 0, 0, omega) for
i in 1:length(inputMasses)
]
ss = sqrt(sum(initial_momenta) * sum(initial_momenta))
final_momenta = MetagraphOptimization.generate_physical_massive_moms(RNG, ss, outputMasses)
return (tuple(initial_momenta...), tuple(final_momenta...))
end
# theta ∈ [0, 2π] and phi ∈ [0, 2π]
function congruent_input_momenta_scenario_2(
processDescription::GenericQEDProcess,
omega::Number,
theta::Number,
phi::Number,
)
# -------------
# same as above
# generate an input sample for given e + nk -> e' + k' process, where the nk are equal
inputMasses = Vector{Float64}()
for particle in incoming_particles(processDescription)
push!(inputMasses, mass(particle))
end
outputMasses = Vector{Float64}()
for particle in outgoing_particles(processDescription)
push!(outputMasses, mass(particle))
end
initial_momenta = [
i == length(inputMasses) ? SFourMomentum(1, 0, 0, 0) : SFourMomentum(omega, 0, 0, omega) for
i in 1:length(inputMasses)
]
ss = sqrt(sum(initial_momenta) * sum(initial_momenta))
# up to here
# ----------
# now calculate the final_momenta from omega, cos_theta and phi
n = number_particles(processDescription, Incoming(), Photon())
cos_theta = cos(theta)
omega_prime = (n * omega) / (1 + n * omega * (1 - cos_theta))
k_prime =
omega_prime * SFourMomentum(1, sqrt(1 - cos_theta^2) * cos(phi), sqrt(1 - cos_theta^2) * sin(phi), cos_theta)
p_prime = sum(initial_momenta) - k_prime
final_momenta = (k_prime, p_prime)
return (tuple(initial_momenta...), tuple(final_momenta...))
end
function build_psp(processDescription::GenericQEDProcess, momenta)
return PhaseSpacePoint(
processDescription,
PerturbativeQED(),
PhasespaceDefinition(SphericalCoordinateSystem(), ElectronRestFrame()),
momenta[1],
momenta[2],
)
end
# hack to fix stacksize for threading
with_stacksize(f, n) = fetch(schedule(Task(f, n)))
# scenario 2
N = 1024 # thetas
M = 1024 # phis
K = 64 # omegas
thetas = collect(LinRange(0, 2π, N))
phis = collect(LinRange(0, 2π, M))
omegas = collect(maprange(log, 2e-2, 2e-7, K))
for photons in 1:5
# temp process to generate momenta
println("Generating $(K*N*M) inputs for $photons photons (Scenario 2 grid walk)...")
temp_process = parse_process("k"^photons * "e->ke", QEDModel(), PolX(), SpinUp(), PolX(), SpinUp())
input_momenta =
Array{typeof(congruent_input_momenta_scenario_2(temp_process, omegas[1], thetas[1], phis[1]))}(undef, (K, N, M))
Threads.@threads for k in 1:K
Threads.@threads for i in 1:N
Threads.@threads for j in 1:M
input_momenta[k, i, j] = congruent_input_momenta_scenario_2(temp_process, omegas[k], thetas[i], phis[j])
end
end
end
cu_results = CuArray{Float64}(undef, size(input_momenta))
fill!(cu_results, 0.0)
i = 1
for (in_pol, in_spin, out_pol, out_spin) in
Iterators.product([PolX(), PolY()], [SpinUp(), SpinDown()], [PolX(), PolY()], [SpinUp(), SpinDown()])
print(
"[$i/16] Calculating for spin/pol config: $in_pol, $in_spin -> $out_pol, $out_spin... Preparing inputs... ",
)
process = parse_process("k"^photons * "e->ke", QEDModel(), in_pol, in_spin, out_pol, out_spin)
inputs = Array{typeof(build_psp(process, input_momenta[1, 1, 1]))}(undef, (K, N, M))
#println("input_momenta: $input_momenta")
Threads.@threads for k in 1:K
Threads.@threads for i in 1:N
Threads.@threads for j in 1:M
inputs[k, i, j] = build_psp(process, input_momenta[k, i, j])
end
end
end
cu_inputs = CuArray(inputs)
print("Preparing graph... ")
graph = gen_graph(process)
optimize_to_fixpoint!(ReductionOptimizer(), graph)
print("Preparing function... ")
kernel! = get_cuda_kernel(graph, process, mock_machine())
#func = get_compute_function(graph, process, mock_machine())
print("Calculating... ")
ts = 32
bs = Int64(length(cu_inputs) / 32)
outputs = CuArray{ComplexF64}(undef, size(cu_inputs))
@cuda threads = ts blocks = bs always_inline = true kernel!(cu_inputs, outputs, length(cu_inputs))
CUDA.device_synchronize()
cu_results += abs2.(outputs)
println("Done.")
i += 1
end
println("Writing results")
out_ph_moms = getindex.(getindex.(input_momenta, 2), 1)
out_el_moms = getindex.(getindex.(input_momenta, 2), 2)
results = NamedDimsArray{(:omegas, :thetas, :phis)}(Array(cu_results))
println("Named results array: $(typeof(results))")
@save "$(photons)_congruent_photons_grid.jld2" omegas thetas phis results
end

View File

@ -1,60 +0,0 @@
using MetagraphOptimization
using Plots
using Random
function gen_plot(filepath)
name = basename(filepath)
name, _ = splitext(name)
filepath = joinpath(@__DIR__, "../input/", filepath)
if !isfile(filepath)
println("File ", filepath, " does not exist, skipping")
return
end
g = parse_dag(filepath, ABCModel())
Random.seed!(1)
println("Random Walking... ")
x = Vector{Float64}()
y = Vector{Float64}()
for i in 1:30
print("\r", i)
# push
opt = get_operations(g)
# choose one of fuse/split/reduce
option = rand(1:3)
if option == 1 && !isempty(opt.nodeFusions)
push_operation!(g, rand(collect(opt.nodeFusions)))
println("NF")
elseif option == 2 && !isempty(opt.nodeReductions)
push_operation!(g, rand(collect(opt.nodeReductions)))
println("NR")
elseif option == 3 && !isempty(opt.nodeSplits)
push_operation!(g, rand(collect(opt.nodeSplits)))
println("NS")
else
i = i - 1
end
props = get_properties(g)
push!(x, props.data)
push!(y, props.computeEffort)
end
println("\rDone.")
plot([x[1], x[2]], [y[1], y[2]], linestyle = :solid, linewidth = 1, color = :red, legend = false)
# Create lines connecting the reference point to each data point
for i in 3:length(x)
plot!([x[i - 1], x[i]], [y[i - 1], y[i]], linestyle = :solid, linewidth = 1, color = :red)
end
return gui()
end
gen_plot("AB->ABBB.txt")

View File

@ -1,96 +0,0 @@
using MetagraphOptimization
using Plots
using Random
function gen_plot(filepath)
name = basename(filepath)
name, _ = splitext(name)
filepath = joinpath(@__DIR__, "../input/", filepath)
if !isfile(filepath)
println("File ", filepath, " does not exist, skipping")
return
end
g = parse_dag(filepath, ABCModel())
Random.seed!(1)
println("Random Walking... ")
for i in 1:30
print("\r", i)
# push
opt = get_operations(g)
# choose one of fuse/split/reduce
option = rand(1:3)
if option == 1 && !isempty(opt.nodeFusions)
push_operation!(g, rand(collect(opt.nodeFusions)))
println("NF")
elseif option == 2 && !isempty(opt.nodeReductions)
push_operation!(g, rand(collect(opt.nodeReductions)))
println("NR")
elseif option == 3 && !isempty(opt.nodeSplits)
push_operation!(g, rand(collect(opt.nodeSplits)))
println("NS")
else
i = i - 1
end
end
println("\rDone.")
props = get_properties(g)
x0 = props.data
y0 = props.computeEffort
x = Vector{Float64}()
y = Vector{Float64}()
names = Vector{String}()
opt = get_operations(g)
for op in opt.nodeFusions
push_operation!(g, op)
props = get_properties(g)
push!(x, props.data)
push!(y, props.computeEffort)
pop_operation!(g)
push!(names, "NF: (" * string(props.data) * ", " * string(props.computeEffort) * ")")
end
for op in opt.nodeReductions
push_operation!(g, op)
props = get_properties(g)
push!(x, props.data)
push!(y, props.computeEffort)
pop_operation!(g)
push!(names, "NR: (" * string(props.data) * ", " * string(props.computeEffort) * ")")
end
for op in opt.nodeSplits
push_operation!(g, op)
props = get_properties(g)
push!(x, props.data)
push!(y, props.computeEffort)
pop_operation!(g)
push!(names, "NS: (" * string(props.data) * ", " * string(props.computeEffort) * ")")
end
plot([x0, x[1]], [y0, y[1]], linestyle = :solid, linewidth = 1, color = :red, legend = false)
# Create lines connecting the reference point to each data point
for i in 2:length(x)
plot!([x0, x[i]], [y0, y[i]], linestyle = :solid, linewidth = 1, color = :red)
end
#scatter!(x, y, label=names)
print(names)
return gui()
end
gen_plot("AB->ABBB.txt")

BIN
images/contour_plot_congruent_in_photons.gif (Stored with Git LFS) Normal file

Binary file not shown.

View File

@ -1,5 +1,5 @@
# Optimizer Plots
Plots of FusionOptimizer, ReductionOptimizer, SplitOptimizer, RandomWalkOptimizer, and GreedyOptimizer, executed on a system with 32 threads and an A30 GPU.
Plots of FusionOptimizer (deprecated), ReductionOptimizer, SplitOptimizer, RandomWalkOptimizer, and GreedyOptimizer, executed on a system with 32 threads and an A30 GPU.
Benchmarked using `notebooks/optimizers.ipynb`.

File diff suppressed because one or more lines are too long

View File

@ -54,8 +54,6 @@
"cu_inputs = CuArray(inputs)\n",
"optimizer = RandomWalkOptimizer(MersenneTwister(0))# GreedyOptimizer(GlobalMetricEstimator())\n",
"\n",
"#done: split, reduce, fuse, greedy\n",
"\n",
"process_str_short = \"qed_k3\"\n",
"optim_str = \"Random Walk Optimization\"\n",
"optim_str_short=\"random\"\n",

BIN
results/1_congruent_photons_grid.jld2 (Stored with Git LFS) Normal file

Binary file not shown.

BIN
results/2_congruent_photons_grid.jld2 (Stored with Git LFS) Normal file

Binary file not shown.

BIN
results/3_congruent_photons_grid.jld2 (Stored with Git LFS) Normal file

Binary file not shown.

BIN
results/4_congruent_photons_grid.jld2 (Stored with Git LFS) Normal file

Binary file not shown.

BIN
results/5_congruent_photons_grid.jld2 (Stored with Git LFS) Normal file

Binary file not shown.

View File

@ -19,7 +19,6 @@ export AbstractTask
export AbstractComputeTask
export AbstractDataTask
export DataTask
export FusedComputeTask
export PossibleOperations
export GraphProperties
@ -44,7 +43,6 @@ export is_valid, is_scheduled
# graph operation related
export Operation
export AppliedOperation
export NodeFusion
export NodeReduction
export NodeSplit
export push_operation!
@ -88,7 +86,7 @@ export GlobalMetricEstimator, CDCost
# optimization
export AbstractOptimizer, GreedyOptimizer, RandomWalkOptimizer
export ReductionOptimizer, SplitOptimizer, FusionOptimizer
export ReductionOptimizer, SplitOptimizer
export optimize_step!, optimize!
export fixpoint_reached, optimize_to_fixpoint!
@ -100,9 +98,6 @@ export ==, in, show, isempty, delete!, length
export bytes_to_human_readable
# TODO: this is probably not good
import QEDprocesses.compute
import Base.length
import Base.show
import Base.==
@ -115,8 +110,6 @@ import Base.delete!
import Base.insert!
import Base.collect
include("QEDprocesses_patch.jl")
include("devices/interface.jl")
include("task/type.jl")
include("node/type.jl")
@ -169,7 +162,6 @@ include("optimization/interface.jl")
include("optimization/greedy.jl")
include("optimization/random_walk.jl")
include("optimization/reduce.jl")
include("optimization/fuse.jl")
include("optimization/split.jl")
include("models/interface.jl")
@ -177,7 +169,6 @@ include("models/print.jl")
include("models/physics_models/interface.jl")
#=
include("models/physics_models/abc/types.jl")
include("models/physics_models/abc/particle.jl")
include("models/physics_models/abc/compute.jl")
@ -185,8 +176,8 @@ include("models/physics_models/abc/create.jl")
include("models/physics_models/abc/properties.jl")
include("models/physics_models/abc/parse.jl")
include("models/physics_models/abc/print.jl")
=#
include("models/physics_models/qed/utility.jl")
include("models/physics_models/qed/types.jl")
include("models/physics_models/qed/particle.jl")
include("models/physics_models/qed/diagrams.jl")

View File

@ -1,71 +0,0 @@
# patch QEDprocesses
# see issue https://github.com/QEDjl-project/QEDprocesses.jl/issues/77
@inline function QEDprocesses.number_particles(
proc_def::QEDbase.AbstractProcessDefinition,
dir::DIR,
::PT,
) where {DIR <: QEDbase.ParticleDirection, PT <: QEDbase.AbstractParticleType}
return count(x -> x isa PT, particles(proc_def, dir))
end
@inline function QEDprocesses.number_particles(
proc_def::QEDbase.AbstractProcessDefinition,
::PS,
) where {
DIR <: QEDbase.ParticleDirection,
PT <: QEDbase.AbstractParticleType,
EL <: AbstractFourMomentum,
PS <: ParticleStateful{DIR, PT, EL},
}
return QEDprocesses.number_particles(proc_def, DIR(), PT())
end
@inline function QEDprocesses.number_particles(
proc_def::QEDbase.AbstractProcessDefinition,
::Type{PS},
) where {
DIR <: QEDbase.ParticleDirection,
PT <: QEDbase.AbstractParticleType,
EL <: AbstractFourMomentum,
PS <: ParticleStateful{DIR, PT, EL},
}
return QEDprocesses.number_particles(proc_def, DIR(), PT())
end
@inline function QEDcore.ParticleStateful{DIR, SPECIES}(
mom::AbstractFourMomentum,
) where {DIR <: ParticleDirection, SPECIES <: AbstractParticleType}
return ParticleStateful(DIR(), SPECIES(), mom)
end
@inline function QEDcore.ParticleStateful{DIR, SPECIES, EL}(
mom::EL,
) where {DIR <: ParticleDirection, SPECIES <: AbstractParticleType, EL <: AbstractFourMomentum}
return ParticleStateful(DIR(), SPECIES(), mom)
end
@inline function QEDbase.momentum(
psp::AbstractPhaseSpacePoint{MODEL, PROC, PS_DEF, INT, OUTT},
dir::ParticleDirection,
species::AbstractParticleType,
n::Int,
) where {MODEL, PROC, PS_DEF, INT, OUTT}
# TODO: can be done through fancy template recursion too with 0 overhead
i = 0
c = n
for p in particles(psp, dir)
i += 1
if particle_species(p) isa typeof(species)
c -= 1
end
if c == 0
break
end
end
if c != 0 || n <= 0
throw(InvalidInputError("could not get $n-th momentum of $dir $species, does not exist"))
end
return momenta(psp, dir)[i]
end

View File

@ -7,7 +7,7 @@ function get_compute_function(graph::DAG, instance, machine::Machine)
tape = gen_tape(graph, instance, machine)
initCaches = Expr(:block, tape.initCachesCode...)
assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...)
assignInputs = Expr(:block, tape.inputAssignCode...)
code = Expr(:block, expr_from_fc.(tape.computeCode)...)
functionId = to_var_name(UUIDs.uuid1(rng[1]))
@ -16,9 +16,7 @@ function get_compute_function(graph::DAG, instance, machine::Machine)
"function compute_$(functionId)(data_input::$(input_type(instance))) $(initCaches); $(assignInputs); $code; return $resSym; end",
)
func = eval(expr)
return func
return expr
end
"""
@ -30,27 +28,27 @@ function get_cuda_kernel(graph::DAG, instance, machine::Machine)
tape = gen_tape(graph, instance, machine)
initCaches = Expr(:block, tape.initCachesCode...)
assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...)
assignInputs = Expr(:block, tape.inputAssignCode...)
code = Expr(:block, expr_from_fc.(tape.computeCode)...)
functionId = to_var_name(UUIDs.uuid1(rng[1]))
resSym = eval(gen_access_expr(entry_device(tape.machine), tape.outputSymbol))
expr = Meta.parse("function compute_$(functionId)(input_vector, output_vector, n::Int64)
id = (blockIdx().x - 1) * blockDim().x + threadIdx().x
if (id > n)
return
end
@inline data_input = input_vector[id]
$(initCaches)
$(assignInputs)
$code
@inline output_vector[id] = $resSym
return nothing
end")
expr = Meta.parse(
"function compute_$(functionId)(input_vector, output_vector, n::Int64)
id = (blockIdx().x - 1) * blockDim().x + threadIdx().x
if (id > n)
return
end
@inline data_input = input_vector[id]
$(initCaches)
$(assignInputs)
$code
@inline output_vector[id] = $resSym
return nothing
end"
)
func = eval(expr)
return func
return expr
end
"""

View File

@ -75,7 +75,7 @@ end
inputSymbols::Dict{String, Vector{Symbol}},
instance::AbstractProblemInstance,
machine::Machine,
problemInputSymbol::Symbol = :input,
problemInputSymbol::Symbol = :data_input,
)
Return a `Vector{Expr}` doing the input assignments from the given `problemInputSymbol` onto the `inputSymbols`.
@ -84,9 +84,9 @@ function gen_input_assignment_code(
inputSymbols::Dict{String, Vector{Symbol}},
instance,
machine::Machine,
problemInputSymbol::Symbol = :input,
problemInputSymbol::Symbol = :data_input,
)
assignInputs = Vector{FunctionCall}()
assignInputs = Vector{Expr}()
for (name, symbols) in inputSymbols
# make a function for this, since we can't use anonymous functions in the FunctionCall
@ -94,7 +94,9 @@ function gen_input_assignment_code(
device = entry_device(machine)
push!(
assignInputs,
FunctionCall(input_expr, SVector{1, Any}(name), SVector{1, Symbol}(problemInputSymbol), symbol, device),
Meta.parse(
"$(eval(gen_access_expr(device, symbol))) = $(input_expr(instance, name, problemInputSymbol))",
),
)
end
end
@ -126,7 +128,7 @@ function gen_tape(graph::DAG, instance, machine::Machine, scheduler::AbstractSch
outSym = Symbol(to_var_name(get_exit_node(graph).id))
initCaches = gen_cache_init_code(machine)
assignInputs = gen_input_assignment_code(inputSyms, instance, machine, :input)
assignInputs = gen_input_assignment_code(inputSyms, instance, machine, :data_input)
return Tape{input_type(instance)}(initCaches, assignInputs, schedule, inputSyms, outSym, Dict(), instance, machine)
end
@ -140,7 +142,7 @@ For implementation reasons, this disregards the set [`CacheStrategy`](@ref) of t
"""
function execute_tape(tape::Tape, input)
cache = Dict{Symbol, Any}()
cache[:input] = input
cache[:data_input] = input
# simply execute all the code snippets here
@assert typeof(input) == input_type(tape.instance)
# TODO: `@assert` that input fits the tape.instance

View File

@ -11,7 +11,7 @@ TODO: update docs
"""
struct Tape{INPUT}
initCachesCode::Vector{Expr}
inputAssignCode::Vector{FunctionCall}
inputAssignCode::Vector{Expr}
computeCode::Vector{FunctionCall}
inputSymbols::Dict{String, Vector{Symbol}}
outputSymbol::Symbol

View File

@ -10,8 +10,7 @@ Representation of a [`DAG`](@ref)'s cost as estimated by the [`GlobalMetricEstim
!!! note
Note that the `computeIntensity` doesn't necessarily make sense in the context of only operation costs.
For example, for node fusions this will always be 0, since the computeEffort is zero.
Note that the `computeIntensity` doesn't necessarily make sense in the context of only operation costs.
It will still work as intended when adding/subtracting to/from a `graph_cost` estimate.
"""
const CDCost = NamedTuple{(:data, :computeEffort, :computeIntensity), Tuple{Float64, Float64, Float64}}
@ -55,10 +54,6 @@ function graph_cost(estimator::GlobalMetricEstimator, graph::DAG)
)::CDCost
end
function operation_effect(estimator::GlobalMetricEstimator, graph::DAG, operation::NodeFusion)
return (data = -data(operation.input[2].task), computeEffort = 0.0, computeIntensity = 0.0)::CDCost
end
function operation_effect(estimator::GlobalMetricEstimator, graph::DAG, operation::NodeReduction)
s = length(operation.input) - 1
return (

View File

@ -169,66 +169,6 @@ function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invali
return nothing
end
function replace_children!(task::FusedComputeTask, before, after)
replacedIn1 = length(findall(x -> x == before, task.t1_inputs))
replacedIn2 = length(findall(x -> x == before, task.t2_inputs))
#@assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
replace!(task.t1_inputs, before => after)
replace!(task.t2_inputs, before => after)
# recursively descend down the tree, but only in the tasks where we're replacing things
if replacedIn1 > 0
replace_children!(task.first_task, before, after)
end
if replacedIn2 > 0
replace_children!(task.second_task, before, after)
end
return nothing
end
function replace_children!(task::AbstractTask, before, after)
return nothing
end
function update_child!(graph::DAG, n::Node, child_before::Symbol, child_after::Symbol; track = true)
# only need to update fused compute tasks
if !(typeof(task(n)) <: FusedComputeTask)
return nothing
end
taskBefore = copy(task(n))
#=if !((child_before in task(n).t1_inputs) || (child_before in task(n).t2_inputs))
println("------------------ Nothing to replace!! ------------------")
child_ids = Vector{String}()
for child in children(n)
push!(child_ids, "$(child.id)")
end
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
@assert false
end=#
replace_children!(task(n), child_before, child_after)
#=if !((child_after in task(n).t1_inputs) || (child_after in task(n).t2_inputs))
println("------------------ Did not replace anything!! ------------------")
child_ids = Vector{String}()
for child in children(n)
push!(child_ids, "$(child.id)")
end
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
@assert false
end=#
# keep track
if (track)
push!(graph.diff.updatedChildren, (n, taskBefore))
end
end
"""
get_snapshot_diff(graph::DAG)
@ -240,31 +180,6 @@ function get_snapshot_diff(graph::DAG)
return swapfield!(graph, :diff, Diff())
end
"""
invalidate_caches!(graph::DAG, operation::NodeFusion)
Invalidate the operation caches for a given [`NodeFusion`](@ref).
This deletes the operation from the graph's possible operations and from the involved nodes' own operation caches.
"""
function invalidate_caches!(graph::DAG, operation::NodeFusion)
delete!(graph.possibleOperations, operation)
# delete the operation from all caches of nodes involved in the operation
for n in [1, 3]
for i in eachindex(operation.input[n].nodeFusions)
if operation == operation.input[n].nodeFusions[i]
splice!(operation.input[n].nodeFusions, i)
break
end
end
end
operation.input[2].nodeFusion = missing
return nothing
end
"""
invalidate_caches!(graph::DAG, operation::NodeReduction)
@ -311,9 +226,6 @@ function invalidate_operation_caches!(graph::DAG, node::ComputeTaskNode)
if !ismissing(node.nodeSplit)
invalidate_caches!(graph, node.nodeSplit)
end
while !isempty(node.nodeFusions)
invalidate_caches!(graph, pop!(node.nodeFusions))
end
return nothing
end
@ -329,8 +241,5 @@ function invalidate_operation_caches!(graph::DAG, node::DataTaskNode)
if !ismissing(node.nodeSplit)
invalidate_caches!(graph, node.nodeSplit)
end
if !ismissing(node.nodeFusion)
invalidate_caches!(graph, node.nodeFusion)
end
return nothing
end

View File

@ -7,7 +7,6 @@ A struct storing all possible operations on a [`DAG`](@ref).
To get the [`PossibleOperations`](@ref) on a [`DAG`](@ref), use [`get_operations`](@ref).
"""
mutable struct PossibleOperations
nodeFusions::Set{NodeFusion}
nodeReductions::Set{NodeReduction}
nodeSplits::Set{NodeSplit}
end
@ -52,7 +51,7 @@ end
Construct and return an empty [`PossibleOperations`](@ref) object.
"""
function PossibleOperations()
return PossibleOperations(Set{NodeFusion}(), Set{NodeReduction}(), Set{NodeSplit}())
return PossibleOperations(Set{NodeReduction}(), Set{NodeSplit}())
end
"""

View File

@ -40,19 +40,11 @@ function is_valid(graph::DAG)
for ns in graph.possibleOperations.nodeSplits
@assert is_valid(graph, ns)
end
for nf in graph.possibleOperations.nodeFusions
@assert is_valid(graph, nf)
end
for node in graph.dirtyNodes
@assert node in graph "Dirty Node is not part of the graph!"
@assert ismissing(node.nodeReduction) "Dirty Node has a NodeReduction!"
@assert ismissing(node.nodeSplit) "Dirty Node has a NodeSplit!"
if (typeof(node) <: DataTaskNode)
@assert ismissing(node.nodeFusion) "Dirty DataTaskNode has a Node Fusion!"
elseif (typeof(node) <: ComputeTaskNode)
@assert isempty(node.nodeFusions) "Dirty ComputeTaskNode has Node Fusions!"
end
end
@assert is_connected(graph) "Graph is not connected!"

View File

@ -39,8 +39,8 @@ Generate the [`DAG`](@ref) for the given [`AbstractProblemInstance`](@ref). Ever
function graph end
"""
input_expr(::AbstractProblemInstance, name::String, input)
input_expr(instance::AbstractProblemInstance, name::String, input_symbol::Symbol)
For a given [`AbstractProblemInstance`](@ref), the problem input (of type `input_type(...)`) and an entry node name, return a that specific input value from the input symbol.
For the given [`AbstractProblemInstance`](@ref), the entry node name, and the symbol of the problem input (where a variable of type `input_type(...)` will exist), return an `Expr` that gets that specific input value from the input symbol.
"""
function input_expr end

View File

@ -0,0 +1,3 @@
## Deprecation Warning
These models are deprecated and should not be used anymore. They will be dropped entirely soon.

View File

@ -1,6 +1,17 @@
using AccurateArithmetic
using StaticArrays
function input_expr(instance::ABCProcessDescription, name::String, psp_symbol::Symbol)
(type, index) = type_index_from_name(ABCModel(), name)
return Meta.parse(
"ABCParticleValue(
$type(momentum($psp_symbol, $(construction_string(particle_direction(type))), $(construction_string(particle_species(type))), Val($index))),
0.0im,
)",
)
end
"""
compute(::ComputeTaskABC_P, data::ABCParticleValue)
@ -8,7 +19,7 @@ Return the particle and value as is.
0 FLOP.
"""
function compute(::ComputeTaskABC_P, data::ABCParticleValue{P})::ABCParticleValue{P} where {P <: ABCParticle}
function compute(::ComputeTaskABC_P, data::ABCParticleValue{P})::ABCParticleValue{P} where {P}
return data
end
@ -19,7 +30,7 @@ Compute an outer edge. Return the particle value with the same particle and the
1 FLOP.
"""
function compute(::ComputeTaskABC_U, data::ABCParticleValue{P})::ABCParticleValue{P} where {P <: ABCParticle}
function compute(::ComputeTaskABC_U, data::ABCParticleValue{P})::ABCParticleValue{P} where {P}
return ABCParticleValue{P}(data.p, data.v * ABC_outer_edge(data.p))
end
@ -34,7 +45,7 @@ function compute(
::ComputeTaskABC_V,
data1::ABCParticleValue{P1},
data2::ABCParticleValue{P2},
)::ABCParticleValue where {P1 <: ABCParticle, P2 <: ABCParticle}
)::ABCParticleValue where {P1, P2}
p3 = ABC_conserve_momentum(data1.p, data2.p)
dataOut = ABCParticleValue{typeof(p3)}(p3, data1.v * ABC_vertex() * data2.v)
return dataOut
@ -49,11 +60,7 @@ For valid inputs, both input particles should have the same momenta at this poin
12 FLOP.
"""
function compute(
::ComputeTaskABC_S2,
data1::ParticleValue{P},
data2::ParticleValue{P},
)::Float64 where {P <: ABCParticle}
function compute(::ComputeTaskABC_S2, data1::ParticleValue{P}, data2::ParticleValue{P})::Float64 where {P}
#=
@assert isapprox(abs(data1.p.momentum.E), abs(data2.p.momentum.E), rtol = 0.001, atol = sqrt(eps())) "E: $(data1.p.momentum.E) vs. $(data2.p.momentum.E)"
@assert isapprox(data1.p.momentum.px, -data2.p.momentum.px, rtol = 0.001, atol = sqrt(eps())) "px: $(data1.p.momentum.px) vs. $(data2.p.momentum.px)"
@ -85,12 +92,6 @@ Linearly many FLOP with growing data.
"""
function compute(::ComputeTaskABC_Sum, data...)::Float64
return sum_kbn([data...])
#=s = 0.0im
for d in data
s += d
end
return s=#
end
function compute(::ComputeTaskABC_Sum, data::AbstractArray)::Float64

View File

@ -14,34 +14,28 @@ struct ABCModel <: AbstractPhysicsModel end
Base type for all particles in the [`ABCModel`](@ref).
"""
abstract type ABCParticle <: AbstractParticle end
abstract type ABCParticle <: AbstractParticleType end
"""
ParticleA <: ABCParticle
An 'A' particle in the ABC Model.
"""
struct ParticleA <: ABCParticle
momentum::SFourMomentum
end
struct ParticleA <: ABCParticle end
"""
ParticleB <: ABCParticle
A 'B' particle in the ABC Model.
"""
struct ParticleB <: ABCParticle
momentum::SFourMomentum
end
struct ParticleB <: ABCParticle end
"""
ParticleC <: ABCParticle
A 'C' particle in the ABC Model.
"""
struct ParticleC <: ABCParticle
momentum::SFourMomentum
end
struct ParticleC <: ABCParticle end
"""
ABCProcessDescription <: AbstractProcessDescription
@ -72,7 +66,7 @@ struct ABCProcessInput{N1, N2, N3, N4, N5, N6} <: AbstractProcessInput
outC::SVector{N6, ParticleC}
end
ABCParticleValue{ParticleType <: ABCParticle} = ParticleValue{ParticleType, ComplexF64}
ABCParticleValue{ParticleType} = ParticleValue{ParticleType, ComplexF64}
"""
mass(t::Type{T}) where {T <: ABCParticle}
@ -109,39 +103,46 @@ end
Return a Vector of the possible types of particle in the [`ABCModel`](@ref).
"""
function types(::ABCModel)
return [ParticleA, ParticleB, ParticleC]
return [
ParticleStateful{Incoming, ParticleA, SFourMomentum},
ParticleStateful{Incoming, ParticleB, SFourMomentum},
ParticleStateful{Incoming, ParticleC, SFourMomentum},
ParticleStateful{Outgoing, ParticleA, SFourMomentum},
ParticleStateful{Outgoing, ParticleB, SFourMomentum},
ParticleStateful{Outgoing, ParticleC, SFourMomentum},
]
end
"""
square(p::ABCParticle)
square(p::AbstractParticleStateful{Dir, ABCParticle})
Return the square of the particle's momentum as a `Float` value.
Takes 7 effective FLOP.
"""
function square(p::ABCParticle)
return getMass2(p.momentum)
function square(p::AbstractParticleStateful{D, ABCParticle}) where {D}
return getMass2(momentum(p))
end
"""
ABC_inner_edge(p::ABCParticle)
ABC_inner_edge(p::AbstractParticleStateful{Dir, ABCParticle})
Return the factor of the inner edge with the given (virtual) particle.
Takes 10 effective FLOP. (3 here + 7 in square(p))
"""
function ABC_inner_edge(p::ABCParticle)
return 1.0 / (square(p) - mass(p)^2)
function ABC_inner_edge(p::AbstractParticleStateful{D, ABCParticle}) where {D}
return 1.0 / (square(p) - mass(particle(p))^2)
end
"""
ABC_outer_edge(p::ABCParticle)
ABC_outer_edge(p::AbstractParticleStateful{Dir, ABCParticle})
Return the factor of the outer edge with the given (real) particle.
Takes 0 effective FLOP.
"""
function ABC_outer_edge(p::ABCParticle)
function ABC_outer_edge(::AbstractParticleStateful{D, ABCParticle}) where {D}
return 1.0
end
@ -179,17 +180,26 @@ model(::ABCProcessDescription) = ABCModel()
model(::ABCProcessInput) = ABCModel()
function type_index_from_name(::ABCModel, name::String)
if startswith(name, "A")
return (ParticleA, parse(Int, name[2:end]))
elseif startswith(name, "B")
return (ParticleB, parse(Int, name[2:end]))
elseif startswith(name, "C")
return (ParticleC, parse(Int, name[2:end]))
if startswith(name, "Ai")
return (ParticleStateful{Incoming, ParticleA, SFourMomentum}, parse(Int, name[3:end]))
elseif startswith(name, "Ao")
return (ParticleStateful{Outgoing, ParticleA, SFourMomentum}, parse(Int, name[3:end]))
elseif startswith(name, "Bi")
return (ParticleStateful{Incoming, ParticleB, SFourMomentum}, parse(Int, name[3:end]))
elseif startswith(name, "Bo")
return (ParticleStateful{Outgoing, ParticleB, SFourMomentum}, parse(Int, name[3:end]))
elseif startswith(name, "Ci")
return (ParticleStateful{Incoming, ParticleC, SFourMomentum}, parse(Int, name[3:end]))
elseif startswith(name, "Co")
return (ParticleStateful{Outgoing, ParticleC, SFourMomentum}, parse(Int, name[3:end]))
else
throw("Invalid name for a particle in the ABC model")
end
end
function String(::Type{PS}) where {DIR, P <: ABCParticle, PS <: AbstractParticleStateful{DIR, P}}
return String(P)
end
function String(::Type{ParticleA})
return "A"
end

View File

@ -1,12 +1,26 @@
using StaticArrays
function input_expr(name::String, psp::PhaseSpacePoint)
construction_string(::Incoming) = "Incoming()"
construction_string(::Outgoing) = "Outgoing()"
construction_string(::Electron) = "Electron()"
construction_string(::Positron) = "Positron()"
construction_string(::Photon) = "Photon()"
construction_string(::PolX) = "PolX()"
construction_string(::PolY) = "PolY()"
construction_string(::SpinUp) = "SpinUp()"
construction_string(::SpinDown) = "SpinDown()"
function input_expr(instance::GenericQEDProcess, name::String, psp_symbol::Symbol)
(type, index) = type_index_from_name(QEDModel(), name)
return ParticleValueSP(
type(momentum(psp, particle_direction(type), particle_species(type), index)),
0.0im,
spin_or_pol(process(psp), type, index),
return Meta.parse(
"ParticleValueSP(
$type(momentum($psp_symbol, $(construction_string(particle_direction(type))), $(construction_string(particle_species(type))), Val($index))),
0.0im,
$(construction_string(spin_or_pol(instance, type, index))),
)",
)
end

View File

@ -8,7 +8,6 @@ function _svector_from_type(processDescription::GenericQEDProcess, type, particl
if haskey(out_particles(processDescription), type)
return SVector{out_particles(processDescription)[type], type}(filter(x -> typeof(x) <: type, particles))
end
return SVector{0, type}()
end
"""
@ -64,9 +63,9 @@ function gen_graph(process_description::GenericQEDProcess)
# TODO: Not all diagram outputs should always be summed at the end, if they differ by fermion exchange they need to be diffed
# Should not matter for n-Photon Compton processes though
sum_node = insert_node!(graph, make_node(ComputeTaskQED_Sum(0)), track = false, invalidate_cache = false)
global_data_out = insert_node!(graph, make_node(DataTask(COMPLEX_SIZE)), track = false, invalidate_cache = false)
insert_edge!(graph, sum_node, global_data_out, track = false, invalidate_cache = false)
sum_node = insert_node!(graph, make_node(ComputeTaskQED_Sum(0)); track = false, invalidate_cache = false)
global_data_out = insert_node!(graph, make_node(DataTask(COMPLEX_SIZE)); track = false, invalidate_cache = false)
insert_edge!(graph, sum_node, global_data_out; track = false, invalidate_cache = false)
# remember the data out nodes for connection
dataOutNodes = Dict()
@ -75,16 +74,16 @@ function gen_graph(process_description::GenericQEDProcess)
# generate data in and U tasks
data_in = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE), String(particle)),
make_node(DataTask(PARTICLE_VALUE_SIZE), String(particle));
track = false,
invalidate_cache = false,
) # read particle data node
compute_u = insert_node!(graph, make_node(ComputeTaskQED_U()), track = false, invalidate_cache = false) # compute U node
compute_u = insert_node!(graph, make_node(ComputeTaskQED_U()); track = false, invalidate_cache = false) # compute U node
data_out =
insert_node!(graph, make_node(DataTask(PARTICLE_VALUE_SIZE)), track = false, invalidate_cache = false) # transfer data out from u (one ABCParticleValue object)
insert_node!(graph, make_node(DataTask(PARTICLE_VALUE_SIZE)); track = false, invalidate_cache = false) # transfer data out from u (one ABCParticleValue object)
insert_edge!(graph, data_in, compute_u, track = false, invalidate_cache = false)
insert_edge!(graph, compute_u, data_out, track = false, invalidate_cache = false)
insert_edge!(graph, data_in, compute_u; track = false, invalidate_cache = false)
insert_edge!(graph, compute_u, data_out; track = false, invalidate_cache = false)
# remember the data_out node for future edges
dataOutNodes[String(particle)] = data_out
@ -100,19 +99,19 @@ function gen_graph(process_description::GenericQEDProcess)
data_in1 = dataOutNodes[String(vertex.in1)]
data_in2 = dataOutNodes[String(vertex.in2)]
compute_V = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false, invalidate_cache = false) # compute vertex
compute_V = insert_node!(graph, make_node(ComputeTaskQED_V()); track = false, invalidate_cache = false) # compute vertex
insert_edge!(graph, data_in1, compute_V, track = false, invalidate_cache = false)
insert_edge!(graph, data_in2, compute_V, track = false, invalidate_cache = false)
insert_edge!(graph, data_in1, compute_V; track = false, invalidate_cache = false)
insert_edge!(graph, data_in2, compute_V; track = false, invalidate_cache = false)
data_V_out = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE)),
make_node(DataTask(PARTICLE_VALUE_SIZE));
track = false,
invalidate_cache = false,
)
insert_edge!(graph, compute_V, data_V_out, track = false, invalidate_cache = false)
insert_edge!(graph, compute_V, data_V_out; track = false, invalidate_cache = false)
if (vertex.out == tie.in1 || vertex.out == tie.in2)
# out particle is part of the tie -> there will be an S2 task with it later, don't make S1 task
@ -122,18 +121,18 @@ function gen_graph(process_description::GenericQEDProcess)
# otherwise, add S1 task
compute_S1 =
insert_node!(graph, make_node(ComputeTaskQED_S1()), track = false, invalidate_cache = false) # compute propagator
insert_node!(graph, make_node(ComputeTaskQED_S1()); track = false, invalidate_cache = false) # compute propagator
insert_edge!(graph, data_V_out, compute_S1, track = false, invalidate_cache = false)
insert_edge!(graph, data_V_out, compute_S1; track = false, invalidate_cache = false)
data_S1_out = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE)),
make_node(DataTask(PARTICLE_VALUE_SIZE));
track = false,
invalidate_cache = false,
)
insert_edge!(graph, compute_S1, data_S1_out, track = false, invalidate_cache = false)
insert_edge!(graph, compute_S1, data_S1_out; track = false, invalidate_cache = false)
# overrides potentially different nodes from previous diagrams, which is intentional
dataOutNodes[String(vertex.out)] = data_S1_out
@ -144,16 +143,16 @@ function gen_graph(process_description::GenericQEDProcess)
data_in1 = dataOutNodes[String(tie.in1)]
data_in2 = dataOutNodes[String(tie.in2)]
compute_S2 = insert_node!(graph, make_node(ComputeTaskQED_S2()), track = false, invalidate_cache = false)
compute_S2 = insert_node!(graph, make_node(ComputeTaskQED_S2()); track = false, invalidate_cache = false)
data_S2 = insert_node!(graph, make_node(DataTask(PARTICLE_VALUE_SIZE)), track = false, invalidate_cache = false)
data_S2 = insert_node!(graph, make_node(DataTask(PARTICLE_VALUE_SIZE)); track = false, invalidate_cache = false)
insert_edge!(graph, data_in1, compute_S2, track = false, invalidate_cache = false)
insert_edge!(graph, data_in2, compute_S2, track = false, invalidate_cache = false)
insert_edge!(graph, data_in1, compute_S2; track = false, invalidate_cache = false)
insert_edge!(graph, data_in2, compute_S2; track = false, invalidate_cache = false)
insert_edge!(graph, compute_S2, data_S2, track = false, invalidate_cache = false)
insert_edge!(graph, compute_S2, data_S2; track = false, invalidate_cache = false)
insert_edge!(graph, data_S2, sum_node, track = false, invalidate_cache = false)
insert_edge!(graph, data_S2, sum_node; track = false, invalidate_cache = false)
add_child!(task(sum_node))
end

View File

@ -152,8 +152,9 @@ function ==(d1::FeynmanDiagram, d2::FeynmanDiagram)
)=#
end
copy(fd::FeynmanDiagram) =
FeynmanDiagram(deepcopy(fd.vertices), copy(fd.tie[]), deepcopy(fd.particles), copy(fd.type_ids))
function copy(fd::FeynmanDiagram)
return FeynmanDiagram(deepcopy(fd.vertices), copy(fd.tie[]), deepcopy(fd.particles), copy(fd.type_ids))
end
"""
id_for_type(d::FeynmanDiagram, t::Type{<:ParticleStateful})
@ -503,7 +504,6 @@ function gen_compton_diagram_from_order(order::Vector{Int}, inFerm, outFerm, n::
return new_diagram
end
#=
"""
gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, outFerm, n::Int, m::Int)
@ -511,8 +511,8 @@ Helper function for [`gen_compton_diagrams`](@Ref). Generates a single diagram f
"""
function gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, outFerm, n::Int, m::Int)
photons = vcat(
[FeynmanParticle(ParticleStateful{Incoming, Photon}, i) for i in 1:n],
[FeynmanParticle(ParticleStateful{Outgoing, Photon}, i) for i in 1:m],
[FeynmanParticle(ParticleStateful{Incoming, Photon, SFourMomentum}, i) for i in 1:n],
[FeynmanParticle(ParticleStateful{Outgoing, Photon, SFourMomentum}, i) for i in 1:m],
)
new_diagram = FeynmanDiagram(
@ -520,10 +520,10 @@ function gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, out
missing,
[inFerm, outFerm, photons...],
Dict{Type, Int64}(
ParticleStateful{Incoming, Electron} => 1,
ParticleStateful{Outgoing, Electron} => 1,
ParticleStateful{Incoming, Photon} => n,
ParticleStateful{Outgoing, Photon} => m,
ParticleStateful{Incoming, Electron, SFourMomentum} => 1,
ParticleStateful{Outgoing, Electron, SFourMomentum} => 1,
ParticleStateful{Incoming, Photon, SFourMomentum} => n,
ParticleStateful{Outgoing, Photon, SFourMomentum} => m,
),
)
@ -535,9 +535,9 @@ function gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, out
while left_index <= right_index
# left side
v_left = FeynmanVertex(
FeynmanParticle(ParticleStateful{Incoming, Electron}, iterations),
FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, iterations),
photons[order[left_index]],
FeynmanParticle(ParticleStateful{Incoming, Electron}, iterations + 1),
FeynmanParticle(ParticleStateful{Incoming, Electron, SFourMomentum}, iterations + 1),
)
left_index += 1
add_vertex!(new_diagram, v_left)
@ -550,9 +550,9 @@ function gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, out
if (iterations == 1)
# right side
v_right = FeynmanVertex(
FeynmanParticle(ParticleStateful{Outgoing, Electron}, iterations),
FeynmanParticle(ParticleStateful{Outgoing, Electron, SFourMomentum}, iterations),
photons[order[right_index]],
FeynmanParticle(ParticleStateful{Outgoing, Electron}, iterations + 1),
FeynmanParticle(ParticleStateful{Outgoing, Electron, SFourMomentum}, iterations + 1),
)
right_index -= 1
add_vertex!(new_diagram, v_right)
@ -566,7 +566,6 @@ function gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, out
add_tie!(new_diagram, FeynmanTie(ps[1], ps[2]))
return new_diagram
end
=#
"""
gen_compton_diagrams(n::Int, m::Int)
@ -587,7 +586,6 @@ function gen_compton_diagrams(n::Int, m::Int)
return vcat(diagrams...)
end
"""
gen_compton_diagrams_one_side(n::Int, m::Int)

View File

@ -1,5 +1,3 @@
# TODO generalize this somehow, probably using AllSpin/AllPol as defaults
"""
parse_process(string::AbstractString, model::QEDModel)
@ -8,62 +6,40 @@ Parse a string representation of a process, such as "ke->ke" into the correspond
function parse_process(
str::AbstractString,
model::QEDModel,
inphpol::AbstractDefinitePolarization,
inelspin::AbstractDefiniteSpin,
outphpol::AbstractDefinitePolarization,
outelspin::AbstractDefiniteSpin,
inphpol::AbstractDefinitePolarization = PolX(),
inelspin::AbstractDefiniteSpin = SpinUp(),
outphpol::AbstractDefinitePolarization = PolX(),
outelspin::AbstractDefiniteSpin = SpinUp(),
)
inParticles = Dict(
ParticleStateful{Incoming, Photon, SFourMomentum} => 0,
ParticleStateful{Incoming, Electron, SFourMomentum} => 0,
ParticleStateful{Incoming, Positron, SFourMomentum} => 0,
)
outParticles = Dict(
ParticleStateful{Outgoing, Photon, SFourMomentum} => 0,
ParticleStateful{Outgoing, Electron, SFourMomentum} => 0,
ParticleStateful{Outgoing, Positron, SFourMomentum} => 0,
)
if !(contains(str, "->"))
throw("Did not find -> while parsing process \"$str\"")
end
(inStr, outStr) = split(str, "->")
(in_str, out_str) = split(str, "->")
if (isempty(inStr) || isempty(outStr))
if (isempty(in_str) || isempty(out_str))
throw("Process (\"$str\") input or output part is empty!")
end
for t in types(model)
if (is_incoming(t))
inCount = count(x -> x == String(t)[1], inStr)
in_particles = Vector{AbstractParticleType}()
out_particles = Vector{AbstractParticleType}()
if inCount != 0
inParticles[t] = inCount
end
end
if (is_outgoing(t))
outCount = count(x -> x == String(t)[1], outStr)
if outCount != 0
outParticles[t] = outCount
for (particle_vector, s) in ((in_particles, in_str), (out_particles, out_str))
for c in s
if c == 'e'
push!(particle_vector, Electron())
elseif c == 'p'
push!(particle_vector, Positron())
elseif c == 'k'
push!(particle_vector, Photon())
else
throw("Encountered unknown characters in the process \"$str\"")
end
end
end
if length(inStr) != sum(values(inParticles))
throw("Encountered unknown characters in the input part of process \"$str\"")
elseif length(outStr) != sum(values(outParticles))
throw("Encountered unknown characters in the output part of process \"$str\"")
end
in_spin_pols = tuple([is_boson(in_particles[i]) ? inphpol : inelspin for i in eachindex(in_particles)]...)
out_spin_pols = tuple([is_boson(out_particles[i]) ? outphpol : outelspin for i in eachindex(out_particles)]...)
in_ph = inParticles[ParticleStateful{Incoming, Photon, SFourMomentum}]
in_el = inParticles[ParticleStateful{Incoming, Electron, SFourMomentum}]
in_pos = inParticles[ParticleStateful{Incoming, Positron, SFourMomentum}]
out_ph = outParticles[ParticleStateful{Outgoing, Photon, SFourMomentum}]
out_el = outParticles[ParticleStateful{Outgoing, Electron, SFourMomentum}]
out_pos = outParticles[ParticleStateful{Outgoing, Positron, SFourMomentum}]
in_spin_pols = tuple([i <= in_ph ? inphpol : inelspin for i in 1:(in_ph + in_el)]...)
out_spin_pols = tuple([i <= out_ph ? outphpol : outelspin for i in 1:(out_ph + out_el)]...)
return GenericQEDProcess(in_ph, out_ph, in_el, out_el, in_pos, out_pos, in_spin_pols, out_spin_pols)
return GenericQEDProcess(tuple(in_particles...), tuple(out_particles...), in_spin_pols, out_spin_pols)
end

View File

@ -3,13 +3,6 @@ using StaticArrays
const e = sqrt(4π / 137)
"""
QEDModel <: AbstractPhysicsModel
Singleton definition for identification of the QED-Model.
"""
struct QEDModel <: AbstractPhysicsModel end
QEDbase.is_incoming(::Type{<:ParticleStateful{Incoming}}) = true
QEDbase.is_outgoing(::Type{<:ParticleStateful{Outgoing}}) = true
QEDbase.is_incoming(::Type{<:ParticleStateful{Outgoing}}) = false
@ -20,61 +13,6 @@ QEDbase.particle_species(
::Type{<:ParticleStateful{DIR, SPECIES}},
) where {DIR <: ParticleDirection, SPECIES <: AbstractParticleType} = SPECIES()
_assert_particle_type_tuple(::Tuple{}) = nothing
_assert_particle_type_tuple(t::Tuple{AbstractParticleType, Vararg}) = _assert_particle_type_tuple(t[2:end])
_assert_particle_type_tuple(t::Any) =
throw(InvalidInputError("invalid input, provide a tuple of AbstractParticleTypes to construct a GenericQEDProcess"))
struct GenericQEDProcess{INT, OUTT, INSP, OUTSP} <:
AbstractProcessDefinition where {INT <: Tuple, OUTT <: Tuple, INSP <: Tuple, OUTSP <: Tuple}
incoming_particles::INT
outgoing_particles::OUTT
incoming_spins_pols::INSP
outgoing_spins_pols::OUTSP
function GenericQEDProcess(
in_particles::INT,
out_particles::OUTT,
in_sp::INSP,
out_sp::OUTSP,
) where {INT <: Tuple, OUTT <: Tuple, INSP <: Tuple, OUTSP <: Tuple}
_assert_particle_type_tuple(in_particles)
_assert_particle_type_tuple(out_particles)
# TODO: check in_sp/out_sp fits to particles (spin->fermion, pol->photon)
return new{INT, OUTT, INSP, OUTSP}(in_particles, out_particles, in_sp, out_sp)
end
"""
GenericQEDProcess{INSP, OUTSP}(in_ph::Int, out_ph::Int, in_el::Int, out_el::Int, in_po::Int, out_po::Int, in_sp::INSP, out_sp::OUTSP)
Convenience constructor from numbers of input/output photons, electrons and positrons.
`in_sp` and `out_sp` are tuples of the spin and polarization configurations of the given particles. Their order is Photons, then Electrons, then Positrons for both incoming and outgoing particles.
# TODO: make default for in_sp and out_sp available for convenience
"""
function GenericQEDProcess(
in_ph::Int,
out_ph::Int,
in_el::Int,
out_el::Int,
in_po::Int,
out_po::Int,
in_sp::INSP,
out_sp::OUTSP,
) where {INSP <: Tuple, OUTSP <: Tuple}
in_p = ntuple(i -> i <= in_ph ? Photon() : i <= in_ph + in_el ? Electron() : Positron(), in_ph + in_el + in_po)
out_p = ntuple(
i -> i <= out_ph ? Photon() : i <= out_ph + out_el ? Electron() : Positron(),
out_ph + out_el + out_po,
)
return GenericQEDProcess(in_p, out_p, in_sp, out_sp)
end
end
function spin_or_pol(
process::GenericQEDProcess,
type::Type{ParticleStateful{DIR, SPECIES, EL}},
@ -105,16 +43,6 @@ function spin_or_pol(
end
end
QEDprocesses.incoming_particles(proc::GenericQEDProcess) = proc.incoming_particles
QEDprocesses.outgoing_particles(proc::GenericQEDProcess) = proc.outgoing_particles
function isphysical(proc::GenericQEDProcess)
return (
number_particles(proc, Incoming(), Electron()) + number_particles(proc, Outgoing(), Positron()) ==
number_particles(proc, Incoming(), Positron()) + number_particles(proc, Outgoing(), Electron())
) && number_particles(proc, Incoming()) + number_particles(proc, Outgoing()) >= 2
end
function input_type(p::GenericQEDProcess)
in_t = QEDcore._assemble_tuple_type(incoming_particles(p), Incoming(), SFourMomentum)
out_t = QEDcore._assemble_tuple_type(outgoing_particles(p), Outgoing(), SFourMomentum)

View File

@ -1,3 +1,10 @@
"""
QEDModel <: AbstractPhysicsModel
Singleton definition for identification of the QED-Model.
"""
struct QEDModel <: AbstractPhysicsModel end
"""
ComputeTaskQED_S1 <: AbstractComputeTask

View File

@ -0,0 +1,14 @@
using QEDprocesses
# add type overload for number_particles function
@inline function QEDprocesses.number_particles(
proc_def::QEDbase.AbstractProcessDefinition,
::Type{PS},
) where {
DIR <: QEDbase.ParticleDirection,
PT <: QEDbase.AbstractParticleType,
EL <: AbstractFourMomentum,
PS <: ParticleStateful{DIR, PT, EL},
}
return QEDprocesses.number_particles(proc_def, DIR(), PT())
end

View File

@ -1,6 +1,6 @@
DataTaskNode(t::AbstractDataTask, name = "") =
DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, missing, name)
DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, name)
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
t, # task
Vector{Node}(), # parents
@ -8,7 +8,6 @@ ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
UUIDs.uuid1(rng[threadid()]), # id
missing, # node reduction
missing, # node split
Vector{NodeFusion}(), # node fusions
missing, # device
)

View File

@ -30,7 +30,6 @@ Any node that transfers data and does no computation.
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.\\
`.name`: The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\
"""
mutable struct DataTaskNode{TaskType <: AbstractDataTask} <: Node
@ -51,9 +50,6 @@ mutable struct DataTaskNode{TaskType <: AbstractDataTask} <: Node
# the NodeSplit involving this node, if it exists
nodeSplit::Union{Operation, Missing}
# the node fusion involving this node, if it exists
nodeFusion::Union{Operation, Missing}
# for input nodes we need a name for the node to distinguish between them
name::String
end
@ -70,7 +66,6 @@ Any node that computes a result from inputs using an [`AbstractComputeTask`](@re
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeFusions`: A vector of this node's [`NodeFusion`](@ref)s. For a `ComputeTaskNode` there can be any number of these, unlike the [`DataTaskNode`](@ref)s.\\
`.device`: The Device this node has been scheduled on by a [`Scheduler`](@ref).
"""
mutable struct ComputeTaskNode{TaskType <: AbstractComputeTask} <: Node
@ -82,9 +77,6 @@ mutable struct ComputeTaskNode{TaskType <: AbstractComputeTask} <: Node
nodeReduction::Union{Operation, Missing}
nodeSplit::Union{Operation, Missing}
# for ComputeTasks there can be multiple fusions, unlike the DataTasks
nodeFusions::Vector{<:Operation}
# the device this node is assigned to execute on
device::Union{AbstractDevice, Missing}
end

View File

@ -29,17 +29,6 @@ function is_valid_node(graph::DAG, node::Node)
@assert is_valid(graph, node.nodeSplit)
end=#
if !(typeof(task(node)) <: FusedComputeTask)
# the remaining checks are only necessary for fused compute tasks
return true
end
# every child must be in some input of the task
for child in node.children
str = Symbol(to_var_name(child.id))
@assert (str in task(node).t1_inputs) || (str in task(node).t2_inputs) "$str was not in any of the tasks' inputs\nt1_inputs: $(task(node).t1_inputs)\nt2_inputs: $(task(node).t2_inputs)"
end
return true
end
@ -53,9 +42,6 @@ This also calls [`is_valid_node(graph::DAG, node::Node)`](@ref).
function is_valid(graph::DAG, node::ComputeTaskNode)
@assert is_valid_node(graph, node)
#=for nf in node.nodeFusions
@assert is_valid(graph, nf)
end=#
return true
end
@ -69,8 +55,5 @@ This also calls [`is_valid_node(graph::DAG, node::Node)`](@ref).
function is_valid(graph::DAG, node::DataTaskNode)
@assert is_valid_node(graph, node)
#=if !ismissing(node.nodeFusion)
@assert is_valid(graph, node.nodeFusion)
end=#
return true
end

View File

@ -26,21 +26,6 @@ function apply_operation!(graph::DAG, operation::Operation)
return error("Unknown operation type!")
end
"""
apply_operation!(graph::DAG, operation::NodeFusion)
Apply the given [`NodeFusion`](@ref) to the graph. Generic wrapper around [`node_fusion!`](@ref).
Return an [`AppliedNodeFusion`](@ref) object generated from the graph's [`Diff`](@ref).
"""
function apply_operation!(graph::DAG, operation::NodeFusion)
diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3])
graph.properties += GraphProperties(diff)
return AppliedNodeFusion(operation, diff)
end
"""
apply_operation!(graph::DAG, operation::NodeReduction)
@ -80,20 +65,10 @@ function revert_operation!(graph::DAG, operation::AppliedOperation)
return error("Unknown operation type!")
end
"""
revert_operation!(graph::DAG, operation::AppliedNodeFusion)
Revert the applied node fusion on the graph. Return the original [`NodeFusion`](@ref) operation.
"""
function revert_operation!(graph::DAG, operation::AppliedNodeFusion)
revert_diff!(graph, operation.diff)
return operation.operation
end
"""
revert_operation!(graph::DAG, operation::AppliedNodeReduction)
Revert the applied node fusion on the graph. Return the original [`NodeReduction`](@ref) operation.
Revert the applied node reduction on the graph. Return the original [`NodeReduction`](@ref) operation.
"""
function revert_operation!(graph::DAG, operation::AppliedNodeReduction)
revert_diff!(graph, operation.diff)
@ -103,7 +78,7 @@ end
"""
revert_operation!(graph::DAG, operation::AppliedNodeSplit)
Revert the applied node fusion on the graph. Return the original [`NodeSplit`](@ref) operation.
Revert the applied node split on the graph. Return the original [`NodeSplit`](@ref) operation.
"""
function revert_operation!(graph::DAG, operation::AppliedNodeSplit)
revert_diff!(graph, operation.diff)
@ -132,88 +107,11 @@ function revert_diff!(graph::DAG, diff::Diff)
insert_edge!(graph, edge.edge[1], edge.edge[2], track = false)
end
for (node, t) in diff.updatedChildren
# node must be fused compute task at this point
@assert typeof(task(node)) <: FusedComputeTask
node.task = t
end
graph.properties -= GraphProperties(diff)
return nothing
end
"""
node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph.
For details see [`NodeFusion`](@ref).
"""
function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
@assert is_valid_node_fusion_input(graph, n1, n2, n3)
# clear snapshot
get_snapshot_diff(graph)
# save children and parents
n1Children = copy(children(n1))
n3Parents = copy(parents(n3))
n1Task = copy(task(n1))
n3Task = copy(task(n3))
# assemble the input node vectors of n1 and n3 to save into the FusedComputeTask
n1Inputs = Vector{Symbol}()
for child in n1Children
push!(n1Inputs, Symbol(to_var_name(child.id)))
end
# remove the edges and nodes that will be replaced by the fused node
remove_edge!(graph, n1, n2)
remove_edge!(graph, n2, n3)
remove_node!(graph, n1)
remove_node!(graph, n2)
# get n3's children now so it automatically excludes n2
n3Children = copy(children(n3))
n3Inputs = Vector{Symbol}()
for child in n3Children
push!(n3Inputs, Symbol(to_var_name(child.id)))
end
remove_node!(graph, n3)
# create new node with the fused compute task
newNode = ComputeTaskNode(FusedComputeTask(n1Task, n3Task, n1Inputs, Symbol(to_var_name(n2.id)), n3Inputs))
insert_node!(graph, newNode)
for child in n1Children
remove_edge!(graph, child, n1)
insert_edge!(graph, child, newNode)
end
for child in n3Children
remove_edge!(graph, child, n3)
if !(child in n1Children)
insert_edge!(graph, child, newNode)
end
end
for parent in n3Parents
remove_edge!(graph, n3, parent)
insert_edge!(graph, newNode, parent)
# important! update the parent node's child names in case they are fused compute tasks
# needed for compute generation so the fused compute task can correctly match inputs to its component tasks
update_child!(graph, parent, Symbol(to_var_name(n3.id)), Symbol(to_var_name(newNode.id)))
end
return get_snapshot_diff(graph)
end
"""
node_reduction!(graph::DAG, nodes::Vector{Node})
@ -265,7 +163,6 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
# this has to be done for all parents, even the ones of n1 because they can be duplicate
prevChild = newParentsChildNames[parent]
update_child!(graph, parent, prevChild, Symbol(to_var_name(n1.id)))
end
return get_snapshot_diff(graph)
@ -307,8 +204,6 @@ function node_split!(
for child in n1Children
insert_edge!(graph, child, nCopy)
end
update_child!(graph, parent, Symbol(to_var_name(n1.id)), Symbol(to_var_name(nCopy.id)))
end
return get_snapshot_diff(graph)

View File

@ -1,60 +1,5 @@
# These are functions for "cleaning" nodes, i.e. regenerating the possible operations for a node
"""
find_fusions!(graph::DAG, node::DataTaskNode)
Find node fusions involving the given data node. The function pushes the found [`NodeFusion`](@ref) (if any) everywhere it needs to be and returns nothing.
Does nothing if the node already has a node fusion set. Since it's a data node, only one node fusion can be possible with it.
"""
function find_fusions!(graph::DAG, node::DataTaskNode)
# if there is already a fusion here, skip to avoid duplicates
if !ismissing(node.nodeFusion)
return nothing
end
if length(parents(node)) != 1 || length(children(node)) != 1
return nothing
end
child_node = first(children(node))
parent_node = first(parents(node))
if !(child_node in graph) || !(parent_node in graph)
error("Parents/Children that are not in the graph!!!")
end
if length(parents(child_node)) != 1
return nothing
end
nf = NodeFusion((child_node, node, parent_node))
push!(graph.possibleOperations.nodeFusions, nf)
push!(child_node.nodeFusions, nf)
node.nodeFusion = nf
push!(parent_node.nodeFusions, nf)
return nothing
end
"""
find_fusions!(graph::DAG, node::ComputeTaskNode)
Find node fusions involving the given compute node. The function pushes the found [`NodeFusion`](@ref)s (if any) everywhere they need to be and returns nothing.
"""
function find_fusions!(graph::DAG, node::ComputeTaskNode)
# just find fusions in neighbouring DataTaskNodes
for child in children(node)
find_fusions!(graph, child)
end
for parent in parents(node)
find_fusions!(graph, parent)
end
return nothing
end
"""
find_reductions!(graph::DAG, node::Node)
@ -121,7 +66,7 @@ end
"""
clean_node!(graph::DAG, node::Node)
Sort this node's parent and child sets, then find fusions, reductions and splits involving it. Needs to be called after the node was changed in some way.
Sort this node's parent and child sets, then find reductions and splits involving it. Needs to be called after the node was changed in some way.
"""
function clean_node!(
graph::DAG,
@ -129,7 +74,6 @@ function clean_node!(
) where {TaskType <: AbstractTask}
sort_node!(node)
find_fusions!(graph, node)
find_reductions!(graph, node)
find_splits!(graph, node)

View File

@ -2,26 +2,6 @@
using Base.Threads
"""
insert_operation!(nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock})
Insert the given node fusion into its input nodes' operation caches. For the compute nodes, locking via the given `locks` is employed to have safe multi-threading. For a large set of nodes, contention on the locks should be very small.
"""
function insert_operation!(nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock})
n1 = nf.input[1]
n2 = nf.input[2]
n3 = nf.input[3]
lock(locks[n1]) do
return push!(nf.input[1].nodeFusions, nf)
end
n2.nodeFusion = nf
lock(locks[n3]) do
return push!(nf.input[3].nodeFusions, nf)
end
return nothing
end
"""
insert_operation!(nf::NodeReduction)
@ -72,41 +52,6 @@ function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Ve
return nothing
end
"""
nf_insertion!(graph::DAG, operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}})
Insert the node fusions into the graph and the nodes' caches. Employs multithreading for speedup.
"""
function nf_insertion!(graph::DAG, operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}})
total_len = 0
for vec in nodeFusions
total_len += length(vec)
end
sizehint!(operations.nodeFusions, total_len)
t = @task for vec in nodeFusions
union!(operations.nodeFusions, Set(vec))
end
schedule(t)
locks = Dict{ComputeTaskNode, SpinLock}()
for n in graph.nodes
if (typeof(n) <: ComputeTaskNode)
locks[n] = SpinLock()
end
end
@threads for vec in nodeFusions
for op in vec
insert_operation!(op, locks)
end
end
wait(t)
return nothing
end
"""
ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplits}})
@ -143,7 +88,6 @@ Generate all possible operations on the graph. Used initially when the graph is
Safely inserts all the found operations into the graph and its nodes.
"""
function generate_operations(graph::DAG)
generatedFusions = [Vector{NodeFusion}() for _ in 1:nthreads()]
generatedReductions = [Vector{NodeReduction}() for _ in 1:nthreads()]
generatedSplits = [Vector{NodeSplit}() for _ in 1:nthreads()]
@ -199,31 +143,6 @@ function generate_operations(graph::DAG)
# remove duplicates
nr_task = @spawn nr_insertion!(graph.possibleOperations, generatedReductions)
# --- find possible node fusions ---
@threads for node in nodeArray
if (typeof(node) <: DataTaskNode)
if length(parents(node)) != 1
# data node can only have a single parent
continue
end
parent_node = first(parents(node))
if length(children(node)) != 1
# this node is an entry node or has multiple children which should not be possible
continue
end
child_node = first(children(node))
if (length(parents(child_node)) != 1)
continue
end
push!(generatedFusions[threadid()], NodeFusion((child_node, node, parent_node)))
end
end
# launch thread for node fusion insertion
nf_task = @spawn nf_insertion!(graph, graph.possibleOperations, generatedFusions)
# find possible node splits
@threads for node in nodeArray
if (can_split(node))
@ -237,7 +156,6 @@ function generate_operations(graph::DAG)
empty!(graph.dirtyNodes)
wait(nr_task)
wait(nf_task)
wait(ns_task)
return nothing

View File

@ -2,8 +2,7 @@ import Base.iterate
const _POSSIBLE_OPERATIONS_FIELDS = fieldnames(PossibleOperations)
_POIteratorStateType =
NamedTuple{(:result, :state), Tuple{Union{NodeFusion, NodeReduction, NodeSplit}, Tuple{Symbol, Int64}}}
_POIteratorStateType = NamedTuple{(:result, :state), Tuple{Union{NodeReduction, NodeSplit}, Tuple{Symbol, Int64}}}
@inline function iterate(possibleOperations::PossibleOperations)::Union{Nothing, _POIteratorStateType}
for fieldname in _POSSIBLE_OPERATIONS_FIELDS

View File

@ -4,11 +4,6 @@
Print a string representation of the set of possible operations to io.
"""
function show(io::IO, ops::PossibleOperations)
print(io, length(ops.nodeFusions))
println(io, " Node Fusions: ")
for nf in ops.nodeFusions
println(io, " - ", nf)
end
print(io, length(ops.nodeReductions))
println(io, " Node Reductions: ")
for nr in ops.nodeReductions
@ -42,17 +37,3 @@ function show(io::IO, op::NodeSplit)
print(io, "NS: ")
return print(io, task(op.input))
end
"""
show(io::IO, op::NodeFusion)
Print a string representation of the node fusion to io.
"""
function show(io::IO, op::NodeFusion)
print(io, "NF: ")
print(io, task(op.input[1]))
print(io, "->")
print(io, task(op.input[2]))
print(io, "->")
return print(io, task(op.input[3]))
end

View File

@ -20,45 +20,6 @@ See also: [`revert_operation!`](@ref).
"""
abstract type AppliedOperation end
"""
NodeFusion <: Operation
The NodeFusion operation. Represents the fusing of a chain of compute node -> data node -> compute node.
After the node fusion is applied, the graph has 2 fewer nodes and edges, and a new [`FusedComputeTask`](@ref) with the two input compute nodes as parts.
# Requirements for successful application
A chain of (n1, n2, n3) can be fused if:
- All nodes are in the graph.
- (n1, n2) is an edge in the graph.
- (n2, n3) is an edge in the graph.
- n2 has exactly one parent (n3) and exactly one child (n1).
- n1 has exactly one parent (n2).
[`is_valid_node_fusion_input`](@ref) can be used to `@assert` these requirements.
See also: [`can_fuse`](@ref)
"""
struct NodeFusion{TaskType1 <: AbstractComputeTask, TaskType2 <: AbstractDataTask, TaskType3 <: AbstractComputeTask} <:
Operation
input::Tuple{ComputeTaskNode{TaskType1}, DataTaskNode{TaskType2}, ComputeTaskNode{TaskType3}}
end
"""
AppliedNodeFusion <: AppliedOperation
The applied version of the [`NodeFusion`](@ref).
"""
struct AppliedNodeFusion{
TaskType1 <: AbstractComputeTask,
TaskType2 <: AbstractDataTask,
TaskType3 <: AbstractComputeTask,
} <: AppliedOperation
operation::NodeFusion{TaskType1, TaskType2, TaskType3}
diff::Diff
end
"""
NodeReduction <: Operation

View File

@ -4,7 +4,7 @@
Return whether `operations` is empty, i.e. all of its fields are empty.
"""
function isempty(operations::PossibleOperations)
return isempty(operations.nodeFusions) && isempty(operations.nodeReductions) && isempty(operations.nodeSplits)
return isempty(operations.nodeReductions) && isempty(operations.nodeSplits)
end
"""
@ -13,21 +13,7 @@ end
Return a named tuple with the number of each of the operation types as a named tuple. The fields are named the same as the [`PossibleOperations`](@ref)'.
"""
function length(operations::PossibleOperations)
return (
nodeFusions = length(operations.nodeFusions),
nodeReductions = length(operations.nodeReductions),
nodeSplits = length(operations.nodeSplits),
)
end
"""
delete!(operations::PossibleOperations, op::NodeFusion)
Delete the given node fusion from the possible operations.
"""
function delete!(operations::PossibleOperations, op::NodeFusion)
delete!(operations.nodeFusions, op)
return operations
return (nodeReductions = length(operations.nodeReductions), nodeSplits = length(operations.nodeSplits))
end
"""
@ -50,24 +36,6 @@ function delete!(operations::PossibleOperations, op::NodeSplit)
return operations
end
"""
can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
Return whether the given nodes can be fused. See [`NodeFusion`](@ref) for the requirements.
"""
function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
if !is_child(n1, n2) || !is_child(n2, n3)
# the checks are redundant but maybe a good sanity check
return false
end
if length(parents(n2)) != 1 || length(children(n2)) != 1 || length(parents(n1)) != 1
return false
end
return true
end
"""
can_reduce(n1::Node, n2::Node)
@ -136,23 +104,6 @@ function ==(op1::Operation, op2::Operation)
return false
end
"""
==(op1::NodeFusion, op2::NodeFusion)
Equality comparison between two node fusions. Two node fusions are considered equal if they have the same inputs.
"""
function ==(
op1::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2},
op2::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2},
) where {
ComputeTaskType1 <: AbstractComputeTask,
DataTaskType <: AbstractDataTask,
ComputeTaskType2 <: AbstractComputeTask,
}
# there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same
return op1.input[2] == op2.input[2]
end
"""
==(op1::NodeReduction, op2::NodeReduction)

View File

@ -2,43 +2,6 @@
# should be called with @assert
# the functions throw their own errors though, to still have helpful error messages
"""
is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
Assert for a gven node fusion input whether the nodes can be fused. For the requirements of a node fusion see [`NodeFusion`](@ref).
Intended for use with `@assert` or `@test`.
"""
function is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
if !(n1 in graph) || !(n2 in graph) || !(n3 in graph)
throw(AssertionError("[Node Fusion] The given nodes are not part of the given graph"))
end
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
throw(
AssertionError(
"[Node Fusion] The given nodes are not connected by edges which is required for node fusion",
),
)
end
if length(n2.parents) > 1
throw(AssertionError("[Node Fusion] The given data node has more than one parent"))
end
if length(n2.children) > 1
throw(AssertionError("[Node Fusion] The given data node has more than one child"))
end
if length(n1.parents) > 1
throw(AssertionError("[Node Fusion] The given n1 has more than one parent"))
end
@assert is_valid(graph, n1)
@assert is_valid(graph, n2)
@assert is_valid(graph, n3)
return true
end
"""
is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node})
@ -131,16 +94,3 @@ function is_valid(graph::DAG, ns::NodeSplit)
#@assert ns in graph.possibleOperations.nodeSplits "NodeSplit is not part of the graph's possible operations!"
return true
end
"""
is_valid(graph::DAG, nr::NodeFusion)
Assert for a given [`NodeFusion`](@ref) whether it is a valid operation in the graph.
Intended for use with `@assert` or `@test`.
"""
function is_valid(graph::DAG, nf::NodeFusion)
@assert is_valid_node_fusion_input(graph, nf.input[1], nf.input[2], nf.input[3])
#@assert nf in graph.possibleOperations.nodeFusions "NodeFusion is not part of the graph's possible operations!"
return true
end

View File

@ -1,36 +0,0 @@
"""
FusionOptimizer
An optimizer that simply applies an available [`NodeFusion`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeFusion`](@ref)s in the graph.
See also: [`SplitOptimizer`](@ref), [`ReductionOptimizer`](@ref)
"""
struct FusionOptimizer <: AbstractOptimizer end
function optimize_step!(optimizer::FusionOptimizer, graph::DAG)
# generate all options
operations = get_operations(graph)
if fixpoint_reached(optimizer, graph)
return false
end
push_operation!(graph, first(operations.nodeFusions))
return true
end
function fixpoint_reached(optimizer::FusionOptimizer, graph::DAG)
operations = get_operations(graph)
return isempty(operations.nodeFusions)
end
function optimize_to_fixpoint!(optimizer::FusionOptimizer, graph::DAG)
while !fixpoint_reached(optimizer, graph)
optimize_step!(optimizer, graph)
end
return nothing
end
function String(::FusionOptimizer)
return "fusion_optimizer"
end

View File

@ -26,16 +26,12 @@ function optimize_step!(optimizer::RandomWalkOptimizer, graph::DAG)
if rand(r, Bool)
# push
# choose one of fuse/split/reduce
# TODO refactor fusions so they actually work
option = rand(r, 2:3)
if option == 1 && !isempty(operations.nodeFusions)
push_operation!(graph, rand(r, collect(operations.nodeFusions)))
return true
elseif option == 2 && !isempty(operations.nodeReductions)
# choose one of split/reduce
option = rand(r, 1:2)
if option == 1 && !isempty(operations.nodeReductions)
push_operation!(graph, rand(r, collect(operations.nodeReductions)))
return true
elseif option == 3 && !isempty(operations.nodeSplits)
elseif option == 2 && !isempty(operations.nodeSplits)
push_operation!(graph, rand(r, collect(operations.nodeSplits)))
return true
end

View File

@ -3,7 +3,7 @@
An optimizer that simply applies an available [`NodeReduction`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeReduction`](@ref)s in the graph.
See also: [`FusionOptimizer`](@ref), [`SplitOptimizer`](@ref)
See also: [`SplitOptimizer`](@ref)
"""
struct ReductionOptimizer <: AbstractOptimizer end

View File

@ -3,7 +3,7 @@
An optimizer that simply applies an available [`NodeSplit`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeSplit`](@ref)s in the graph.
See also: [`FusionOptimizer`](@ref), [`ReductionOptimizer`](@ref)
See also: [`ReductionOptimizer`](@ref)
"""
struct SplitOptimizer <: AbstractOptimizer end

View File

@ -1,31 +1,13 @@
using StaticArrays
"""
compute(t::FusedComputeTask, data)
Compute a [`FusedComputeTask`](@ref). This simply asserts false and should not be called. Fused Compute Tasks generate their expressions directly through the other tasks instead.
"""
function compute(t::FusedComputeTask, data...)
inter = compute(t.first_task)
return compute(t.second_task, inter, data2...)
end
"""
get_function_call(n::Node)
get_function_call(t::AbstractTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol)
For a node or a task together with necessary information, return a vector of [`FunctionCall`](@ref)s for the computation of the node or task.
For ordinary compute or data tasks the vector will contain exactly one element, for a [`FusedComputeTask`](@ref) there can be any number of tasks greater 1.
For ordinary compute or data tasks the vector will contain exactly one element.
"""
function get_function_call(t::FusedComputeTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol)
# sort out the symbols to the correct tasks
return [
get_function_call(t.first_task, device, t.t1_inputs, t.t1_output)...,
get_function_call(t.second_task, device, [t.t2_inputs..., t.t1_output], outSymbol)...,
]
end
function get_function_call(
t::CompTask,
device::AbstractDevice,

View File

@ -11,22 +11,3 @@ copy(t::AbstractDataTask) = error("Need to implement copying for your data tasks
Return a copy of the given compute task.
"""
copy(t::AbstractComputeTask) = typeof(t)()
"""
copy(t::FusedComputeTask)
Return a copy of th egiven [`FusedComputeTask`](@ref).
"""
function copy(t::FusedComputeTask)
return FusedComputeTask(copy(t.first_task), copy(t.second_task), copy(t.t1_inputs), t.t1_output, copy(t.t2_inputs))
end
function FusedComputeTask(
T1::Type{<:AbstractComputeTask},
T2::Type{<:AbstractComputeTask},
t1_inputs::Vector{String},
t1_output::String,
t2_inputs::Vector{String},
)
return FusedComputeTask(T1(), T2(), t1_inputs, t1_output, t2_inputs)
end

View File

@ -47,34 +47,9 @@ Return the number of children of a data task (always 1).
"""
children(::DataTask) = 1
"""
children(t::FusedComputeTask)
Return the number of children of a FusedComputeTask.
"""
function children(t::FusedComputeTask)
return length(union(Set(t.t1_inputs), Set(t.t2_inputs)))
end
"""
data(t::AbstractComputeTask)
Return the data of a compute task, always zero, regardless of the specific task.
"""
data(t::AbstractComputeTask)::Float64 = 0.0
"""
compute_effort(t::FusedComputeTask)
Return the compute effort of a fused compute task.
"""
function compute_effort(t::FusedComputeTask)::Float64
return compute_effort(t.first_task) + compute_effort(t.second_task)
end
"""
get_types(::FusedComputeTask{T1, T2})
Return a tuple of a the fused compute task's components' types.
"""
get_types(t::FusedComputeTask) = (typeof(t.first_task), typeof(t.second_task))

View File

@ -27,21 +27,3 @@ Task representing a specific data transfer.
struct DataTask <: AbstractDataTask
data::Float64
end
"""
FusedComputeTask{T1 <: AbstractComputeTask, T2 <: AbstractComputeTask} <: AbstractComputeTask
A fused compute task made up of the computation of first `T1` and then `T2`.
Also see: [`get_types`](@ref).
"""
struct FusedComputeTask <: AbstractComputeTask
first_task::AbstractComputeTask
second_task::AbstractComputeTask
# the names of the inputs for T1
t1_inputs::Vector{Symbol}
# output name of T1
t1_output::Symbol
# t2_inputs doesn't include the output of t1, that's implicit
t2_inputs::Vector{Symbol}
end

View File

@ -74,9 +74,6 @@ function mem(graph::DAG)
size += sizeof(graph.operationsToApply)
size += sizeof(graph.possibleOperations)
for op in graph.possibleOperations.nodeFusions
size += mem(op)
end
for op in graph.possibleOperations.nodeReductions
size += mem(op)
end

View File

@ -3,48 +3,15 @@ using Random
RNG = Random.MersenneTwister(321)
function test_known_graph(name::String, n, fusion_test = true)
function test_known_graph(name::String, n)
@testset "Test $name Graph ($n)" begin
graph = parse_dag(joinpath(@__DIR__, "..", "input", "$name.txt"), ABCModel())
props = get_properties(graph)
if (fusion_test)
test_node_fusion(graph)
end
test_random_walk(RNG, graph, n)
end
end
function test_node_fusion(g::DAG)
@testset "Test Node Fusion" begin
props = get_properties(g)
options = get_operations(g)
nodes_number = length(g.nodes)
data = props.data
compute_effort = props.computeEffort
while !isempty(options.nodeFusions)
fusion = first(options.nodeFusions)
@test typeof(fusion) <: NodeFusion
push_operation!(g, fusion)
props = get_properties(g)
@test props.data < data
@test props.computeEffort == compute_effort
nodes_number = length(g.nodes)
data = props.data
compute_effort = props.computeEffort
options = get_operations(g)
end
end
end
function test_random_walk(RNG, g::DAG, n::Int64)
@testset "Test Random Walk ($n)" begin
# the purpose here is to do "random" operations and reverse them again and validate that the graph stays the same and doesn't diverge
@ -60,13 +27,11 @@ function test_random_walk(RNG, g::DAG, n::Int64)
# push
opt = get_operations(g)
# choose one of fuse/split/reduce
option = rand(RNG, 1:3)
if option == 1 && !isempty(opt.nodeFusions)
push_operation!(g, rand(RNG, collect(opt.nodeFusions)))
elseif option == 2 && !isempty(opt.nodeReductions)
# choose one of split/reduce
option = rand(RNG, 1:2)
if option == 1 && !isempty(opt.nodeReductions)
push_operation!(g, rand(RNG, collect(opt.nodeReductions)))
elseif option == 3 && !isempty(opt.nodeSplits)
elseif option == 2 && !isempty(opt.nodeSplits)
push_operation!(g, rand(RNG, collect(opt.nodeSplits)))
else
i = i - 1
@ -91,4 +56,4 @@ end
test_known_graph("AB->AB", 10000)
test_known_graph("AB->ABBB", 10000)
test_known_graph("AB->ABBBBB", 1000, false)
test_known_graph("AB->ABBBBB", 1000)

View File

@ -61,17 +61,14 @@ insert_edge!(graph, CD, C1C, track = false)
opt = get_operations(graph)
@test length(opt) == (nodeFusions = 6, nodeReductions = 1, nodeSplits = 1)
#println("Initial State:\n", opt)
@test length(opt) == (nodeReductions = 1, nodeSplits = 1)
nr = first(opt.nodeReductions)
@test Set(nr.input) == Set([B1C_1, B1C_2])
push_operation!(graph, nr)
opt = get_operations(graph)
@test length(opt) == (nodeFusions = 4, nodeReductions = 1, nodeSplits = 1)
#println("After 1 Node Reduction:\n", opt)
@test length(opt) == (nodeReductions = 1, nodeSplits = 1)
nr = first(opt.nodeReductions)
@test Set(nr.input) == Set([B1D_1, B1D_2])
@ -80,19 +77,16 @@ opt = get_operations(graph)
@test is_valid(graph)
@test length(opt) == (nodeFusions = 4, nodeReductions = 0, nodeSplits = 1)
#println("After 2 Node Reductions:\n", opt)
@test length(opt) == (nodeReductions = 0, nodeSplits = 1)
pop_operation!(graph)
opt = get_operations(graph)
@test length(opt) == (nodeFusions = 4, nodeReductions = 1, nodeSplits = 1)
#println("After reverting the second Node Reduction:\n", opt)
@test length(opt) == (nodeReductions = 1, nodeSplits = 1)
reset_graph!(graph)
opt = get_operations(graph)
@test length(opt) == (nodeFusions = 6, nodeReductions = 1, nodeSplits = 1)
#println("After reverting to the initial state:\n", opt)
@test length(opt) == (nodeReductions = 1, nodeSplits = 1)
@test is_valid(graph)

View File

@ -1,5 +1,4 @@
using SafeTestsets
#=
@safetestset "Utility Unit Tests " begin
include("unit_tests_utility.jl")
@ -19,11 +18,9 @@ end
@safetestset "ABC-Model Unit Tests " begin
include("unit_tests_abcmodel.jl")
end
=#
@safetestset "QED-Model Unit Tests " begin
include("unit_tests_qedmodel.jl")
end
#=
@safetestset "QED Feynman Diagram Generation Tests" begin
include("unit_tests_qed_diagrams.jl")
end
@ -33,6 +30,7 @@ end
@safetestset "Graph Unit Tests " begin
include("unit_tests_graph.jl")
end
=#
@safetestset "Execution Unit Tests " begin
include("unit_tests_execution.jl")
end
@ -42,4 +40,3 @@ end
@safetestset "Known Graph Tests " begin
include("known_graphs.jl")
end
=#

View File

@ -1,5 +1,5 @@
using MetagraphOptimization
using QEDbase
using QEDcore
import MetagraphOptimization.interaction_result

View File

@ -1,16 +1,5 @@
using MetagraphOptimization
function test_op_specific(estimator, graph, nf::NodeFusion)
estimate = operation_effect(estimator, graph, nf)
data_reduce = data(nf.input[2].task)
@test isapprox(estimate.data, -data_reduce)
@test isapprox(estimate.computeEffort, 0; atol = eps(Float64))
@test isapprox(estimate.computeIntensity, 0; atol = eps(Float64))
return nothing
end
function test_op_specific(estimator, graph, nr::NodeReduction)
estimate = operation_effect(estimator, graph, nr)
@ -74,13 +63,9 @@ end
@testset "Operation Cost" begin
ops = get_operations(graph)
nfs = copy(ops.nodeFusions)
nrs = copy(ops.nodeReductions)
nss = copy(ops.nodeSplits)
for nf in nfs
test_op(estimator, graph, nf)
end
for nr in nrs
test_op(estimator, graph, nr)
end

View File

@ -1,5 +1,5 @@
using MetagraphOptimization
using QEDbase
using QEDcore
using AccurateArithmetic
using Random
using UUIDs
@ -63,8 +63,14 @@ machine = Machine(
)
process_2_2 = ABCProcessDescription(
Dict{Type, Int64}(ParticleA => 1, ParticleB => 1),
Dict{Type, Int64}(ParticleA => 1, ParticleB => 1),
Dict{Type, Int64}(
ParticleStateful{Incoming, ParticleA, SFourMomentum} => 1,
ParticleStateful{Incoming, ParticleB, SFourMomentum} => 1,
),
Dict{Type, Int64}(
ParticleStateful{Outgoing, ParticleA, SFourMomentum} => 1,
ParticleStateful{Outgoing, ParticleB, SFourMomentum} => 1,
),
)
particles_2_2 = ABCProcessInput(
@ -106,8 +112,14 @@ end
end
process_2_4 = ABCProcessDescription(
Dict{Type, Int64}(ParticleA => 1, ParticleB => 1),
Dict{Type, Int64}(ParticleA => 1, ParticleB => 3),
Dict{Type, Int64}(
ParticleStateful{Incoming, ParticleA, SFourMomentum} => 1,
ParticleStateful{Incoming, ParticleB, SFourMomentum} => 1,
),
Dict{Type, Int64}(
ParticleStateful{Outgoing, ParticleA, SFourMomentum} => 1,
ParticleStateful{Outgoing, ParticleB, SFourMomentum} => 3,
),
)
particles_2_4 = gen_process_input(process_2_4)
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"), ABCModel())
@ -136,105 +148,6 @@ TODO: fix precision(?) issues
end
=#
@testset "AB->AB large sum fusion" begin
for _ in 1:20
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
# push a fusion with the sum node
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, ComputeTaskABC_Sum)
push_operation!(graph, fusion)
break
end
end
# push two more fusions with the fused node
for _ in 1:15
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, FusedComputeTask)
push_operation!(graph, fusion)
break
end
end
end
# try execute
@test is_valid(graph)
expected_result = ground_truth_graph_result(particles_2_2)
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
end
end
@testset "AB->AB large sum fusion" begin
for _ in 1:20
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
# push a fusion with the sum node
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, ComputeTaskABC_Sum)
push_operation!(graph, fusion)
break
end
end
# push two more fusions with the fused node
for _ in 1:15
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[3].task, FusedComputeTask)
push_operation!(graph, fusion)
break
end
end
end
# try execute
@test is_valid(graph)
expected_result = ground_truth_graph_result(particles_2_2)
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
end
end
@testset "AB->AB fusion edge case" begin
for _ in 1:20
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
# push two fusions with ComputeTaskABC_V
for _ in 1:2
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[1].task, ComputeTaskABC_V)
push_operation!(graph, fusion)
break
end
end
end
# push fusions until the end
cont = true
while cont
cont = false
ops = get_operations(graph)
for fusion in ops.nodeFusions
if isa(fusion.input[1].task, FusedComputeTask)
push_operation!(graph, fusion)
cont = true
break
end
end
end
# try execute
@test is_valid(graph)
expected_result = ground_truth_graph_result(particles_2_2)
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
end
end
@testset "$(process) after random walk" for process in ["ke->ke", "ke->kke", "ke->kkke"]
process = parse_process("ke->kkke", QEDModel())
inputs = [gen_process_input(process) for _ in 1:100]

View File

@ -13,7 +13,7 @@ graph = MetagraphOptimization.DAG()
@test length(graph.operationsToApply) == 0
@test length(graph.dirtyNodes) == 0
@test length(graph.diff) == (addedNodes = 0, removedNodes = 0, addedEdges = 0, removedEdges = 0)
@test length(get_operations(graph)) == (nodeFusions = 0, nodeReductions = 0, nodeSplits = 0)
@test length(get_operations(graph)) == (nodeReductions = 0, nodeSplits = 0)
# s to output (exit node)
d_exit = insert_node!(graph, make_node(DataTask(10)), track = false)
@ -133,13 +133,10 @@ insert_edge!(graph, s0, d_exit, track = false)
@test length(siblings(s0)) == 1
operations = get_operations(graph)
@test length(operations) == (nodeFusions = 10, nodeReductions = 0, nodeSplits = 0)
@test length(operations) == (nodeReductions = 0, nodeSplits = 0)
@test length(graph.dirtyNodes) == 0
@test sum(length(operations)) == 10
@test operations == get_operations(graph)
nf = first(operations.nodeFusions)
properties = get_properties(graph)
@test properties.computeEffort == 28
@ -148,54 +145,19 @@ properties = get_properties(graph)
@test properties.noNodes == 26
@test properties.noEdges == 25
push_operation!(graph, nf)
# **does not immediately apply the operation**
@test length(graph.nodes) == 26
@test length(graph.appliedOperations) == 0
@test length(graph.operationsToApply) == 1
@test first(graph.operationsToApply) == nf
@test length(graph.dirtyNodes) == 0
@test length(graph.diff) == (addedNodes = 0, removedNodes = 0, addedEdges = 0, removedEdges = 0)
# this applies pending operations
properties = get_properties(graph)
@test length(graph.nodes) == 24
@test length(graph.appliedOperations) == 1
@test length(graph.operationsToApply) == 0
@test length(graph.dirtyNodes) != 0
@test properties.noNodes == 24
@test properties.noEdges == 23
@test properties.computeEffort == 28
@test properties.data < 62
@test properties.computeIntensity > 28 / 62
operations = get_operations(graph)
@test length(graph.dirtyNodes) == 0
@test length(operations) == (nodeFusions = 9, nodeReductions = 0, nodeSplits = 0)
@test !isempty(operations)
possibleNF = 9
while !isempty(operations.nodeFusions)
push_operation!(graph, first(operations.nodeFusions))
global operations = get_operations(graph)
global possibleNF = possibleNF - 1
@test length(operations) == (nodeFusions = possibleNF, nodeReductions = 0, nodeSplits = 0)
end
@test length(operations) == (nodeReductions = 0, nodeSplits = 0)
@test isempty(operations)
@test length(operations) == (nodeFusions = 0, nodeReductions = 0, nodeSplits = 0)
@test length(graph.dirtyNodes) == 0
@test length(graph.nodes) == 6
@test length(graph.appliedOperations) == 10
@test length(graph.nodes) == 26
@test length(graph.appliedOperations) == 0
@test length(graph.operationsToApply) == 0
reset_graph!(graph)
@test length(graph.dirtyNodes) == 26
@test length(graph.dirtyNodes) == 0
@test length(graph.nodes) == 26
@test length(graph.appliedOperations) == 0
@test length(graph.operationsToApply) == 0
@ -208,6 +170,6 @@ properties = get_properties(graph)
@test properties.computeIntensity 28 / 62
operations = get_operations(graph)
@test length(operations) == (nodeFusions = 10, nodeReductions = 0, nodeSplits = 0)
@test length(operations) == (nodeReductions = 0, nodeSplits = 0)
@test is_valid(graph)

View File

@ -6,8 +6,7 @@ RNG = Random.MersenneTwister(0)
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"), ABCModel())
# create the optimizers
FIXPOINT_OPTIMIZERS =
[GreedyOptimizer(GlobalMetricEstimator()), ReductionOptimizer(), SplitOptimizer(), FusionOptimizer()]
FIXPOINT_OPTIMIZERS = [GreedyOptimizer(GlobalMetricEstimator()), ReductionOptimizer(), SplitOptimizer()]
NO_FIXPOINT_OPTIMIZERS = [RandomWalkOptimizer(RNG)]
@testset "Optimizer $optimizer" for optimizer in vcat(NO_FIXPOINT_OPTIMIZERS, FIXPOINT_OPTIMIZERS)

View File

@ -1,7 +1,8 @@
using MetagraphOptimization
using QEDcore
import MetagraphOptimization.gen_diagrams
import MetagraphOptimization.isincoming
import MetagraphOptimization.types
@ -9,25 +10,12 @@ model = QEDModel()
compton = ("Compton Scattering", parse_process("ke->ke", model), 2)
compton_3 = ("3-Photon Compton Scattering", parse_process("kkke->ke", QEDModel()), 24)
compton_4 = ("4-Photon Compton Scattering", parse_process("kkkke->ke", QEDModel()), 120)
bhabha = ("Bhabha Scattering", parse_process("ep->ep", model), 2)
moller = ("Møller Scattering", parse_process("ee->ee", model), 2)
pair_production = ("Pair production", parse_process("kk->ep", model), 2)
pair_annihilation = ("Pair annihilation", parse_process("ep->kk", model), 2)
trident = ("Trident", parse_process("ke->epe", model), 8)
@testset "Known Processes" begin
@testset "$name" for (name, process, n) in
[compton, bhabha, moller, pair_production, pair_annihilation, trident, compton_3, compton_4]
@testset "$name" for (name, process, n) in [compton, compton_3, compton_4]
initial_diagram = FeynmanDiagram(process)
n_particles = number_incoming_particles(process) + number_outgoing_particles(process)
n_particles = 0
for type in types(model)
if (is_incoming(type))
n_particles += get(process.inParticles, type, 0)
else
n_particles += get(process.outParticles, type, 0)
end
end
@test n_particles == length(initial_diagram.particles)
@test ismissing(initial_diagram.tie[])
@test isempty(initial_diagram.vertices)

View File

@ -56,8 +56,8 @@ function compton_groundtruth(input::PhaseSpacePoint)
virt2_mom = p1.momentum + k1.momentum
@test isapprox(p2.momentum + k2.momentum, virt2_mom)
s_p2_k1 = propagator(Electron(), virt1_mom)
s_p1_k1 = propagator(Electron(), virt2_mom)
s_p2_k1 = QEDbase.propagator(Electron(), virt1_mom)
s_p1_k1 = QEDbase.propagator(Electron(), virt2_mom)
diagram1 = u_p2 * (eps_1 * QED_vertex()) * s_p2_k1 * (eps_2 * QED_vertex()) * u_p1
diagram2 = u_p2 * (eps_2 * QED_vertex()) * s_p1_k1 * (eps_1 * QED_vertex()) * u_p1
@ -65,7 +65,6 @@ function compton_groundtruth(input::PhaseSpacePoint)
return diagram1 + diagram2
end
@testset "Interaction Result" begin
import MetagraphOptimization.QED_conserve_momentum
@ -107,34 +106,26 @@ end
end
@testset "Parse Process" begin
@testset "Order invariance" begin
@test parse_process("ke->ke", QEDModel()) == parse_process("ek->ke", QEDModel())
@test parse_process("ke->ke", QEDModel()) == parse_process("ek->ek", QEDModel())
@test parse_process("ke->ke", QEDModel()) == parse_process("ke->ek", QEDModel())
@test parse_process("kkke->eep", QEDModel()) == parse_process("kkek->epe", QEDModel())
end
@testset "Known processes" begin
compton_process = GenericQEDProcess(1, 1, 1, 1, 0, 0)
proc = parse_process("ke->ke", QEDModel())
@test incoming_particles(proc) == (Photon(), Electron())
@test outgoing_particles(proc) == (Photon(), Electron())
@test parse_process("ke->ke", QEDModel()) == compton_process
proc = parse_process("kp->kp", QEDModel())
@test incoming_particles(proc) == (Photon(), Positron())
@test outgoing_particles(proc) == (Photon(), Positron())
positron_compton_process = GenericQEDProcess(1, 1, 0, 0, 1, 1)
proc = parse_process("ke->eep", QEDModel())
@test incoming_particles(proc) == (Photon(), Electron())
@test outgoing_particles(proc) == (Electron(), Electron(), Positron())
@test parse_process("kp->kp", QEDModel()) == positron_compton_process
proc = parse_process("kk->pe", QEDModel())
@test incoming_particles(proc) == (Photon(), Photon())
@test outgoing_particles(proc) == (Positron(), Electron())
trident_process = GenericQEDProcess(1, 0, 1, 2, 0, 1)
@test parse_process("ke->eep", QEDModel()) == trident_process
pair_production_process = GenericQEDProcess(2, 0, 0, 1, 0, 1)
@test parse_process("kk->pe", QEDModel()) == pair_production_process
pair_annihilation_process = GenericQEDProcess(0, 2, 1, 0, 1, 0)
@test parse_process("pe->kk", QEDModel()) == pair_annihilation_process
proc = parse_process("pe->kk", QEDModel())
@test incoming_particles(proc) == (Positron(), Electron())
@test outgoing_particles(proc) == (Photon(), Photon())
end
end
@ -145,32 +136,7 @@ end
for i in 1:100
input = gen_process_input(process)
@test length(input.inFerms) ==
get(process.inParticles, ParticleStateful{Incoming, Electron, SFourMomentum}, 0)
@test length(input.inAntiferms) ==
get(process.inParticles, ParticleStateful{Incoming, Positron, SFourMomentum}, 0)
@test length(input.inPhotons) ==
get(process.inParticles, ParticleStateful{Incoming, Photon, SFourMomentum}, 0)
@test length(input.outFerms) ==
get(process.outParticles, ParticleStateful{Outgoing, Electron, SFourMomentum}, 0)
@test length(input.outAntiferms) ==
get(process.outParticles, ParticleStateful{Outgoing, Positron, SFourMomentum}, 0)
@test length(input.outPhotons) ==
get(process.outParticles, ParticleStateful{Outgoing, Photon, SFourMomentum}, 0)
@test isapprox(
sum([
getfield.(input.inFerms, :momentum)...,
getfield.(input.inAntiferms, :momentum)...,
getfield.(input.inPhotons, :momentum)...,
]),
sum([
getfield.(input.outFerms, :momentum)...,
getfield.(input.outAntiferms, :momentum)...,
getfield.(input.outPhotons, :momentum)...,
]);
atol = sqrt(eps()),
)
@test isapprox(sum(momenta(input, Incoming())), sum(momenta(input, Outgoing())); atol = sqrt(eps()))
end
end
end
@ -202,97 +168,97 @@ end
graph = DAG()
# s to output (exit node)
d_exit = insert_node!(graph, make_node(DataTask(16)), track = false)
d_exit = insert_node!(graph, make_node(DataTask(16)); track=false)
sum_node = insert_node!(graph, make_node(ComputeTaskQED_Sum(2)), track = false)
sum_node = insert_node!(graph, make_node(ComputeTaskQED_Sum(2)); track=false)
d_s0_sum = insert_node!(graph, make_node(DataTask(16)), track = false)
d_s1_sum = insert_node!(graph, make_node(DataTask(16)), track = false)
d_s0_sum = insert_node!(graph, make_node(DataTask(16)); track=false)
d_s1_sum = insert_node!(graph, make_node(DataTask(16)); track=false)
# final s compute
s0 = insert_node!(graph, make_node(ComputeTaskQED_S2()), track = false)
s1 = insert_node!(graph, make_node(ComputeTaskQED_S2()), track = false)
s0 = insert_node!(graph, make_node(ComputeTaskQED_S2()); track=false)
s1 = insert_node!(graph, make_node(ComputeTaskQED_S2()); track=false)
# data from v0 and v1 to s0
d_v0_s0 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_v1_s0 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_v2_s1 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_v3_s1 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_v0_s0 = insert_node!(graph, make_node(DataTask(96)); track=false)
d_v1_s0 = insert_node!(graph, make_node(DataTask(96)); track=false)
d_v2_s1 = insert_node!(graph, make_node(DataTask(96)); track=false)
d_v3_s1 = insert_node!(graph, make_node(DataTask(96)); track=false)
# v0 and v1 compute
v0 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
v1 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
v2 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
v3 = insert_node!(graph, make_node(ComputeTaskQED_V()), track = false)
v0 = insert_node!(graph, make_node(ComputeTaskQED_V()); track=false)
v1 = insert_node!(graph, make_node(ComputeTaskQED_V()); track=false)
v2 = insert_node!(graph, make_node(ComputeTaskQED_V()); track=false)
v3 = insert_node!(graph, make_node(ComputeTaskQED_V()); track=false)
# data from uPhIn, uPhOut, uElIn, uElOut to v0 and v1
d_uPhIn_v0 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uElIn_v0 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uPhOut_v1 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uElOut_v1 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uPhIn_v0 = insert_node!(graph, make_node(DataTask(96)); track=false)
d_uElIn_v0 = insert_node!(graph, make_node(DataTask(96)); track=false)
d_uPhOut_v1 = insert_node!(graph, make_node(DataTask(96)); track=false)
d_uElOut_v1 = insert_node!(graph, make_node(DataTask(96)); track=false)
# data from uPhIn, uPhOut, uElIn, uElOut to v2 and v3
d_uPhOut_v2 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uElIn_v2 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uPhIn_v3 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uElOut_v3 = insert_node!(graph, make_node(DataTask(96)), track = false)
d_uPhOut_v2 = insert_node!(graph, make_node(DataTask(96)); track=false)
d_uElIn_v2 = insert_node!(graph, make_node(DataTask(96)); track=false)
d_uPhIn_v3 = insert_node!(graph, make_node(DataTask(96)); track=false)
d_uElOut_v3 = insert_node!(graph, make_node(DataTask(96)); track=false)
# uPhIn, uPhOut, uElIn and uElOut computes
uPhIn = insert_node!(graph, make_node(ComputeTaskQED_U()), track = false)
uPhOut = insert_node!(graph, make_node(ComputeTaskQED_U()), track = false)
uElIn = insert_node!(graph, make_node(ComputeTaskQED_U()), track = false)
uElOut = insert_node!(graph, make_node(ComputeTaskQED_U()), track = false)
uPhIn = insert_node!(graph, make_node(ComputeTaskQED_U()); track=false)
uPhOut = insert_node!(graph, make_node(ComputeTaskQED_U()); track=false)
uElIn = insert_node!(graph, make_node(ComputeTaskQED_U()); track=false)
uElOut = insert_node!(graph, make_node(ComputeTaskQED_U()); track=false)
# data into U
d_uPhIn = insert_node!(graph, make_node(DataTask(16), "ki1"), track = false)
d_uPhOut = insert_node!(graph, make_node(DataTask(16), "ko1"), track = false)
d_uElIn = insert_node!(graph, make_node(DataTask(16), "ei1"), track = false)
d_uElOut = insert_node!(graph, make_node(DataTask(16), "eo1"), track = false)
d_uPhIn = insert_node!(graph, make_node(DataTask(16), "ki1"); track=false)
d_uPhOut = insert_node!(graph, make_node(DataTask(16), "ko1"); track=false)
d_uElIn = insert_node!(graph, make_node(DataTask(16), "ei1"); track=false)
d_uElOut = insert_node!(graph, make_node(DataTask(16), "eo1"); track=false)
# now for all the edges
insert_edge!(graph, d_uPhIn, uPhIn, track = false)
insert_edge!(graph, d_uPhOut, uPhOut, track = false)
insert_edge!(graph, d_uElIn, uElIn, track = false)
insert_edge!(graph, d_uElOut, uElOut, track = false)
insert_edge!(graph, d_uPhIn, uPhIn; track=false)
insert_edge!(graph, d_uPhOut, uPhOut; track=false)
insert_edge!(graph, d_uElIn, uElIn; track=false)
insert_edge!(graph, d_uElOut, uElOut; track=false)
insert_edge!(graph, uPhIn, d_uPhIn_v0, track = false)
insert_edge!(graph, uPhOut, d_uPhOut_v1, track = false)
insert_edge!(graph, uElIn, d_uElIn_v0, track = false)
insert_edge!(graph, uElOut, d_uElOut_v1, track = false)
insert_edge!(graph, uPhIn, d_uPhIn_v0; track=false)
insert_edge!(graph, uPhOut, d_uPhOut_v1; track=false)
insert_edge!(graph, uElIn, d_uElIn_v0; track=false)
insert_edge!(graph, uElOut, d_uElOut_v1; track=false)
insert_edge!(graph, uPhIn, d_uPhIn_v3, track = false)
insert_edge!(graph, uPhOut, d_uPhOut_v2, track = false)
insert_edge!(graph, uElIn, d_uElIn_v2, track = false)
insert_edge!(graph, uElOut, d_uElOut_v3, track = false)
insert_edge!(graph, uPhIn, d_uPhIn_v3; track=false)
insert_edge!(graph, uPhOut, d_uPhOut_v2; track=false)
insert_edge!(graph, uElIn, d_uElIn_v2; track=false)
insert_edge!(graph, uElOut, d_uElOut_v3; track=false)
insert_edge!(graph, d_uPhIn_v0, v0, track = false)
insert_edge!(graph, d_uPhOut_v1, v1, track = false)
insert_edge!(graph, d_uElIn_v0, v0, track = false)
insert_edge!(graph, d_uElOut_v1, v1, track = false)
insert_edge!(graph, d_uPhIn_v0, v0; track=false)
insert_edge!(graph, d_uPhOut_v1, v1; track=false)
insert_edge!(graph, d_uElIn_v0, v0; track=false)
insert_edge!(graph, d_uElOut_v1, v1; track=false)
insert_edge!(graph, d_uPhIn_v3, v3, track = false)
insert_edge!(graph, d_uPhOut_v2, v2, track = false)
insert_edge!(graph, d_uElIn_v2, v2, track = false)
insert_edge!(graph, d_uElOut_v3, v3, track = false)
insert_edge!(graph, d_uPhIn_v3, v3; track=false)
insert_edge!(graph, d_uPhOut_v2, v2; track=false)
insert_edge!(graph, d_uElIn_v2, v2; track=false)
insert_edge!(graph, d_uElOut_v3, v3; track=false)
insert_edge!(graph, v0, d_v0_s0, track = false)
insert_edge!(graph, v1, d_v1_s0, track = false)
insert_edge!(graph, v2, d_v2_s1, track = false)
insert_edge!(graph, v3, d_v3_s1, track = false)
insert_edge!(graph, v0, d_v0_s0; track=false)
insert_edge!(graph, v1, d_v1_s0; track=false)
insert_edge!(graph, v2, d_v2_s1; track=false)
insert_edge!(graph, v3, d_v3_s1; track=false)
insert_edge!(graph, d_v0_s0, s0, track = false)
insert_edge!(graph, d_v1_s0, s0, track = false)
insert_edge!(graph, d_v0_s0, s0; track=false)
insert_edge!(graph, d_v1_s0, s0; track=false)
insert_edge!(graph, d_v2_s1, s1, track = false)
insert_edge!(graph, d_v3_s1, s1, track = false)
insert_edge!(graph, d_v2_s1, s1; track=false)
insert_edge!(graph, d_v3_s1, s1; track=false)
insert_edge!(graph, s0, d_s0_sum, track = false)
insert_edge!(graph, s1, d_s1_sum, track = false)
insert_edge!(graph, s0, d_s0_sum; track=false)
insert_edge!(graph, s1, d_s1_sum; track=false)
insert_edge!(graph, d_s0_sum, sum_node, track = false)
insert_edge!(graph, d_s1_sum, sum_node, track = false)
insert_edge!(graph, d_s0_sum, sum_node; track=false)
insert_edge!(graph, d_s1_sum, sum_node; track=false)
insert_edge!(graph, sum_node, d_exit, track = false)
insert_edge!(graph, sum_node, d_exit; track=false)
input = [gen_process_input(process) for _ in 1:1000]
@ -305,9 +271,12 @@ end
@test isapprox(compton_function.(input), compton_groundtruth.(input))
end
@testset "Equal results after optimization" for optimizer in
[ReductionOptimizer(), RandomWalkOptimizer(MersenneTwister(0))]
@testset "Process $proc_str" for proc_str in ["ke->ke", "kp->kp", "kk->ep", "ep->kk", "ke->kke", "ke->kkke"]
@testset "Equal results after optimization" for optimizer in [
ReductionOptimizer(), RandomWalkOptimizer(MersenneTwister(0))
]
@testset "Process $proc_str" for proc_str in [
"ke->ke", "kp->kp", "kk->ep", "ep->kk", "ke->kke", "ke->kkke"
]
model = QEDModel()
process = parse_process(proc_str, model)
machine = Machine(