Add scheduling, machine info, caching strategies and devices (#9)
Reviewed-on: Rubydragon/MetagraphOptimization.jl#9 Co-authored-by: Anton Reinhard <anton.reinhard@proton.me> Co-committed-by: Anton Reinhard <anton.reinhard@proton.me>
This commit is contained in:
@@ -38,8 +38,7 @@ end
|
||||
|
||||
Return `true` if [`pop_operation!`](@ref) is possible, `false` otherwise.
|
||||
"""
|
||||
can_pop(graph::DAG) =
|
||||
!isempty(graph.operationsToApply) || !isempty(graph.appliedOperations)
|
||||
can_pop(graph::DAG) = !isempty(graph.operationsToApply) || !isempty(graph.appliedOperations)
|
||||
|
||||
"""
|
||||
reset_graph!(graph::DAG)
|
||||
|
@@ -15,12 +15,7 @@ Insert the node into the graph.
|
||||
|
||||
See also: [`remove_node!`](@ref), [`insert_edge!`](@ref), [`remove_edge!`](@ref)
|
||||
"""
|
||||
function insert_node!(
|
||||
graph::DAG,
|
||||
node::Node,
|
||||
track = true,
|
||||
invalidate_cache = true,
|
||||
)
|
||||
function insert_node!(graph::DAG, node::Node; track = true, invalidate_cache = true)
|
||||
# 1: mute
|
||||
push!(graph.nodes, node)
|
||||
|
||||
@@ -50,14 +45,8 @@ Insert the edge between node1 (child) and node2 (parent) into the graph.
|
||||
|
||||
See also: [`insert_node!`](@ref), [`remove_node!`](@ref), [`remove_edge!`](@ref)
|
||||
"""
|
||||
function insert_edge!(
|
||||
graph::DAG,
|
||||
node1::Node,
|
||||
node2::Node,
|
||||
track = true,
|
||||
invalidate_cache = true,
|
||||
)
|
||||
# @assert (node2 ∉ node1.parents) && (node1 ∉ node2.children) "Edge to insert already exists"
|
||||
function insert_edge!(graph::DAG, node1::Node, node2::Node; track = true, invalidate_cache = true)
|
||||
@assert (node2 ∉ node1.parents) && (node1 ∉ node2.children) "Edge to insert already exists"
|
||||
|
||||
# 1: mute
|
||||
# edge points from child to parent
|
||||
@@ -95,13 +84,8 @@ Remove the node from the graph.
|
||||
|
||||
See also: [`insert_node!`](@ref), [`insert_edge!`](@ref), [`remove_edge!`](@ref)
|
||||
"""
|
||||
function remove_node!(
|
||||
graph::DAG,
|
||||
node::Node,
|
||||
track = true,
|
||||
invalidate_cache = true,
|
||||
)
|
||||
# @assert node in graph.nodes "Trying to remove a node that's not in the graph"
|
||||
function remove_node!(graph::DAG, node::Node; track = true, invalidate_cache = true)
|
||||
@assert node in graph.nodes "Trying to remove a node that's not in the graph"
|
||||
|
||||
# 1: mute
|
||||
delete!(graph.nodes, node)
|
||||
@@ -134,13 +118,7 @@ Remove the edge between node1 (child) and node2 (parent) into the graph.
|
||||
|
||||
See also: [`insert_node!`](@ref), [`remove_node!`](@ref), [`insert_edge!`](@ref)
|
||||
"""
|
||||
function remove_edge!(
|
||||
graph::DAG,
|
||||
node1::Node,
|
||||
node2::Node,
|
||||
track = true,
|
||||
invalidate_cache = true,
|
||||
)
|
||||
function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invalidate_cache = true)
|
||||
# 1: mute
|
||||
pre_length1 = length(node1.parents)
|
||||
pre_length2 = length(node2.children)
|
||||
@@ -149,15 +127,15 @@ function remove_edge!(
|
||||
filter!(x -> x != node2, node1.parents)
|
||||
filter!(x -> x != node1, node2.children)
|
||||
|
||||
#=@assert begin
|
||||
removed = pre_length1 - length(node1.parents)
|
||||
removed <= 1
|
||||
end "removed more than one node from node1's parents"=#
|
||||
@assert begin
|
||||
removed = pre_length1 - length(node1.parents)
|
||||
removed <= 1
|
||||
end "removed more than one node from node1's parents"
|
||||
|
||||
#=@assert begin
|
||||
removed = pre_length2 - length(node2.children)
|
||||
removed <= 1
|
||||
end "removed more than one node from node2's children"=#
|
||||
@assert begin
|
||||
removed = pre_length2 - length(node2.children)
|
||||
removed <= 1
|
||||
end "removed more than one node from node2's children"
|
||||
|
||||
# 2: keep track
|
||||
if (track)
|
||||
@@ -181,6 +159,66 @@ function remove_edge!(
|
||||
return nothing
|
||||
end
|
||||
|
||||
function replace_children!(task::FusedComputeTask, before, after)
|
||||
replacedIn1 = length(findall(x -> x == before, task.t1_inputs))
|
||||
replacedIn2 = length(findall(x -> x == before, task.t2_inputs))
|
||||
|
||||
@assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
|
||||
|
||||
replace!(task.t1_inputs, before => after)
|
||||
replace!(task.t2_inputs, before => after)
|
||||
|
||||
# recursively descend down the tree, but only in the tasks where we're replacing things
|
||||
if replacedIn1 > 0
|
||||
replace_children!(task.first_task, before, after)
|
||||
end
|
||||
if replacedIn2 > 0
|
||||
replace_children!(task.second_task, before, after)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function replace_children!(task::AbstractTask, before, after)
|
||||
return nothing
|
||||
end
|
||||
|
||||
function update_child!(graph::DAG, n::Node, child_before::Symbol, child_after::Symbol; track = true)
|
||||
# only need to update fused compute tasks
|
||||
if !(typeof(n.task) <: FusedComputeTask)
|
||||
return nothing
|
||||
end
|
||||
|
||||
taskBefore = copy(n.task)
|
||||
|
||||
if !((child_before in n.task.t1_inputs) || (child_before in n.task.t2_inputs))
|
||||
println("------------------ Nothing to replace!! ------------------")
|
||||
child_ids = Vector{String}()
|
||||
for child in n.children
|
||||
push!(child_ids, "$(child.id)")
|
||||
end
|
||||
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
|
||||
@assert false
|
||||
end
|
||||
|
||||
replace_children!(n.task, child_before, child_after)
|
||||
|
||||
if !((child_after in n.task.t1_inputs) || (child_after in n.task.t2_inputs))
|
||||
println("------------------ Did not replace anything!! ------------------")
|
||||
child_ids = Vector{String}()
|
||||
for child in n.children
|
||||
push!(child_ids, "$(child.id)")
|
||||
end
|
||||
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
|
||||
@assert false
|
||||
end
|
||||
|
||||
# keep track
|
||||
if (track)
|
||||
push!(graph.diff.updatedChildren, (n, taskBefore))
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
get_snapshot_diff(graph::DAG)
|
||||
|
||||
|
@@ -62,9 +62,5 @@ function show(io::IO, graph::DAG)
|
||||
properties = get_properties(graph)
|
||||
println(io, " Total Compute Effort: ", properties.computeEffort)
|
||||
println(io, " Total Data Transfer: ", properties.data)
|
||||
return println(
|
||||
io,
|
||||
" Total Compute Intensity: ",
|
||||
properties.computeIntensity,
|
||||
)
|
||||
return println(io, " Total Compute Intensity: ", properties.computeIntensity)
|
||||
end
|
||||
|
@@ -34,6 +34,7 @@ end
|
||||
Return a vector of the graph's entry nodes.
|
||||
"""
|
||||
function get_entry_nodes(graph::DAG)
|
||||
apply_all!(graph)
|
||||
result = Vector{Node}()
|
||||
for node in graph.nodes
|
||||
if (is_entry_node(node))
|
||||
|
@@ -17,7 +17,7 @@ end
|
||||
|
||||
The representation of the graph as a set of [`Node`](@ref)s.
|
||||
|
||||
A DAG can be loaded using the appropriate parse function, e.g. [`parse_abc`](@ref).
|
||||
A DAG can be loaded using the appropriate parse_dag function, e.g. [`parse_dag`](@ref).
|
||||
|
||||
[`Operation`](@ref)s can be applied on it using [`push_operation!`](@ref) and reverted using [`pop_operation!`](@ref) like a stack.
|
||||
To get the set of possible operations, use [`get_operations`](@ref).
|
||||
@@ -52,11 +52,7 @@ end
|
||||
Construct and return an empty [`PossibleOperations`](@ref) object.
|
||||
"""
|
||||
function PossibleOperations()
|
||||
return PossibleOperations(
|
||||
Set{NodeFusion}(),
|
||||
Set{NodeReduction}(),
|
||||
Set{NodeSplit}(),
|
||||
)
|
||||
return PossibleOperations(Set{NodeFusion}(), Set{NodeReduction}(), Set{NodeSplit}())
|
||||
end
|
||||
|
||||
"""
|
||||
|
@@ -59,3 +59,19 @@ function is_valid(graph::DAG)
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
is_scheduled(graph::DAG)
|
||||
|
||||
Validate that the entire graph has been scheduled, i.e., every [`ComputeTaskNode`](@ref) has its `.device` set.
|
||||
"""
|
||||
function is_scheduled(graph::DAG)
|
||||
for node in graph.nodes
|
||||
if (node isa DataTaskNode)
|
||||
continue
|
||||
end
|
||||
@assert !ismissing(node.device)
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
Reference in New Issue
Block a user