Add scheduler interface
This commit is contained in:
parent
a86901e425
commit
140a954d01
15
docs/src/lib/internals/scheduler.md
Normal file
15
docs/src/lib/internals/scheduler.md
Normal file
@ -0,0 +1,15 @@
|
||||
# Scheduler
|
||||
|
||||
## Interface
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["scheduler/interface.jl"]
|
||||
Order = [:type, :function]
|
||||
```
|
||||
|
||||
## Greedy
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["scheduler/greedy.jl"]
|
||||
Order = [:type, :function]
|
||||
```
|
@ -77,13 +77,13 @@ import Base.insert!
|
||||
import Base.collect
|
||||
|
||||
|
||||
include("devices/interface.jl")
|
||||
include("task/type.jl")
|
||||
include("node/type.jl")
|
||||
include("diff/type.jl")
|
||||
include("properties/type.jl")
|
||||
include("operation/type.jl")
|
||||
include("graph/type.jl")
|
||||
include("devices/interface.jl")
|
||||
|
||||
include("trie.jl")
|
||||
include("utility.jl")
|
||||
@ -143,6 +143,9 @@ include("devices/cuda/impl.jl")
|
||||
# oneapi seems also broken for now
|
||||
# include("devices/oneapi/impl.jl")
|
||||
|
||||
include("scheduler/interface.jl")
|
||||
include("scheduler/greedy.jl")
|
||||
|
||||
include("code_gen/main.jl")
|
||||
|
||||
end # module MetagraphOptimization
|
||||
|
@ -1,4 +1,3 @@
|
||||
using DataStructures
|
||||
|
||||
"""
|
||||
gen_code(graph::DAG)
|
||||
@ -12,12 +11,18 @@ Generate the code for a given graph. The return value is a named tuple of:
|
||||
See also: [`execute`](@ref)
|
||||
"""
|
||||
function gen_code(graph::DAG, machine::Machine)
|
||||
nodeQueue = PriorityQueue{Node, Int}()
|
||||
inputSyms = Dict{String, Vector{Symbol}}()
|
||||
sched = schedule_dag(GreedyScheduler(), graph, machine)
|
||||
|
||||
# use a priority equal to the number of unseen children -> 0 are nodes that can be added
|
||||
codeAcc = Vector{Expr}()
|
||||
sizehint!(codeAcc, length(graph.nodes))
|
||||
|
||||
for node in sched
|
||||
push!(codeAcc, get_expression(node, machine.devices[1]))
|
||||
end
|
||||
|
||||
# get inSymbols
|
||||
inputSyms = Dict{String, Vector{Symbol}}()
|
||||
for node in get_entry_nodes(graph)
|
||||
enqueue!(nodeQueue, node => 0)
|
||||
if !haskey(inputSyms, node.name)
|
||||
inputSyms[node.name] = Vector{Symbol}()
|
||||
end
|
||||
@ -25,35 +30,8 @@ function gen_code(graph::DAG, machine::Machine)
|
||||
push!(inputSyms[node.name], Symbol("$(to_var_name(node.id))_in"))
|
||||
end
|
||||
|
||||
schedule = Vector{Node}()
|
||||
sizehint!(schedule, length(graph.nodes))
|
||||
|
||||
# "scheduling"
|
||||
node = nothing
|
||||
while !isempty(nodeQueue)
|
||||
@assert peek(nodeQueue)[2] == 0
|
||||
node = dequeue!(nodeQueue)
|
||||
|
||||
push!(schedule, node)
|
||||
for parent in node.parents
|
||||
# reduce the priority of all parents by one
|
||||
if (!haskey(nodeQueue, parent))
|
||||
enqueue!(nodeQueue, parent => length(parent.children) - 1)
|
||||
else
|
||||
nodeQueue[parent] = nodeQueue[parent] - 1
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
codeAcc = Vector{Expr}()
|
||||
sizehint!(codeAcc, length(graph.nodes))
|
||||
|
||||
for node in schedule
|
||||
push!(codeAcc, get_expression(node, machine.devices[1]))
|
||||
end
|
||||
|
||||
# node is now the last node we looked at -> the output node
|
||||
outSym = Symbol(to_var_name(node.id))
|
||||
# get outSymbol
|
||||
outSym = Symbol(to_var_name(get_exit_node(graph).id))
|
||||
|
||||
return (code = Expr(:block, codeAcc...), inputSymbols = inputSyms, outputSymbol = outSym)
|
||||
end
|
||||
@ -166,7 +144,16 @@ function execute(graph::DAG, process::AbstractProcessDescription, machine::Machi
|
||||
catch e
|
||||
println("Error while evaluating: $e")
|
||||
|
||||
println("Function:\n$expr")
|
||||
# if we find a uuid in the exception we can color it in so it's easier to spot
|
||||
uuidRegex = r"[0-9a-f]{8}_[0-9a-f]{4}_[0-9a-f]{4}_[0-9a-f]{4}_[0-9a-f]{12}"
|
||||
m = match(uuidRegex, string(e))
|
||||
|
||||
functionStr = string(expr)
|
||||
if (isa(m, RegexMatch))
|
||||
functionStr = replace(functionStr, m.match => "\033[31m$(m.match)\033[0m")
|
||||
end
|
||||
|
||||
println("Function:\n$functionStr")
|
||||
@assert false
|
||||
end
|
||||
|
||||
|
@ -2,13 +2,14 @@
|
||||
DataTaskNode(t::AbstractDataTask, name = "") =
|
||||
DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, missing, name)
|
||||
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
|
||||
t,
|
||||
Vector{Node}(),
|
||||
Vector{Node}(),
|
||||
UUIDs.uuid1(rng[threadid()]),
|
||||
missing,
|
||||
missing,
|
||||
Vector{NodeFusion}(),
|
||||
t, # task
|
||||
Vector{Node}(), # parents
|
||||
Vector{Node}(), # children
|
||||
UUIDs.uuid1(rng[threadid()]), # id
|
||||
missing, # node reduction
|
||||
missing, # node split
|
||||
Vector{NodeFusion}(), # node fusions
|
||||
missing, # device
|
||||
)
|
||||
|
||||
copy(m::Missing) = missing
|
||||
|
@ -24,13 +24,14 @@ abstract type Operation end
|
||||
Any node that transfers data and does no computation.
|
||||
|
||||
# Fields
|
||||
`.task`: The node's data task type. Usually [`DataTask`](@ref).\\
|
||||
`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
|
||||
`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
|
||||
`.id`: The node's id. Improves the speed of comparisons.\\
|
||||
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.
|
||||
`.task`: The node's data task type. Usually [`DataTask`](@ref).\\
|
||||
`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
|
||||
`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
|
||||
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
|
||||
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.\\
|
||||
`.name`: The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\
|
||||
"""
|
||||
mutable struct DataTaskNode <: Node
|
||||
task::AbstractDataTask
|
||||
@ -60,16 +61,17 @@ end
|
||||
"""
|
||||
ComputeTaskNode <: Node
|
||||
|
||||
Any node that transfers data and does no computation.
|
||||
Any node that computes a result from inputs using an [`AbstractComputeTask`](@ref).
|
||||
|
||||
# Fields
|
||||
`.task`: The node's data task type. Usually [`DataTask`](@ref).\\
|
||||
`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
|
||||
`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
|
||||
`.id`: The node's id. Improves the speed of comparisons.\\
|
||||
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeFusion`: A vector of this node's [`NodeFusion`](@ref)s. For a ComputeTaskNode there can be any number of these, unlike the DataTaskNodes.
|
||||
`.task`: The node's compute task type. A concrete subtype of [`AbstractComputeTask`](@ref).\\
|
||||
`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
|
||||
`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
|
||||
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
|
||||
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeFusions`: A vector of this node's [`NodeFusion`](@ref)s. For a `ComputeTaskNode` there can be any number of these, unlike the [`DataTaskNode`](@ref)s.\\
|
||||
`.device`: The Device this node has been scheduled on by a [`Scheduler`](@ref).
|
||||
"""
|
||||
mutable struct ComputeTaskNode <: Node
|
||||
task::AbstractComputeTask
|
||||
@ -82,6 +84,9 @@ mutable struct ComputeTaskNode <: Node
|
||||
|
||||
# for ComputeTasks there can be multiple fusions, unlike the DataTasks
|
||||
nodeFusions::Vector{Operation}
|
||||
|
||||
# the device this node is assigned to execute on
|
||||
device::Union{AbstractDevice, Missing}
|
||||
end
|
||||
|
||||
"""
|
||||
|
43
src/scheduler/greedy.jl
Normal file
43
src/scheduler/greedy.jl
Normal file
@ -0,0 +1,43 @@
|
||||
|
||||
"""
|
||||
GreedyScheduler
|
||||
|
||||
A greedy implementation of a scheduler, creating a topological ordering of nodes and naively balancing them onto the different devices.
|
||||
"""
|
||||
struct GreedyScheduler end
|
||||
|
||||
function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine)
|
||||
nodeQueue = PriorityQueue{Node, Int}()
|
||||
|
||||
# use a priority equal to the number of unseen children -> 0 are nodes that can be added
|
||||
for node in get_entry_nodes(graph)
|
||||
enqueue!(nodeQueue, node => 0)
|
||||
end
|
||||
|
||||
schedule = Vector{Node}()
|
||||
sizehint!(schedule, length(graph.nodes))
|
||||
|
||||
# keep an accumulated cost of things scheduled to this device so far
|
||||
deviceAccCost = Dict{AbstractDevice, Int}()
|
||||
for device in machine.devices
|
||||
deviceAccCost[device] = 0
|
||||
end
|
||||
|
||||
node = nothing
|
||||
while !isempty(nodeQueue)
|
||||
@assert peek(nodeQueue)[2] == 0
|
||||
node = dequeue!(nodeQueue)
|
||||
|
||||
push!(schedule, node)
|
||||
for parent in node.parents
|
||||
# reduce the priority of all parents by one
|
||||
if (!haskey(nodeQueue, parent))
|
||||
enqueue!(nodeQueue, parent => length(parent.children) - 1)
|
||||
else
|
||||
nodeQueue[parent] = nodeQueue[parent] - 1
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return schedule
|
||||
end
|
18
src/scheduler/interface.jl
Normal file
18
src/scheduler/interface.jl
Normal file
@ -0,0 +1,18 @@
|
||||
|
||||
"""
|
||||
Scheduler
|
||||
|
||||
Abstract base type for scheduler implementations. The scheduler is used to assign each node to a device and create a topological ordering of tasks.
|
||||
"""
|
||||
abstract type Scheduler end
|
||||
|
||||
"""
|
||||
schedule_dag(::Scheduler, ::DAG, ::Machine)
|
||||
|
||||
Interface functions that must be implemented for implementations of [`Scheduler`](@ref).
|
||||
|
||||
The function assigns each [`ComputeTaskNode`](@ref) of the [`DAG`](@ref) to one of the devices in the given [`Machine`](@ref) and returns a `Vector{Node}` representing a topological ordering.
|
||||
|
||||
[`DataTaskNode`](@ref)s are not scheduled to devices since they do not compute. Instead, a data node transfers data from the [`Device`](@ref) of their child to all [`Device`](@ref)s of its parents.
|
||||
"""
|
||||
function schedule_dag end
|
Loading…
x
Reference in New Issue
Block a user