Anton Reinhard b7560685d4 Optimizer interface and sample implementation (#19)
Reviewed-on: Rubydragon/MetagraphOptimization.jl#19
Co-authored-by: Anton Reinhard <anton.reinhard@proton.me>
Co-committed-by: Anton Reinhard <anton.reinhard@proton.me>
2023-11-22 13:51:54 +01:00

90 lines
3.4 KiB
Julia

"""
compute(t::FusedComputeTask, data)
Compute a [`FusedComputeTask`](@ref). This simply asserts false and should not be called. Fused Compute Tasks generate their expressions directly through the other tasks instead.
"""
function compute(t::FusedComputeTask, data)
@assert false "This is not implemented and should never be called"
end
"""
get_expression(t::FusedComputeTask, device::AbstractDevice, inExprs::Vector{String}, outExpr::String)
Generate code evaluating a [`FusedComputeTask`](@ref) on `inExprs`, providing the output on `outExpr`.
`inExprs` should be of the correct types and may be heterogeneous. `outExpr` will be of the type of the output of `T2` of t.
"""
function get_expression(t::FusedComputeTask, device::AbstractDevice, inExprs::Vector, outExpr)
inExprs1 = Vector()
for sym in t.t1_inputs
push!(inExprs1, gen_access_expr(device, sym))
end
outExpr1 = gen_access_expr(device, t.t1_output)
inExprs2 = Vector()
for sym in t.t2_inputs
push!(inExprs2, gen_access_expr(device, sym))
end
expr1 = get_expression(t.first_task, device, inExprs1, outExpr1)
expr2 = get_expression(t.second_task, device, [inExprs2..., outExpr1], outExpr)
full_expr = Expr(:block, expr1, expr2)
return full_expr
end
"""
get_expression(node::ComputeTaskNode)
Generate and return code for a given [`ComputeTaskNode`](@ref).
"""
function get_expression(node::ComputeTaskNode)
@assert length(children(node)) <= children(task(node)) "Node $(node) has too many children for its task: node has $(length(node.children)) versus task has $(children(task(node)))\nNode's children: $(getfield.(node.children, :children))"
@assert !ismissing(node.device) "Trying to get expression for an unscheduled ComputeTaskNode\nNode: $(node)"
inExprs = Vector()
for id in getfield.(children(node), :id)
push!(inExprs, gen_access_expr(node.device, Symbol(to_var_name(id))))
end
outExpr = gen_access_expr(node.device, Symbol(to_var_name(node.id)))
return get_expression(task(node), node.device, inExprs, outExpr)
end
"""
get_expression(node::DataTaskNode)
Generate and return code for a given [`DataTaskNode`](@ref).
"""
function get_expression(node::DataTaskNode)
@assert length(children(node)) == 1 "Trying to call get_expression on a data task node that has $(length(node.children)) children instead of 1"
# TODO: dispatch to device implementations generating the copy commands
child = children(node)[1]
inExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(child.id))))
outExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(node.id))))
dataTransportExp = Meta.parse("$outExpr = $inExpr")
return dataTransportExp
end
"""
get_init_expression(node::DataTaskNode, device::AbstractDevice)
Generate and return code for the initial input reading expression for [`DataTaskNode`](@ref)s with 0 children, i.e., entry nodes.
See also: [`get_entry_nodes`](@ref)
"""
function get_init_expression(node::DataTaskNode, device::AbstractDevice)
@assert isempty(children(node)) "Trying to call get_init_expression on a data task node that is not an entry node."
inExpr = eval(gen_access_expr(device, Symbol("$(to_var_name(node.id))_in")))
outExpr = eval(gen_access_expr(device, Symbol(to_var_name(node.id))))
dataTransportExp = Meta.parse("$outExpr = $inExpr")
return dataTransportExp
end