Add accurate arithmetic for summation, fix order of input particles

This commit is contained in:
Anton Reinhard
2023-09-07 15:15:21 +02:00
parent 0f78053ccf
commit d1666de432
14 changed files with 183 additions and 66 deletions

View File

@ -5,12 +5,12 @@ function gen_code(graph::DAG)
sizehint!(code, length(graph.nodes))
nodeQueue = PriorityQueue{Node, Int}()
inputSyms = Vector{Symbol}()
inputSyms = Dict{String, Symbol}()
# use a priority equal to the number of unseen children -> 0 are nodes that can be added
for node in get_entry_nodes(graph)
enqueue!(nodeQueue, node => 0)
push!(inputSyms, Symbol("data_$(to_var_name(node.id))_in"))
push!(inputSyms, node.name => Symbol("data_$(to_var_name(node.id))_in"))
end
node = nothing
@ -39,16 +39,26 @@ function gen_code(graph::DAG)
)
end
function execute(generated_code, input::Vector{Particle})
function execute(generated_code, input::Dict{ParticleType, Vector{Particle}})
(code, inputSymbols, outputSymbol) = generated_code
@assert length(input) == length(inputSymbols)
assignInputs = Vector{Expr}()
for i in 1:length(input)
for (name, symbol) in inputSymbols
type = nothing
if startswith("A", name)
type = A
elseif startswith("B", name)
type = B
else
type = C
end
index = parse(Int, name[2:end])
push!(
assignInputs,
Meta.parse(
"$(inputSymbols[i]) = ParticleValue(Particle($(input[i]).P0, $(input[i]).P1, $(input[i]).P2, $(input[i]).P3, $(input[i]).m), 1.0)",
"$(symbol) = ParticleValue(Particle($(input[type][index]).P0, $(input[type][index]).P1, $(input[type][index]).P2, $(input[type][index]).P3, $(type)), 1.0)",
),
)
end
@ -61,17 +71,25 @@ function execute(generated_code, input::Vector{Particle})
return result
end
function execute(graph::DAG, input::Vector{Particle})
function execute(graph::DAG, input::Dict{ParticleType, Vector{Particle}})
(code, inputSymbols, outputSymbol) = gen_code(graph)
@assert length(input) == length(inputSymbols)
assignInputs = Vector{Expr}()
for i in 1:length(input)
for (name, symbol) in inputSymbols
type = nothing
if startswith(name, "A")
type = A
elseif startswith(name, "B")
type = B
else
type = C
end
index = parse(Int, name[2:end])
push!(
assignInputs,
Meta.parse(
"$(inputSymbols[i]) = ParticleValue(Particle($(input[i]).P0, $(input[i]).P1, $(input[i]).P2, $(input[i]).P3, $(input[i]).m), 1.0)",
"$(symbol) = ParticleValue(Particle($(input[type][index]).P0, $(input[type][index]).P1, $(input[type][index]).P2, $(input[type][index]).P3, $(type)), 1.0)",
),
)
end

View File

@ -144,6 +144,8 @@ function remove_edge!(
# 1: mute
pre_length1 = length(node1.parents)
pre_length2 = length(node2.children)
#TODO: filter is very slow
filter!(x -> x != node2, node1.parents)
filter!(x -> x != node1, node2.children)
@ -201,6 +203,7 @@ function invalidate_caches!(graph::DAG, operation::NodeFusion)
delete!(graph.possibleOperations, operation)
# delete the operation from all caches of nodes involved in the operation
# TODO: filter is very slow
filter!(!=(operation), operation.input[1].nodeFusions)
filter!(!=(operation), operation.input[3].nodeFusions)

View File

@ -23,6 +23,7 @@ end
Print the given graph to io. If there are too many nodes it will print only a summary of them.
"""
function show(io::IO, graph::DAG)
apply_all!(graph)
println(io, "Graph:")
print(io, " Nodes: ")

View File

@ -1,3 +1,4 @@
using AccurateArithmetic
# Compute Particle, nothing to be done (0 FLOP)
function compute(::ComputeTaskP, data::ParticleValue)
@ -40,7 +41,8 @@ function get_expression(
end
# compute final inner edge (no output particle)
function compute(::ComputeTaskS2, data1, data2)
function compute(::ComputeTaskS2, data1::ParticleValue, data2::ParticleValue)
# data1 and data2 particles should be equal in a calculation with valid inputs, so it doesn't matter which one is used for inner_edge()
return data1.v * inner_edge(data1.p) * data2.v
end
@ -56,7 +58,7 @@ function get_expression(
end
# compute inner edge
function compute(::ComputeTaskS1, data)
function compute(::ComputeTaskS1, data::ParticleValue)
return ParticleValue(data.p, data.v * inner_edge(data.p))
end
@ -65,7 +67,8 @@ function get_expression(::ComputeTaskS1, inSymbol::Symbol, outSymbol::Symbol)
end
function compute(::ComputeTaskSum, data::Vector{Float64})
return sum(data)
# use an error correcting sum since the vectors may get very large
return sum_kbn(data)
end
function get_expression(
@ -78,29 +81,70 @@ function get_expression(
end
end
function compute(t::FusedComputeTask, data)
@assert false "This is not implemented and should never be called"
end
# expects the inSymbols ordered
function get_expression(
t::FusedComputeTask,
inSymbols::Vector{Symbol},
outSymbol::Symbol,
)
(T1, T2) = get_types(t)
c1 = children(T1())
c2 = children(T2())
expr1 = nothing
expr2 = nothing
# TODO need to figure out how to know which inputs belong to which subtask
# since we order the vectors with the child nodes we can't just split
if (c1 == 1)
expr1 = get_expression(T1(), inSymbols[begin], :intermediate)
elseif (c1 == 2)
expr1 =
get_expression(T1(), inSymbols[begin], inSymbols[2], :intermediate)
else
expr1 = get_expression(T1(), inSymbols[begin:c1], :intermediate)
end
if (c2 == 1)
expr2 = get_expression(T2(), :intermediate, outSymbol)
elseif c2 == 2
expr2 =
get_expression(T2(), :intermediate, inSymbols[c1 + 1], outSymbol)
else
expr2 = get_expression(
T2(),
:intermediate * inSymbols[(c1 + 1):end],
outSymbol,
)
end
return Expr(:block, expr1, expr2)
end
function get_expression(node::ComputeTaskNode)
t = typeof(node.task)
@assert length(node.children) == children(node.task) || t <: ComputeTaskSum
if (t <: ComputeTaskU || t <: ComputeTaskP || t <: ComputeTaskS1) # single input
@assert length(node.children) == 1
symbolIn = Symbol("data_$(to_var_name(node.children[1].id))")
symbolOut = Symbol("data_$(to_var_name(node.id))")
return get_expression(t(), symbolIn, symbolOut)
elseif (t <: ComputeTaskS2 || t <: ComputeTaskV) # double input
@assert length(node.children) == 2
symbolIn1 = Symbol("data_$(to_var_name(node.children[1].id))")
symbolIn2 = Symbol("data_$(to_var_name(node.children[2].id))")
symbolOut = Symbol("data_$(to_var_name(node.id))")
return get_expression(t(), symbolIn1, symbolIn2, symbolOut)
elseif (t <: ComputeTaskSum) # vector input
@assert length(node.children) > 0
elseif (t <: ComputeTaskSum || t <: FusedComputeTask) # vector input
inSymbols = Vector{Symbol}()
for child in node.children
push!(inSymbols, Symbol("data_$(to_var_name(child.id))"))
end
outSymbol = Symbol("data_$(to_var_name(node.id))")
return get_expression(t(), inSymbols, outSymbol)
elseif (t <: FusedComputeTask)
# uuuuuh
else
error("Unknown compute task")
end

View File

@ -4,14 +4,16 @@
Return a randomly generated particle.
"""
function Particle(rng)
return Particle(
rand(rng, Float64),
rand(rng, Float64),
rand(rng, Float64),
rand(rng, Float64),
rand(rng, Float64),
)
function Particle(rng, type::ParticleType)
p1 = rand(rng, Float64)
p2 = rand(rng, Float64)
p3 = rand(rng, Float64)
m = mass(type)
p4 = sqrt(p1^2 + p2^2 + p3^2 + m^2)
return Particle(p1, p2, p3, p4, type)
end
"""
@ -19,12 +21,15 @@ end
Return a Vector of `n` randomly generated [`Particle`](@ref)s.
"""
function gen_particles(n::Int)
particles = Vector{Particle}()
sizehint!(particles, n)
function gen_particles(ns::Dict{ParticleType, Int})
particles = Dict{ParticleType, Vector{Particle}}()
rng = MersenneTwister(0)
for i in 1:n
push!(particles, Particle(rng))
for (type, n) in ns
particles[type] = Vector{Particle}()
for i in 1:n
push!(particles[type], Particle(rng, type))
end
end
return particles
end

View File

@ -86,7 +86,12 @@ function parse_abc(filename::String, verbose::Bool = false)
end
if occursin(regex_a, node)
# add nodes and edges for the state reading to u(P(Particle))
data_in = insert_node!(graph, make_node(DataTask(4)), false, false) # read particle data node
data_in = insert_node!(
graph,
make_node(DataTask(4), string(node)),
false,
false,
) # read particle data node
compute_P =
insert_node!(graph, make_node(ComputeTaskP()), false, false) # compute P node
data_Pu = insert_node!(graph, make_node(DataTask(6)), false, false) # transfer data from P to u

View File

@ -1,3 +1,7 @@
@enum ParticleType A = 1 B = 2 C = 3 ALL = 6
const PARTICLE_MASSES =
Dict{ParticleType, Float64}(A => 1.0, B => 1.0, C => 0.0)
struct Particle
P0::Float64
@ -5,7 +9,7 @@ struct Particle
P2::Float64
P3::Float64
m::Float64
type::ParticleType
end
struct ParticleValue
@ -13,12 +17,25 @@ struct ParticleValue
v::Float64
end
mass(t::ParticleType) = PARTICLE_MASSES[t]
function remaining_type(t1::ParticleType, t2::ParticleType)
@assert t1 != t2
if t1 != A && t2 != A
return A
elseif t1 != B && t2 != B
return B
else
return C
end
end
function square(p::Particle)
return p.P0 * p.P0 - p.P1 * p.P1 - p.P2 * p.P2 - p.P3 * p.P3
end
function inner_edge(p::Particle)
return 1.0 / (square(p) - p.m * p.m)
return 1.0 / (square(p) - mass(p.type) * mass(p.type))
end
function outer_edge(p::Particle)
@ -33,17 +50,13 @@ end
# calculate new particle from two given interacting ones
function preserve_momentum(p1::Particle, p2::Particle)
# TODO: is this correct?
p3 = Particle(
p1.P0 + p2.P0,
p1.P1 + p2.P1,
p1.P2 + p2.P2,
p1.P3 + p2.P3,
1.0,
remaining_type(p1.type, p2.type),
)
# m3 = sqrt(- PC * PC / c^2)
return p3
end

View File

@ -100,3 +100,17 @@ show(io::IO, t::ComputeTaskSum) = print("ComputeSum")
Copy the data task and return it.
"""
copy(t::DataTask) = DataTask(t.data)
children(::DataTask) = 1
children(::ComputeTaskS1) = 1
children(::ComputeTaskS2) = 2
children(::ComputeTaskP) = 1
children(::ComputeTaskU) = 1
children(::ComputeTaskV) = 2
# TODO: this is kind of bad because it means we can't fuse with a sum task
children(::ComputeTaskSum) = -1 # wildcard for "n" children
function children(t::FusedComputeTask)
(T1, T2) = get_types(t)
return children(T1()) + children(T2()) - 1 # one of the inputs is the output of T1 and thus not a child of the node
end

View File

@ -1,5 +1,5 @@
DataTaskNode(t::AbstractDataTask) = DataTaskNode(
DataTaskNode(t::AbstractDataTask, name = "") = DataTaskNode(
t,
Vector{Node}(),
Vector{Node}(),
@ -7,6 +7,7 @@ DataTaskNode(t::AbstractDataTask) = DataTaskNode(
missing,
missing,
missing,
name,
)
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
t,
@ -36,6 +37,7 @@ copy(n::DataTaskNode) = DataTaskNode(
copy(n.nodeReduction),
copy(n.nodeSplit),
copy(n.nodeFusion),
copy(n.name),
)
"""
@ -52,8 +54,8 @@ end
Construct and return a new [`DataTaskNode`](@ref) with the given task.
"""
function make_node(t::AbstractDataTask)
return DataTaskNode(t)
function make_node(t::AbstractDataTask, name::String = "")
return DataTaskNode(t, name)
end
"""

View File

@ -52,6 +52,9 @@ mutable struct DataTaskNode <: Node
# the node fusion involving this node, if it exists
nodeFusion::Union{Operation, Missing}
# for input nodes we need a name for the node to distinguish between them
name::String
end
"""

View File

@ -160,7 +160,6 @@ function node_fusion!(
# clear snapshot
get_snapshot_diff(graph)
# save children and parents
n1_children = children(n1)
n3_parents = parents(n3)
@ -181,26 +180,18 @@ function node_fusion!(
ComputeTaskNode(FusedComputeTask{typeof(n1.task), typeof(n3.task)}())
insert_node!(graph, new_node)
# use a set for combined children of n1 and n3 to not get duplicates
n1and3_children = Set{Node}()
# remove edges from n1 children to n1
for child in n1_children
remove_edge!(graph, child, n1)
push!(n1and3_children, child)
end
# remove edges from n3 children to n3
for child in n3_children
remove_edge!(graph, child, n3)
push!(n1and3_children, child)
end
for child in n1and3_children
insert_edge!(graph, child, new_node)
end
# "repoint" parents of n3 from new node
for child in n3_children
remove_edge!(graph, child, n3)
if !(child in n1_children)
insert_edge!(graph, child, new_node)
end
end
for parent in n3_parents
remove_edge!(graph, n3, parent)
insert_edge!(graph, new_node, parent)

View File

@ -71,14 +71,8 @@ function find_reductions!(graph::DAG, node::Node)
partners_ = partners(node)
delete!(partners_, node)
for partner in partners_
if partner graph.nodes
error("Partner is not part of the graph")
end
@assert partner in graph.nodes
if can_reduce(node, partner)
if Set(node.children) != Set(partner.children)
error("Not equal children")
end
if reductionVector === nothing
# only when there's at least one reduction partner, insert the vector
reductionVector = Vector{Node}()