Fix topoligical ordering on the graph

This commit is contained in:
2023-09-05 12:14:41 +02:00
parent 7a1a97dac8
commit 0f78053ccf
6 changed files with 105 additions and 27 deletions

View File

@@ -21,7 +21,7 @@ function get_expression(::ComputeTaskU, inSymbol::Symbol, outSymbol::Symbol)
end
# compute vertex
function compute(::ComputeTaskV, data1, data2)
function compute(::ComputeTaskV, data1::ParticleValue, data2::ParticleValue)
# calculate new particle from the two input particles
p3 = preserve_momentum(data1.p, data2.p)
dataOut = ParticleValue(p3, data1.v * vertex() * data2.v)
@@ -57,7 +57,7 @@ end
# compute inner edge
function compute(::ComputeTaskS1, data)
return (particle = data.p, v = data.v * inner_edge(data.p))
return ParticleValue(data.p, data.v * inner_edge(data.p))
end
function get_expression(::ComputeTaskS1, inSymbol::Symbol, outSymbol::Symbol)
@@ -82,28 +82,22 @@ function get_expression(node::ComputeTaskNode)
t = typeof(node.task)
if (t <: ComputeTaskU || t <: ComputeTaskP || t <: ComputeTaskS1) # single input
@assert length(node.children) == 1
symbolIn =
Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))")
symbolOut = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
symbolIn = Symbol("data_$(to_var_name(node.children[1].id))")
symbolOut = Symbol("data_$(to_var_name(node.id))")
return get_expression(t(), symbolIn, symbolOut)
elseif (t <: ComputeTaskS2 || t <: ComputeTaskV) # double input
@assert length(node.children) == 2
symbolIn1 =
Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))")
symbolIn2 =
Symbol("data_$(replace(string(node.children[2].id), "-"=>"_"))")
symbolOut = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
symbolIn1 = Symbol("data_$(to_var_name(node.children[1].id))")
symbolIn2 = Symbol("data_$(to_var_name(node.children[2].id))")
symbolOut = Symbol("data_$(to_var_name(node.id))")
return get_expression(t(), symbolIn1, symbolIn2, symbolOut)
elseif (t <: ComputeTaskSum) # vector input
@assert length(node.children) > 0
inSymbols = Vector{Symbol}()
for child in node.children
push!(
inSymbols,
Symbol("data_$(replace(string(child.id), "-"=>"_"))"),
)
push!(inSymbols, Symbol("data_$(to_var_name(child.id))"))
end
outSymbol = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
outSymbol = Symbol("data_$(to_var_name(node.id))")
return get_expression(t(), inSymbols, outSymbol)
elseif (t <: FusedComputeTask)
# uuuuuh
@@ -118,12 +112,11 @@ function get_expression(node::DataTaskNode)
inSymbol = nothing
if (length(node.children) == 1)
inSymbol =
Symbol("data_$(replace(string(node.children[1].id), "-"=>"_"))")
inSymbol = Symbol("data_$(to_var_name(node.children[1].id))")
else
inSymbol = Symbol("data_$(replace(string(node.id), "-"=>"_"))_in")
inSymbol = Symbol("data_$(to_var_name(node.id))_in")
end
outSymbol = Symbol("data_$(replace(string(node.id), "-"=>"_"))")
outSymbol = Symbol("data_$(to_var_name(node.id))")
dataTransportExp = Meta.parse("$outSymbol = $inSymbol")

30
src/models/abc/create.jl Normal file
View File

@@ -0,0 +1,30 @@
"""
Particle(rng)
Return a randomly generated particle.
"""
function Particle(rng)
return Particle(
rand(rng, Float64),
rand(rng, Float64),
rand(rng, Float64),
rand(rng, Float64),
rand(rng, Float64),
)
end
"""
gen_particles(n::Int)
Return a Vector of `n` randomly generated [`Particle`](@ref)s.
"""
function gen_particles(n::Int)
particles = Vector{Particle}()
sizehint!(particles, n)
rng = MersenneTwister(0)
for i in 1:n
push!(particles, Particle(rng))
end
return particles
end