Add number of children information to sum tasks

This commit is contained in:
Anton Reinhard 2023-09-27 16:16:33 +02:00
parent 24ade323f0
commit 4b44eb5286
6 changed files with 18 additions and 8 deletions

View File

@ -67,6 +67,8 @@ function gen_input_assignment_code(
for type in types(particles[1][1])
in_out_count[type] = (0, 0)
end
# we assume that the particles with lower numbers are the inputs, and then the output particles follow
for p in particles[1]
(i, o) = in_out_count[p.type]
in_out_count[p.type] = (i + 1, o)

View File

@ -208,7 +208,7 @@ Generate and return code for a given [`ComputeTaskNode`](@ref).
"""
function get_expression(node::ComputeTaskNode)
t = typeof(node.task)
# @assert length(node.children) == children(node.task) || t <: ComputeTaskSum "Node $(node) has inconsistent number of children"
@assert length(node.children) == children(node.task) "Node $(node) has inconsistent number of children"
if (t <: ComputeTaskU || t <: ComputeTaskP || t <: ComputeTaskS1) # single input
symbolIn = "data_$(to_var_name(node.children[1].id))"

View File

@ -3,6 +3,8 @@ using Random
using Roots
using ForwardDiff
ComputeTaskSum() = ComputeTaskSum(0)
"""
gen_particles(in::Vector{ParticleType}, out::Vector{ParticleType})

View File

@ -65,7 +65,7 @@ function parse_abc(filename::String, verbose::Bool = false)
sum_node = insert_node!(
graph,
make_node(ComputeTaskSum()),
make_node(ComputeTaskSum(0)),
track = false,
invalidate_cache = false,
)
@ -376,6 +376,7 @@ function parse_abc(filename::String, verbose::Bool = false)
track = false,
invalidate_cache = false,
)
add_child!(sum_node.task)
elseif occursin(regex_plus, node)
if (verbose)
println("\rReading Nodes Complete ")
@ -399,6 +400,7 @@ function parse_abc(filename::String, verbose::Bool = false)
if (verbose)
println("Done")
end
# don't actually need to read the edges
return graph
end

View File

@ -147,12 +147,9 @@ children(::ComputeTaskV) = 2
"""
children(::ComputeTaskSum)
Return the number of children of a ComputeTaskSum, since this is variable and the task doesn't know
how many children it will sum over, return a wildcard -1.
TODO: this is kind of bad because it means we can't fuse with a sum task
Return the number of children of a ComputeTaskSum.
"""
children(::ComputeTaskSum) = -1
children(t::ComputeTaskSum) = t.children_number
"""
children(t::FusedComputeTask)
@ -162,3 +159,8 @@ Return the number of children of a FusedComputeTask.
function children(t::FusedComputeTask)
return length(union(Set(t.t1_inputs), Set(t.t2_inputs)))
end
function add_child!(t::ComputeTaskSum)
t.children_number += 1
return nothing
end

View File

@ -47,7 +47,9 @@ struct ComputeTaskU <: AbstractComputeTask end
Task that sums all its inputs, n children.
"""
struct ComputeTaskSum <: AbstractComputeTask end
mutable struct ComputeTaskSum <: AbstractComputeTask
children_number::Int
end
"""
ABC_TASKS