diff --git a/src/MetagraphOptimization.jl b/src/MetagraphOptimization.jl index fadf456..9c09392 100644 --- a/src/MetagraphOptimization.jl +++ b/src/MetagraphOptimization.jl @@ -50,7 +50,7 @@ export ComputeTaskV export ComputeTaskU export ComputeTaskSum -export gen_code +export execute export ParticleValue export Particle @@ -108,8 +108,6 @@ include("operation/validate.jl") include("properties/create.jl") include("properties/utility.jl") -include("code_gen/main.jl") - include("task/create.jl") include("task/compare.jl") include("task/print.jl") @@ -121,4 +119,6 @@ include("models/abc/compute.jl") include("models/abc/properties.jl") include("models/abc/parse.jl") +include("code_gen/main.jl") + end # module MetagraphOptimization diff --git a/src/code_gen/main.jl b/src/code_gen/main.jl index 1180bc6..919b2a7 100644 --- a/src/code_gen/main.jl +++ b/src/code_gen/main.jl @@ -9,7 +9,10 @@ function gen_code(graph::DAG) for node in get_entry_nodes(graph) enqueue!(nodeQueue, node => 1) - push!(inputSyms, Symbol("data_$(replace(string(node.id), "-"=>"_"))_in")) + push!( + inputSyms, + Symbol("data_$(replace(string(node.id), "-"=>"_"))_in"), + ) end node = nothing @@ -27,6 +30,33 @@ function gen_code(graph::DAG) # node is now the last node we looked at -> the output node outSym = Symbol("data_$(replace(string(node.id), "-"=>"_"))") - - return (code = Expr(:block, code...), inputSymbols = inputSyms, outputSymbol = outSym) + + return ( + code = Expr(:block, code...), + inputSymbols = inputSyms, + outputSymbol = outSym, + ) +end + +function execute(graph::DAG, input::Vector{Particle}) + (code, inputSymbols, outputSymbol) = gen_code(graph) + + @assert length(input) == length(inputSymbols) + + assignInputs = Vector{Expr}() + for i in 1:length(input) + push!( + assignInputs, + Meta.parse( + "$(inputSymbols[i]) = ParticleValue(Particle($(input[i]).P0, $(input[i]).P1, $(input[i]).P2, $(input[i]).P3, $(input[i]).m), 1.0)", + ), + ) + end + + assignInputs = Expr(:block, assignInputs...) + eval(assignInputs) + eval(code) + + eval(Meta.parse("result = $outputSymbol")) + return result end diff --git a/src/models/abc/compute.jl b/src/models/abc/compute.jl index f931756..d3c7f84 100644 --- a/src/models/abc/compute.jl +++ b/src/models/abc/compute.jl @@ -28,8 +28,15 @@ function compute(::ComputeTaskV, data1, data2) return dataOut end -function get_expression(::ComputeTaskV, inSymbol1::Symbol, inSymbol2::Symbol, outSymbol::Symbol) - return Meta.parse("$outSymbol = compute(ComputeTaskV(), $inSymbol1, $inSymbol2)") +function get_expression( + ::ComputeTaskV, + inSymbol1::Symbol, + inSymbol2::Symbol, + outSymbol::Symbol, +) + return Meta.parse( + "$outSymbol = compute(ComputeTaskV(), $inSymbol1, $inSymbol2)", + ) end # compute final inner edge (no output particle) @@ -37,8 +44,15 @@ function compute(::ComputeTaskS2, data1, data2) return data1.v * inner_edge(data1.p) * data2.v end -function get_expression(::ComputeTaskS2, inSymbol1::Symbol, inSymbol2::Symbol, outSymbol::Symbol) - return Meta.parse("$outSymbol = compute(ComputeTaskS2(), $inSymbol1, $inSymbol2)") +function get_expression( + ::ComputeTaskS2, + inSymbol1::Symbol, + inSymbol2::Symbol, + outSymbol::Symbol, +) + return Meta.parse( + "$outSymbol = compute(ComputeTaskS2(), $inSymbol1, $inSymbol2)", + ) end # compute inner edge @@ -54,7 +68,11 @@ function compute(::ComputeTaskSum, data::Vector{Float64}) return sum(data) end -function get_expression(::ComputeTaskSum, inSymbols::Vector{Symbol}, outSymbol::Symbol) +function get_expression( + ::ComputeTaskSum, + inSymbols::Vector{Symbol}, + outSymbol::Symbol, +) return quote $outSymbol = compute(ComputeTaskSum(), [$(inSymbols...)]) end @@ -64,20 +82,26 @@ function get_expression(node::ComputeTaskNode) t = typeof(node.task) if (t <: ComputeTaskU || t <: ComputeTaskP || t <: ComputeTaskS1) # single input @assert length(node.children) == 1 - symbolIn = Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))") + symbolIn = + Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))") symbolOut = Symbol("data_$(replace(string(node.id), "-"=>"_"))") return get_expression(t(), symbolIn, symbolOut) elseif (t <: ComputeTaskS2 || t <: ComputeTaskV) # double input @assert length(node.children) == 2 - symbolIn1 = Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))") - symbolIn2 = Symbol("data_$(replace(string(node.children[2].id), "-"=>"_"))") + symbolIn1 = + Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))") + symbolIn2 = + Symbol("data_$(replace(string(node.children[2].id), "-"=>"_"))") symbolOut = Symbol("data_$(replace(string(node.id), "-"=>"_"))") return get_expression(t(), symbolIn1, symbolIn2, symbolOut) elseif (t <: ComputeTaskSum) # vector input @assert length(node.children) > 0 inSymbols = Vector{Symbol}() for child in node.children - push!(inSymbols, Symbol("data_$(replace(string(child.id), "-"=>"_"))")) + push!( + inSymbols, + Symbol("data_$(replace(string(child.id), "-"=>"_"))"), + ) end outSymbol = Symbol("data_$(replace(string(node.id), "-"=>"_"))") return get_expression(t(), inSymbols, outSymbol) @@ -94,7 +118,8 @@ function get_expression(node::DataTaskNode) inSymbol = nothing if (length(node.children) == 1) - inSymbol = Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))") + inSymbol = + Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))") else inSymbol = Symbol("data_$(replace(string(node.id), "-"=>"_"))_in") end diff --git a/src/models/abc/particle.jl b/src/models/abc/particle.jl index d515b06..e336de6 100644 --- a/src/models/abc/particle.jl +++ b/src/models/abc/particle.jl @@ -27,7 +27,7 @@ end function vertex() i = 1.0 - lambda = 1.0/137.0 + lambda = 1.0 / 137.0 return i * lambda end @@ -40,7 +40,7 @@ function preserve_momentum(p1::Particle, p2::Particle) p1.P1 + p2.P1, p1.P2 + p2.P2, p1.P3 + p2.P3, - 1.0 + 1.0, ) # m3 = sqrt(- PC * PC / c^2) diff --git a/src/task/properties.jl b/src/task/properties.jl index 11aa59b..4b1d889 100644 --- a/src/task/properties.jl +++ b/src/task/properties.jl @@ -96,7 +96,11 @@ end """ get_expression() """ -function get_expression(t::FusedComputeTask, inSymbol::Symbol, outSymbol::Symbol) +function get_expression( + t::FusedComputeTask, + inSymbol::Symbol, + outSymbol::Symbol, +) #TODO computeExp = quote $outSymbol = compute($t, $inSymbol)