Fix topoligical ordering on the graph
This commit is contained in:
parent
7a1a97dac8
commit
0f78053ccf
@ -9,6 +9,13 @@ Pages = ["models/abc/types.jl"]
|
||||
Order = [:type, :constant]
|
||||
```
|
||||
|
||||
### Particle
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["models/abc/particle.jl"]
|
||||
Order = [:type, :constant]
|
||||
```
|
||||
|
||||
### Parse
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
@ -23,6 +30,20 @@ Pages = ["models/abc/properties.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
### Create
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["models/abc/create.jl]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
### Compute
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["models/abc/compute.jl]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## QED-Model
|
||||
|
||||
*To be added*
|
||||
|
@ -51,6 +51,7 @@ export ComputeTaskU
|
||||
export ComputeTaskSum
|
||||
|
||||
export execute
|
||||
export gen_particles
|
||||
export ParticleValue
|
||||
export Particle
|
||||
|
||||
@ -116,6 +117,7 @@ include("task/properties.jl")
|
||||
include("models/abc/types.jl")
|
||||
include("models/abc/particle.jl")
|
||||
include("models/abc/compute.jl")
|
||||
include("models/abc/create.jl")
|
||||
include("models/abc/properties.jl")
|
||||
include("models/abc/parse.jl")
|
||||
|
||||
|
@ -7,29 +7,30 @@ function gen_code(graph::DAG)
|
||||
nodeQueue = PriorityQueue{Node, Int}()
|
||||
inputSyms = Vector{Symbol}()
|
||||
|
||||
# use a priority equal to the number of unseen children -> 0 are nodes that can be added
|
||||
for node in get_entry_nodes(graph)
|
||||
enqueue!(nodeQueue, node => 1)
|
||||
push!(
|
||||
inputSyms,
|
||||
Symbol("data_$(replace(string(node.id), "-"=>"_"))_in"),
|
||||
)
|
||||
enqueue!(nodeQueue, node => 0)
|
||||
push!(inputSyms, Symbol("data_$(to_var_name(node.id))_in"))
|
||||
end
|
||||
|
||||
node = nothing
|
||||
while !isempty(nodeQueue)
|
||||
prio = peek(nodeQueue)[2]
|
||||
@assert peek(nodeQueue)[2] == 0
|
||||
node = dequeue!(nodeQueue)
|
||||
|
||||
push!(code, get_expression(node))
|
||||
for parent in node.parents
|
||||
# reduce the priority of all parents by one
|
||||
if (!haskey(nodeQueue, parent))
|
||||
enqueue!(nodeQueue, parent => prio + length(parent.children))
|
||||
enqueue!(nodeQueue, parent => length(parent.children) - 1)
|
||||
else
|
||||
nodeQueue[parent] = nodeQueue[parent] - 1
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# node is now the last node we looked at -> the output node
|
||||
outSym = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
|
||||
outSym = Symbol("data_$(to_var_name(node.id))")
|
||||
|
||||
return (
|
||||
code = Expr(:block, code...),
|
||||
@ -38,6 +39,28 @@ function gen_code(graph::DAG)
|
||||
)
|
||||
end
|
||||
|
||||
function execute(generated_code, input::Vector{Particle})
|
||||
(code, inputSymbols, outputSymbol) = generated_code
|
||||
@assert length(input) == length(inputSymbols)
|
||||
|
||||
assignInputs = Vector{Expr}()
|
||||
for i in 1:length(input)
|
||||
push!(
|
||||
assignInputs,
|
||||
Meta.parse(
|
||||
"$(inputSymbols[i]) = ParticleValue(Particle($(input[i]).P0, $(input[i]).P1, $(input[i]).P2, $(input[i]).P3, $(input[i]).m), 1.0)",
|
||||
),
|
||||
)
|
||||
end
|
||||
|
||||
assignInputs = Expr(:block, assignInputs...)
|
||||
eval(assignInputs)
|
||||
eval(code)
|
||||
|
||||
eval(Meta.parse("result = $outputSymbol"))
|
||||
return result
|
||||
end
|
||||
|
||||
function execute(graph::DAG, input::Vector{Particle})
|
||||
(code, inputSymbols, outputSymbol) = gen_code(graph)
|
||||
|
||||
|
@ -21,7 +21,7 @@ function get_expression(::ComputeTaskU, inSymbol::Symbol, outSymbol::Symbol)
|
||||
end
|
||||
|
||||
# compute vertex
|
||||
function compute(::ComputeTaskV, data1, data2)
|
||||
function compute(::ComputeTaskV, data1::ParticleValue, data2::ParticleValue)
|
||||
# calculate new particle from the two input particles
|
||||
p3 = preserve_momentum(data1.p, data2.p)
|
||||
dataOut = ParticleValue(p3, data1.v * vertex() * data2.v)
|
||||
@ -57,7 +57,7 @@ end
|
||||
|
||||
# compute inner edge
|
||||
function compute(::ComputeTaskS1, data)
|
||||
return (particle = data.p, v = data.v * inner_edge(data.p))
|
||||
return ParticleValue(data.p, data.v * inner_edge(data.p))
|
||||
end
|
||||
|
||||
function get_expression(::ComputeTaskS1, inSymbol::Symbol, outSymbol::Symbol)
|
||||
@ -82,28 +82,22 @@ function get_expression(node::ComputeTaskNode)
|
||||
t = typeof(node.task)
|
||||
if (t <: ComputeTaskU || t <: ComputeTaskP || t <: ComputeTaskS1) # single input
|
||||
@assert length(node.children) == 1
|
||||
symbolIn =
|
||||
Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))")
|
||||
symbolOut = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
|
||||
symbolIn = Symbol("data_$(to_var_name(node.children[1].id))")
|
||||
symbolOut = Symbol("data_$(to_var_name(node.id))")
|
||||
return get_expression(t(), symbolIn, symbolOut)
|
||||
elseif (t <: ComputeTaskS2 || t <: ComputeTaskV) # double input
|
||||
@assert length(node.children) == 2
|
||||
symbolIn1 =
|
||||
Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))")
|
||||
symbolIn2 =
|
||||
Symbol("data_$(replace(string(node.children[2].id), "-"=>"_"))")
|
||||
symbolOut = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
|
||||
symbolIn1 = Symbol("data_$(to_var_name(node.children[1].id))")
|
||||
symbolIn2 = Symbol("data_$(to_var_name(node.children[2].id))")
|
||||
symbolOut = Symbol("data_$(to_var_name(node.id))")
|
||||
return get_expression(t(), symbolIn1, symbolIn2, symbolOut)
|
||||
elseif (t <: ComputeTaskSum) # vector input
|
||||
@assert length(node.children) > 0
|
||||
inSymbols = Vector{Symbol}()
|
||||
for child in node.children
|
||||
push!(
|
||||
inSymbols,
|
||||
Symbol("data_$(replace(string(child.id), "-"=>"_"))"),
|
||||
)
|
||||
push!(inSymbols, Symbol("data_$(to_var_name(child.id))"))
|
||||
end
|
||||
outSymbol = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
|
||||
outSymbol = Symbol("data_$(to_var_name(node.id))")
|
||||
return get_expression(t(), inSymbols, outSymbol)
|
||||
elseif (t <: FusedComputeTask)
|
||||
# uuuuuh
|
||||
@ -118,12 +112,11 @@ function get_expression(node::DataTaskNode)
|
||||
|
||||
inSymbol = nothing
|
||||
if (length(node.children) == 1)
|
||||
inSymbol =
|
||||
Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))")
|
||||
inSymbol = Symbol("data_$(to_var_name(node.children[1].id))")
|
||||
else
|
||||
inSymbol = Symbol("data_$(replace(string(node.id), "-"=>"_"))_in")
|
||||
inSymbol = Symbol("data_$(to_var_name(node.id))_in")
|
||||
end
|
||||
outSymbol = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
|
||||
outSymbol = Symbol("data_$(to_var_name(node.id))")
|
||||
|
||||
dataTransportExp = Meta.parse("$outSymbol = $inSymbol")
|
||||
|
||||
|
30
src/models/abc/create.jl
Normal file
30
src/models/abc/create.jl
Normal file
@ -0,0 +1,30 @@
|
||||
|
||||
"""
|
||||
Particle(rng)
|
||||
|
||||
Return a randomly generated particle.
|
||||
"""
|
||||
function Particle(rng)
|
||||
return Particle(
|
||||
rand(rng, Float64),
|
||||
rand(rng, Float64),
|
||||
rand(rng, Float64),
|
||||
rand(rng, Float64),
|
||||
rand(rng, Float64),
|
||||
)
|
||||
end
|
||||
|
||||
"""
|
||||
gen_particles(n::Int)
|
||||
|
||||
Return a Vector of `n` randomly generated [`Particle`](@ref)s.
|
||||
"""
|
||||
function gen_particles(n::Int)
|
||||
particles = Vector{Particle}()
|
||||
sizehint!(particles, n)
|
||||
rng = MersenneTwister(0)
|
||||
for i in 1:n
|
||||
push!(particles, Particle(rng))
|
||||
end
|
||||
return particles
|
||||
end
|
@ -15,3 +15,12 @@ Print a short string representation of the edge to io.
|
||||
function show(io::IO, e::Edge)
|
||||
return print(io, "Edge(", e.edge[1], ", ", e.edge[2], ")")
|
||||
end
|
||||
|
||||
"""
|
||||
to_var_name(id::UUID)
|
||||
|
||||
Return the uuid as a string usable as a variable name in code generation.
|
||||
"""
|
||||
function to_var_name(id::UUID)
|
||||
return replace(string(id), "-" => "_")
|
||||
end
|
||||
|
Loading…
x
Reference in New Issue
Block a user