Use the scheduling information in the execution
This commit is contained in:
parent
9b28601f18
commit
4dcb616606
@ -1,4 +1,3 @@
|
||||
|
||||
"""
|
||||
gen_code(graph::DAG)
|
||||
|
||||
@ -17,7 +16,12 @@ function gen_code(graph::DAG, machine::Machine)
|
||||
sizehint!(codeAcc, length(graph.nodes))
|
||||
|
||||
for node in sched
|
||||
push!(codeAcc, get_expression(node, machine.devices[1]))
|
||||
# TODO: this is kind of ugly, should init nodes be scheduled differently from the rest?
|
||||
if (node isa DataTaskNode && length(node.children) == 0)
|
||||
push!(codeAcc, get_init_expression(node, entry_device(machine)))
|
||||
continue
|
||||
end
|
||||
push!(codeAcc, get_expression(node))
|
||||
end
|
||||
|
||||
# get inSymbols
|
||||
@ -72,9 +76,8 @@ function gen_input_assignment_code(
|
||||
end
|
||||
|
||||
for symbol in symbols
|
||||
# TODO generate correct access expression
|
||||
# TODO how to define cahce strategies?
|
||||
device = machine.devices[1]
|
||||
# TODO: how to get the "default" cpu device?
|
||||
device = entry_device(machine)
|
||||
evalExpr = eval(gen_access_expr(device, symbol))
|
||||
push!(assignInputs, Meta.parse("$(evalExpr) = ParticleValue($p, 1.0)"))
|
||||
end
|
||||
@ -94,12 +97,8 @@ function get_compute_function(graph::DAG, process::AbstractProcessDescription, m
|
||||
initCaches = gen_cache_init_code(machine)
|
||||
assignInputs = gen_input_assignment_code(inputSymbols, process, machine, :input)
|
||||
|
||||
# TODO generate correct access expression
|
||||
# TODO how to define cache strategies?
|
||||
device = machine.devices[1]
|
||||
|
||||
functionId = to_var_name(UUIDs.uuid1(rng[1]))
|
||||
resSym = eval(gen_access_expr(device, outputSymbol))
|
||||
resSym = eval(gen_access_expr(entry_device(machine), outputSymbol))
|
||||
expr = Meta.parse(
|
||||
"function compute_$(functionId)(input::AbstractProcessInput) $initCaches; $assignInputs; $code; return $resSym; end",
|
||||
)
|
||||
@ -127,12 +126,9 @@ function execute(graph::DAG, process::AbstractProcessDescription, machine::Machi
|
||||
initCaches = gen_cache_init_code(machine)
|
||||
assignInputs = gen_input_assignment_code(inputSymbols, process, machine, :input)
|
||||
|
||||
# TODO generate correct access expression
|
||||
# TODO how to define cache strategies?
|
||||
device = machine.devices[1]
|
||||
|
||||
functionId = to_var_name(UUIDs.uuid1(rng[1]))
|
||||
resSym = eval(gen_access_expr(device, outputSymbol))
|
||||
resSym = eval(gen_access_expr(entry_device(machine), outputSymbol))
|
||||
expr = Meta.parse(
|
||||
"function compute_$(functionId)(input::AbstractProcessInput) $initCaches; $assignInputs; $code; return $resSym; end",
|
||||
)
|
||||
|
@ -9,6 +9,15 @@ function device_types()
|
||||
return DEVICE_TYPES
|
||||
end
|
||||
|
||||
"""
|
||||
entry_device(machine::Machine)
|
||||
|
||||
Return the "entry" device, i.e., the device that starts CPU threads and GPU kernels, and takes input values and returns the output value.
|
||||
"""
|
||||
function entry_device(machine::Machine)
|
||||
return machine.devices[1]
|
||||
end
|
||||
|
||||
"""
|
||||
strategies(t::Type{T}) where {T <: AbstractDevice}
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
|
||||
"""
|
||||
AbstractDevice
|
||||
|
||||
|
@ -13,6 +13,6 @@ Interface functions that must be implemented for implementations of [`Scheduler`
|
||||
|
||||
The function assigns each [`ComputeTaskNode`](@ref) of the [`DAG`](@ref) to one of the devices in the given [`Machine`](@ref) and returns a `Vector{Node}` representing a topological ordering.
|
||||
|
||||
[`DataTaskNode`](@ref)s are not scheduled to devices since they do not compute. Instead, a data node transfers data from the [`Device`](@ref) of their child to all [`Device`](@ref)s of its parents.
|
||||
[`DataTaskNode`](@ref)s are not scheduled to devices since they do not compute. Instead, a data node transfers data from the [`AbstractDevice`](@ref) of their child to all [`AbstractDevice`](@ref)s of its parents.
|
||||
"""
|
||||
function schedule_dag end
|
||||
|
@ -15,13 +15,6 @@ Generate code evaluating a [`FusedComputeTask`](@ref) on `inExprs`, providing th
|
||||
`inExprs` should be of the correct types and may be heterogeneous. `outExpr` will be of the type of the output of `T2` of t.
|
||||
"""
|
||||
function get_expression(t::FusedComputeTask, device::AbstractDevice, inExprs::Vector, outExpr)
|
||||
c1 = length(t.t1_inputs)
|
||||
c2 = length(t.t2_inputs) + 1
|
||||
expr1 = nothing
|
||||
expr2 = nothing
|
||||
|
||||
cacheStrategy = cache_strategy(device)
|
||||
|
||||
inExprs1 = Vector()
|
||||
for sym in t.t1_inputs
|
||||
push!(inExprs1, gen_access_expr(device, sym))
|
||||
@ -43,46 +36,53 @@ function get_expression(t::FusedComputeTask, device::AbstractDevice, inExprs::Ve
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(node::ComputeTaskNode, device::AbstractDevice)
|
||||
get_expression(node::ComputeTaskNode)
|
||||
|
||||
Generate and return code for a given [`ComputeTaskNode`](@ref).
|
||||
"""
|
||||
function get_expression(node::ComputeTaskNode, device::AbstractDevice)
|
||||
t = typeof(node.task)
|
||||
function get_expression(node::ComputeTaskNode)
|
||||
@assert length(node.children) <= children(node.task) "Node $(node) has too many children for its task: node has $(length(node.children)) versus task has $(children(node.task))\nNode's children: $(getfield.(node.children, :children))"
|
||||
|
||||
# TODO get device from the node
|
||||
cacheStrategy = cache_strategy(device)
|
||||
@assert !ismissing(node.device) "Trying to get expression for an unscheduled ComputeTaskNode\nNode: $(node)"
|
||||
|
||||
inExprs = Vector()
|
||||
for id in getfield.(node.children, :id)
|
||||
push!(inExprs, gen_access_expr(device, Symbol(to_var_name(id))))
|
||||
push!(inExprs, gen_access_expr(node.device, Symbol(to_var_name(id))))
|
||||
end
|
||||
outExpr = gen_access_expr(device, Symbol(to_var_name(node.id)))
|
||||
outExpr = gen_access_expr(node.device, Symbol(to_var_name(node.id)))
|
||||
|
||||
return get_expression(node.task, device, inExprs, outExpr)
|
||||
return get_expression(node.task, node.device, inExprs, outExpr)
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(node::DataTaskNode, device::AbstractDevice)
|
||||
get_expression(node::DataTaskNode)
|
||||
|
||||
Generate and return code for a given [`DataTaskNode`](@ref).
|
||||
"""
|
||||
function get_expression(node::DataTaskNode, device::AbstractDevice)
|
||||
@assert length(node.children) <= 1
|
||||
function get_expression(node::DataTaskNode)
|
||||
@assert length(node.children) == 1 "Trying to call get_expression on a data task node that has $(length(node.children)) children instead of 1"
|
||||
|
||||
# TODO: do things to transport data from/to gpu, between numa nodes, etc.
|
||||
# TODO get device from the node
|
||||
|
||||
cacheStrategy = cache_strategy(device)
|
||||
inExpr = nothing
|
||||
if (length(node.children) == 1)
|
||||
inExpr = eval(gen_access_expr(device, Symbol(to_var_name(node.children[1].id))))
|
||||
else
|
||||
inExpr = eval(gen_access_expr(device, Symbol("$(to_var_name(node.id))_in")))
|
||||
end
|
||||
outExpr = eval(gen_access_expr(device, Symbol(to_var_name(node.id))))
|
||||
# TODO: dispatch to device implementations generating the copy commands
|
||||
|
||||
child = node.children[1]
|
||||
inExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(child.id))))
|
||||
outExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(node.id))))
|
||||
dataTransportExp = Meta.parse("$outExpr = $inExpr")
|
||||
|
||||
return dataTransportExp
|
||||
end
|
||||
|
||||
"""
|
||||
get_init_expression(node::DataTaskNode, device::AbstractDevice)
|
||||
|
||||
Generate and return code for the initial input reading expression for [`DataTaskNode`](@ref)s with 0 children, i.e., entry nodes.
|
||||
|
||||
See also: [`get_entry_nodes`](@ref)
|
||||
"""
|
||||
function get_init_expression(node::DataTaskNode, device::AbstractDevice)
|
||||
@assert isempty(node.children) "Trying to call get_init_expression on a data task node that is not an entry node."
|
||||
|
||||
inExpr = eval(gen_access_expr(device, Symbol("$(to_var_name(node.id))_in")))
|
||||
outExpr = eval(gen_access_expr(device, Symbol(to_var_name(node.id))))
|
||||
dataTransportExp = Meta.parse("$outExpr = $inExpr")
|
||||
|
||||
return dataTransportExp
|
||||
|
Loading…
x
Reference in New Issue
Block a user