Fun with type stability

This commit is contained in:
2023-11-21 01:48:59 +01:00
parent 9d947a49ce
commit 705bfb30fe
25 changed files with 164 additions and 133 deletions

View File

@ -132,11 +132,11 @@ function revert_diff!(graph::DAG, diff::Diff)
insert_edge!(graph, edge.edge[1], edge.edge[2], track = false)
end
for (node, task) in diff.updatedChildren
for (node, t) in diff.updatedChildren
# node must be fused compute task at this point
@assert typeof(node.task) <: FusedComputeTask
@assert typeof(task(node)) <: FusedComputeTask
node.task = task
node.task = t
end
graph.properties -= GraphProperties(diff)
@ -158,11 +158,11 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com
get_snapshot_diff(graph)
# save children and parents
n1Children = children(n1)
n3Parents = parents(n3)
n1Children = copy(children(n1))
n3Parents = copy(parents(n3))
n1Task = copy(n1.task)
n3Task = copy(n3.task)
n1Task = copy(task(n1))
n3Task = copy(task(n3))
# assemble the input node vectors of n1 and n3 to save into the FusedComputeTask
n1Inputs = Vector{Symbol}()
@ -177,7 +177,7 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com
remove_node!(graph, n2)
# get n3's children now so it automatically excludes n2
n3Children = children(n3)
n3Children = copy(children(n3))
n3Inputs = Vector{Symbol}()
for child in n3Children
@ -228,7 +228,7 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
get_snapshot_diff(graph)
n1 = nodes[1]
n1Children = children(n1)
n1Children = copy(children(n1))
n1Parents = Set(n1.parents)
@ -245,7 +245,7 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
remove_edge!(graph, child, n)
end
for parent in parents(n)
for parent in copy(parents(n))
remove_edge!(graph, n, parent)
# collect all parents
@ -278,14 +278,17 @@ Split the given node into one node per parent, return the applied difference to
For details see [`NodeSplit`](@ref).
"""
function node_split!(graph::DAG, n1::Node)
function node_split!(
graph::DAG,
n1::Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}},
) where {TaskType <: AbstractTask}
@assert is_valid_node_split_input(graph, n1)
# clear snapshot
get_snapshot_diff(graph)
n1Parents = parents(n1)
n1Children = children(n1)
n1Parents = copy(parents(n1))
n1Children = copy(children(n1))
for parent in n1Parents
remove_edge!(graph, n1, parent)

View File

@ -13,18 +13,18 @@ function find_fusions!(graph::DAG, node::DataTaskNode)
return nothing
end
if length(node.parents) != 1 || length(node.children) != 1
if length(parents(node)) != 1 || length(children(node)) != 1
return nothing
end
child_node = first(node.children)
parent_node = first(node.parents)
child_node = first(children(node))
parent_node = first(parents(node))
if !(child_node in graph) || !(parent_node in graph)
error("Parents/Children that are not in the graph!!!")
end
if length(child_node.parents) != 1
if length(parents(child_node)) != 1
return nothing
end
@ -44,11 +44,11 @@ Find node fusions involving the given compute node. The function pushes the foun
"""
function find_fusions!(graph::DAG, node::ComputeTaskNode)
# just find fusions in neighbouring DataTaskNodes
for child in node.children
for child in children(node)
find_fusions!(graph, child)
end
for parent in node.parents
for parent in parents(node)
find_fusions!(graph, parent)
end
@ -123,7 +123,10 @@ end
Sort this node's parent and child sets, then find fusions, reductions and splits involving it. Needs to be called after the node was changed in some way.
"""
function clean_node!(graph::DAG, node::Node)
function clean_node!(
graph::DAG,
node::Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}},
) where {TaskType <: AbstractTask}
sort_node!(node)
find_fusions!(graph, node)

View File

@ -203,18 +203,18 @@ function generate_operations(graph::DAG)
# --- find possible node fusions ---
@threads for node in nodeArray
if (typeof(node) <: DataTaskNode)
if length(node.parents) != 1
if length(parents(node)) != 1
# data node can only have a single parent
continue
end
parent_node = first(node.parents)
parent_node = first(parents(node))
if length(node.children) != 1
if length(children(node)) != 1
# this node is an entry node or has multiple children which should not be possible
continue
end
child_node = first(node.children)
if (length(child_node.parents) != 1)
child_node = first(children(node))
if (length(parents(child_node)) != 1)
continue
end

View File

@ -14,9 +14,7 @@ function get_operations(graph::DAG)
generate_operations(graph)
end
for node in graph.dirtyNodes
clean_node!(graph, node)
end
clean_node!.(Ref(graph), graph.dirtyNodes)
empty!(graph.dirtyNodes)
return graph.possibleOperations

View File

@ -30,7 +30,7 @@ function show(io::IO, op::NodeReduction)
print(io, "NR: ")
print(io, length(op.input))
print(io, "x")
return print(io, op.input[1].task)
return print(io, task(op.input[1]))
end
"""
@ -40,7 +40,7 @@ Print a string representation of the node split to io.
"""
function show(io::IO, op::NodeSplit)
print(io, "NS: ")
return print(io, op.input.task)
return print(io, task(op.input))
end
"""
@ -50,9 +50,9 @@ Print a string representation of the node fusion to io.
"""
function show(io::IO, op::NodeFusion)
print(io, "NF: ")
print(io, op.input[1].task)
print(io, task(op.input[1]))
print(io, "->")
print(io, op.input[2].task)
print(io, task(op.input[2]))
print(io, "->")
return print(io, op.input[3].task)
return print(io, task(op.input[3]))
end

View File

@ -40,8 +40,9 @@ A chain of (n1, n2, n3) can be fused if:
See also: [`can_fuse`](@ref)
"""
struct NodeFusion <: Operation
input::Tuple{ComputeTaskNode, DataTaskNode, ComputeTaskNode}
struct NodeFusion{TaskType1 <: AbstractComputeTask, TaskType2 <: AbstractDataTask, TaskType3 <: AbstractComputeTask} <:
Operation
input::Tuple{ComputeTaskNode{TaskType1}, DataTaskNode{TaskType2}, ComputeTaskNode{TaskType3}}
end
"""
@ -49,8 +50,12 @@ end
The applied version of the [`NodeFusion`](@ref).
"""
struct AppliedNodeFusion <: AppliedOperation
operation::NodeFusion
struct AppliedNodeFusion{
TaskType1 <: AbstractComputeTask,
TaskType2 <: AbstractDataTask,
TaskType3 <: AbstractComputeTask,
} <: AppliedOperation
operation::NodeFusion{TaskType1, TaskType2, TaskType3}
diff::Diff
end
@ -73,8 +78,8 @@ A vector of nodes can be reduced if:
See also: [`can_reduce`](@ref)
"""
struct NodeReduction <: Operation
input::Vector{Node}
struct NodeReduction{NodeType <: Node} <: Operation
input::Vector{NodeType}
end
"""
@ -82,8 +87,8 @@ end
The applied version of the [`NodeReduction`](@ref).
"""
struct AppliedNodeReduction <: AppliedOperation
operation::NodeReduction
struct AppliedNodeReduction{NodeType <: Node} <: AppliedOperation
operation::NodeReduction{NodeType}
diff::Diff
end
@ -102,8 +107,8 @@ A node can be split if:
See also: [`can_split`](@ref)
"""
struct NodeSplit <: Operation
input::Node
struct NodeSplit{NodeType <: Node} <: Operation
input::NodeType
end
"""
@ -111,7 +116,7 @@ end
The applied version of the [`NodeSplit`](@ref).
"""
struct AppliedNodeSplit <: AppliedOperation
operation::NodeSplit
struct AppliedNodeSplit{NodeType <: Node} <: AppliedOperation
operation::NodeSplit{NodeType}
diff::Diff
end

View File

@ -61,7 +61,7 @@ function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
return false
end
if length(n2.parents) != 1 || length(n2.children) != 1 || length(n1.parents) != 1
if length(parents(n2)) != 1 || length(children(n2)) != 1 || length(parents(n1)) != 1
return false
end
@ -74,12 +74,15 @@ end
Return whether the given two nodes can be reduced. See [`NodeReduction`](@ref) for the requirements.
"""
function can_reduce(n1::Node, n2::Node)
if (n1.task != n2.task)
return false
end
return false
end
n1_length = length(n1.children)
n2_length = length(n2.children)
function can_reduce(
n1::NodeType,
n2::NodeType,
) where {TaskType <: AbstractTask, NodeType <: Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}}}
n1_length = length(children(n1))
n2_length = length(children(n2))
if (n1_length != n2_length)
return false
@ -88,19 +91,19 @@ function can_reduce(n1::Node, n2::Node)
# this seems to be the most common case so do this first
# doing it manually is a lot faster than using the sets for a general solution
if (n1_length == 2)
if (n1.children[1] != n2.children[1])
if (n1.children[1] != n2.children[2])
if (children(n1)[1] != children(n2)[1])
if (children(n1)[1] != children(n2)[2])
return false
end
# 1_1 == 2_2
if (n1.children[2] != n2.children[1])
if (children(n1)[2] != children(n2)[1])
return false
end
return true
end
# 1_1 == 2_1
if (n1.children[2] != n2.children[2])
if (children(n1)[2] != children(n2)[2])
return false
end
return true
@ -108,11 +111,11 @@ function can_reduce(n1::Node, n2::Node)
# this is simple
if (n1_length == 1)
return n1.children[1] == n2.children[1]
return children(n1)[1] == children(n2)[1]
end
# this takes a long time
return Set(n1.children) == Set(n2.children)
return Set(children(n1)) == Set(children(n2))
end
"""

View File

@ -54,9 +54,9 @@ function is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node})
@assert is_valid(graph, n)
end
t = typeof(nodes[1].task)
t = typeof(task(nodes[1]))
for n in nodes
if typeof(n.task) != t
if typeof(task(n)) != t
throw(AssertionError("[Node Reduction] The given nodes are not of the same type"))
end