2024-01-03 16:38:32 +01:00
using StaticArrays
2023-10-12 17:51:03 +02:00
"""
compute ( t :: FusedComputeTask , data )
Compute a [ ` FusedComputeTask ` ] ( @ref ) . This simply asserts false and should not be called . Fused Compute Tasks generate their expressions directly through the other tasks instead .
"""
2024-01-03 16:38:32 +01:00
function compute ( t :: FusedComputeTask , data ... )
2024-02-28 13:52:46 +01:00
@assert false
2023-10-12 17:51:03 +02:00
end
"""
2024-01-03 16:38:32 +01:00
get_function_call ( n :: Node )
get_function_call ( t :: AbstractTask , device :: AbstractDevice , inSymbols :: AbstractVector , outSymbol :: Symbol )
2023-10-12 17:51:03 +02:00
2024-01-03 16:38:32 +01:00
For a node or a task together with necessary information , return a vector of [ ` FunctionCall ` ] ( @ref ) s for the computation of the node or task .
2023-10-12 17:51:03 +02:00
2024-01-03 16:38:32 +01:00
For ordinary compute or data tasks the vector will contain exactly one element , for a [ ` FusedComputeTask ` ] ( @ref ) there can be any number of tasks greater 1.
"""
function get_function_call ( t :: FusedComputeTask , device :: AbstractDevice , inSymbols :: AbstractVector , outSymbol :: Symbol )
# sort out the symbols to the correct tasks
return [
2024-02-28 13:52:46 +01:00
get_function_call ( t . first_func , device , t . t1_inputs , t . t1_output ) ... ,
get_function_call ( t . second_func , device , [ t . t2_inputs ... , t . t1_output ] , outSymbol ) ... ,
2024-01-03 16:38:32 +01:00
]
2023-10-12 17:51:03 +02:00
end
2024-01-03 16:38:32 +01:00
function get_function_call (
t :: CompTask ,
device :: AbstractDevice ,
inSymbols :: AbstractVector ,
outSymbol :: Symbol ,
) where { CompTask <: AbstractComputeTask }
return [ FunctionCall ( compute , inSymbols , SVector { 1 , Any } ( t ) , outSymbol , device ) ]
end
2023-10-12 17:51:03 +02:00
2024-01-03 16:38:32 +01:00
function get_function_call ( node :: ComputeTaskNode )
2023-11-22 13:51:54 +01:00
@assert length ( children ( node ) ) <= children ( task ( node ) ) " Node $ ( node ) has too many children for its task: node has $ ( length ( node . children ) ) versus task has $ ( children ( task ( node ) ) ) \n Node's children: $ ( getfield . ( node . children , :children ) ) "
2023-10-12 17:51:03 +02:00
@assert ! ismissing ( node . device ) " Trying to get expression for an unscheduled ComputeTaskNode \n Node: $ ( node ) "
2024-01-03 16:38:32 +01:00
if ( length ( node . children ) <= 50 )
#only use an SVector when there are few children
return get_function_call (
node . task ,
node . device ,
SVector { length ( node . children ) , Symbol } ( Symbol . ( to_var_name . ( getfield . ( children ( node ) , :id ) ) ) ... ) ,
Symbol ( to_var_name ( node . id ) ) ,
)
else
return get_function_call (
node . task ,
node . device ,
Symbol . ( to_var_name . ( getfield . ( children ( node ) , :id ) ) ) ,
Symbol ( to_var_name ( node . id ) ) ,
)
2023-10-12 17:51:03 +02:00
end
end
2024-01-03 16:38:32 +01:00
function get_function_call ( node :: DataTaskNode )
2023-11-22 13:51:54 +01:00
@assert length ( children ( node ) ) == 1 " Trying to call get_expression on a data task node that has $ ( length ( node . children ) ) children instead of 1 "
2023-10-12 17:51:03 +02:00
# TODO: dispatch to device implementations generating the copy commands
2024-01-03 16:38:32 +01:00
return [
FunctionCall (
unpack_identity ,
SVector { 1 , Symbol } ( Symbol ( to_var_name ( first ( children ( node ) ) . id ) ) ) ,
SVector { 0 , Any } ( ) ,
Symbol ( to_var_name ( node . id ) ) ,
first ( children ( node ) ) . device ,
) ,
]
2023-10-12 17:51:03 +02:00
end
2024-01-03 16:38:32 +01:00
function get_init_function_call ( node :: DataTaskNode , device :: AbstractDevice )
2023-11-22 13:51:54 +01:00
@assert isempty ( children ( node ) ) " Trying to call get_init_expression on a data task node that is not an entry node. "
2023-10-12 17:51:03 +02:00
2024-01-03 16:38:32 +01:00
return FunctionCall (
unpack_identity ,
SVector { 1 , Symbol } ( Symbol ( " $ ( to_var_name ( node . id ) ) _in " ) ) ,
SVector { 0 , Any } ( ) ,
Symbol ( to_var_name ( node . id ) ) ,
device ,
)
2023-10-12 17:51:03 +02:00
end