Add number of children information to sum tasks
This commit is contained in:
parent
24ade323f0
commit
4b44eb5286
@ -67,6 +67,8 @@ function gen_input_assignment_code(
|
||||
for type in types(particles[1][1])
|
||||
in_out_count[type] = (0, 0)
|
||||
end
|
||||
|
||||
# we assume that the particles with lower numbers are the inputs, and then the output particles follow
|
||||
for p in particles[1]
|
||||
(i, o) = in_out_count[p.type]
|
||||
in_out_count[p.type] = (i + 1, o)
|
||||
|
@ -208,7 +208,7 @@ Generate and return code for a given [`ComputeTaskNode`](@ref).
|
||||
"""
|
||||
function get_expression(node::ComputeTaskNode)
|
||||
t = typeof(node.task)
|
||||
# @assert length(node.children) == children(node.task) || t <: ComputeTaskSum "Node $(node) has inconsistent number of children"
|
||||
@assert length(node.children) == children(node.task) "Node $(node) has inconsistent number of children"
|
||||
|
||||
if (t <: ComputeTaskU || t <: ComputeTaskP || t <: ComputeTaskS1) # single input
|
||||
symbolIn = "data_$(to_var_name(node.children[1].id))"
|
||||
|
@ -3,6 +3,8 @@ using Random
|
||||
using Roots
|
||||
using ForwardDiff
|
||||
|
||||
ComputeTaskSum() = ComputeTaskSum(0)
|
||||
|
||||
"""
|
||||
gen_particles(in::Vector{ParticleType}, out::Vector{ParticleType})
|
||||
|
||||
|
@ -65,7 +65,7 @@ function parse_abc(filename::String, verbose::Bool = false)
|
||||
|
||||
sum_node = insert_node!(
|
||||
graph,
|
||||
make_node(ComputeTaskSum()),
|
||||
make_node(ComputeTaskSum(0)),
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
@ -376,6 +376,7 @@ function parse_abc(filename::String, verbose::Bool = false)
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
add_child!(sum_node.task)
|
||||
elseif occursin(regex_plus, node)
|
||||
if (verbose)
|
||||
println("\rReading Nodes Complete ")
|
||||
@ -399,6 +400,7 @@ function parse_abc(filename::String, verbose::Bool = false)
|
||||
if (verbose)
|
||||
println("Done")
|
||||
end
|
||||
|
||||
# don't actually need to read the edges
|
||||
return graph
|
||||
end
|
||||
|
@ -147,12 +147,9 @@ children(::ComputeTaskV) = 2
|
||||
"""
|
||||
children(::ComputeTaskSum)
|
||||
|
||||
Return the number of children of a ComputeTaskSum, since this is variable and the task doesn't know
|
||||
how many children it will sum over, return a wildcard -1.
|
||||
|
||||
TODO: this is kind of bad because it means we can't fuse with a sum task
|
||||
Return the number of children of a ComputeTaskSum.
|
||||
"""
|
||||
children(::ComputeTaskSum) = -1
|
||||
children(t::ComputeTaskSum) = t.children_number
|
||||
|
||||
"""
|
||||
children(t::FusedComputeTask)
|
||||
@ -162,3 +159,8 @@ Return the number of children of a FusedComputeTask.
|
||||
function children(t::FusedComputeTask)
|
||||
return length(union(Set(t.t1_inputs), Set(t.t2_inputs)))
|
||||
end
|
||||
|
||||
function add_child!(t::ComputeTaskSum)
|
||||
t.children_number += 1
|
||||
return nothing
|
||||
end
|
||||
|
@ -47,7 +47,9 @@ struct ComputeTaskU <: AbstractComputeTask end
|
||||
|
||||
Task that sums all its inputs, n children.
|
||||
"""
|
||||
struct ComputeTaskSum <: AbstractComputeTask end
|
||||
mutable struct ComputeTaskSum <: AbstractComputeTask
|
||||
children_number::Int
|
||||
end
|
||||
|
||||
"""
|
||||
ABC_TASKS
|
||||
|
Loading…
x
Reference in New Issue
Block a user