using StaticArrays """ compute(t::FusedComputeTask, data) Compute a [`FusedComputeTask`](@ref). This simply asserts false and should not be called. Fused Compute Tasks generate their expressions directly through the other tasks instead. """ function compute(t::FusedComputeTask, data...) @assert false end """ get_function_call(n::Node) get_function_call(t::AbstractTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol) For a node or a task together with necessary information, return a vector of [`FunctionCall`](@ref)s for the computation of the node or task. For ordinary compute or data tasks the vector will contain exactly one element, for a [`FusedComputeTask`](@ref) there can be any number of tasks greater 1. """ function get_function_call(t::FusedComputeTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol) # sort out the symbols to the correct tasks return [ get_function_call(t.first_func, device, t.t1_inputs, t.t1_output)..., get_function_call(t.second_func, device, [t.t2_inputs..., t.t1_output], outSymbol)..., ] end function get_function_call( t::CompTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol, ) where {CompTask <: AbstractComputeTask} return [FunctionCall(compute, inSymbols, SVector{1, Any}(t), outSymbol, device)] end function get_function_call(node::ComputeTaskNode) @assert length(children(node)) <= children(task(node)) "Node $(node) has too many children for its task: node has $(length(node.children)) versus task has $(children(task(node)))\nNode's children: $(getfield.(node.children, :children))" @assert !ismissing(node.device) "Trying to get expression for an unscheduled ComputeTaskNode\nNode: $(node)" if (length(node.children) <= 50) #only use an SVector when there are few children return get_function_call( node.task, node.device, SVector{length(node.children), Symbol}(Symbol.(to_var_name.(getfield.(children(node), :id)))...), Symbol(to_var_name(node.id)), ) else return get_function_call( node.task, node.device, Symbol.(to_var_name.(getfield.(children(node), :id))), Symbol(to_var_name(node.id)), ) end end function get_function_call(node::DataTaskNode) @assert length(children(node)) == 1 "Trying to call get_expression on a data task node that has $(length(node.children)) children instead of 1" # TODO: dispatch to device implementations generating the copy commands return [ FunctionCall( unpack_identity, SVector{1, Symbol}(Symbol(to_var_name(first(children(node)).id))), SVector{0, Any}(), Symbol(to_var_name(node.id)), first(children(node)).device, ), ] end function get_init_function_call(node::DataTaskNode, device::AbstractDevice) @assert isempty(children(node)) "Trying to call get_init_expression on a data task node that is not an entry node." return FunctionCall( unpack_identity, SVector{1, Symbol}(Symbol("$(to_var_name(node.id))_in")), SVector{0, Any}(), Symbol(to_var_name(node.id)), device, ) end