Add tests for AB->ABBB execution and fix errors

This commit is contained in:
Anton Reinhard
2023-09-26 18:30:37 +02:00
parent 95f92f080c
commit 24ade323f0
7 changed files with 62 additions and 12 deletions

View File

@@ -153,8 +153,8 @@ function execute(graph::DAG, input::Tuple{Vector{Particle}, Vector{Particle}})
eval(Meta.parse("result = $outputSymbol"))
catch e
println("Error while evaluating: $e")
println("Assign Input Code:\n$assignInputs\n")
println("Code:\n$code")
# println("Assign Input Code:\n$assignInputs\n")
# println("Code:\n$code")
@assert false
end

View File

@@ -34,6 +34,7 @@ end
Return a vector of the graph's entry nodes.
"""
function get_entry_nodes(graph::DAG)
apply_all!(graph)
result = Vector{Node}()
for node in graph.nodes
if (is_entry_node(node))

View File

@@ -45,6 +45,14 @@ For valid inputs, both input particles should have the same momenta at this poin
12 FLOP.
"""
function compute(::ComputeTaskS2, data1::ParticleValue, data2::ParticleValue)
@assert isapprox(
abs(data1.p.momentum.E),
abs(data2.p.momentum.E),
rtol = 0.001,
) "E: $(data1.p.momentum.E) vs. $(data2.p.momentum.E)"
@assert isapprox(data1.p.momentum.px, -data2.p.momentum.px, rtol = 0.001) "px: $(data1.p.momentum.px) vs. $(data2.p.momentum.px)"
@assert isapprox(data1.p.momentum.py, -data2.p.momentum.py, rtol = 0.001) "py: $(data1.p.momentum.py) vs. $(data2.p.momentum.py)"
@assert isapprox(data1.p.momentum.pz, -data2.p.momentum.pz, rtol = 0.001) "pz: $(data1.p.momentum.pz) vs. $(data2.p.momentum.pz)"
return data1.v * inner_edge(data1.p) * data2.v
end
@@ -200,7 +208,7 @@ Generate and return code for a given [`ComputeTaskNode`](@ref).
"""
function get_expression(node::ComputeTaskNode)
t = typeof(node.task)
@assert length(node.children) == children(node.task) || t <: ComputeTaskSum "Node $(node) has inconsistent number of children"
# @assert length(node.children) == children(node.task) || t <: ComputeTaskSum "Node $(node) has inconsistent number of children"
if (t <: ComputeTaskU || t <: ComputeTaskP || t <: ComputeTaskS1) # single input
symbolIn = "data_$(to_var_name(node.children[1].id))"

View File

@@ -42,7 +42,10 @@ function gen_particles(
output_particles = Vector{Particle}()
final_momenta = generate_physical_massive_moms(rng, mass_sum, output_masses)
for (mom, type) in zip(final_momenta, out_particles)
push!(output_particles, Particle(mom, type))
push!(
output_particles,
Particle(SFourMomentum(-mom.E, mom.px, mom.py, mom.pz), type),
)
end
return (input_particles, output_particles)

View File

@@ -89,7 +89,7 @@ end
Return the factor of the inner edge with the given (virtual) particle.
Takes 10 effective FLOP. (3 here + 10 in square(p))
Takes 10 effective FLOP. (3 here + 7 in square(p))
"""
function inner_edge(p::Particle)
return 1.0 / (square(p) - mass(p.type) * mass(p.type))

View File

@@ -162,7 +162,7 @@ function node_fusion!(
n2::DataTaskNode,
n3::ComputeTaskNode,
)
# @assert is_valid_node_fusion_input(graph, n1, n2, n3)
@assert is_valid_node_fusion_input(graph, n1, n2, n3)
# clear snapshot
get_snapshot_diff(graph)
@@ -241,7 +241,7 @@ Reduce the given nodes together into one node, return the applied difference to
For details see [`NodeReduction`](@ref).
"""
function node_reduction!(graph::DAG, nodes::Vector{Node})
# @assert is_valid_node_reduction_input(graph, nodes)
@assert is_valid_node_reduction_input(graph, nodes)
# clear snapshot
get_snapshot_diff(graph)
@@ -301,7 +301,7 @@ Split the given node into one node per parent, return the applied difference to
For details see [`NodeSplit`](@ref).
"""
function node_split!(graph::DAG, n1::Node)
# @assert is_valid_node_split_input(graph, n1)
@assert is_valid_node_split_input(graph, n1)
# clear snapshot
get_snapshot_diff(graph)