Tape Machine (#30)
Adds a tape machine way of executing the code. The tape machine is a series of FunctionCall objects, which can either be called one by one, or be used to generate expressions to make up a function. Reviewed-on: Rubydragon/MetagraphOptimization.jl#30 Co-authored-by: Anton Reinhard <anton.reinhard@proton.me> Co-committed-by: Anton Reinhard <anton.reinhard@proton.me>
This commit is contained in:
@ -1,89 +1,85 @@
|
||||
using StaticArrays
|
||||
|
||||
"""
|
||||
compute(t::FusedComputeTask, data)
|
||||
|
||||
Compute a [`FusedComputeTask`](@ref). This simply asserts false and should not be called. Fused Compute Tasks generate their expressions directly through the other tasks instead.
|
||||
"""
|
||||
function compute(t::FusedComputeTask, data)
|
||||
@assert false "This is not implemented and should never be called"
|
||||
function compute(t::FusedComputeTask, data...)
|
||||
inter = compute(t.first_task)
|
||||
return compute(t.second_task, inter, data2...)
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(t::FusedComputeTask, device::AbstractDevice, inExprs::Vector{String}, outExpr::String)
|
||||
get_function_call(n::Node)
|
||||
get_function_call(t::AbstractTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol)
|
||||
|
||||
Generate code evaluating a [`FusedComputeTask`](@ref) on `inExprs`, providing the output on `outExpr`.
|
||||
`inExprs` should be of the correct types and may be heterogeneous. `outExpr` will be of the type of the output of `T2` of t.
|
||||
For a node or a task together with necessary information, return a vector of [`FunctionCall`](@ref)s for the computation of the node or task.
|
||||
|
||||
For ordinary compute or data tasks the vector will contain exactly one element, for a [`FusedComputeTask`](@ref) there can be any number of tasks greater 1.
|
||||
"""
|
||||
function get_expression(t::FusedComputeTask, device::AbstractDevice, inExprs::Vector, outExpr)
|
||||
inExprs1 = Vector()
|
||||
for sym in t.t1_inputs
|
||||
push!(inExprs1, gen_access_expr(device, sym))
|
||||
end
|
||||
|
||||
outExpr1 = gen_access_expr(device, t.t1_output)
|
||||
|
||||
inExprs2 = Vector()
|
||||
for sym in t.t2_inputs
|
||||
push!(inExprs2, gen_access_expr(device, sym))
|
||||
end
|
||||
|
||||
expr1 = get_expression(t.first_task, device, inExprs1, outExpr1)
|
||||
expr2 = get_expression(t.second_task, device, [inExprs2..., outExpr1], outExpr)
|
||||
|
||||
full_expr = Expr(:block, expr1, expr2)
|
||||
|
||||
return full_expr
|
||||
function get_function_call(t::FusedComputeTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol)
|
||||
# sort out the symbols to the correct tasks
|
||||
return [
|
||||
get_function_call(t.first_task, device, t.t1_inputs, t.t1_output)...,
|
||||
get_function_call(t.second_task, device, [t.t2_inputs..., t.t1_output], outSymbol)...,
|
||||
]
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(node::ComputeTaskNode)
|
||||
function get_function_call(
|
||||
t::CompTask,
|
||||
device::AbstractDevice,
|
||||
inSymbols::AbstractVector,
|
||||
outSymbol::Symbol,
|
||||
) where {CompTask <: AbstractComputeTask}
|
||||
return [FunctionCall(compute, inSymbols, SVector{1, Any}(t), outSymbol, device)]
|
||||
end
|
||||
|
||||
Generate and return code for a given [`ComputeTaskNode`](@ref).
|
||||
"""
|
||||
function get_expression(node::ComputeTaskNode)
|
||||
function get_function_call(node::ComputeTaskNode)
|
||||
@assert length(children(node)) <= children(task(node)) "Node $(node) has too many children for its task: node has $(length(node.children)) versus task has $(children(task(node)))\nNode's children: $(getfield.(node.children, :children))"
|
||||
@assert !ismissing(node.device) "Trying to get expression for an unscheduled ComputeTaskNode\nNode: $(node)"
|
||||
|
||||
inExprs = Vector()
|
||||
for id in getfield.(children(node), :id)
|
||||
push!(inExprs, gen_access_expr(node.device, Symbol(to_var_name(id))))
|
||||
if (length(node.children) <= 50)
|
||||
#only use an SVector when there are few children
|
||||
return get_function_call(
|
||||
node.task,
|
||||
node.device,
|
||||
SVector{length(node.children), Symbol}(Symbol.(to_var_name.(getfield.(children(node), :id)))...),
|
||||
Symbol(to_var_name(node.id)),
|
||||
)
|
||||
else
|
||||
return get_function_call(
|
||||
node.task,
|
||||
node.device,
|
||||
Symbol.(to_var_name.(getfield.(children(node), :id))),
|
||||
Symbol(to_var_name(node.id)),
|
||||
)
|
||||
end
|
||||
outExpr = gen_access_expr(node.device, Symbol(to_var_name(node.id)))
|
||||
|
||||
return get_expression(task(node), node.device, inExprs, outExpr)
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(node::DataTaskNode)
|
||||
|
||||
Generate and return code for a given [`DataTaskNode`](@ref).
|
||||
"""
|
||||
function get_expression(node::DataTaskNode)
|
||||
function get_function_call(node::DataTaskNode)
|
||||
@assert length(children(node)) == 1 "Trying to call get_expression on a data task node that has $(length(node.children)) children instead of 1"
|
||||
|
||||
# TODO: dispatch to device implementations generating the copy commands
|
||||
|
||||
child = children(node)[1]
|
||||
inExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(child.id))))
|
||||
outExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(node.id))))
|
||||
dataTransportExp = Meta.parse("$outExpr = $inExpr")
|
||||
|
||||
return dataTransportExp
|
||||
return [
|
||||
FunctionCall(
|
||||
unpack_identity,
|
||||
SVector{1, Symbol}(Symbol(to_var_name(first(children(node)).id))),
|
||||
SVector{0, Any}(),
|
||||
Symbol(to_var_name(node.id)),
|
||||
first(children(node)).device,
|
||||
),
|
||||
]
|
||||
end
|
||||
|
||||
"""
|
||||
get_init_expression(node::DataTaskNode, device::AbstractDevice)
|
||||
|
||||
Generate and return code for the initial input reading expression for [`DataTaskNode`](@ref)s with 0 children, i.e., entry nodes.
|
||||
|
||||
See also: [`get_entry_nodes`](@ref)
|
||||
"""
|
||||
function get_init_expression(node::DataTaskNode, device::AbstractDevice)
|
||||
function get_init_function_call(node::DataTaskNode, device::AbstractDevice)
|
||||
@assert isempty(children(node)) "Trying to call get_init_expression on a data task node that is not an entry node."
|
||||
|
||||
inExpr = eval(gen_access_expr(device, Symbol("$(to_var_name(node.id))_in")))
|
||||
outExpr = eval(gen_access_expr(device, Symbol(to_var_name(node.id))))
|
||||
dataTransportExp = Meta.parse("$outExpr = $inExpr")
|
||||
|
||||
return dataTransportExp
|
||||
return FunctionCall(
|
||||
unpack_identity,
|
||||
SVector{1, Symbol}(Symbol("$(to_var_name(node.id))_in")),
|
||||
SVector{0, Any}(),
|
||||
Symbol(to_var_name(node.id)),
|
||||
device,
|
||||
)
|
||||
end
|
||||
|
Reference in New Issue
Block a user