From 0f78053ccfbbcd3b04e764738fb42f928d1fc8ed Mon Sep 17 00:00:00 2001 From: Anton Reinhard Date: Tue, 5 Sep 2023 12:14:41 +0200 Subject: [PATCH] Fix topoligical ordering on the graph --- docs/src/lib/internals/models.md | 21 +++++++++++++++++ src/MetagraphOptimization.jl | 2 ++ src/code_gen/main.jl | 39 +++++++++++++++++++++++++------- src/models/abc/compute.jl | 31 ++++++++++--------------- src/models/abc/create.jl | 30 ++++++++++++++++++++++++ src/node/print.jl | 9 ++++++++ 6 files changed, 105 insertions(+), 27 deletions(-) create mode 100644 src/models/abc/create.jl diff --git a/docs/src/lib/internals/models.md b/docs/src/lib/internals/models.md index d28254d..1f7877f 100644 --- a/docs/src/lib/internals/models.md +++ b/docs/src/lib/internals/models.md @@ -9,6 +9,13 @@ Pages = ["models/abc/types.jl"] Order = [:type, :constant] ``` +### Particle +```@autodocs +Modules = [MetagraphOptimization] +Pages = ["models/abc/particle.jl"] +Order = [:type, :constant] +``` + ### Parse ```@autodocs Modules = [MetagraphOptimization] @@ -23,6 +30,20 @@ Pages = ["models/abc/properties.jl"] Order = [:function] ``` +### Create +```@autodocs +Modules = [MetagraphOptimization] +Pages = ["models/abc/create.jl] +Order = [:function] +``` + +### Compute +```@autodocs +Modules = [MetagraphOptimization] +Pages = ["models/abc/compute.jl] +Order = [:function] +``` + ## QED-Model *To be added* diff --git a/src/MetagraphOptimization.jl b/src/MetagraphOptimization.jl index 9c09392..093536a 100644 --- a/src/MetagraphOptimization.jl +++ b/src/MetagraphOptimization.jl @@ -51,6 +51,7 @@ export ComputeTaskU export ComputeTaskSum export execute +export gen_particles export ParticleValue export Particle @@ -116,6 +117,7 @@ include("task/properties.jl") include("models/abc/types.jl") include("models/abc/particle.jl") include("models/abc/compute.jl") +include("models/abc/create.jl") include("models/abc/properties.jl") include("models/abc/parse.jl") diff --git a/src/code_gen/main.jl b/src/code_gen/main.jl index 919b2a7..e392510 100644 --- a/src/code_gen/main.jl +++ b/src/code_gen/main.jl @@ -7,29 +7,30 @@ function gen_code(graph::DAG) nodeQueue = PriorityQueue{Node, Int}() inputSyms = Vector{Symbol}() + # use a priority equal to the number of unseen children -> 0 are nodes that can be added for node in get_entry_nodes(graph) - enqueue!(nodeQueue, node => 1) - push!( - inputSyms, - Symbol("data_$(replace(string(node.id), "-"=>"_"))_in"), - ) + enqueue!(nodeQueue, node => 0) + push!(inputSyms, Symbol("data_$(to_var_name(node.id))_in")) end node = nothing while !isempty(nodeQueue) - prio = peek(nodeQueue)[2] + @assert peek(nodeQueue)[2] == 0 node = dequeue!(nodeQueue) push!(code, get_expression(node)) for parent in node.parents + # reduce the priority of all parents by one if (!haskey(nodeQueue, parent)) - enqueue!(nodeQueue, parent => prio + length(parent.children)) + enqueue!(nodeQueue, parent => length(parent.children) - 1) + else + nodeQueue[parent] = nodeQueue[parent] - 1 end end end # node is now the last node we looked at -> the output node - outSym = Symbol("data_$(replace(string(node.id), "-"=>"_"))") + outSym = Symbol("data_$(to_var_name(node.id))") return ( code = Expr(:block, code...), @@ -38,6 +39,28 @@ function gen_code(graph::DAG) ) end +function execute(generated_code, input::Vector{Particle}) + (code, inputSymbols, outputSymbol) = generated_code + @assert length(input) == length(inputSymbols) + + assignInputs = Vector{Expr}() + for i in 1:length(input) + push!( + assignInputs, + Meta.parse( + "$(inputSymbols[i]) = ParticleValue(Particle($(input[i]).P0, $(input[i]).P1, $(input[i]).P2, $(input[i]).P3, $(input[i]).m), 1.0)", + ), + ) + end + + assignInputs = Expr(:block, assignInputs...) + eval(assignInputs) + eval(code) + + eval(Meta.parse("result = $outputSymbol")) + return result +end + function execute(graph::DAG, input::Vector{Particle}) (code, inputSymbols, outputSymbol) = gen_code(graph) diff --git a/src/models/abc/compute.jl b/src/models/abc/compute.jl index d3c7f84..e0143dd 100644 --- a/src/models/abc/compute.jl +++ b/src/models/abc/compute.jl @@ -21,7 +21,7 @@ function get_expression(::ComputeTaskU, inSymbol::Symbol, outSymbol::Symbol) end # compute vertex -function compute(::ComputeTaskV, data1, data2) +function compute(::ComputeTaskV, data1::ParticleValue, data2::ParticleValue) # calculate new particle from the two input particles p3 = preserve_momentum(data1.p, data2.p) dataOut = ParticleValue(p3, data1.v * vertex() * data2.v) @@ -57,7 +57,7 @@ end # compute inner edge function compute(::ComputeTaskS1, data) - return (particle = data.p, v = data.v * inner_edge(data.p)) + return ParticleValue(data.p, data.v * inner_edge(data.p)) end function get_expression(::ComputeTaskS1, inSymbol::Symbol, outSymbol::Symbol) @@ -82,28 +82,22 @@ function get_expression(node::ComputeTaskNode) t = typeof(node.task) if (t <: ComputeTaskU || t <: ComputeTaskP || t <: ComputeTaskS1) # single input @assert length(node.children) == 1 - symbolIn = - Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))") - symbolOut = Symbol("data_$(replace(string(node.id), "-"=>"_"))") + symbolIn = Symbol("data_$(to_var_name(node.children[1].id))") + symbolOut = Symbol("data_$(to_var_name(node.id))") return get_expression(t(), symbolIn, symbolOut) elseif (t <: ComputeTaskS2 || t <: ComputeTaskV) # double input @assert length(node.children) == 2 - symbolIn1 = - Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))") - symbolIn2 = - Symbol("data_$(replace(string(node.children[2].id), "-"=>"_"))") - symbolOut = Symbol("data_$(replace(string(node.id), "-"=>"_"))") + symbolIn1 = Symbol("data_$(to_var_name(node.children[1].id))") + symbolIn2 = Symbol("data_$(to_var_name(node.children[2].id))") + symbolOut = Symbol("data_$(to_var_name(node.id))") return get_expression(t(), symbolIn1, symbolIn2, symbolOut) elseif (t <: ComputeTaskSum) # vector input @assert length(node.children) > 0 inSymbols = Vector{Symbol}() for child in node.children - push!( - inSymbols, - Symbol("data_$(replace(string(child.id), "-"=>"_"))"), - ) + push!(inSymbols, Symbol("data_$(to_var_name(child.id))")) end - outSymbol = Symbol("data_$(replace(string(node.id), "-"=>"_"))") + outSymbol = Symbol("data_$(to_var_name(node.id))") return get_expression(t(), inSymbols, outSymbol) elseif (t <: FusedComputeTask) # uuuuuh @@ -118,12 +112,11 @@ function get_expression(node::DataTaskNode) inSymbol = nothing if (length(node.children) == 1) - inSymbol = - Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))") + inSymbol = Symbol("data_$(to_var_name(node.children[1].id))") else - inSymbol = Symbol("data_$(replace(string(node.id), "-"=>"_"))_in") + inSymbol = Symbol("data_$(to_var_name(node.id))_in") end - outSymbol = Symbol("data_$(replace(string(node.id), "-"=>"_"))") + outSymbol = Symbol("data_$(to_var_name(node.id))") dataTransportExp = Meta.parse("$outSymbol = $inSymbol") diff --git a/src/models/abc/create.jl b/src/models/abc/create.jl new file mode 100644 index 0000000..58b11bb --- /dev/null +++ b/src/models/abc/create.jl @@ -0,0 +1,30 @@ + +""" + Particle(rng) + +Return a randomly generated particle. +""" +function Particle(rng) + return Particle( + rand(rng, Float64), + rand(rng, Float64), + rand(rng, Float64), + rand(rng, Float64), + rand(rng, Float64), + ) +end + +""" + gen_particles(n::Int) + +Return a Vector of `n` randomly generated [`Particle`](@ref)s. +""" +function gen_particles(n::Int) + particles = Vector{Particle}() + sizehint!(particles, n) + rng = MersenneTwister(0) + for i in 1:n + push!(particles, Particle(rng)) + end + return particles +end diff --git a/src/node/print.jl b/src/node/print.jl index 3a6ee1a..c39c1b5 100644 --- a/src/node/print.jl +++ b/src/node/print.jl @@ -15,3 +15,12 @@ Print a short string representation of the edge to io. function show(io::IO, e::Edge) return print(io, "Edge(", e.edge[1], ", ", e.edge[2], ")") end + +""" + to_var_name(id::UUID) + +Return the uuid as a string usable as a variable name in code generation. +""" +function to_var_name(id::UUID) + return replace(string(id), "-" => "_") +end