Fix tests and operation cache
This commit is contained in:
@ -6,7 +6,7 @@ export make_node, make_edge, insert_node, insert_edge, is_entry_node, is_exit_no
|
||||
export NodeFusion, NodeReduction, NodeSplit, push_operation!, pop_operation!, can_pop, reset_graph!, get_operations
|
||||
export import_txt
|
||||
|
||||
export ==, in, show
|
||||
export ==, in, show, isempty, delete!
|
||||
|
||||
export bytes_to_human_readable
|
||||
|
||||
@ -15,6 +15,8 @@ import Base.show
|
||||
import Base.==
|
||||
import Base.in
|
||||
import Base.copy
|
||||
import Base.isempty
|
||||
import Base.delete!
|
||||
|
||||
|
||||
include("tasks.jl")
|
||||
|
@ -3,6 +3,25 @@ using DataStructures
|
||||
in(node::Node, graph::DAG) = node in graph.nodes
|
||||
in(edge::Edge, graph::DAG) = edge in graph.edges
|
||||
|
||||
function isempty(operations::PossibleOperations)
|
||||
return isempty(operations.nodeFusions) &&
|
||||
isempty(operations.nodeReductions) &&
|
||||
isempty(operations.nodeSplits)
|
||||
end
|
||||
|
||||
function delete!(operations::PossibleOperations, op::NodeFusion)
|
||||
delete!(operations.nodeFusions, op)
|
||||
return operations
|
||||
end
|
||||
function delete!(operations::PossibleOperations, op::NodeReduction)
|
||||
delete!(operations.nodeReductions, op)
|
||||
return operations
|
||||
end
|
||||
function delete!(operations::PossibleOperations, op::NodeSplit)
|
||||
delete!(operations.nodeSplits, op)
|
||||
return operations
|
||||
end
|
||||
|
||||
function is_parent(potential_parent, node)
|
||||
return potential_parent in node.parents
|
||||
end
|
||||
@ -68,14 +87,25 @@ function invalidate_caches!(graph::DAG, operation::Operation)
|
||||
delete!(graph.possibleOperations, operation)
|
||||
|
||||
# delete the operation from all caches of nodes involved in the operation
|
||||
# (we can iterate over single values, tuples and vectors just fine)
|
||||
# (we can iterate over tuples and vectors just fine)
|
||||
for node in operation.input
|
||||
delete!(node.operations, operation)
|
||||
filter!(!=(operation), node.operations)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# function to invalidate the operation caches for a given Node Split specifically
|
||||
function invalidate_caches!(graph::DAG, operation::NodeSplit)
|
||||
delete!(graph.possibleOperations, operation)
|
||||
|
||||
# delete the operation from all caches of nodes involved in the operation
|
||||
# for node split there is only one node
|
||||
filter!(!=(operation), operation.input.operations)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# for graph mutating functions we need to do a few things
|
||||
# 1: mute the graph (duh)
|
||||
# 2: keep track of what was changed for the diff (if track == true)
|
||||
@ -127,7 +157,7 @@ function remove_node!(graph::DAG, node::Node, track=true)
|
||||
if (track) push!(graph.diff.removedNodes, node) end
|
||||
|
||||
# 3: invalidate caches
|
||||
while !isempty(node)
|
||||
while !isempty(node.operations)
|
||||
invalidate_caches!(graph, first(node.operations))
|
||||
end
|
||||
delete!(graph.dirtyNodes, node)
|
||||
@ -197,7 +227,16 @@ function get_exit_node(graph::DAG)
|
||||
end
|
||||
|
||||
function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
#Todo
|
||||
if !is_child(n1, n2) || !is_child(n2, n3)
|
||||
# the checks are redundant but maybe a good sanity check
|
||||
return false
|
||||
end
|
||||
|
||||
if length(parents(n2)) != 1 || length(children(n2)) != 1
|
||||
return false
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function can_reduce(n1::Node, n2::Node)
|
||||
|
@ -237,9 +237,113 @@ function node_split!(graph::DAG, n1::Node)
|
||||
return get_snapshot_diff(graph)
|
||||
end
|
||||
|
||||
# function to find node fusions involving the given node
|
||||
function find_fusions(graph::DAG, node::Node)
|
||||
|
||||
# function to find node fusions involving the given node if it's a data node
|
||||
# pushes the found fusion everywhere it needs to be and returns nothing
|
||||
function find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
if length(parents(node)) != 1 || length(children(node)) != 1
|
||||
return nothing
|
||||
end
|
||||
|
||||
child_node = first(children(node))
|
||||
parent_node = first(parents(node))
|
||||
|
||||
nf = NodeFusion((child_node, node, parent_node))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(child_node.operations, nf)
|
||||
push!(node.operations, nf)
|
||||
push!(parent_node.operations, nf)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# function to find node fusions involving the given node if it's a compute node
|
||||
# pushes the found fusion(s) everywhere it needs to be and returns nothing
|
||||
function find_fusions!(graph::DAG, node::ComputeTaskNode)
|
||||
# for loop that always runs once for a scoped block we can break out of
|
||||
for _ in 1:1
|
||||
# assume this node as child of the chain
|
||||
if length(parents(node)) < 1
|
||||
break
|
||||
end
|
||||
node2 = first(parents(node))
|
||||
if length(parents(node2)) != 1 || length(children(node2)) != 1
|
||||
break
|
||||
end
|
||||
node3 = first(parents(node2))
|
||||
|
||||
nf = NodeFusion((node, node2, node3))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(node.operations, nf)
|
||||
push!(node2.operations, nf)
|
||||
push!(node3.operations, nf)
|
||||
end
|
||||
|
||||
for _ in 1:1
|
||||
# assume this node as parent of the chain
|
||||
if length(children(node)) < 1
|
||||
break
|
||||
end
|
||||
node2 = first(children(node))
|
||||
if length(parents(node2)) != 1 || length(children(node2)) != 1
|
||||
break
|
||||
end
|
||||
node1 = first(children(node2))
|
||||
|
||||
nf = NodeFusion((node1, node2, node))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(node1.operations, nf)
|
||||
push!(node2.operations, nf)
|
||||
push!(node.operations, nf)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function find_reductions!(graph::DAG, node::Node)
|
||||
reductionVector = nothing
|
||||
# possible reductions are with nodes that are partners, i.e. parents of children
|
||||
for partner in partners(node)
|
||||
if can_reduce(node, partner)
|
||||
if reductionVector === nothing
|
||||
# only when there's at least one reduction partner, insert the vector
|
||||
reductionVector = Vector{Node}()
|
||||
push!(reductionVector, node)
|
||||
end
|
||||
|
||||
push!(reductionVector, partner)
|
||||
end
|
||||
end
|
||||
|
||||
if reductionVector !== nothing
|
||||
nr = NodeReduction(reductionVector)
|
||||
push!(graph.possibleOperations.nodeReductions, nr)
|
||||
for node in reductionVector
|
||||
push!(node.operations, nr)
|
||||
end
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function find_splits!(graph::DAG, node::Node)
|
||||
for node in graph.nodes
|
||||
if (can_split(node))
|
||||
ns = NodeSplit(node)
|
||||
push!(graph.possibleOperations.nodeSplits, ns)
|
||||
push!(node.operations, ns)
|
||||
end
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# "clean" the operations on a dirty node
|
||||
function clean_node!(graph::DAG, node::Node)
|
||||
find_fusions!(graph, node)
|
||||
find_reductions!(graph, node)
|
||||
find_splits!(graph, node)
|
||||
|
||||
delete!(graph.dirtyNodes, node)
|
||||
end
|
||||
|
||||
# function to generate all possible optmizations on the graph
|
||||
@ -317,15 +421,20 @@ function generate_options(graph::DAG)
|
||||
end
|
||||
end
|
||||
|
||||
options.dirty = false
|
||||
|
||||
graph.possibleOperations = options
|
||||
empty!(graph.dirtyNodes)
|
||||
end
|
||||
|
||||
function get_operations(graph::DAG)
|
||||
if (graph.possibleOperations.dirty)
|
||||
apply_all!(graph)
|
||||
|
||||
if isempty(graph.possibleOperations)
|
||||
generate_options(graph)
|
||||
end
|
||||
|
||||
while !isempty(graph.dirtyNodes)
|
||||
clean_node!(graph, first(graph.dirtyNodes))
|
||||
end
|
||||
|
||||
return graph.possibleOperations
|
||||
end
|
||||
end
|
||||
|
Reference in New Issue
Block a user