Use the scheduling information in the execution

This commit is contained in:
Anton Reinhard 2023-10-12 15:15:36 +02:00
parent 9b28601f18
commit 4dcb616606
5 changed files with 50 additions and 46 deletions

View File

@ -1,4 +1,3 @@
"""
gen_code(graph::DAG)
@ -17,7 +16,12 @@ function gen_code(graph::DAG, machine::Machine)
sizehint!(codeAcc, length(graph.nodes))
for node in sched
push!(codeAcc, get_expression(node, machine.devices[1]))
# TODO: this is kind of ugly, should init nodes be scheduled differently from the rest?
if (node isa DataTaskNode && length(node.children) == 0)
push!(codeAcc, get_init_expression(node, entry_device(machine)))
continue
end
push!(codeAcc, get_expression(node))
end
# get inSymbols
@ -72,9 +76,8 @@ function gen_input_assignment_code(
end
for symbol in symbols
# TODO generate correct access expression
# TODO how to define cahce strategies?
device = machine.devices[1]
# TODO: how to get the "default" cpu device?
device = entry_device(machine)
evalExpr = eval(gen_access_expr(device, symbol))
push!(assignInputs, Meta.parse("$(evalExpr) = ParticleValue($p, 1.0)"))
end
@ -94,12 +97,8 @@ function get_compute_function(graph::DAG, process::AbstractProcessDescription, m
initCaches = gen_cache_init_code(machine)
assignInputs = gen_input_assignment_code(inputSymbols, process, machine, :input)
# TODO generate correct access expression
# TODO how to define cache strategies?
device = machine.devices[1]
functionId = to_var_name(UUIDs.uuid1(rng[1]))
resSym = eval(gen_access_expr(device, outputSymbol))
resSym = eval(gen_access_expr(entry_device(machine), outputSymbol))
expr = Meta.parse(
"function compute_$(functionId)(input::AbstractProcessInput) $initCaches; $assignInputs; $code; return $resSym; end",
)
@ -127,12 +126,9 @@ function execute(graph::DAG, process::AbstractProcessDescription, machine::Machi
initCaches = gen_cache_init_code(machine)
assignInputs = gen_input_assignment_code(inputSymbols, process, machine, :input)
# TODO generate correct access expression
# TODO how to define cache strategies?
device = machine.devices[1]
functionId = to_var_name(UUIDs.uuid1(rng[1]))
resSym = eval(gen_access_expr(device, outputSymbol))
resSym = eval(gen_access_expr(entry_device(machine), outputSymbol))
expr = Meta.parse(
"function compute_$(functionId)(input::AbstractProcessInput) $initCaches; $assignInputs; $code; return $resSym; end",
)

View File

@ -9,6 +9,15 @@ function device_types()
return DEVICE_TYPES
end
"""
entry_device(machine::Machine)
Return the "entry" device, i.e., the device that starts CPU threads and GPU kernels, and takes input values and returns the output value.
"""
function entry_device(machine::Machine)
return machine.devices[1]
end
"""
strategies(t::Type{T}) where {T <: AbstractDevice}

View File

@ -1,4 +1,3 @@
"""
AbstractDevice

View File

@ -13,6 +13,6 @@ Interface functions that must be implemented for implementations of [`Scheduler`
The function assigns each [`ComputeTaskNode`](@ref) of the [`DAG`](@ref) to one of the devices in the given [`Machine`](@ref) and returns a `Vector{Node}` representing a topological ordering.
[`DataTaskNode`](@ref)s are not scheduled to devices since they do not compute. Instead, a data node transfers data from the [`Device`](@ref) of their child to all [`Device`](@ref)s of its parents.
[`DataTaskNode`](@ref)s are not scheduled to devices since they do not compute. Instead, a data node transfers data from the [`AbstractDevice`](@ref) of their child to all [`AbstractDevice`](@ref)s of its parents.
"""
function schedule_dag end

View File

@ -15,13 +15,6 @@ Generate code evaluating a [`FusedComputeTask`](@ref) on `inExprs`, providing th
`inExprs` should be of the correct types and may be heterogeneous. `outExpr` will be of the type of the output of `T2` of t.
"""
function get_expression(t::FusedComputeTask, device::AbstractDevice, inExprs::Vector, outExpr)
c1 = length(t.t1_inputs)
c2 = length(t.t2_inputs) + 1
expr1 = nothing
expr2 = nothing
cacheStrategy = cache_strategy(device)
inExprs1 = Vector()
for sym in t.t1_inputs
push!(inExprs1, gen_access_expr(device, sym))
@ -43,46 +36,53 @@ function get_expression(t::FusedComputeTask, device::AbstractDevice, inExprs::Ve
end
"""
get_expression(node::ComputeTaskNode, device::AbstractDevice)
get_expression(node::ComputeTaskNode)
Generate and return code for a given [`ComputeTaskNode`](@ref).
"""
function get_expression(node::ComputeTaskNode, device::AbstractDevice)
t = typeof(node.task)
function get_expression(node::ComputeTaskNode)
@assert length(node.children) <= children(node.task) "Node $(node) has too many children for its task: node has $(length(node.children)) versus task has $(children(node.task))\nNode's children: $(getfield.(node.children, :children))"
# TODO get device from the node
cacheStrategy = cache_strategy(device)
@assert !ismissing(node.device) "Trying to get expression for an unscheduled ComputeTaskNode\nNode: $(node)"
inExprs = Vector()
for id in getfield.(node.children, :id)
push!(inExprs, gen_access_expr(device, Symbol(to_var_name(id))))
push!(inExprs, gen_access_expr(node.device, Symbol(to_var_name(id))))
end
outExpr = gen_access_expr(device, Symbol(to_var_name(node.id)))
outExpr = gen_access_expr(node.device, Symbol(to_var_name(node.id)))
return get_expression(node.task, device, inExprs, outExpr)
return get_expression(node.task, node.device, inExprs, outExpr)
end
"""
get_expression(node::DataTaskNode, device::AbstractDevice)
get_expression(node::DataTaskNode)
Generate and return code for a given [`DataTaskNode`](@ref).
"""
function get_expression(node::DataTaskNode, device::AbstractDevice)
@assert length(node.children) <= 1
function get_expression(node::DataTaskNode)
@assert length(node.children) == 1 "Trying to call get_expression on a data task node that has $(length(node.children)) children instead of 1"
# TODO: do things to transport data from/to gpu, between numa nodes, etc.
# TODO get device from the node
cacheStrategy = cache_strategy(device)
inExpr = nothing
if (length(node.children) == 1)
inExpr = eval(gen_access_expr(device, Symbol(to_var_name(node.children[1].id))))
else
inExpr = eval(gen_access_expr(device, Symbol("$(to_var_name(node.id))_in")))
end
outExpr = eval(gen_access_expr(device, Symbol(to_var_name(node.id))))
# TODO: dispatch to device implementations generating the copy commands
child = node.children[1]
inExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(child.id))))
outExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(node.id))))
dataTransportExp = Meta.parse("$outExpr = $inExpr")
return dataTransportExp
end
"""
get_init_expression(node::DataTaskNode, device::AbstractDevice)
Generate and return code for the initial input reading expression for [`DataTaskNode`](@ref)s with 0 children, i.e., entry nodes.
See also: [`get_entry_nodes`](@ref)
"""
function get_init_expression(node::DataTaskNode, device::AbstractDevice)
@assert isempty(node.children) "Trying to call get_init_expression on a data task node that is not an entry node."
inExpr = eval(gen_access_expr(device, Symbol("$(to_var_name(node.id))_in")))
outExpr = eval(gen_access_expr(device, Symbol(to_var_name(node.id))))
dataTransportExp = Meta.parse("$outExpr = $inExpr")
return dataTransportExp