Fix topoligical ordering on the graph

This commit is contained in:
Anton Reinhard 2023-09-05 12:14:41 +02:00
parent 7a1a97dac8
commit 0f78053ccf
6 changed files with 105 additions and 27 deletions

View File

@ -9,6 +9,13 @@ Pages = ["models/abc/types.jl"]
Order = [:type, :constant]
```
### Particle
```@autodocs
Modules = [MetagraphOptimization]
Pages = ["models/abc/particle.jl"]
Order = [:type, :constant]
```
### Parse
```@autodocs
Modules = [MetagraphOptimization]
@ -23,6 +30,20 @@ Pages = ["models/abc/properties.jl"]
Order = [:function]
```
### Create
```@autodocs
Modules = [MetagraphOptimization]
Pages = ["models/abc/create.jl]
Order = [:function]
```
### Compute
```@autodocs
Modules = [MetagraphOptimization]
Pages = ["models/abc/compute.jl]
Order = [:function]
```
## QED-Model
*To be added*

View File

@ -51,6 +51,7 @@ export ComputeTaskU
export ComputeTaskSum
export execute
export gen_particles
export ParticleValue
export Particle
@ -116,6 +117,7 @@ include("task/properties.jl")
include("models/abc/types.jl")
include("models/abc/particle.jl")
include("models/abc/compute.jl")
include("models/abc/create.jl")
include("models/abc/properties.jl")
include("models/abc/parse.jl")

View File

@ -7,29 +7,30 @@ function gen_code(graph::DAG)
nodeQueue = PriorityQueue{Node, Int}()
inputSyms = Vector{Symbol}()
# use a priority equal to the number of unseen children -> 0 are nodes that can be added
for node in get_entry_nodes(graph)
enqueue!(nodeQueue, node => 1)
push!(
inputSyms,
Symbol("data_$(replace(string(node.id), "-"=>"_"))_in"),
)
enqueue!(nodeQueue, node => 0)
push!(inputSyms, Symbol("data_$(to_var_name(node.id))_in"))
end
node = nothing
while !isempty(nodeQueue)
prio = peek(nodeQueue)[2]
@assert peek(nodeQueue)[2] == 0
node = dequeue!(nodeQueue)
push!(code, get_expression(node))
for parent in node.parents
# reduce the priority of all parents by one
if (!haskey(nodeQueue, parent))
enqueue!(nodeQueue, parent => prio + length(parent.children))
enqueue!(nodeQueue, parent => length(parent.children) - 1)
else
nodeQueue[parent] = nodeQueue[parent] - 1
end
end
end
# node is now the last node we looked at -> the output node
outSym = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
outSym = Symbol("data_$(to_var_name(node.id))")
return (
code = Expr(:block, code...),
@ -38,6 +39,28 @@ function gen_code(graph::DAG)
)
end
function execute(generated_code, input::Vector{Particle})
(code, inputSymbols, outputSymbol) = generated_code
@assert length(input) == length(inputSymbols)
assignInputs = Vector{Expr}()
for i in 1:length(input)
push!(
assignInputs,
Meta.parse(
"$(inputSymbols[i]) = ParticleValue(Particle($(input[i]).P0, $(input[i]).P1, $(input[i]).P2, $(input[i]).P3, $(input[i]).m), 1.0)",
),
)
end
assignInputs = Expr(:block, assignInputs...)
eval(assignInputs)
eval(code)
eval(Meta.parse("result = $outputSymbol"))
return result
end
function execute(graph::DAG, input::Vector{Particle})
(code, inputSymbols, outputSymbol) = gen_code(graph)

View File

@ -21,7 +21,7 @@ function get_expression(::ComputeTaskU, inSymbol::Symbol, outSymbol::Symbol)
end
# compute vertex
function compute(::ComputeTaskV, data1, data2)
function compute(::ComputeTaskV, data1::ParticleValue, data2::ParticleValue)
# calculate new particle from the two input particles
p3 = preserve_momentum(data1.p, data2.p)
dataOut = ParticleValue(p3, data1.v * vertex() * data2.v)
@ -57,7 +57,7 @@ end
# compute inner edge
function compute(::ComputeTaskS1, data)
return (particle = data.p, v = data.v * inner_edge(data.p))
return ParticleValue(data.p, data.v * inner_edge(data.p))
end
function get_expression(::ComputeTaskS1, inSymbol::Symbol, outSymbol::Symbol)
@ -82,28 +82,22 @@ function get_expression(node::ComputeTaskNode)
t = typeof(node.task)
if (t <: ComputeTaskU || t <: ComputeTaskP || t <: ComputeTaskS1) # single input
@assert length(node.children) == 1
symbolIn =
Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))")
symbolOut = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
symbolIn = Symbol("data_$(to_var_name(node.children[1].id))")
symbolOut = Symbol("data_$(to_var_name(node.id))")
return get_expression(t(), symbolIn, symbolOut)
elseif (t <: ComputeTaskS2 || t <: ComputeTaskV) # double input
@assert length(node.children) == 2
symbolIn1 =
Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))")
symbolIn2 =
Symbol("data_$(replace(string(node.children[2].id), "-"=>"_"))")
symbolOut = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
symbolIn1 = Symbol("data_$(to_var_name(node.children[1].id))")
symbolIn2 = Symbol("data_$(to_var_name(node.children[2].id))")
symbolOut = Symbol("data_$(to_var_name(node.id))")
return get_expression(t(), symbolIn1, symbolIn2, symbolOut)
elseif (t <: ComputeTaskSum) # vector input
@assert length(node.children) > 0
inSymbols = Vector{Symbol}()
for child in node.children
push!(
inSymbols,
Symbol("data_$(replace(string(child.id), "-"=>"_"))"),
)
push!(inSymbols, Symbol("data_$(to_var_name(child.id))"))
end
outSymbol = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
outSymbol = Symbol("data_$(to_var_name(node.id))")
return get_expression(t(), inSymbols, outSymbol)
elseif (t <: FusedComputeTask)
# uuuuuh
@ -118,12 +112,11 @@ function get_expression(node::DataTaskNode)
inSymbol = nothing
if (length(node.children) == 1)
inSymbol =
Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))")
inSymbol = Symbol("data_$(to_var_name(node.children[1].id))")
else
inSymbol = Symbol("data_$(replace(string(node.id), "-"=>"_"))_in")
inSymbol = Symbol("data_$(to_var_name(node.id))_in")
end
outSymbol = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
outSymbol = Symbol("data_$(to_var_name(node.id))")
dataTransportExp = Meta.parse("$outSymbol = $inSymbol")

30
src/models/abc/create.jl Normal file
View File

@ -0,0 +1,30 @@
"""
Particle(rng)
Return a randomly generated particle.
"""
function Particle(rng)
return Particle(
rand(rng, Float64),
rand(rng, Float64),
rand(rng, Float64),
rand(rng, Float64),
rand(rng, Float64),
)
end
"""
gen_particles(n::Int)
Return a Vector of `n` randomly generated [`Particle`](@ref)s.
"""
function gen_particles(n::Int)
particles = Vector{Particle}()
sizehint!(particles, n)
rng = MersenneTwister(0)
for i in 1:n
push!(particles, Particle(rng))
end
return particles
end

View File

@ -15,3 +15,12 @@ Print a short string representation of the edge to io.
function show(io::IO, e::Edge)
return print(io, "Edge(", e.edge[1], ", ", e.edge[2], ")")
end
"""
to_var_name(id::UUID)
Return the uuid as a string usable as a variable name in code generation.
"""
function to_var_name(id::UUID)
return replace(string(id), "-" => "_")
end