Add tests for AB->ABBB execution and fix errors

This commit is contained in:
Anton Reinhard 2023-09-26 18:30:37 +02:00
parent 95f92f080c
commit 24ade323f0
7 changed files with 62 additions and 12 deletions

View File

@ -153,8 +153,8 @@ function execute(graph::DAG, input::Tuple{Vector{Particle}, Vector{Particle}})
eval(Meta.parse("result = $outputSymbol"))
catch e
println("Error while evaluating: $e")
println("Assign Input Code:\n$assignInputs\n")
println("Code:\n$code")
# println("Assign Input Code:\n$assignInputs\n")
# println("Code:\n$code")
@assert false
end

View File

@ -34,6 +34,7 @@ end
Return a vector of the graph's entry nodes.
"""
function get_entry_nodes(graph::DAG)
apply_all!(graph)
result = Vector{Node}()
for node in graph.nodes
if (is_entry_node(node))

View File

@ -45,6 +45,14 @@ For valid inputs, both input particles should have the same momenta at this poin
12 FLOP.
"""
function compute(::ComputeTaskS2, data1::ParticleValue, data2::ParticleValue)
@assert isapprox(
abs(data1.p.momentum.E),
abs(data2.p.momentum.E),
rtol = 0.001,
) "E: $(data1.p.momentum.E) vs. $(data2.p.momentum.E)"
@assert isapprox(data1.p.momentum.px, -data2.p.momentum.px, rtol = 0.001) "px: $(data1.p.momentum.px) vs. $(data2.p.momentum.px)"
@assert isapprox(data1.p.momentum.py, -data2.p.momentum.py, rtol = 0.001) "py: $(data1.p.momentum.py) vs. $(data2.p.momentum.py)"
@assert isapprox(data1.p.momentum.pz, -data2.p.momentum.pz, rtol = 0.001) "pz: $(data1.p.momentum.pz) vs. $(data2.p.momentum.pz)"
return data1.v * inner_edge(data1.p) * data2.v
end
@ -200,7 +208,7 @@ Generate and return code for a given [`ComputeTaskNode`](@ref).
"""
function get_expression(node::ComputeTaskNode)
t = typeof(node.task)
@assert length(node.children) == children(node.task) || t <: ComputeTaskSum "Node $(node) has inconsistent number of children"
# @assert length(node.children) == children(node.task) || t <: ComputeTaskSum "Node $(node) has inconsistent number of children"
if (t <: ComputeTaskU || t <: ComputeTaskP || t <: ComputeTaskS1) # single input
symbolIn = "data_$(to_var_name(node.children[1].id))"

View File

@ -42,7 +42,10 @@ function gen_particles(
output_particles = Vector{Particle}()
final_momenta = generate_physical_massive_moms(rng, mass_sum, output_masses)
for (mom, type) in zip(final_momenta, out_particles)
push!(output_particles, Particle(mom, type))
push!(
output_particles,
Particle(SFourMomentum(-mom.E, mom.px, mom.py, mom.pz), type),
)
end
return (input_particles, output_particles)

View File

@ -89,7 +89,7 @@ end
Return the factor of the inner edge with the given (virtual) particle.
Takes 10 effective FLOP. (3 here + 10 in square(p))
Takes 10 effective FLOP. (3 here + 7 in square(p))
"""
function inner_edge(p::Particle)
return 1.0 / (square(p) - mass(p.type) * mass(p.type))

View File

@ -162,7 +162,7 @@ function node_fusion!(
n2::DataTaskNode,
n3::ComputeTaskNode,
)
# @assert is_valid_node_fusion_input(graph, n1, n2, n3)
@assert is_valid_node_fusion_input(graph, n1, n2, n3)
# clear snapshot
get_snapshot_diff(graph)
@ -241,7 +241,7 @@ Reduce the given nodes together into one node, return the applied difference to
For details see [`NodeReduction`](@ref).
"""
function node_reduction!(graph::DAG, nodes::Vector{Node})
# @assert is_valid_node_reduction_input(graph, nodes)
@assert is_valid_node_reduction_input(graph, nodes)
# clear snapshot
get_snapshot_diff(graph)
@ -301,7 +301,7 @@ Split the given node into one node per parent, return the applied difference to
For details see [`NodeSplit`](@ref).
"""
function node_split!(graph::DAG, n1::Node)
# @assert is_valid_node_split_input(graph, n1)
@assert is_valid_node_split_input(graph, n1)
# clear snapshot
get_snapshot_diff(graph)

View File

@ -7,7 +7,7 @@ using QEDbase
include("../examples/profiling_utilities.jl")
@testset "Unit Tests Execution" begin
particles = Tuple{Vector{Particle}, Vector{Particle}}((
particles_2_2 = Tuple{Vector{Particle}, Vector{Particle}}((
[
Particle(SFourMomentum(0.823648, 0.0, 0.0, 0.823648), A),
Particle(SFourMomentum(0.823648, 0.0, 0.0, -0.823648), B),
@ -27,14 +27,14 @@ include("../examples/profiling_utilities.jl")
for _ in 1:10 # test in a loop because graph layout should not change the result
graph = parse_abc(joinpath(@__DIR__, "..", "input", "AB->AB.txt"))
@test isapprox(
execute(graph, particles),
execute(graph, particles_2_2),
expected_result;
rtol = 0.001,
)
code = MetagraphOptimization.gen_code(graph)
@test isapprox(
execute(code, particles),
execute(code, particles_2_2),
expected_result;
rtol = 0.001,
)
@ -45,13 +45,51 @@ include("../examples/profiling_utilities.jl")
for i in 1:100
graph = parse_abc(joinpath(@__DIR__, "..", "input", "AB->AB.txt"))
random_walk!(graph, 50)
@test is_valid(graph)
@test isapprox(
execute(graph, particles),
execute(graph, particles_2_2),
expected_result;
rtol = 0.001,
)
end
end
particles_2_4 = gen_particles([A, B], [A, B, B, B])
graph = parse_abc(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"))
expected_result = execute(graph, particles_2_4)
@testset "AB->ABBB no optimization" begin
for _ in 1:5 # test in a loop because graph layout should not change the result
graph = parse_abc(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"))
@test isapprox(
execute(graph, particles_2_4),
expected_result;
rtol = 0.001,
)
code = MetagraphOptimization.gen_code(graph)
@test isapprox(
execute(code, particles_2_4),
expected_result;
rtol = 0.001,
)
end
end
@testset "AB->ABBB after random walk" begin
for i in 1:10
graph = parse_abc(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"))
random_walk!(graph, 20)
@test is_valid(graph)
@test isapprox(
execute(graph, particles_2_4),
expected_result;
rtol = 0.001,
)
end
end
end
println("Execution Unit Tests Complete!")