Add tests for AB->ABBB execution and fix errors
This commit is contained in:
@@ -153,8 +153,8 @@ function execute(graph::DAG, input::Tuple{Vector{Particle}, Vector{Particle}})
|
||||
eval(Meta.parse("result = $outputSymbol"))
|
||||
catch e
|
||||
println("Error while evaluating: $e")
|
||||
println("Assign Input Code:\n$assignInputs\n")
|
||||
println("Code:\n$code")
|
||||
# println("Assign Input Code:\n$assignInputs\n")
|
||||
# println("Code:\n$code")
|
||||
@assert false
|
||||
end
|
||||
|
||||
|
@@ -34,6 +34,7 @@ end
|
||||
Return a vector of the graph's entry nodes.
|
||||
"""
|
||||
function get_entry_nodes(graph::DAG)
|
||||
apply_all!(graph)
|
||||
result = Vector{Node}()
|
||||
for node in graph.nodes
|
||||
if (is_entry_node(node))
|
||||
|
@@ -45,6 +45,14 @@ For valid inputs, both input particles should have the same momenta at this poin
|
||||
12 FLOP.
|
||||
"""
|
||||
function compute(::ComputeTaskS2, data1::ParticleValue, data2::ParticleValue)
|
||||
@assert isapprox(
|
||||
abs(data1.p.momentum.E),
|
||||
abs(data2.p.momentum.E),
|
||||
rtol = 0.001,
|
||||
) "E: $(data1.p.momentum.E) vs. $(data2.p.momentum.E)"
|
||||
@assert isapprox(data1.p.momentum.px, -data2.p.momentum.px, rtol = 0.001) "px: $(data1.p.momentum.px) vs. $(data2.p.momentum.px)"
|
||||
@assert isapprox(data1.p.momentum.py, -data2.p.momentum.py, rtol = 0.001) "py: $(data1.p.momentum.py) vs. $(data2.p.momentum.py)"
|
||||
@assert isapprox(data1.p.momentum.pz, -data2.p.momentum.pz, rtol = 0.001) "pz: $(data1.p.momentum.pz) vs. $(data2.p.momentum.pz)"
|
||||
return data1.v * inner_edge(data1.p) * data2.v
|
||||
end
|
||||
|
||||
@@ -200,7 +208,7 @@ Generate and return code for a given [`ComputeTaskNode`](@ref).
|
||||
"""
|
||||
function get_expression(node::ComputeTaskNode)
|
||||
t = typeof(node.task)
|
||||
@assert length(node.children) == children(node.task) || t <: ComputeTaskSum "Node $(node) has inconsistent number of children"
|
||||
# @assert length(node.children) == children(node.task) || t <: ComputeTaskSum "Node $(node) has inconsistent number of children"
|
||||
|
||||
if (t <: ComputeTaskU || t <: ComputeTaskP || t <: ComputeTaskS1) # single input
|
||||
symbolIn = "data_$(to_var_name(node.children[1].id))"
|
||||
|
@@ -42,7 +42,10 @@ function gen_particles(
|
||||
output_particles = Vector{Particle}()
|
||||
final_momenta = generate_physical_massive_moms(rng, mass_sum, output_masses)
|
||||
for (mom, type) in zip(final_momenta, out_particles)
|
||||
push!(output_particles, Particle(mom, type))
|
||||
push!(
|
||||
output_particles,
|
||||
Particle(SFourMomentum(-mom.E, mom.px, mom.py, mom.pz), type),
|
||||
)
|
||||
end
|
||||
|
||||
return (input_particles, output_particles)
|
||||
|
@@ -89,7 +89,7 @@ end
|
||||
|
||||
Return the factor of the inner edge with the given (virtual) particle.
|
||||
|
||||
Takes 10 effective FLOP. (3 here + 10 in square(p))
|
||||
Takes 10 effective FLOP. (3 here + 7 in square(p))
|
||||
"""
|
||||
function inner_edge(p::Particle)
|
||||
return 1.0 / (square(p) - mass(p.type) * mass(p.type))
|
||||
|
@@ -162,7 +162,7 @@ function node_fusion!(
|
||||
n2::DataTaskNode,
|
||||
n3::ComputeTaskNode,
|
||||
)
|
||||
# @assert is_valid_node_fusion_input(graph, n1, n2, n3)
|
||||
@assert is_valid_node_fusion_input(graph, n1, n2, n3)
|
||||
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
@@ -241,7 +241,7 @@ Reduce the given nodes together into one node, return the applied difference to
|
||||
For details see [`NodeReduction`](@ref).
|
||||
"""
|
||||
function node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
# @assert is_valid_node_reduction_input(graph, nodes)
|
||||
@assert is_valid_node_reduction_input(graph, nodes)
|
||||
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
@@ -301,7 +301,7 @@ Split the given node into one node per parent, return the applied difference to
|
||||
For details see [`NodeSplit`](@ref).
|
||||
"""
|
||||
function node_split!(graph::DAG, n1::Node)
|
||||
# @assert is_valid_node_split_input(graph, n1)
|
||||
@assert is_valid_node_split_input(graph, n1)
|
||||
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
|
Reference in New Issue
Block a user