Add scheduling, machine info, caching strategies and devices (#9)
Some checks failed
MetagraphOptimization_CI / prepare (push) Has been cancelled
MetagraphOptimization_CI / test (push) Has been cancelled
MetagraphOptimization_CI / docs (push) Has been cancelled

Reviewed-on: Rubydragon/MetagraphOptimization.jl#9
Co-authored-by: Anton Reinhard <anton.reinhard@proton.me>
Co-committed-by: Anton Reinhard <anton.reinhard@proton.me>
This commit is contained in:
2023-10-12 17:51:03 +02:00
committed by Anton Reinhard
parent bd6c54c1ae
commit 5a30f57e1f
72 changed files with 3397 additions and 987 deletions

View File

@ -1,44 +1,20 @@
DataTaskNode(t::AbstractDataTask, name = "") = DataTaskNode(
t,
Vector{Node}(),
Vector{Node}(),
UUIDs.uuid1(rng[threadid()]),
missing,
missing,
missing,
name,
)
DataTaskNode(t::AbstractDataTask, name = "") =
DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, missing, name)
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
t,
Vector{Node}(),
Vector{Node}(),
UUIDs.uuid1(rng[threadid()]),
missing,
missing,
Vector{NodeFusion}(),
t, # task
Vector{Node}(), # parents
Vector{Node}(), # children
UUIDs.uuid1(rng[threadid()]), # id
missing, # node reduction
missing, # node split
Vector{NodeFusion}(), # node fusions
missing, # device
)
copy(m::Missing) = missing
copy(n::ComputeTaskNode) = ComputeTaskNode(
copy(n.task),
copy(n.parents),
copy(n.children),
UUIDs.uuid1(rng[threadid()]),
copy(n.nodeReduction),
copy(n.nodeSplit),
copy(n.nodeFusions),
)
copy(n::DataTaskNode) = DataTaskNode(
copy(n.task),
copy(n.parents),
copy(n.children),
UUIDs.uuid1(rng[threadid()]),
copy(n.nodeReduction),
copy(n.nodeSplit),
copy(n.nodeFusion),
n.name,
)
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task))
copy(n::DataTaskNode) = DataTaskNode(copy(n.task), n.name)
"""
make_node(t::AbstractTask)

View File

@ -22,5 +22,6 @@ end
Return the uuid as a string usable as a variable name in code generation.
"""
function to_var_name(id::UUID)
return replace(string(id), "-" => "_")
str = "_" * replace(string(id), "-" => "_")
return str
end

View File

@ -24,13 +24,14 @@ abstract type Operation end
Any node that transfers data and does no computation.
# Fields
`.task`: The node's data task type. Usually [`DataTask`](@ref).\\
`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
`.id`: The node's id. Improves the speed of comparisons.\\
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.
`.task`: The node's data task type. Usually [`DataTask`](@ref).\\
`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.\\
`.name`: The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\
"""
mutable struct DataTaskNode <: Node
task::AbstractDataTask
@ -60,16 +61,17 @@ end
"""
ComputeTaskNode <: Node
Any node that transfers data and does no computation.
Any node that computes a result from inputs using an [`AbstractComputeTask`](@ref).
# Fields
`.task`: The node's data task type. Usually [`DataTask`](@ref).\\
`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
`.id`: The node's id. Improves the speed of comparisons.\\
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeFusion`: A vector of this node's [`NodeFusion`](@ref)s. For a ComputeTaskNode there can be any number of these, unlike the DataTaskNodes.
`.task`: The node's compute task type. A concrete subtype of [`AbstractComputeTask`](@ref).\\
`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeFusions`: A vector of this node's [`NodeFusion`](@ref)s. For a `ComputeTaskNode` there can be any number of these, unlike the [`DataTaskNode`](@ref)s.\\
`.device`: The Device this node has been scheduled on by a [`Scheduler`](@ref).
"""
mutable struct ComputeTaskNode <: Node
task::AbstractComputeTask
@ -82,6 +84,9 @@ mutable struct ComputeTaskNode <: Node
# for ComputeTasks there can be multiple fusions, unlike the DataTasks
nodeFusions::Vector{Operation}
# the device this node is assigned to execute on
device::Union{AbstractDevice, Missing}
end
"""
@ -95,8 +100,5 @@ The child is the prerequisite node of the parent.
"""
struct Edge
# edge points from child to parent
edge::Union{
Tuple{DataTaskNode, ComputeTaskNode},
Tuple{ComputeTaskNode, DataTaskNode},
}
edge::Union{Tuple{DataTaskNode, ComputeTaskNode}, Tuple{ComputeTaskNode, DataTaskNode}}
end

View File

@ -22,12 +22,24 @@ function is_valid_node(graph::DAG, node::Node)
@assert node in child.parents "Node is not a parent of its child!"
end
if !ismissing(node.nodeReduction)
#=if !ismissing(node.nodeReduction)
@assert is_valid(graph, node.nodeReduction)
end
if !ismissing(node.nodeSplit)
@assert is_valid(graph, node.nodeSplit)
end=#
if !(typeof(node.task) <: FusedComputeTask)
# the remaining checks are only necessary for fused compute tasks
return true
end
# every child must be in some input of the task
for child in node.children
str = Symbol(to_var_name(child.id))
@assert (str in node.task.t1_inputs) || (str in node.task.t2_inputs) "$str was not in any of the tasks' inputs\nt1_inputs: $(node.task.t1_inputs)\nt2_inputs: $(node.task.t2_inputs)"
end
return true
end
@ -41,9 +53,9 @@ This also calls [`is_valid_node(graph::DAG, node::Node)`](@ref).
function is_valid(graph::DAG, node::ComputeTaskNode)
@assert is_valid_node(graph, node)
for nf in node.nodeFusions
#=for nf in node.nodeFusions
@assert is_valid(graph, nf)
end
end=#
return true
end
@ -57,8 +69,8 @@ This also calls [`is_valid_node(graph::DAG, node::Node)`](@ref).
function is_valid(graph::DAG, node::DataTaskNode)
@assert is_valid_node(graph, node)
if !ismissing(node.nodeFusion)
#=if !ismissing(node.nodeFusion)
@assert is_valid(graph, node.nodeFusion)
end
end=#
return true
end