Reviewed-on: Rubydragon/MetagraphOptimization.jl#19 Co-authored-by: Anton Reinhard <anton.reinhard@proton.me> Co-committed-by: Anton Reinhard <anton.reinhard@proton.me>
90 lines
3.4 KiB
Julia
90 lines
3.4 KiB
Julia
|
|
"""
|
|
compute(t::FusedComputeTask, data)
|
|
|
|
Compute a [`FusedComputeTask`](@ref). This simply asserts false and should not be called. Fused Compute Tasks generate their expressions directly through the other tasks instead.
|
|
"""
|
|
function compute(t::FusedComputeTask, data)
|
|
@assert false "This is not implemented and should never be called"
|
|
end
|
|
|
|
"""
|
|
get_expression(t::FusedComputeTask, device::AbstractDevice, inExprs::Vector{String}, outExpr::String)
|
|
|
|
Generate code evaluating a [`FusedComputeTask`](@ref) on `inExprs`, providing the output on `outExpr`.
|
|
`inExprs` should be of the correct types and may be heterogeneous. `outExpr` will be of the type of the output of `T2` of t.
|
|
"""
|
|
function get_expression(t::FusedComputeTask, device::AbstractDevice, inExprs::Vector, outExpr)
|
|
inExprs1 = Vector()
|
|
for sym in t.t1_inputs
|
|
push!(inExprs1, gen_access_expr(device, sym))
|
|
end
|
|
|
|
outExpr1 = gen_access_expr(device, t.t1_output)
|
|
|
|
inExprs2 = Vector()
|
|
for sym in t.t2_inputs
|
|
push!(inExprs2, gen_access_expr(device, sym))
|
|
end
|
|
|
|
expr1 = get_expression(t.first_task, device, inExprs1, outExpr1)
|
|
expr2 = get_expression(t.second_task, device, [inExprs2..., outExpr1], outExpr)
|
|
|
|
full_expr = Expr(:block, expr1, expr2)
|
|
|
|
return full_expr
|
|
end
|
|
|
|
"""
|
|
get_expression(node::ComputeTaskNode)
|
|
|
|
Generate and return code for a given [`ComputeTaskNode`](@ref).
|
|
"""
|
|
function get_expression(node::ComputeTaskNode)
|
|
@assert length(children(node)) <= children(task(node)) "Node $(node) has too many children for its task: node has $(length(node.children)) versus task has $(children(task(node)))\nNode's children: $(getfield.(node.children, :children))"
|
|
@assert !ismissing(node.device) "Trying to get expression for an unscheduled ComputeTaskNode\nNode: $(node)"
|
|
|
|
inExprs = Vector()
|
|
for id in getfield.(children(node), :id)
|
|
push!(inExprs, gen_access_expr(node.device, Symbol(to_var_name(id))))
|
|
end
|
|
outExpr = gen_access_expr(node.device, Symbol(to_var_name(node.id)))
|
|
|
|
return get_expression(task(node), node.device, inExprs, outExpr)
|
|
end
|
|
|
|
"""
|
|
get_expression(node::DataTaskNode)
|
|
|
|
Generate and return code for a given [`DataTaskNode`](@ref).
|
|
"""
|
|
function get_expression(node::DataTaskNode)
|
|
@assert length(children(node)) == 1 "Trying to call get_expression on a data task node that has $(length(node.children)) children instead of 1"
|
|
|
|
# TODO: dispatch to device implementations generating the copy commands
|
|
|
|
child = children(node)[1]
|
|
inExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(child.id))))
|
|
outExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(node.id))))
|
|
dataTransportExp = Meta.parse("$outExpr = $inExpr")
|
|
|
|
return dataTransportExp
|
|
end
|
|
|
|
"""
|
|
get_init_expression(node::DataTaskNode, device::AbstractDevice)
|
|
|
|
Generate and return code for the initial input reading expression for [`DataTaskNode`](@ref)s with 0 children, i.e., entry nodes.
|
|
|
|
See also: [`get_entry_nodes`](@ref)
|
|
"""
|
|
function get_init_expression(node::DataTaskNode, device::AbstractDevice)
|
|
@assert isempty(children(node)) "Trying to call get_init_expression on a data task node that is not an entry node."
|
|
|
|
inExpr = eval(gen_access_expr(device, Symbol("$(to_var_name(node.id))_in")))
|
|
outExpr = eval(gen_access_expr(device, Symbol(to_var_name(node.id))))
|
|
dataTransportExp = Meta.parse("$outExpr = $inExpr")
|
|
|
|
return dataTransportExp
|
|
end
|