Optimizer interface and sample implementation (#19)
Reviewed-on: Rubydragon/MetagraphOptimization.jl#19 Co-authored-by: Anton Reinhard <anton.reinhard@proton.me> Co-committed-by: Anton Reinhard <anton.reinhard@proton.me>
This commit is contained in:
@@ -41,16 +41,16 @@ end
|
||||
Generate and return code for a given [`ComputeTaskNode`](@ref).
|
||||
"""
|
||||
function get_expression(node::ComputeTaskNode)
|
||||
@assert length(node.children) <= children(node.task) "Node $(node) has too many children for its task: node has $(length(node.children)) versus task has $(children(node.task))\nNode's children: $(getfield.(node.children, :children))"
|
||||
@assert length(children(node)) <= children(task(node)) "Node $(node) has too many children for its task: node has $(length(node.children)) versus task has $(children(task(node)))\nNode's children: $(getfield.(node.children, :children))"
|
||||
@assert !ismissing(node.device) "Trying to get expression for an unscheduled ComputeTaskNode\nNode: $(node)"
|
||||
|
||||
inExprs = Vector()
|
||||
for id in getfield.(node.children, :id)
|
||||
for id in getfield.(children(node), :id)
|
||||
push!(inExprs, gen_access_expr(node.device, Symbol(to_var_name(id))))
|
||||
end
|
||||
outExpr = gen_access_expr(node.device, Symbol(to_var_name(node.id)))
|
||||
|
||||
return get_expression(node.task, node.device, inExprs, outExpr)
|
||||
return get_expression(task(node), node.device, inExprs, outExpr)
|
||||
end
|
||||
|
||||
"""
|
||||
@@ -59,11 +59,11 @@ end
|
||||
Generate and return code for a given [`DataTaskNode`](@ref).
|
||||
"""
|
||||
function get_expression(node::DataTaskNode)
|
||||
@assert length(node.children) == 1 "Trying to call get_expression on a data task node that has $(length(node.children)) children instead of 1"
|
||||
@assert length(children(node)) == 1 "Trying to call get_expression on a data task node that has $(length(node.children)) children instead of 1"
|
||||
|
||||
# TODO: dispatch to device implementations generating the copy commands
|
||||
|
||||
child = node.children[1]
|
||||
child = children(node)[1]
|
||||
inExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(child.id))))
|
||||
outExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(node.id))))
|
||||
dataTransportExp = Meta.parse("$outExpr = $inExpr")
|
||||
@@ -79,7 +79,7 @@ Generate and return code for the initial input reading expression for [`DataTask
|
||||
See also: [`get_entry_nodes`](@ref)
|
||||
"""
|
||||
function get_init_expression(node::DataTaskNode, device::AbstractDevice)
|
||||
@assert isempty(node.children) "Trying to call get_init_expression on a data task node that is not an entry node."
|
||||
@assert isempty(children(node)) "Trying to call get_init_expression on a data task node that is not an entry node."
|
||||
|
||||
inExpr = eval(gen_access_expr(device, Symbol("$(to_var_name(node.id))_in")))
|
||||
outExpr = eval(gen_access_expr(device, Symbol(to_var_name(node.id))))
|
||||
|
@@ -17,15 +17,16 @@ copy(t::AbstractComputeTask) = typeof(t)()
|
||||
|
||||
Return a copy of th egiven [`FusedComputeTask`](@ref).
|
||||
"""
|
||||
function copy(t::FusedComputeTask{T1, T2}) where {T1, T2}
|
||||
return FusedComputeTask{T1, T2}(
|
||||
copy(t.first_task),
|
||||
copy(t.second_task),
|
||||
copy(t.t1_inputs),
|
||||
t.t1_output,
|
||||
copy(t.t2_inputs),
|
||||
)
|
||||
function copy(t::FusedComputeTask)
|
||||
return FusedComputeTask(copy(t.first_task), copy(t.second_task), copy(t.t1_inputs), t.t1_output, copy(t.t2_inputs))
|
||||
end
|
||||
|
||||
FusedComputeTask{T1, T2}(t1_inputs::Vector{String}, t1_output::String, t2_inputs::Vector{String}) where {T1, T2} =
|
||||
FusedComputeTask{T1, T2}(T1(), T2(), t1_inputs, t1_output, t2_inputs)
|
||||
function FusedComputeTask(
|
||||
T1::Type{<:AbstractComputeTask},
|
||||
T2::Type{<:AbstractComputeTask},
|
||||
t1_inputs::Vector{String},
|
||||
t1_output::String,
|
||||
t2_inputs::Vector{String},
|
||||
)
|
||||
return FusedComputeTask(T1(), T2(), t1_inputs, t1_output, t2_inputs)
|
||||
end
|
||||
|
@@ -30,7 +30,7 @@ compute(t::AbstractDataTask; data...) = data
|
||||
|
||||
Fallback implementation of the compute effort of a task, throwing an error.
|
||||
"""
|
||||
function compute_effort(t::AbstractTask)
|
||||
function compute_effort(t::AbstractTask)::Float64
|
||||
# default implementation using compute
|
||||
return error("Need to implement compute_effort()")
|
||||
end
|
||||
@@ -40,7 +40,7 @@ end
|
||||
|
||||
Fallback implementation of the data of a task, throwing an error.
|
||||
"""
|
||||
function data(t::AbstractTask)
|
||||
function data(t::AbstractTask)::Float64
|
||||
return error("Need to implement data()")
|
||||
end
|
||||
|
||||
@@ -49,28 +49,28 @@ end
|
||||
|
||||
Return the compute effort of a data task, always zero, regardless of the specific task.
|
||||
"""
|
||||
compute_effort(t::AbstractDataTask) = 0.0
|
||||
compute_effort(t::AbstractDataTask)::Float64 = 0.0
|
||||
|
||||
"""
|
||||
data(t::AbstractDataTask)
|
||||
|
||||
Return the data of a data task. Given by the task's `.data` field.
|
||||
"""
|
||||
data(t::AbstractDataTask) = getfield(t, :data)
|
||||
data(t::AbstractDataTask)::Float64 = getfield(t, :data)
|
||||
|
||||
"""
|
||||
data(t::AbstractComputeTask)
|
||||
|
||||
Return the data of a compute task, always zero, regardless of the specific task.
|
||||
"""
|
||||
data(t::AbstractComputeTask) = 0.0
|
||||
data(t::AbstractComputeTask)::Float64 = 0.0
|
||||
|
||||
"""
|
||||
compute_effort(t::FusedComputeTask)
|
||||
|
||||
Return the compute effort of a fused compute task.
|
||||
"""
|
||||
function compute_effort(t::FusedComputeTask)
|
||||
function compute_effort(t::FusedComputeTask)::Float64
|
||||
return compute_effort(t.first_task) + compute_effort(t.second_task)
|
||||
end
|
||||
|
||||
@@ -79,4 +79,4 @@ end
|
||||
|
||||
Return a tuple of a the fused compute task's components' types.
|
||||
"""
|
||||
get_types(::FusedComputeTask{T1, T2}) where {T1, T2} = (T1, T2)
|
||||
get_types(t::FusedComputeTask) = (typeof(t.first_task), typeof(t.second_task))
|
||||
|
@@ -26,9 +26,9 @@ A fused compute task made up of the computation of first `T1` and then `T2`.
|
||||
|
||||
Also see: [`get_types`](@ref).
|
||||
"""
|
||||
struct FusedComputeTask{T1 <: AbstractComputeTask, T2 <: AbstractComputeTask} <: AbstractComputeTask
|
||||
first_task::T1
|
||||
second_task::T2
|
||||
struct FusedComputeTask <: AbstractComputeTask
|
||||
first_task::AbstractComputeTask
|
||||
second_task::AbstractComputeTask
|
||||
# the names of the inputs for T1
|
||||
t1_inputs::Vector{Symbol}
|
||||
# output name of T1
|
||||
|
Reference in New Issue
Block a user