Start adding code generation
This commit is contained in:
parent
32fcd069d7
commit
f1edce258a
@ -50,6 +50,10 @@ export ComputeTaskV
|
||||
export ComputeTaskU
|
||||
export ComputeTaskSum
|
||||
|
||||
export gen_code
|
||||
export ParticleValue
|
||||
export Particle
|
||||
|
||||
export ==, in, show, isempty, delete!, length
|
||||
|
||||
export bytes_to_human_readable
|
||||
@ -104,12 +108,16 @@ include("operation/validate.jl")
|
||||
include("properties/create.jl")
|
||||
include("properties/utility.jl")
|
||||
|
||||
include("code_gen/main.jl")
|
||||
|
||||
include("task/create.jl")
|
||||
include("task/compare.jl")
|
||||
include("task/print.jl")
|
||||
include("task/properties.jl")
|
||||
|
||||
include("models/abc/types.jl")
|
||||
include("models/abc/particle.jl")
|
||||
include("models/abc/compute.jl")
|
||||
include("models/abc/properties.jl")
|
||||
include("models/abc/parse.jl")
|
||||
|
||||
|
32
src/code_gen/main.jl
Normal file
32
src/code_gen/main.jl
Normal file
@ -0,0 +1,32 @@
|
||||
using DataStructures
|
||||
|
||||
function gen_code(graph::DAG)
|
||||
code = Vector{Expr}()
|
||||
sizehint!(code, length(graph.nodes))
|
||||
|
||||
nodeQueue = PriorityQueue{Node, Int}()
|
||||
inputSyms = Vector{Symbol}()
|
||||
|
||||
for node in get_entry_nodes(graph)
|
||||
enqueue!(nodeQueue, node => 1)
|
||||
push!(inputSyms, Symbol("data_$(replace(string(node.id), "-"=>"_"))_in"))
|
||||
end
|
||||
|
||||
node = nothing
|
||||
while !isempty(nodeQueue)
|
||||
prio = peek(nodeQueue)[2]
|
||||
node = dequeue!(nodeQueue)
|
||||
|
||||
push!(code, get_expression(node))
|
||||
for parent in node.parents
|
||||
if (!haskey(nodeQueue, parent))
|
||||
enqueue!(nodeQueue, parent => prio + length(parent.children))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# node is now the last node we looked at -> the output node
|
||||
outSym = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
|
||||
|
||||
return (code = Expr(:block, code...), inputSymbols = inputSyms, outputSymbol = outSym)
|
||||
end
|
@ -27,3 +27,18 @@ function get_exit_node(graph::DAG)
|
||||
end
|
||||
@assert false "The given graph has no exit node! It is either empty or not acyclic!"
|
||||
end
|
||||
|
||||
"""
|
||||
get_entry_nodes(graph::DAG)
|
||||
|
||||
Return a vector of the graph's entry nodes.
|
||||
"""
|
||||
function get_entry_nodes(graph::DAG)
|
||||
result = Vector{Node}()
|
||||
for node in graph.nodes
|
||||
if (is_entry_node(node))
|
||||
push!(result, node)
|
||||
end
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
106
src/models/abc/compute.jl
Normal file
106
src/models/abc/compute.jl
Normal file
@ -0,0 +1,106 @@
|
||||
|
||||
# Compute Particle, nothing to be done (0 FLOP)
|
||||
function compute(::ComputeTaskP, data::ParticleValue)
|
||||
return data
|
||||
end
|
||||
|
||||
# generate code evaluating ComputeTaskP on inSymbol, providing the output on outSymbol
|
||||
function get_expression(::ComputeTaskP, inSymbol::Symbol, outSymbol::Symbol)
|
||||
return Meta.parse("$outSymbol = compute(ComputeTaskP(), $inSymbol)")
|
||||
end
|
||||
|
||||
# Compute outer edge
|
||||
function compute(::ComputeTaskU, data::ParticleValue)
|
||||
return ParticleValue(data.p, data.v * outer_edge(data.p))
|
||||
end
|
||||
|
||||
# generate code evaluating ComputeTaskU on inSymbol, providing the output on outSymbol
|
||||
# inSymbol should be of type ParticleValue, outSymbol will be of type ParticleValue
|
||||
function get_expression(::ComputeTaskU, inSymbol::Symbol, outSymbol::Symbol)
|
||||
return Meta.parse("$outSymbol = compute(ComputeTaskU(), $inSymbol)")
|
||||
end
|
||||
|
||||
# compute vertex
|
||||
function compute(::ComputeTaskV, data1, data2)
|
||||
# calculate new particle from the two input particles
|
||||
p3 = preserve_momentum(data1.p, data2.p)
|
||||
dataOut = ParticleValue(p3, data1.v * vertex() * data2.v)
|
||||
return dataOut
|
||||
end
|
||||
|
||||
function get_expression(::ComputeTaskV, inSymbol1::Symbol, inSymbol2::Symbol, outSymbol::Symbol)
|
||||
return Meta.parse("$outSymbol = compute(ComputeTaskV(), $inSymbol1, $inSymbol2)")
|
||||
end
|
||||
|
||||
# compute final inner edge (no output particle)
|
||||
function compute(::ComputeTaskS2, data1, data2)
|
||||
return data1.v * inner_edge(data1.p) * data2.v
|
||||
end
|
||||
|
||||
function get_expression(::ComputeTaskS2, inSymbol1::Symbol, inSymbol2::Symbol, outSymbol::Symbol)
|
||||
return Meta.parse("$outSymbol = compute(ComputeTaskS2(), $inSymbol1, $inSymbol2)")
|
||||
end
|
||||
|
||||
# compute inner edge
|
||||
function compute(::ComputeTaskS1, data)
|
||||
return (particle = data.p, v = data.v * inner_edge(data.p))
|
||||
end
|
||||
|
||||
function get_expression(::ComputeTaskS1, inSymbol::Symbol, outSymbol::Symbol)
|
||||
return Meta.parse("$outSymbol = compute(ComputeTaskS1(), $inSymbol)")
|
||||
end
|
||||
|
||||
function compute(::ComputeTaskSum, data::Vector{Float64})
|
||||
return sum(data)
|
||||
end
|
||||
|
||||
function get_expression(::ComputeTaskSum, inSymbols::Vector{Symbol}, outSymbol::Symbol)
|
||||
return quote
|
||||
$outSymbol = compute(ComputeTaskSum(), [$(inSymbols...)])
|
||||
end
|
||||
end
|
||||
|
||||
function get_expression(node::ComputeTaskNode)
|
||||
t = typeof(node.task)
|
||||
if (t <: ComputeTaskU || t <: ComputeTaskP || t <: ComputeTaskS1) # single input
|
||||
@assert length(node.children) == 1
|
||||
symbolIn = Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))")
|
||||
symbolOut = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
|
||||
return get_expression(t(), symbolIn, symbolOut)
|
||||
elseif (t <: ComputeTaskS2 || t <: ComputeTaskV) # double input
|
||||
@assert length(node.children) == 2
|
||||
symbolIn1 = Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))")
|
||||
symbolIn2 = Symbol("data_$(replace(string(node.children[2].id), "-"=>"_"))")
|
||||
symbolOut = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
|
||||
return get_expression(t(), symbolIn1, symbolIn2, symbolOut)
|
||||
elseif (t <: ComputeTaskSum) # vector input
|
||||
@assert length(node.children) > 0
|
||||
inSymbols = Vector{Symbol}()
|
||||
for child in node.children
|
||||
push!(inSymbols, Symbol("data_$(replace(string(child.id), "-"=>"_"))"))
|
||||
end
|
||||
outSymbol = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
|
||||
return get_expression(t(), inSymbols, outSymbol)
|
||||
elseif (t <: FusedComputeTask)
|
||||
# uuuuuh
|
||||
else
|
||||
error("Unknown compute task")
|
||||
end
|
||||
end
|
||||
|
||||
function get_expression(node::DataTaskNode)
|
||||
# TODO: do things to transport data from/to gpu, between numa nodes, etc.
|
||||
@assert length(node.children) <= 1
|
||||
|
||||
inSymbol = nothing
|
||||
if (length(node.children) == 1)
|
||||
inSymbol = Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))")
|
||||
else
|
||||
inSymbol = Symbol("data_$(replace(string(node.id), "-"=>"_"))_in")
|
||||
end
|
||||
outSymbol = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
|
||||
|
||||
dataTransportExp = Meta.parse("$outSymbol = $inSymbol")
|
||||
|
||||
return dataTransportExp
|
||||
end
|
49
src/models/abc/particle.jl
Normal file
49
src/models/abc/particle.jl
Normal file
@ -0,0 +1,49 @@
|
||||
|
||||
struct Particle
|
||||
P0::Float64
|
||||
P1::Float64
|
||||
P2::Float64
|
||||
P3::Float64
|
||||
|
||||
m::Float64
|
||||
end
|
||||
|
||||
struct ParticleValue
|
||||
p::Particle
|
||||
v::Float64
|
||||
end
|
||||
|
||||
function square(p::Particle)
|
||||
return p.P0 * p.P0 - p.P1 * p.P1 - p.P2 * p.P2 - p.P3 * p.P3
|
||||
end
|
||||
|
||||
function inner_edge(p::Particle)
|
||||
return 1.0 / (square(p) - p.m * p.m)
|
||||
end
|
||||
|
||||
function outer_edge(p::Particle)
|
||||
return 1.0
|
||||
end
|
||||
|
||||
function vertex()
|
||||
i = 1.0
|
||||
lambda = 1.0/137.0
|
||||
return i * lambda
|
||||
end
|
||||
|
||||
# calculate new particle from two given interacting ones
|
||||
function preserve_momentum(p1::Particle, p2::Particle)
|
||||
# TODO: is this correct?
|
||||
|
||||
p3 = Particle(
|
||||
p1.P0 + p2.P0,
|
||||
p1.P1 + p2.P1,
|
||||
p1.P2 + p2.P2,
|
||||
p1.P3 + p2.P3,
|
||||
1.0
|
||||
)
|
||||
|
||||
# m3 = sqrt(- PC * PC / c^2)
|
||||
|
||||
return p3
|
||||
end
|
@ -7,6 +7,24 @@ function compute(t::AbstractTask; data...)
|
||||
return error("Need to implement compute()")
|
||||
end
|
||||
|
||||
"""
|
||||
compute(t::FusedComputeTask; data...)
|
||||
|
||||
Compute a fused compute task.
|
||||
"""
|
||||
function compute(t::FusedComputeTask; data...)
|
||||
(T1, T2) = collect(typeof(t).parameters)
|
||||
|
||||
return compute(T2(), compute(T1(), data))
|
||||
end
|
||||
|
||||
"""
|
||||
compute(t::AbstractDataTask; data...)
|
||||
|
||||
The compute function of a data task, always the identity function, regardless of the specific task.
|
||||
"""
|
||||
compute(t::AbstractDataTask; data...) = data
|
||||
|
||||
"""
|
||||
compute_effort(t::AbstractTask)
|
||||
|
||||
@ -33,13 +51,6 @@ Return the compute effort of a data task, always zero, regardless of the specifi
|
||||
"""
|
||||
compute_effort(t::AbstractDataTask) = 0
|
||||
|
||||
"""
|
||||
compute(t::AbstractDataTask; data...)
|
||||
|
||||
The compute function of a data task, always the identity function, regardless of the specific task.
|
||||
"""
|
||||
compute(t::AbstractDataTask; data...) = data
|
||||
|
||||
"""
|
||||
data(t::AbstractDataTask)
|
||||
|
||||
@ -64,12 +75,32 @@ function compute_effort(t::FusedComputeTask)
|
||||
return compute_effort(T1()) + compute_effort(T2())
|
||||
end
|
||||
|
||||
# actual compute functions for the tasks can stay undefined for now
|
||||
# compute(t::ComputeTaskU, data::Any) = mycomputation(data)
|
||||
|
||||
"""
|
||||
get_types(::FusedComputeTask{T1, T2})
|
||||
|
||||
Return a tuple of a the fused compute task's components' types.
|
||||
"""
|
||||
get_types(::FusedComputeTask{T1, T2}) where {T1, T2} = (T1, T2)
|
||||
|
||||
"""
|
||||
get_expression(t::AbstractTask)
|
||||
|
||||
Return an expression evaluating the given task on the :dataIn symbol
|
||||
"""
|
||||
function get_expression(t::AbstractTask)
|
||||
return quote
|
||||
dataOut = compute($t, dataIn)
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression()
|
||||
"""
|
||||
function get_expression(t::FusedComputeTask, inSymbol::Symbol, outSymbol::Symbol)
|
||||
#TODO
|
||||
computeExp = quote
|
||||
$outSymbol = compute($t, $inSymbol)
|
||||
end
|
||||
|
||||
return computeExp
|
||||
end
|
||||
|
Loading…
x
Reference in New Issue
Block a user