Start adding code generation

This commit is contained in:
Anton Reinhard
2023-08-31 18:24:48 +02:00
parent 32fcd069d7
commit f1edce258a
6 changed files with 251 additions and 10 deletions

106
src/models/abc/compute.jl Normal file
View File

@@ -0,0 +1,106 @@
# Compute Particle, nothing to be done (0 FLOP)
function compute(::ComputeTaskP, data::ParticleValue)
return data
end
# generate code evaluating ComputeTaskP on inSymbol, providing the output on outSymbol
function get_expression(::ComputeTaskP, inSymbol::Symbol, outSymbol::Symbol)
return Meta.parse("$outSymbol = compute(ComputeTaskP(), $inSymbol)")
end
# Compute outer edge
function compute(::ComputeTaskU, data::ParticleValue)
return ParticleValue(data.p, data.v * outer_edge(data.p))
end
# generate code evaluating ComputeTaskU on inSymbol, providing the output on outSymbol
# inSymbol should be of type ParticleValue, outSymbol will be of type ParticleValue
function get_expression(::ComputeTaskU, inSymbol::Symbol, outSymbol::Symbol)
return Meta.parse("$outSymbol = compute(ComputeTaskU(), $inSymbol)")
end
# compute vertex
function compute(::ComputeTaskV, data1, data2)
# calculate new particle from the two input particles
p3 = preserve_momentum(data1.p, data2.p)
dataOut = ParticleValue(p3, data1.v * vertex() * data2.v)
return dataOut
end
function get_expression(::ComputeTaskV, inSymbol1::Symbol, inSymbol2::Symbol, outSymbol::Symbol)
return Meta.parse("$outSymbol = compute(ComputeTaskV(), $inSymbol1, $inSymbol2)")
end
# compute final inner edge (no output particle)
function compute(::ComputeTaskS2, data1, data2)
return data1.v * inner_edge(data1.p) * data2.v
end
function get_expression(::ComputeTaskS2, inSymbol1::Symbol, inSymbol2::Symbol, outSymbol::Symbol)
return Meta.parse("$outSymbol = compute(ComputeTaskS2(), $inSymbol1, $inSymbol2)")
end
# compute inner edge
function compute(::ComputeTaskS1, data)
return (particle = data.p, v = data.v * inner_edge(data.p))
end
function get_expression(::ComputeTaskS1, inSymbol::Symbol, outSymbol::Symbol)
return Meta.parse("$outSymbol = compute(ComputeTaskS1(), $inSymbol)")
end
function compute(::ComputeTaskSum, data::Vector{Float64})
return sum(data)
end
function get_expression(::ComputeTaskSum, inSymbols::Vector{Symbol}, outSymbol::Symbol)
return quote
$outSymbol = compute(ComputeTaskSum(), [$(inSymbols...)])
end
end
function get_expression(node::ComputeTaskNode)
t = typeof(node.task)
if (t <: ComputeTaskU || t <: ComputeTaskP || t <: ComputeTaskS1) # single input
@assert length(node.children) == 1
symbolIn = Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))")
symbolOut = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
return get_expression(t(), symbolIn, symbolOut)
elseif (t <: ComputeTaskS2 || t <: ComputeTaskV) # double input
@assert length(node.children) == 2
symbolIn1 = Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))")
symbolIn2 = Symbol("data_$(replace(string(node.children[2].id), "-"=>"_"))")
symbolOut = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
return get_expression(t(), symbolIn1, symbolIn2, symbolOut)
elseif (t <: ComputeTaskSum) # vector input
@assert length(node.children) > 0
inSymbols = Vector{Symbol}()
for child in node.children
push!(inSymbols, Symbol("data_$(replace(string(child.id), "-"=>"_"))"))
end
outSymbol = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
return get_expression(t(), inSymbols, outSymbol)
elseif (t <: FusedComputeTask)
# uuuuuh
else
error("Unknown compute task")
end
end
function get_expression(node::DataTaskNode)
# TODO: do things to transport data from/to gpu, between numa nodes, etc.
@assert length(node.children) <= 1
inSymbol = nothing
if (length(node.children) == 1)
inSymbol = Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))")
else
inSymbol = Symbol("data_$(replace(string(node.id), "-"=>"_"))_in")
end
outSymbol = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
dataTransportExp = Meta.parse("$outSymbol = $inSymbol")
return dataTransportExp
end

View File

@@ -0,0 +1,49 @@
struct Particle
P0::Float64
P1::Float64
P2::Float64
P3::Float64
m::Float64
end
struct ParticleValue
p::Particle
v::Float64
end
function square(p::Particle)
return p.P0 * p.P0 - p.P1 * p.P1 - p.P2 * p.P2 - p.P3 * p.P3
end
function inner_edge(p::Particle)
return 1.0 / (square(p) - p.m * p.m)
end
function outer_edge(p::Particle)
return 1.0
end
function vertex()
i = 1.0
lambda = 1.0/137.0
return i * lambda
end
# calculate new particle from two given interacting ones
function preserve_momentum(p1::Particle, p2::Particle)
# TODO: is this correct?
p3 = Particle(
p1.P0 + p2.P0,
p1.P1 + p2.P1,
p1.P2 + p2.P2,
p1.P3 + p2.P3,
1.0
)
# m3 = sqrt(- PC * PC / c^2)
return p3
end