This commit is contained in:
Anton Reinhard 2023-09-28 00:48:57 +02:00
parent 4b44eb5286
commit a69dd6018e
3 changed files with 54 additions and 36 deletions

View File

@ -56,7 +56,8 @@ end
function gen_input_assignment_code(
inputSymbols::Dict{String, Vector{Symbol}},
particles::Tuple{Vector{Particle}, Vector{Particle}},
inOutCount::Dict{ParticleType, Tuple{Int, Int}},
functionInputSymbol::Symbol = :input,
)
@assert !isempty(particles[1]) "Can't have 0 input particles!"
@assert !isempty(particles[2]) "Can't have 0 output particles!"
@ -92,23 +93,19 @@ function gen_input_assignment_code(
p = nothing
condition(x) = x.type == type
if (index > in_out_count[type][1])
index -= in_out_count[type][1]
@assert index <= in_out_count[type][2] "Too few particles of type $type in input particles for this process"
p = particles[2][findall(condition, particles[2])[index]]
p = "$(functionInputSymbol)[2][$(index)]"
else
p = particles[1][findall(condition, particles[1])[index]]
p = "$(functionInputSymbol)[1][$(index)]"
end
for symbol in symbols
push!(
assignInputs,
Meta.parse(
"$(symbol) = ParticleValue(Particle($(p.momentum), $(p.type)), 1.0)",
),
Meta.parse("$(symbol) = ParticleValue($p, 1.0)"),
)
end
end
@ -116,47 +113,51 @@ function gen_input_assignment_code(
return Expr(:block, assignInputs...)
end
"""
execute(generated_code, input::Dict{ParticleType, Vector{Particle}})
Execute the given `generated_code` (as returned by [`gen_code`](@ref)) on the given input particles.
"""
function execute(
generated_code,
function get_compute_function(
graph::DAG,
input::Tuple{Vector{Particle}, Vector{Particle}},
)
(code, inputSymbols, outputSymbol) = generated_code
(code, inputSymbols, outputSymbol) = gen_code(graph)
assignInputs = gen_input_assignment_code(inputSymbols, input)
eval(assignInputs)
eval(code)
assignInputs = gen_input_assignment_code(inputSymbols, input, :input)
eval(Meta.parse("result = $outputSymbol"))
return result
function_id = to_var_name(UUIDs.uuid1(rng[1]))
func = eval(
Meta.parse(
"function compute_$(function_id)(input::Tuple{Vector{Particle}, Vector{Particle}}) $assignInputs; $code; return $outputSymbol; end",
),
)
return func
end
"""
execute(graph::DAG, input::Dict{ParticleType, Vector{Particle}})
Execute the given `generated_code` (as returned by [`gen_code`](@ref)) on the given input particles.
Execute the code of the given `graph` on the given input particles.
The input particles should be sorted correctly into the dictionary to their according [`ParticleType`](@ref)s.
This is essentially shorthand for
```julia
graph = parse_abc(input_file)
particles = gen_particles(...)
compute_graph = get_compute_function(graph, particles)
result = compute_graph(particles)
```
See also: [`gen_particles`](@ref)
"""
function execute(graph::DAG, input::Tuple{Vector{Particle}, Vector{Particle}})
(code, inputSymbols, outputSymbol) = gen_code(graph)
assignInputs = gen_input_assignment_code(inputSymbols, input)
func = get_compute_function(graph, input)
result = 0
try
eval(assignInputs)
eval(code)
eval(Meta.parse("result = $outputSymbol"))
result = @eval $func($input)
catch e
println("Error while evaluating: $e")
# println("Assign Input Code:\n$assignInputs\n")
# println("Code:\n$code")
println("Function: $func")
@assert false
end

View File

@ -131,3 +131,20 @@ function preserve_momentum(p1::Particle, p2::Particle)
return p3
end
"""
type_from_name(name::String)
For a name of a particle, return the [`ParticleType`].
"""
function type_from_name(name::String)
if startswith(name, "A")
return A
elseif startswith(name, "B")
return B
elseif startswith(name, "C")
return C
else
throw("Invalid name for a particle in the ABC model")
end
end

View File

@ -32,12 +32,12 @@ include("../examples/profiling_utilities.jl")
rtol = 0.001,
)
code = MetagraphOptimization.gen_code(graph)
#=code = MetagraphOptimization.gen_code(graph)
@test isapprox(
execute(code, particles_2_2),
expected_result;
rtol = 0.001,
)
)=#
end
end
@ -68,19 +68,19 @@ include("../examples/profiling_utilities.jl")
rtol = 0.001,
)
code = MetagraphOptimization.gen_code(graph)
#=code = MetagraphOptimization.gen_code(graph)
@test isapprox(
execute(code, particles_2_4),
expected_result;
rtol = 0.001,
)
)=#
end
end
@testset "AB->ABBB after random walk" begin
for i in 1:10
for i in 1:20
graph = parse_abc(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"))
random_walk!(graph, 20)
random_walk!(graph, 100)
@test is_valid(graph)
@test isapprox(