Add basic execution function
This commit is contained in:
parent
f1edce258a
commit
7a1a97dac8
@ -50,7 +50,7 @@ export ComputeTaskV
|
||||
export ComputeTaskU
|
||||
export ComputeTaskSum
|
||||
|
||||
export gen_code
|
||||
export execute
|
||||
export ParticleValue
|
||||
export Particle
|
||||
|
||||
@ -108,8 +108,6 @@ include("operation/validate.jl")
|
||||
include("properties/create.jl")
|
||||
include("properties/utility.jl")
|
||||
|
||||
include("code_gen/main.jl")
|
||||
|
||||
include("task/create.jl")
|
||||
include("task/compare.jl")
|
||||
include("task/print.jl")
|
||||
@ -121,4 +119,6 @@ include("models/abc/compute.jl")
|
||||
include("models/abc/properties.jl")
|
||||
include("models/abc/parse.jl")
|
||||
|
||||
include("code_gen/main.jl")
|
||||
|
||||
end # module MetagraphOptimization
|
||||
|
@ -9,7 +9,10 @@ function gen_code(graph::DAG)
|
||||
|
||||
for node in get_entry_nodes(graph)
|
||||
enqueue!(nodeQueue, node => 1)
|
||||
push!(inputSyms, Symbol("data_$(replace(string(node.id), "-"=>"_"))_in"))
|
||||
push!(
|
||||
inputSyms,
|
||||
Symbol("data_$(replace(string(node.id), "-"=>"_"))_in"),
|
||||
)
|
||||
end
|
||||
|
||||
node = nothing
|
||||
@ -27,6 +30,33 @@ function gen_code(graph::DAG)
|
||||
|
||||
# node is now the last node we looked at -> the output node
|
||||
outSym = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
|
||||
|
||||
return (code = Expr(:block, code...), inputSymbols = inputSyms, outputSymbol = outSym)
|
||||
|
||||
return (
|
||||
code = Expr(:block, code...),
|
||||
inputSymbols = inputSyms,
|
||||
outputSymbol = outSym,
|
||||
)
|
||||
end
|
||||
|
||||
function execute(graph::DAG, input::Vector{Particle})
|
||||
(code, inputSymbols, outputSymbol) = gen_code(graph)
|
||||
|
||||
@assert length(input) == length(inputSymbols)
|
||||
|
||||
assignInputs = Vector{Expr}()
|
||||
for i in 1:length(input)
|
||||
push!(
|
||||
assignInputs,
|
||||
Meta.parse(
|
||||
"$(inputSymbols[i]) = ParticleValue(Particle($(input[i]).P0, $(input[i]).P1, $(input[i]).P2, $(input[i]).P3, $(input[i]).m), 1.0)",
|
||||
),
|
||||
)
|
||||
end
|
||||
|
||||
assignInputs = Expr(:block, assignInputs...)
|
||||
eval(assignInputs)
|
||||
eval(code)
|
||||
|
||||
eval(Meta.parse("result = $outputSymbol"))
|
||||
return result
|
||||
end
|
||||
|
@ -28,8 +28,15 @@ function compute(::ComputeTaskV, data1, data2)
|
||||
return dataOut
|
||||
end
|
||||
|
||||
function get_expression(::ComputeTaskV, inSymbol1::Symbol, inSymbol2::Symbol, outSymbol::Symbol)
|
||||
return Meta.parse("$outSymbol = compute(ComputeTaskV(), $inSymbol1, $inSymbol2)")
|
||||
function get_expression(
|
||||
::ComputeTaskV,
|
||||
inSymbol1::Symbol,
|
||||
inSymbol2::Symbol,
|
||||
outSymbol::Symbol,
|
||||
)
|
||||
return Meta.parse(
|
||||
"$outSymbol = compute(ComputeTaskV(), $inSymbol1, $inSymbol2)",
|
||||
)
|
||||
end
|
||||
|
||||
# compute final inner edge (no output particle)
|
||||
@ -37,8 +44,15 @@ function compute(::ComputeTaskS2, data1, data2)
|
||||
return data1.v * inner_edge(data1.p) * data2.v
|
||||
end
|
||||
|
||||
function get_expression(::ComputeTaskS2, inSymbol1::Symbol, inSymbol2::Symbol, outSymbol::Symbol)
|
||||
return Meta.parse("$outSymbol = compute(ComputeTaskS2(), $inSymbol1, $inSymbol2)")
|
||||
function get_expression(
|
||||
::ComputeTaskS2,
|
||||
inSymbol1::Symbol,
|
||||
inSymbol2::Symbol,
|
||||
outSymbol::Symbol,
|
||||
)
|
||||
return Meta.parse(
|
||||
"$outSymbol = compute(ComputeTaskS2(), $inSymbol1, $inSymbol2)",
|
||||
)
|
||||
end
|
||||
|
||||
# compute inner edge
|
||||
@ -54,7 +68,11 @@ function compute(::ComputeTaskSum, data::Vector{Float64})
|
||||
return sum(data)
|
||||
end
|
||||
|
||||
function get_expression(::ComputeTaskSum, inSymbols::Vector{Symbol}, outSymbol::Symbol)
|
||||
function get_expression(
|
||||
::ComputeTaskSum,
|
||||
inSymbols::Vector{Symbol},
|
||||
outSymbol::Symbol,
|
||||
)
|
||||
return quote
|
||||
$outSymbol = compute(ComputeTaskSum(), [$(inSymbols...)])
|
||||
end
|
||||
@ -64,20 +82,26 @@ function get_expression(node::ComputeTaskNode)
|
||||
t = typeof(node.task)
|
||||
if (t <: ComputeTaskU || t <: ComputeTaskP || t <: ComputeTaskS1) # single input
|
||||
@assert length(node.children) == 1
|
||||
symbolIn = Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))")
|
||||
symbolIn =
|
||||
Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))")
|
||||
symbolOut = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
|
||||
return get_expression(t(), symbolIn, symbolOut)
|
||||
elseif (t <: ComputeTaskS2 || t <: ComputeTaskV) # double input
|
||||
@assert length(node.children) == 2
|
||||
symbolIn1 = Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))")
|
||||
symbolIn2 = Symbol("data_$(replace(string(node.children[2].id), "-"=>"_"))")
|
||||
symbolIn1 =
|
||||
Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))")
|
||||
symbolIn2 =
|
||||
Symbol("data_$(replace(string(node.children[2].id), "-"=>"_"))")
|
||||
symbolOut = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
|
||||
return get_expression(t(), symbolIn1, symbolIn2, symbolOut)
|
||||
elseif (t <: ComputeTaskSum) # vector input
|
||||
@assert length(node.children) > 0
|
||||
inSymbols = Vector{Symbol}()
|
||||
for child in node.children
|
||||
push!(inSymbols, Symbol("data_$(replace(string(child.id), "-"=>"_"))"))
|
||||
push!(
|
||||
inSymbols,
|
||||
Symbol("data_$(replace(string(child.id), "-"=>"_"))"),
|
||||
)
|
||||
end
|
||||
outSymbol = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
|
||||
return get_expression(t(), inSymbols, outSymbol)
|
||||
@ -94,7 +118,8 @@ function get_expression(node::DataTaskNode)
|
||||
|
||||
inSymbol = nothing
|
||||
if (length(node.children) == 1)
|
||||
inSymbol = Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))")
|
||||
inSymbol =
|
||||
Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))")
|
||||
else
|
||||
inSymbol = Symbol("data_$(replace(string(node.id), "-"=>"_"))_in")
|
||||
end
|
||||
|
@ -27,7 +27,7 @@ end
|
||||
|
||||
function vertex()
|
||||
i = 1.0
|
||||
lambda = 1.0/137.0
|
||||
lambda = 1.0 / 137.0
|
||||
return i * lambda
|
||||
end
|
||||
|
||||
@ -40,7 +40,7 @@ function preserve_momentum(p1::Particle, p2::Particle)
|
||||
p1.P1 + p2.P1,
|
||||
p1.P2 + p2.P2,
|
||||
p1.P3 + p2.P3,
|
||||
1.0
|
||||
1.0,
|
||||
)
|
||||
|
||||
# m3 = sqrt(- PC * PC / c^2)
|
||||
|
@ -96,7 +96,11 @@ end
|
||||
"""
|
||||
get_expression()
|
||||
"""
|
||||
function get_expression(t::FusedComputeTask, inSymbol::Symbol, outSymbol::Symbol)
|
||||
function get_expression(
|
||||
t::FusedComputeTask,
|
||||
inSymbol::Symbol,
|
||||
outSymbol::Symbol,
|
||||
)
|
||||
#TODO
|
||||
computeExp = quote
|
||||
$outSymbol = compute($t, $inSymbol)
|
||||
|
Loading…
x
Reference in New Issue
Block a user