Fix tests and operation cache

This commit is contained in:
2023-08-17 14:15:02 +02:00
parent 8a081ba93c
commit ae07b4cf80
8 changed files with 197 additions and 16 deletions

View File

@ -6,7 +6,7 @@ export make_node, make_edge, insert_node, insert_edge, is_entry_node, is_exit_no
export NodeFusion, NodeReduction, NodeSplit, push_operation!, pop_operation!, can_pop, reset_graph!, get_operations
export import_txt
export ==, in, show
export ==, in, show, isempty, delete!
export bytes_to_human_readable
@ -15,6 +15,8 @@ import Base.show
import Base.==
import Base.in
import Base.copy
import Base.isempty
import Base.delete!
include("tasks.jl")

View File

@ -3,6 +3,25 @@ using DataStructures
in(node::Node, graph::DAG) = node in graph.nodes
in(edge::Edge, graph::DAG) = edge in graph.edges
function isempty(operations::PossibleOperations)
return isempty(operations.nodeFusions) &&
isempty(operations.nodeReductions) &&
isempty(operations.nodeSplits)
end
function delete!(operations::PossibleOperations, op::NodeFusion)
delete!(operations.nodeFusions, op)
return operations
end
function delete!(operations::PossibleOperations, op::NodeReduction)
delete!(operations.nodeReductions, op)
return operations
end
function delete!(operations::PossibleOperations, op::NodeSplit)
delete!(operations.nodeSplits, op)
return operations
end
function is_parent(potential_parent, node)
return potential_parent in node.parents
end
@ -68,14 +87,25 @@ function invalidate_caches!(graph::DAG, operation::Operation)
delete!(graph.possibleOperations, operation)
# delete the operation from all caches of nodes involved in the operation
# (we can iterate over single values, tuples and vectors just fine)
# (we can iterate over tuples and vectors just fine)
for node in operation.input
delete!(node.operations, operation)
filter!(!=(operation), node.operations)
end
return nothing
end
# function to invalidate the operation caches for a given Node Split specifically
function invalidate_caches!(graph::DAG, operation::NodeSplit)
delete!(graph.possibleOperations, operation)
# delete the operation from all caches of nodes involved in the operation
# for node split there is only one node
filter!(!=(operation), operation.input.operations)
return nothing
end
# for graph mutating functions we need to do a few things
# 1: mute the graph (duh)
# 2: keep track of what was changed for the diff (if track == true)
@ -127,7 +157,7 @@ function remove_node!(graph::DAG, node::Node, track=true)
if (track) push!(graph.diff.removedNodes, node) end
# 3: invalidate caches
while !isempty(node)
while !isempty(node.operations)
invalidate_caches!(graph, first(node.operations))
end
delete!(graph.dirtyNodes, node)
@ -197,7 +227,16 @@ function get_exit_node(graph::DAG)
end
function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
#Todo
if !is_child(n1, n2) || !is_child(n2, n3)
# the checks are redundant but maybe a good sanity check
return false
end
if length(parents(n2)) != 1 || length(children(n2)) != 1
return false
end
return true
end
function can_reduce(n1::Node, n2::Node)

View File

@ -237,9 +237,113 @@ function node_split!(graph::DAG, n1::Node)
return get_snapshot_diff(graph)
end
# function to find node fusions involving the given node
function find_fusions(graph::DAG, node::Node)
# function to find node fusions involving the given node if it's a data node
# pushes the found fusion everywhere it needs to be and returns nothing
function find_fusions!(graph::DAG, node::DataTaskNode)
if length(parents(node)) != 1 || length(children(node)) != 1
return nothing
end
child_node = first(children(node))
parent_node = first(parents(node))
nf = NodeFusion((child_node, node, parent_node))
push!(graph.possibleOperations.nodeFusions, nf)
push!(child_node.operations, nf)
push!(node.operations, nf)
push!(parent_node.operations, nf)
return nothing
end
# function to find node fusions involving the given node if it's a compute node
# pushes the found fusion(s) everywhere it needs to be and returns nothing
function find_fusions!(graph::DAG, node::ComputeTaskNode)
# for loop that always runs once for a scoped block we can break out of
for _ in 1:1
# assume this node as child of the chain
if length(parents(node)) < 1
break
end
node2 = first(parents(node))
if length(parents(node2)) != 1 || length(children(node2)) != 1
break
end
node3 = first(parents(node2))
nf = NodeFusion((node, node2, node3))
push!(graph.possibleOperations.nodeFusions, nf)
push!(node.operations, nf)
push!(node2.operations, nf)
push!(node3.operations, nf)
end
for _ in 1:1
# assume this node as parent of the chain
if length(children(node)) < 1
break
end
node2 = first(children(node))
if length(parents(node2)) != 1 || length(children(node2)) != 1
break
end
node1 = first(children(node2))
nf = NodeFusion((node1, node2, node))
push!(graph.possibleOperations.nodeFusions, nf)
push!(node1.operations, nf)
push!(node2.operations, nf)
push!(node.operations, nf)
end
return nothing
end
function find_reductions!(graph::DAG, node::Node)
reductionVector = nothing
# possible reductions are with nodes that are partners, i.e. parents of children
for partner in partners(node)
if can_reduce(node, partner)
if reductionVector === nothing
# only when there's at least one reduction partner, insert the vector
reductionVector = Vector{Node}()
push!(reductionVector, node)
end
push!(reductionVector, partner)
end
end
if reductionVector !== nothing
nr = NodeReduction(reductionVector)
push!(graph.possibleOperations.nodeReductions, nr)
for node in reductionVector
push!(node.operations, nr)
end
end
return nothing
end
function find_splits!(graph::DAG, node::Node)
for node in graph.nodes
if (can_split(node))
ns = NodeSplit(node)
push!(graph.possibleOperations.nodeSplits, ns)
push!(node.operations, ns)
end
end
return nothing
end
# "clean" the operations on a dirty node
function clean_node!(graph::DAG, node::Node)
find_fusions!(graph, node)
find_reductions!(graph, node)
find_splits!(graph, node)
delete!(graph.dirtyNodes, node)
end
# function to generate all possible optmizations on the graph
@ -317,15 +421,20 @@ function generate_options(graph::DAG)
end
end
options.dirty = false
graph.possibleOperations = options
empty!(graph.dirtyNodes)
end
function get_operations(graph::DAG)
if (graph.possibleOperations.dirty)
apply_all!(graph)
if isempty(graph.possibleOperations)
generate_options(graph)
end
while !isempty(graph.dirtyNodes)
clean_node!(graph, first(graph.dirtyNodes))
end
return graph.possibleOperations
end
end