Add accurate arithmetic for summation, fix order of input particles
This commit is contained in:
@ -5,12 +5,12 @@ function gen_code(graph::DAG)
|
||||
sizehint!(code, length(graph.nodes))
|
||||
|
||||
nodeQueue = PriorityQueue{Node, Int}()
|
||||
inputSyms = Vector{Symbol}()
|
||||
inputSyms = Dict{String, Symbol}()
|
||||
|
||||
# use a priority equal to the number of unseen children -> 0 are nodes that can be added
|
||||
for node in get_entry_nodes(graph)
|
||||
enqueue!(nodeQueue, node => 0)
|
||||
push!(inputSyms, Symbol("data_$(to_var_name(node.id))_in"))
|
||||
push!(inputSyms, node.name => Symbol("data_$(to_var_name(node.id))_in"))
|
||||
end
|
||||
|
||||
node = nothing
|
||||
@ -39,16 +39,26 @@ function gen_code(graph::DAG)
|
||||
)
|
||||
end
|
||||
|
||||
function execute(generated_code, input::Vector{Particle})
|
||||
function execute(generated_code, input::Dict{ParticleType, Vector{Particle}})
|
||||
(code, inputSymbols, outputSymbol) = generated_code
|
||||
@assert length(input) == length(inputSymbols)
|
||||
|
||||
assignInputs = Vector{Expr}()
|
||||
for i in 1:length(input)
|
||||
for (name, symbol) in inputSymbols
|
||||
type = nothing
|
||||
if startswith("A", name)
|
||||
type = A
|
||||
elseif startswith("B", name)
|
||||
type = B
|
||||
else
|
||||
type = C
|
||||
end
|
||||
index = parse(Int, name[2:end])
|
||||
|
||||
push!(
|
||||
assignInputs,
|
||||
Meta.parse(
|
||||
"$(inputSymbols[i]) = ParticleValue(Particle($(input[i]).P0, $(input[i]).P1, $(input[i]).P2, $(input[i]).P3, $(input[i]).m), 1.0)",
|
||||
"$(symbol) = ParticleValue(Particle($(input[type][index]).P0, $(input[type][index]).P1, $(input[type][index]).P2, $(input[type][index]).P3, $(type)), 1.0)",
|
||||
),
|
||||
)
|
||||
end
|
||||
@ -61,17 +71,25 @@ function execute(generated_code, input::Vector{Particle})
|
||||
return result
|
||||
end
|
||||
|
||||
function execute(graph::DAG, input::Vector{Particle})
|
||||
function execute(graph::DAG, input::Dict{ParticleType, Vector{Particle}})
|
||||
(code, inputSymbols, outputSymbol) = gen_code(graph)
|
||||
|
||||
@assert length(input) == length(inputSymbols)
|
||||
|
||||
assignInputs = Vector{Expr}()
|
||||
for i in 1:length(input)
|
||||
for (name, symbol) in inputSymbols
|
||||
type = nothing
|
||||
if startswith(name, "A")
|
||||
type = A
|
||||
elseif startswith(name, "B")
|
||||
type = B
|
||||
else
|
||||
type = C
|
||||
end
|
||||
index = parse(Int, name[2:end])
|
||||
|
||||
push!(
|
||||
assignInputs,
|
||||
Meta.parse(
|
||||
"$(inputSymbols[i]) = ParticleValue(Particle($(input[i]).P0, $(input[i]).P1, $(input[i]).P2, $(input[i]).P3, $(input[i]).m), 1.0)",
|
||||
"$(symbol) = ParticleValue(Particle($(input[type][index]).P0, $(input[type][index]).P1, $(input[type][index]).P2, $(input[type][index]).P3, $(type)), 1.0)",
|
||||
),
|
||||
)
|
||||
end
|
||||
|
@ -144,6 +144,8 @@ function remove_edge!(
|
||||
# 1: mute
|
||||
pre_length1 = length(node1.parents)
|
||||
pre_length2 = length(node2.children)
|
||||
|
||||
#TODO: filter is very slow
|
||||
filter!(x -> x != node2, node1.parents)
|
||||
filter!(x -> x != node1, node2.children)
|
||||
|
||||
@ -201,6 +203,7 @@ function invalidate_caches!(graph::DAG, operation::NodeFusion)
|
||||
delete!(graph.possibleOperations, operation)
|
||||
|
||||
# delete the operation from all caches of nodes involved in the operation
|
||||
# TODO: filter is very slow
|
||||
filter!(!=(operation), operation.input[1].nodeFusions)
|
||||
filter!(!=(operation), operation.input[3].nodeFusions)
|
||||
|
||||
|
@ -23,6 +23,7 @@ end
|
||||
Print the given graph to io. If there are too many nodes it will print only a summary of them.
|
||||
"""
|
||||
function show(io::IO, graph::DAG)
|
||||
apply_all!(graph)
|
||||
println(io, "Graph:")
|
||||
print(io, " Nodes: ")
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
using AccurateArithmetic
|
||||
|
||||
# Compute Particle, nothing to be done (0 FLOP)
|
||||
function compute(::ComputeTaskP, data::ParticleValue)
|
||||
@ -40,7 +41,8 @@ function get_expression(
|
||||
end
|
||||
|
||||
# compute final inner edge (no output particle)
|
||||
function compute(::ComputeTaskS2, data1, data2)
|
||||
function compute(::ComputeTaskS2, data1::ParticleValue, data2::ParticleValue)
|
||||
# data1 and data2 particles should be equal in a calculation with valid inputs, so it doesn't matter which one is used for inner_edge()
|
||||
return data1.v * inner_edge(data1.p) * data2.v
|
||||
end
|
||||
|
||||
@ -56,7 +58,7 @@ function get_expression(
|
||||
end
|
||||
|
||||
# compute inner edge
|
||||
function compute(::ComputeTaskS1, data)
|
||||
function compute(::ComputeTaskS1, data::ParticleValue)
|
||||
return ParticleValue(data.p, data.v * inner_edge(data.p))
|
||||
end
|
||||
|
||||
@ -65,7 +67,8 @@ function get_expression(::ComputeTaskS1, inSymbol::Symbol, outSymbol::Symbol)
|
||||
end
|
||||
|
||||
function compute(::ComputeTaskSum, data::Vector{Float64})
|
||||
return sum(data)
|
||||
# use an error correcting sum since the vectors may get very large
|
||||
return sum_kbn(data)
|
||||
end
|
||||
|
||||
function get_expression(
|
||||
@ -78,29 +81,70 @@ function get_expression(
|
||||
end
|
||||
end
|
||||
|
||||
function compute(t::FusedComputeTask, data)
|
||||
@assert false "This is not implemented and should never be called"
|
||||
end
|
||||
|
||||
# expects the inSymbols ordered
|
||||
function get_expression(
|
||||
t::FusedComputeTask,
|
||||
inSymbols::Vector{Symbol},
|
||||
outSymbol::Symbol,
|
||||
)
|
||||
(T1, T2) = get_types(t)
|
||||
c1 = children(T1())
|
||||
c2 = children(T2())
|
||||
|
||||
expr1 = nothing
|
||||
expr2 = nothing
|
||||
|
||||
# TODO need to figure out how to know which inputs belong to which subtask
|
||||
# since we order the vectors with the child nodes we can't just split
|
||||
if (c1 == 1)
|
||||
expr1 = get_expression(T1(), inSymbols[begin], :intermediate)
|
||||
elseif (c1 == 2)
|
||||
expr1 =
|
||||
get_expression(T1(), inSymbols[begin], inSymbols[2], :intermediate)
|
||||
else
|
||||
expr1 = get_expression(T1(), inSymbols[begin:c1], :intermediate)
|
||||
end
|
||||
|
||||
if (c2 == 1)
|
||||
expr2 = get_expression(T2(), :intermediate, outSymbol)
|
||||
elseif c2 == 2
|
||||
expr2 =
|
||||
get_expression(T2(), :intermediate, inSymbols[c1 + 1], outSymbol)
|
||||
else
|
||||
expr2 = get_expression(
|
||||
T2(),
|
||||
:intermediate * inSymbols[(c1 + 1):end],
|
||||
outSymbol,
|
||||
)
|
||||
end
|
||||
|
||||
return Expr(:block, expr1, expr2)
|
||||
end
|
||||
|
||||
function get_expression(node::ComputeTaskNode)
|
||||
t = typeof(node.task)
|
||||
@assert length(node.children) == children(node.task) || t <: ComputeTaskSum
|
||||
|
||||
if (t <: ComputeTaskU || t <: ComputeTaskP || t <: ComputeTaskS1) # single input
|
||||
@assert length(node.children) == 1
|
||||
symbolIn = Symbol("data_$(to_var_name(node.children[1].id))")
|
||||
symbolOut = Symbol("data_$(to_var_name(node.id))")
|
||||
return get_expression(t(), symbolIn, symbolOut)
|
||||
elseif (t <: ComputeTaskS2 || t <: ComputeTaskV) # double input
|
||||
@assert length(node.children) == 2
|
||||
symbolIn1 = Symbol("data_$(to_var_name(node.children[1].id))")
|
||||
symbolIn2 = Symbol("data_$(to_var_name(node.children[2].id))")
|
||||
symbolOut = Symbol("data_$(to_var_name(node.id))")
|
||||
return get_expression(t(), symbolIn1, symbolIn2, symbolOut)
|
||||
elseif (t <: ComputeTaskSum) # vector input
|
||||
@assert length(node.children) > 0
|
||||
elseif (t <: ComputeTaskSum || t <: FusedComputeTask) # vector input
|
||||
inSymbols = Vector{Symbol}()
|
||||
for child in node.children
|
||||
push!(inSymbols, Symbol("data_$(to_var_name(child.id))"))
|
||||
end
|
||||
outSymbol = Symbol("data_$(to_var_name(node.id))")
|
||||
return get_expression(t(), inSymbols, outSymbol)
|
||||
elseif (t <: FusedComputeTask)
|
||||
# uuuuuh
|
||||
else
|
||||
error("Unknown compute task")
|
||||
end
|
||||
|
@ -4,14 +4,16 @@
|
||||
|
||||
Return a randomly generated particle.
|
||||
"""
|
||||
function Particle(rng)
|
||||
return Particle(
|
||||
rand(rng, Float64),
|
||||
rand(rng, Float64),
|
||||
rand(rng, Float64),
|
||||
rand(rng, Float64),
|
||||
rand(rng, Float64),
|
||||
)
|
||||
function Particle(rng, type::ParticleType)
|
||||
|
||||
p1 = rand(rng, Float64)
|
||||
p2 = rand(rng, Float64)
|
||||
p3 = rand(rng, Float64)
|
||||
m = mass(type)
|
||||
|
||||
p4 = sqrt(p1^2 + p2^2 + p3^2 + m^2)
|
||||
|
||||
return Particle(p1, p2, p3, p4, type)
|
||||
end
|
||||
|
||||
"""
|
||||
@ -19,12 +21,15 @@ end
|
||||
|
||||
Return a Vector of `n` randomly generated [`Particle`](@ref)s.
|
||||
"""
|
||||
function gen_particles(n::Int)
|
||||
particles = Vector{Particle}()
|
||||
sizehint!(particles, n)
|
||||
function gen_particles(ns::Dict{ParticleType, Int})
|
||||
particles = Dict{ParticleType, Vector{Particle}}()
|
||||
|
||||
rng = MersenneTwister(0)
|
||||
for i in 1:n
|
||||
push!(particles, Particle(rng))
|
||||
for (type, n) in ns
|
||||
particles[type] = Vector{Particle}()
|
||||
for i in 1:n
|
||||
push!(particles[type], Particle(rng, type))
|
||||
end
|
||||
end
|
||||
return particles
|
||||
end
|
||||
|
@ -86,7 +86,12 @@ function parse_abc(filename::String, verbose::Bool = false)
|
||||
end
|
||||
if occursin(regex_a, node)
|
||||
# add nodes and edges for the state reading to u(P(Particle))
|
||||
data_in = insert_node!(graph, make_node(DataTask(4)), false, false) # read particle data node
|
||||
data_in = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(4), string(node)),
|
||||
false,
|
||||
false,
|
||||
) # read particle data node
|
||||
compute_P =
|
||||
insert_node!(graph, make_node(ComputeTaskP()), false, false) # compute P node
|
||||
data_Pu = insert_node!(graph, make_node(DataTask(6)), false, false) # transfer data from P to u
|
||||
|
@ -1,3 +1,7 @@
|
||||
@enum ParticleType A = 1 B = 2 C = 3 ALL = 6
|
||||
|
||||
const PARTICLE_MASSES =
|
||||
Dict{ParticleType, Float64}(A => 1.0, B => 1.0, C => 0.0)
|
||||
|
||||
struct Particle
|
||||
P0::Float64
|
||||
@ -5,7 +9,7 @@ struct Particle
|
||||
P2::Float64
|
||||
P3::Float64
|
||||
|
||||
m::Float64
|
||||
type::ParticleType
|
||||
end
|
||||
|
||||
struct ParticleValue
|
||||
@ -13,12 +17,25 @@ struct ParticleValue
|
||||
v::Float64
|
||||
end
|
||||
|
||||
mass(t::ParticleType) = PARTICLE_MASSES[t]
|
||||
|
||||
function remaining_type(t1::ParticleType, t2::ParticleType)
|
||||
@assert t1 != t2
|
||||
if t1 != A && t2 != A
|
||||
return A
|
||||
elseif t1 != B && t2 != B
|
||||
return B
|
||||
else
|
||||
return C
|
||||
end
|
||||
end
|
||||
|
||||
function square(p::Particle)
|
||||
return p.P0 * p.P0 - p.P1 * p.P1 - p.P2 * p.P2 - p.P3 * p.P3
|
||||
end
|
||||
|
||||
function inner_edge(p::Particle)
|
||||
return 1.0 / (square(p) - p.m * p.m)
|
||||
return 1.0 / (square(p) - mass(p.type) * mass(p.type))
|
||||
end
|
||||
|
||||
function outer_edge(p::Particle)
|
||||
@ -33,17 +50,13 @@ end
|
||||
|
||||
# calculate new particle from two given interacting ones
|
||||
function preserve_momentum(p1::Particle, p2::Particle)
|
||||
# TODO: is this correct?
|
||||
|
||||
p3 = Particle(
|
||||
p1.P0 + p2.P0,
|
||||
p1.P1 + p2.P1,
|
||||
p1.P2 + p2.P2,
|
||||
p1.P3 + p2.P3,
|
||||
1.0,
|
||||
remaining_type(p1.type, p2.type),
|
||||
)
|
||||
|
||||
# m3 = sqrt(- PC * PC / c^2)
|
||||
|
||||
return p3
|
||||
end
|
||||
|
@ -100,3 +100,17 @@ show(io::IO, t::ComputeTaskSum) = print("ComputeSum")
|
||||
Copy the data task and return it.
|
||||
"""
|
||||
copy(t::DataTask) = DataTask(t.data)
|
||||
|
||||
children(::DataTask) = 1
|
||||
children(::ComputeTaskS1) = 1
|
||||
children(::ComputeTaskS2) = 2
|
||||
children(::ComputeTaskP) = 1
|
||||
children(::ComputeTaskU) = 1
|
||||
children(::ComputeTaskV) = 2
|
||||
|
||||
# TODO: this is kind of bad because it means we can't fuse with a sum task
|
||||
children(::ComputeTaskSum) = -1 # wildcard for "n" children
|
||||
function children(t::FusedComputeTask)
|
||||
(T1, T2) = get_types(t)
|
||||
return children(T1()) + children(T2()) - 1 # one of the inputs is the output of T1 and thus not a child of the node
|
||||
end
|
||||
|
@ -1,5 +1,5 @@
|
||||
|
||||
DataTaskNode(t::AbstractDataTask) = DataTaskNode(
|
||||
DataTaskNode(t::AbstractDataTask, name = "") = DataTaskNode(
|
||||
t,
|
||||
Vector{Node}(),
|
||||
Vector{Node}(),
|
||||
@ -7,6 +7,7 @@ DataTaskNode(t::AbstractDataTask) = DataTaskNode(
|
||||
missing,
|
||||
missing,
|
||||
missing,
|
||||
name,
|
||||
)
|
||||
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
|
||||
t,
|
||||
@ -36,6 +37,7 @@ copy(n::DataTaskNode) = DataTaskNode(
|
||||
copy(n.nodeReduction),
|
||||
copy(n.nodeSplit),
|
||||
copy(n.nodeFusion),
|
||||
copy(n.name),
|
||||
)
|
||||
|
||||
"""
|
||||
@ -52,8 +54,8 @@ end
|
||||
|
||||
Construct and return a new [`DataTaskNode`](@ref) with the given task.
|
||||
"""
|
||||
function make_node(t::AbstractDataTask)
|
||||
return DataTaskNode(t)
|
||||
function make_node(t::AbstractDataTask, name::String = "")
|
||||
return DataTaskNode(t, name)
|
||||
end
|
||||
|
||||
"""
|
||||
|
@ -52,6 +52,9 @@ mutable struct DataTaskNode <: Node
|
||||
|
||||
# the node fusion involving this node, if it exists
|
||||
nodeFusion::Union{Operation, Missing}
|
||||
|
||||
# for input nodes we need a name for the node to distinguish between them
|
||||
name::String
|
||||
end
|
||||
|
||||
"""
|
||||
|
@ -160,7 +160,6 @@ function node_fusion!(
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
|
||||
# save children and parents
|
||||
n1_children = children(n1)
|
||||
n3_parents = parents(n3)
|
||||
@ -181,26 +180,18 @@ function node_fusion!(
|
||||
ComputeTaskNode(FusedComputeTask{typeof(n1.task), typeof(n3.task)}())
|
||||
insert_node!(graph, new_node)
|
||||
|
||||
# use a set for combined children of n1 and n3 to not get duplicates
|
||||
n1and3_children = Set{Node}()
|
||||
|
||||
# remove edges from n1 children to n1
|
||||
for child in n1_children
|
||||
remove_edge!(graph, child, n1)
|
||||
push!(n1and3_children, child)
|
||||
end
|
||||
|
||||
# remove edges from n3 children to n3
|
||||
for child in n3_children
|
||||
remove_edge!(graph, child, n3)
|
||||
push!(n1and3_children, child)
|
||||
end
|
||||
|
||||
for child in n1and3_children
|
||||
insert_edge!(graph, child, new_node)
|
||||
end
|
||||
|
||||
# "repoint" parents of n3 from new node
|
||||
for child in n3_children
|
||||
remove_edge!(graph, child, n3)
|
||||
if !(child in n1_children)
|
||||
insert_edge!(graph, child, new_node)
|
||||
end
|
||||
end
|
||||
|
||||
for parent in n3_parents
|
||||
remove_edge!(graph, n3, parent)
|
||||
insert_edge!(graph, new_node, parent)
|
||||
|
@ -71,14 +71,8 @@ function find_reductions!(graph::DAG, node::Node)
|
||||
partners_ = partners(node)
|
||||
delete!(partners_, node)
|
||||
for partner in partners_
|
||||
if partner ∉ graph.nodes
|
||||
error("Partner is not part of the graph")
|
||||
end
|
||||
|
||||
@assert partner in graph.nodes
|
||||
if can_reduce(node, partner)
|
||||
if Set(node.children) != Set(partner.children)
|
||||
error("Not equal children")
|
||||
end
|
||||
if reductionVector === nothing
|
||||
# only when there's at least one reduction partner, insert the vector
|
||||
reductionVector = Vector{Node}()
|
||||
|
Reference in New Issue
Block a user