2024-02-28 13:52:46 +01:00

85 lines
3.2 KiB
Julia

using StaticArrays
"""
compute(t::FusedComputeTask, data)
Compute a [`FusedComputeTask`](@ref). This simply asserts false and should not be called. Fused Compute Tasks generate their expressions directly through the other tasks instead.
"""
function compute(t::FusedComputeTask, data...)
@assert false
end
"""
get_function_call(n::Node)
get_function_call(t::AbstractTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol)
For a node or a task together with necessary information, return a vector of [`FunctionCall`](@ref)s for the computation of the node or task.
For ordinary compute or data tasks the vector will contain exactly one element, for a [`FusedComputeTask`](@ref) there can be any number of tasks greater 1.
"""
function get_function_call(t::FusedComputeTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol)
# sort out the symbols to the correct tasks
return [
get_function_call(t.first_func, device, t.t1_inputs, t.t1_output)...,
get_function_call(t.second_func, device, [t.t2_inputs..., t.t1_output], outSymbol)...,
]
end
function get_function_call(
t::CompTask,
device::AbstractDevice,
inSymbols::AbstractVector,
outSymbol::Symbol,
) where {CompTask <: AbstractComputeTask}
return [FunctionCall(compute, inSymbols, SVector{1, Any}(t), outSymbol, device)]
end
function get_function_call(node::ComputeTaskNode)
@assert length(children(node)) <= children(task(node)) "Node $(node) has too many children for its task: node has $(length(node.children)) versus task has $(children(task(node)))\nNode's children: $(getfield.(node.children, :children))"
@assert !ismissing(node.device) "Trying to get expression for an unscheduled ComputeTaskNode\nNode: $(node)"
if (length(node.children) <= 50)
#only use an SVector when there are few children
return get_function_call(
node.task,
node.device,
SVector{length(node.children), Symbol}(Symbol.(to_var_name.(getfield.(children(node), :id)))...),
Symbol(to_var_name(node.id)),
)
else
return get_function_call(
node.task,
node.device,
Symbol.(to_var_name.(getfield.(children(node), :id))),
Symbol(to_var_name(node.id)),
)
end
end
function get_function_call(node::DataTaskNode)
@assert length(children(node)) == 1 "Trying to call get_expression on a data task node that has $(length(node.children)) children instead of 1"
# TODO: dispatch to device implementations generating the copy commands
return [
FunctionCall(
unpack_identity,
SVector{1, Symbol}(Symbol(to_var_name(first(children(node)).id))),
SVector{0, Any}(),
Symbol(to_var_name(node.id)),
first(children(node)).device,
),
]
end
function get_init_function_call(node::DataTaskNode, device::AbstractDevice)
@assert isempty(children(node)) "Trying to call get_init_expression on a data task node that is not an entry node."
return FunctionCall(
unpack_identity,
SVector{1, Symbol}(Symbol("$(to_var_name(node.id))_in")),
SVector{0, Any}(),
Symbol(to_var_name(node.id)),
device,
)
end