This commit is contained in:
2023-08-15 18:48:18 +02:00
parent f086411720
commit 8a081ba93c
5 changed files with 124 additions and 34 deletions

View File

@ -54,15 +54,13 @@ mutable struct PossibleOperations
nodeFusions::Set{NodeFusion}
nodeReductions::Set{NodeReduction}
nodeSplits::Set{NodeSplit}
dirty::Bool
end
function PossibleOperations()
return PossibleOperations(
Set{NodeFusion}(),
Set{NodeReduction}(),
Set{NodeSplit}(),
true
Set{NodeSplit}()
)
end
@ -81,11 +79,14 @@ mutable struct DAG
# The possible operations at the current state of the DAG
possibleOperations::PossibleOperations
# The set of nodes whose possible operations need to be reevaluated
dirtyNodes::Set{Node}
# "snapshot" system: keep track of added/removed nodes/edges since last snapshot
# these are muted in insert_node! etc.
diff::Diff
end
function DAG()
return DAG(Set{Node}(), Stack{AppliedOperation}(), Deque{Operation}(), PossibleOperations(), Diff())
return DAG(Set{Node}(), Stack{AppliedOperation}(), Deque{Operation}(), PossibleOperations(), Set{Node}(), Diff())
end

View File

@ -63,38 +63,100 @@ end
is_entry_node(node::Node) = length(children(node)) == 0
is_exit_node(node::Node) = length(parents(node)) == 0
# function to invalidate the operation caches for a given operation
function invalidate_caches!(graph::DAG, operation::Operation)
delete!(graph.possibleOperations, operation)
# delete the operation from all caches of nodes involved in the operation
# (we can iterate over single values, tuples and vectors just fine)
for node in operation.input
delete!(node.operations, operation)
end
return nothing
end
# for graph mutating functions we need to do a few things
# 1: mute the graph (duh)
# 2: keep track of what was changed for the diff (if track == true)
# 3: invalidate operation caches
function insert_node!(graph::DAG, node::Node, track=true)
# 1: mute
push!(graph.nodes, node)
# 2: keep track
if (track) push!(graph.diff.addedNodes, node) end
graph.possibleOperations.dirty = true
# 3: invalidate caches
push!(graph.dirtyNodes, node)
return node
end
function insert_edge!(graph::DAG, edge::Edge, track=true)
# edge points from child to parent
push!(edge.edge[1].parents, edge.edge[2])
push!(edge.edge[2].children, edge.edge[1])
node1 = edge.edge[1]
node2 = edge.edge[2]
# 1: mute
# edge points from child to parent
push!(node1.parents, node2)
push!(node2.children, node1)
# 2: keep track
if (track) push!(graph.diff.addedEdges, edge) end
graph.possibleOperations.dirty = true
# 3: invalidate caches
while !isempty(node1.operations)
invalidate_caches!(graph, first(node1.operations))
end
while !isempty(node2.operations)
invalidate_caches!(graph, first(node2.operations))
end
push!(graph.dirtyNodes, node1)
push!(graph.dirtyNodes, node2)
return edge
end
function remove_node!(graph::DAG, node::Node, track=true)
# 1: mute
delete!(graph.nodes, node)
# 2: keep track
if (track) push!(graph.diff.removedNodes, node) end
graph.possibleOperations.dirty = true
# 3: invalidate caches
while !isempty(node)
invalidate_caches!(graph, first(node.operations))
end
delete!(graph.dirtyNodes, node)
# no need to invalidate anything else, the node is gone afterwards anyways
return nothing
end
function remove_edge!(graph::DAG, edge::Edge, track=true)
filter!(x -> x != edge.edge[2], edge.edge[1].parents)
filter!(x -> x != edge.edge[1], edge.edge[2].children)
node1 = edge.edge[1]
node2 = edge.edge[2]
# 1: mute
filter!(x -> x != node2, node1.parents)
filter!(x -> x != node1, node2.children)
# 2: keep track
if (track) push!(graph.diff.removedEdges, edge) end
graph.possibleOperations.dirty = true
# 3: invalidate caches
while !isempty(node1.operations)
invalidate_caches!(graph, first(node1.operations))
end
while !isempty(node2.operations)
invalidate_caches!(graph, first(node2.operations))
end
push!(graph.dirtyNodes, node1)
push!(graph.dirtyNodes, node2)
return nothing
end
@ -134,6 +196,10 @@ function get_exit_node(graph::DAG)
error("The given graph has no exit node! It is either empty or not acyclic!")
end
function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
#Todo
end
function can_reduce(n1::Node, n2::Node)
if (n1.task != n2.task)
return false
@ -198,12 +264,17 @@ function show(io::IO, graph::DAG)
else
print("Total: ", length(graph.nodes), ", ")
first = true
i = 0
for (type, number) in zip(keys(nodeDict), values(nodeDict))
i += 1
if first
first = false
else
print(", ")
end
if (i % 3 == 0)
print("\n ")
end
print(type, ": ", number)
end
end

View File

@ -5,12 +5,7 @@ function push_operation!(graph::DAG, operation::Operation)
# 1.: Add the operation to the DAG
push!(graph.operationsToApply, operation)
# 2.: Apply all operations in the chain
apply_all!(graph)
# 3.: Regenerate properties, possible operations from here
graph.possibleOperations.dirty = true
return nothing
end
# reverts the latest applied operation, essentially like a ctrl+z for
@ -24,13 +19,7 @@ function pop_operation!(graph::DAG)
else
error("No more operations to pop!")
end
# 2.: Apply all (remaining) operations in the chain
apply_all!(graph)
# 3.: Regenerate properties, possible operations from here
graph.possibleOperations.dirty = true
return nothing
end
can_pop(graph::DAG) = !isempty(graph.operationsToApply) || !isempty(graph.appliedOperations)
@ -40,6 +29,8 @@ function reset_graph!(graph::DAG)
while (can_pop(graph))
pop_operation!(graph)
end
return nothing
end
# implementation detail functions, don't export
@ -56,6 +47,7 @@ function apply_all!(graph::DAG)
# push to the end of the appliedOperations deque
push!(graph.appliedOperations, appliedOp)
end
return nothing
end
@ -245,6 +237,11 @@ function node_split!(graph::DAG, n1::Node)
return get_snapshot_diff(graph)
end
# function to find node fusions involving the given node
function find_fusions(graph::DAG, node::Node)
end
# function to generate all possible optmizations on the graph
function generate_options(graph::DAG)
options = PossibleOperations()

View File

@ -20,7 +20,7 @@ function parse_edges(input::AbstractString)
return output
end
function import_txt(filename::String, verbose::Bool = isinteractive())
function import_txt(filename::String, verbose::Bool = false)
file = open(filename, "r")
if (verbose) println("Opened file") end