Add scheduling, machine info, caching strategies and devices (#9)
Some checks failed
MetagraphOptimization_CI / prepare (push) Has been cancelled
MetagraphOptimization_CI / test (push) Has been cancelled
MetagraphOptimization_CI / docs (push) Has been cancelled

Reviewed-on: Rubydragon/MetagraphOptimization.jl#9
Co-authored-by: Anton Reinhard <anton.reinhard@proton.me>
Co-committed-by: Anton Reinhard <anton.reinhard@proton.me>
This commit is contained in:
2023-10-12 17:51:03 +02:00
committed by Anton Reinhard
parent bd6c54c1ae
commit 5a30f57e1f
72 changed files with 3397 additions and 987 deletions

View File

@ -34,12 +34,7 @@ Apply the given [`NodeFusion`](@ref) to the graph. Generic wrapper around [`node
Return an [`AppliedNodeFusion`](@ref) object generated from the graph's [`Diff`](@ref).
"""
function apply_operation!(graph::DAG, operation::NodeFusion)
diff = node_fusion!(
graph,
operation.input[1],
operation.input[2],
operation.input[3],
)
diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3])
graph.properties += GraphProperties(diff)
@ -124,17 +119,24 @@ function revert_diff!(graph::DAG, diff::Diff)
# add removed nodes, remove added nodes, same for edges
# note the order
for edge in diff.addedEdges
remove_edge!(graph, edge.edge[1], edge.edge[2], false)
remove_edge!(graph, edge.edge[1], edge.edge[2], track = false)
end
for node in diff.addedNodes
remove_node!(graph, node, false)
remove_node!(graph, node, track = false)
end
for node in diff.removedNodes
insert_node!(graph, node, false)
insert_node!(graph, node, track = false)
end
for edge in diff.removedEdges
insert_edge!(graph, edge.edge[1], edge.edge[2], false)
insert_edge!(graph, edge.edge[1], edge.edge[2], track = false)
end
for (node, task) in diff.updatedChildren
# node must be fused compute task at this point
@assert typeof(node.task) <: FusedComputeTask
node.task = task
end
graph.properties -= GraphProperties(diff)
@ -149,21 +151,24 @@ Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference
For details see [`NodeFusion`](@ref).
"""
function node_fusion!(
graph::DAG,
n1::ComputeTaskNode,
n2::DataTaskNode,
n3::ComputeTaskNode,
)
# @assert is_valid_node_fusion_input(graph, n1, n2, n3)
function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
@assert is_valid_node_fusion_input(graph, n1, n2, n3)
# clear snapshot
get_snapshot_diff(graph)
# save children and parents
n1_children = children(n1)
n3_parents = parents(n3)
n3_children = children(n3)
n1Children = children(n1)
n3Parents = parents(n3)
n1Task = copy(n1.task)
n3Task = copy(n3.task)
# assemble the input node vectors of n1 and n3 to save into the FusedComputeTask
n1Inputs = Vector{Symbol}()
for child in n1Children
push!(n1Inputs, Symbol(to_var_name(child.id)))
end
# remove the edges and nodes that will be replaced by the fused node
remove_edge!(graph, n1, n2)
@ -172,29 +177,38 @@ function node_fusion!(
remove_node!(graph, n2)
# get n3's children now so it automatically excludes n2
n3_children = children(n3)
n3Children = children(n3)
n3Inputs = Vector{Symbol}()
for child in n3Children
push!(n3Inputs, Symbol(to_var_name(child.id)))
end
remove_node!(graph, n3)
# create new node with the fused compute task
new_node =
ComputeTaskNode(FusedComputeTask{typeof(n1.task), typeof(n3.task)}())
insert_node!(graph, new_node)
newNode = ComputeTaskNode(FusedComputeTask(n1Task, n3Task, n1Inputs, Symbol(to_var_name(n2.id)), n3Inputs))
insert_node!(graph, newNode)
for child in n1_children
for child in n1Children
remove_edge!(graph, child, n1)
insert_edge!(graph, child, new_node)
insert_edge!(graph, child, newNode)
end
for child in n3_children
for child in n3Children
remove_edge!(graph, child, n3)
if !(child in n1_children)
insert_edge!(graph, child, new_node)
if !(child in n1Children)
insert_edge!(graph, child, newNode)
end
end
for parent in n3_parents
for parent in n3Parents
remove_edge!(graph, n3, parent)
insert_edge!(graph, new_node, parent)
insert_edge!(graph, newNode, parent)
# important! update the parent node's child names in case they are fused compute tasks
# needed for compute generation so the fused compute task can correctly match inputs to its component tasks
update_child!(graph, parent, Symbol(to_var_name(n3.id)), Symbol(to_var_name(newNode.id)))
end
return get_snapshot_diff(graph)
@ -208,21 +222,26 @@ Reduce the given nodes together into one node, return the applied difference to
For details see [`NodeReduction`](@ref).
"""
function node_reduction!(graph::DAG, nodes::Vector{Node})
# @assert is_valid_node_reduction_input(graph, nodes)
@assert is_valid_node_reduction_input(graph, nodes)
# clear snapshot
get_snapshot_diff(graph)
n1 = nodes[1]
n1_children = children(n1)
n1Children = children(n1)
n1_parents = Set(n1.parents)
new_parents = Set{Node}()
n1Parents = Set(n1.parents)
# set of the new parents of n1
newParents = Set{Node}()
# names of the previous children that n1 now replaces per parent
newParentsChildNames = Dict{Node, Symbol}()
# remove all of the nodes' parents and children and the nodes themselves (except for first node)
for i in 2:length(nodes)
n = nodes[i]
for child in n1_children
for child in n1Children
remove_edge!(graph, child, n)
end
@ -230,17 +249,23 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
remove_edge!(graph, n, parent)
# collect all parents
push!(new_parents, parent)
push!(newParents, parent)
newParentsChildNames[parent] = Symbol(to_var_name(n.id))
end
remove_node!(graph, n)
end
setdiff!(new_parents, n1_parents)
for parent in new_parents
for parent in newParents
# now add parents of all input nodes to n1 without duplicates
insert_edge!(graph, n1, parent)
if !(parent in n1Parents)
# don't double insert edges
insert_edge!(graph, n1, parent)
end
# this has to be done for all parents, even the ones of n1 because they can be duplicate
prevChild = newParentsChildNames[parent]
update_child!(graph, parent, prevChild, Symbol(to_var_name(n1.id)))
end
return get_snapshot_diff(graph)
@ -254,30 +279,33 @@ Split the given node into one node per parent, return the applied difference to
For details see [`NodeSplit`](@ref).
"""
function node_split!(graph::DAG, n1::Node)
# @assert is_valid_node_split_input(graph, n1)
@assert is_valid_node_split_input(graph, n1)
# clear snapshot
get_snapshot_diff(graph)
n1_parents = parents(n1)
n1_children = children(n1)
n1Parents = parents(n1)
n1Children = children(n1)
for parent in n1_parents
for parent in n1Parents
remove_edge!(graph, n1, parent)
end
for child in n1_children
for child in n1Children
remove_edge!(graph, child, n1)
end
remove_node!(graph, n1)
for parent in n1_parents
n_copy = copy(n1)
insert_node!(graph, n_copy)
insert_edge!(graph, n_copy, parent)
for parent in n1Parents
nCopy = copy(n1)
for child in n1_children
insert_edge!(graph, child, n_copy)
insert_node!(graph, nCopy)
insert_edge!(graph, nCopy, parent)
for child in n1Children
insert_edge!(graph, child, nCopy)
end
update_child!(graph, parent, Symbol(to_var_name(n1.id)), Symbol(to_var_name(nCopy.id)))
end
return get_snapshot_diff(graph)

View File

@ -7,10 +7,7 @@ using Base.Threads
Insert the given node fusion into its input nodes' operation caches. For the compute nodes, locking via the given `locks` is employed to have safe multi-threading. For a large set of nodes, contention on the locks should be very small.
"""
function insert_operation!(
nf::NodeFusion,
locks::Dict{ComputeTaskNode, SpinLock},
)
function insert_operation!(nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock})
n1 = nf.input[1]
n2 = nf.input[2]
n3 = nf.input[3]
@ -52,10 +49,7 @@ end
Insert the node reductions into the graph and the nodes' caches. Employs multithreading for speedup.
"""
function nr_insertion!(
operations::PossibleOperations,
nodeReductions::Vector{Vector{NodeReduction}},
)
function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}})
total_len = 0
for vec in nodeReductions
total_len += length(vec)
@ -83,11 +77,7 @@ end
Insert the node fusions into the graph and the nodes' caches. Employs multithreading for speedup.
"""
function nf_insertion!(
graph::DAG,
operations::PossibleOperations,
nodeFusions::Vector{Vector{NodeFusion}},
)
function nf_insertion!(graph::DAG, operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}})
total_len = 0
for vec in nodeFusions
total_len += length(vec)
@ -122,10 +112,7 @@ end
Insert the node splits into the graph and the nodes' caches. Employs multithreading for speedup.
"""
function ns_insertion!(
operations::PossibleOperations,
nodeSplits::Vector{Vector{NodeSplit}},
)
function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}})
total_len = 0
for vec in nodeSplits
total_len += length(vec)
@ -231,16 +218,12 @@ function generate_operations(graph::DAG)
continue
end
push!(
generatedFusions[threadid()],
NodeFusion((child_node, node, parent_node)),
)
push!(generatedFusions[threadid()], NodeFusion((child_node, node, parent_node)))
end
end
# launch thread for node fusion insertion
nf_task =
@task nf_insertion!(graph, graph.possibleOperations, generatedFusions)
nf_task = @task nf_insertion!(graph, graph.possibleOperations, generatedFusions)
schedule(nf_task)
# find possible node splits

View File

@ -4,9 +4,7 @@
Return whether `operations` is empty, i.e. all of its fields are empty.
"""
function isempty(operations::PossibleOperations)
return isempty(operations.nodeFusions) &&
isempty(operations.nodeReductions) &&
isempty(operations.nodeSplits)
return isempty(operations.nodeFusions) && isempty(operations.nodeReductions) && isempty(operations.nodeSplits)
end
"""
@ -63,9 +61,7 @@ function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
return false
end
if length(n2.parents) != 1 ||
length(n2.children) != 1 ||
length(n1.parents) != 1
if length(n2.parents) != 1 || length(n2.children) != 1 || length(n1.parents) != 1
return false
end

View File

@ -9,24 +9,12 @@ Assert for a gven node fusion input whether the nodes can be fused. For the requ
Intended for use with `@assert` or `@test`.
"""
function is_valid_node_fusion_input(
graph::DAG,
n1::ComputeTaskNode,
n2::DataTaskNode,
n3::ComputeTaskNode,
)
function is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
if !(n1 in graph) || !(n2 in graph) || !(n3 in graph)
throw(
AssertionError(
"[Node Fusion] The given nodes are not part of the given graph",
),
)
throw(AssertionError("[Node Fusion] The given nodes are not part of the given graph"))
end
if !is_child(n1, n2) ||
!is_child(n2, n3) ||
!is_parent(n3, n2) ||
!is_parent(n2, n1)
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
throw(
AssertionError(
"[Node Fusion] The given nodes are not connected by edges which is required for node fusion",
@ -35,27 +23,19 @@ function is_valid_node_fusion_input(
end
if length(n2.parents) > 1
throw(
AssertionError(
"[Node Fusion] The given data node has more than one parent",
),
)
throw(AssertionError("[Node Fusion] The given data node has more than one parent"))
end
if length(n2.children) > 1
throw(
AssertionError(
"[Node Fusion] The given data node has more than one child",
),
)
throw(AssertionError("[Node Fusion] The given data node has more than one child"))
end
if length(n1.parents) > 1
throw(
AssertionError(
"[Node Fusion] The given n1 has more than one parent",
),
)
throw(AssertionError("[Node Fusion] The given n1 has more than one parent"))
end
@assert is_valid(graph, n1)
@assert is_valid(graph, n2)
@assert is_valid(graph, n3)
return true
end
@ -69,22 +49,21 @@ Intended for use with `@assert` or `@test`.
function is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node})
for n in nodes
if n graph
throw(
AssertionError(
"[Node Reduction] The given nodes are not part of the given graph",
),
)
throw(AssertionError("[Node Reduction] The given nodes are not part of the given graph"))
end
@assert is_valid(graph, n)
end
t = typeof(nodes[1].task)
for n in nodes
if typeof(n.task) != t
throw(
AssertionError(
"[Node Reduction] The given nodes are not of the same type",
),
)
throw(AssertionError("[Node Reduction] The given nodes are not of the same type"))
end
if (typeof(n) <: DataTaskNode)
if (n.name != nodes[1].name)
throw(AssertionError("[Node Reduction] The given nodes do not have the same name"))
end
end
end
@ -111,11 +90,7 @@ Intended for use with `@assert` or `@test`.
"""
function is_valid_node_split_input(graph::DAG, n1::Node)
if n1 graph
throw(
AssertionError(
"[Node Split] The given node is not part of the given graph",
),
)
throw(AssertionError("[Node Split] The given node is not part of the given graph"))
end
if length(n1.parents) <= 1
@ -126,6 +101,8 @@ function is_valid_node_split_input(graph::DAG, n1::Node)
)
end
@assert is_valid(graph, n1)
return true
end
@ -163,12 +140,7 @@ Assert for a given [`NodeFusion`](@ref) whether it is a valid operation in the g
Intended for use with `@assert` or `@test`.
"""
function is_valid(graph::DAG, nf::NodeFusion)
@assert is_valid_node_fusion_input(
graph,
nf.input[1],
nf.input[2],
nf.input[3],
)
@assert is_valid_node_fusion_input(graph, nf.input[1], nf.input[2], nf.input[3])
@assert nf in graph.possibleOperations.nodeFusions "NodeFusion is not part of the graph's possible operations!"
return true
end