Optimizer interface and sample implementation (#19)
Reviewed-on: Rubydragon/MetagraphOptimization.jl#19 Co-authored-by: Anton Reinhard <anton.reinhard@proton.me> Co-committed-by: Anton Reinhard <anton.reinhard@proton.me>
This commit is contained in:
@@ -21,7 +21,7 @@ end
|
||||
|
||||
Equality comparison between two [`ComputeTaskNode`](@ref)s.
|
||||
"""
|
||||
function ==(n1::ComputeTaskNode, n2::ComputeTaskNode)
|
||||
function ==(n1::ComputeTaskNode{TaskType}, n2::ComputeTaskNode{TaskType}) where {TaskType <: AbstractComputeTask}
|
||||
return n1.id == n2.id
|
||||
end
|
||||
|
||||
@@ -30,6 +30,6 @@ end
|
||||
|
||||
Equality comparison between two [`DataTaskNode`](@ref)s.
|
||||
"""
|
||||
function ==(n1::DataTaskNode, n2::DataTaskNode)
|
||||
function ==(n1::DataTaskNode{TaskType}, n2::DataTaskNode{TaskType}) where {TaskType <: AbstractDataTask}
|
||||
return n1.id == n2.id
|
||||
end
|
||||
|
@@ -13,8 +13,8 @@ ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
|
||||
)
|
||||
|
||||
copy(m::Missing) = missing
|
||||
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task))
|
||||
copy(n::DataTaskNode) = DataTaskNode(copy(n.task), n.name)
|
||||
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(task(n)))
|
||||
copy(n::DataTaskNode) = DataTaskNode(copy(task(n)), n.name)
|
||||
|
||||
"""
|
||||
make_node(t::AbstractTask)
|
||||
|
@@ -4,7 +4,7 @@
|
||||
Print a short string representation of the node to io.
|
||||
"""
|
||||
function show(io::IO, n::Node)
|
||||
return print(io, "Node(", n.task, ")")
|
||||
return print(io, "Node(", task(n), ")")
|
||||
end
|
||||
|
||||
"""
|
||||
|
@@ -3,25 +3,27 @@
|
||||
|
||||
Return whether this node is an entry node in its graph, i.e., it has no children.
|
||||
"""
|
||||
is_entry_node(node::Node) = length(node.children) == 0
|
||||
is_entry_node(node::Node) = length(children(node)) == 0
|
||||
|
||||
"""
|
||||
is_exit_node(node::Node)
|
||||
|
||||
Return whether this node is an exit node of its graph, i.e., it has no parents.
|
||||
"""
|
||||
is_exit_node(node::Node) = length(node.parents) == 0
|
||||
is_exit_node(node::Node)::Bool = length(parents(node)) == 0
|
||||
|
||||
"""
|
||||
data(edge::Edge)
|
||||
task(node::Node)
|
||||
|
||||
Return the data transfered by this edge, i.e., 0 if the child is a [`ComputeTaskNode`](@ref), otherwise the child's `data()`.
|
||||
Return the node's task.
|
||||
"""
|
||||
function data(edge::Edge)
|
||||
if typeof(edge.edge[1]) <: DataTaskNode
|
||||
return data(edge.edge[1].task)
|
||||
end
|
||||
return 0.0
|
||||
function task(node::DataTaskNode{TaskType})::TaskType where {TaskType <: Union{AbstractDataTask, AbstractComputeTask}}
|
||||
return node.task
|
||||
end
|
||||
function task(
|
||||
node::ComputeTaskNode{TaskType},
|
||||
)::TaskType where {TaskType <: Union{AbstractDataTask, AbstractComputeTask}}
|
||||
return node.task
|
||||
end
|
||||
|
||||
"""
|
||||
@@ -31,8 +33,11 @@ Return a copy of the node's children so it can safely be muted without changing
|
||||
|
||||
A node's children are its prerequisite nodes, nodes that need to execute before the task of this node.
|
||||
"""
|
||||
function children(node::Node)
|
||||
return copy(node.children)
|
||||
function children(node::DataTaskNode)::Vector{ComputeTaskNode}
|
||||
return node.children
|
||||
end
|
||||
function children(node::ComputeTaskNode)::Vector{DataTaskNode}
|
||||
return node.children
|
||||
end
|
||||
|
||||
"""
|
||||
@@ -42,8 +47,11 @@ Return a copy of the node's parents so it can safely be muted without changing t
|
||||
|
||||
A node's parents are its subsequent nodes, nodes that need this node to execute.
|
||||
"""
|
||||
function parents(node::Node)
|
||||
return copy(node.parents)
|
||||
function parents(node::DataTaskNode)::Vector{ComputeTaskNode}
|
||||
return node.parents
|
||||
end
|
||||
function parents(node::ComputeTaskNode)::Vector{DataTaskNode}
|
||||
return node.parents
|
||||
end
|
||||
|
||||
"""
|
||||
@@ -53,11 +61,11 @@ Return a vector of all siblings of this node.
|
||||
|
||||
A node's siblings are all children of any of its parents. The result contains no duplicates and includes the node itself.
|
||||
"""
|
||||
function siblings(node::Node)
|
||||
function siblings(node::Node)::Set{Node}
|
||||
result = Set{Node}()
|
||||
push!(result, node)
|
||||
for parent in node.parents
|
||||
union!(result, parent.children)
|
||||
for parent in parents(node)
|
||||
union!(result, children(parent))
|
||||
end
|
||||
|
||||
return result
|
||||
@@ -73,11 +81,11 @@ A node's partners are all parents of any of its children. The result contains no
|
||||
Note: This is very slow when there are multiple children with many parents.
|
||||
This is less of a problem in [`siblings(node::Node)`](@ref) because (depending on the model) there are no nodes with a large number of children, or only a single one.
|
||||
"""
|
||||
function partners(node::Node)
|
||||
function partners(node::Node)::Set{Node}
|
||||
result = Set{Node}()
|
||||
push!(result, node)
|
||||
for child in node.children
|
||||
union!(result, child.parents)
|
||||
for child in children(node)
|
||||
union!(result, parents(child))
|
||||
end
|
||||
|
||||
return result
|
||||
@@ -90,8 +98,8 @@ Alternative version to [`partners(node::Node)`](@ref), avoiding allocation of a
|
||||
"""
|
||||
function partners(node::Node, set::Set{Node})
|
||||
push!(set, node)
|
||||
for child in node.children
|
||||
union!(set, child.parents)
|
||||
for child in children(node)
|
||||
union!(set, parents(child))
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
@@ -101,8 +109,8 @@ end
|
||||
|
||||
Return whether the `potential_parent` is a parent of `node`.
|
||||
"""
|
||||
function is_parent(potential_parent::Node, node::Node)
|
||||
return potential_parent in node.parents
|
||||
function is_parent(potential_parent::Node, node::Node)::Bool
|
||||
return potential_parent in parents(node)
|
||||
end
|
||||
|
||||
"""
|
||||
@@ -110,6 +118,6 @@ end
|
||||
|
||||
Return whether the `potential_child` is a child of `node`.
|
||||
"""
|
||||
function is_child(potential_child::Node, node::Node)
|
||||
return potential_child in node.children
|
||||
function is_child(potential_child::Node, node::Node)::Bool
|
||||
return potential_child in children(node)
|
||||
end
|
||||
|
@@ -33,8 +33,8 @@ Any node that transfers data and does no computation.
|
||||
`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.\\
|
||||
`.name`: The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\
|
||||
"""
|
||||
mutable struct DataTaskNode <: Node
|
||||
task::AbstractDataTask
|
||||
mutable struct DataTaskNode{TaskType <: AbstractDataTask} <: Node
|
||||
task::TaskType
|
||||
|
||||
# use vectors as sets have way too much memory overhead
|
||||
parents::Vector{Node}
|
||||
@@ -73,8 +73,8 @@ Any node that computes a result from inputs using an [`AbstractComputeTask`](@re
|
||||
`.nodeFusions`: A vector of this node's [`NodeFusion`](@ref)s. For a `ComputeTaskNode` there can be any number of these, unlike the [`DataTaskNode`](@ref)s.\\
|
||||
`.device`: The Device this node has been scheduled on by a [`Scheduler`](@ref).
|
||||
"""
|
||||
mutable struct ComputeTaskNode <: Node
|
||||
task::AbstractComputeTask
|
||||
mutable struct ComputeTaskNode{TaskType <: AbstractComputeTask} <: Node
|
||||
task::TaskType
|
||||
parents::Vector{Node}
|
||||
children::Vector{Node}
|
||||
id::Base.UUID
|
||||
@@ -83,7 +83,7 @@ mutable struct ComputeTaskNode <: Node
|
||||
nodeSplit::Union{Operation, Missing}
|
||||
|
||||
# for ComputeTasks there can be multiple fusions, unlike the DataTasks
|
||||
nodeFusions::Vector{Operation}
|
||||
nodeFusions::Vector{<:Operation}
|
||||
|
||||
# the device this node is assigned to execute on
|
||||
device::Union{AbstractDevice, Missing}
|
||||
|
@@ -29,7 +29,7 @@ function is_valid_node(graph::DAG, node::Node)
|
||||
@assert is_valid(graph, node.nodeSplit)
|
||||
end=#
|
||||
|
||||
if !(typeof(node.task) <: FusedComputeTask)
|
||||
if !(typeof(task(node)) <: FusedComputeTask)
|
||||
# the remaining checks are only necessary for fused compute tasks
|
||||
return true
|
||||
end
|
||||
@@ -37,7 +37,7 @@ function is_valid_node(graph::DAG, node::Node)
|
||||
# every child must be in some input of the task
|
||||
for child in node.children
|
||||
str = Symbol(to_var_name(child.id))
|
||||
@assert (str in node.task.t1_inputs) || (str in node.task.t2_inputs) "$str was not in any of the tasks' inputs\nt1_inputs: $(node.task.t1_inputs)\nt2_inputs: $(node.task.t2_inputs)"
|
||||
@assert (str in task(node).t1_inputs) || (str in task(node).t2_inputs) "$str was not in any of the tasks' inputs\nt1_inputs: $(task(node).t1_inputs)\nt2_inputs: $(task(node).t2_inputs)"
|
||||
end
|
||||
|
||||
return true
|
||||
|
Reference in New Issue
Block a user