163 lines
4.6 KiB
Julia
Raw Normal View History

# for graph mutating functions we need to do a few things
# 1: mute the graph (duh)
# 2: keep track of what was changed for the diff (if track == true)
# 3: invalidate operation caches
function insert_node!(graph::DAG, node::Node, track=true, invalidate_cache=true)
# 1: mute
push!(graph.nodes, node)
# 2: keep track
if (track) push!(graph.diff.addedNodes, node) end
# 3: invalidate caches
if (!invalidate_cache) return node end
push!(graph.dirtyNodes, node)
return node
end
function insert_edge!(graph::DAG, node1::Node, node2::Node, track=true, invalidate_cache=true)
# @assert (node2 ∉ node1.parents) && (node1 ∉ node2.children) "Edge to insert already exists"
# 1: mute
# edge points from child to parent
push!(node1.parents, node2)
push!(node2.children, node1)
# 2: keep track
if (track) push!(graph.diff.addedEdges, make_edge(node1, node2)) end
# 3: invalidate caches
if (!invalidate_cache) return nothing end
invalidate_operation_caches!(graph, node1)
invalidate_operation_caches!(graph, node2)
push!(graph.dirtyNodes, node1)
push!(graph.dirtyNodes, node2)
return nothing
end
function remove_node!(graph::DAG, node::Node, track=true, invalidate_cache=true)
# @assert node in graph.nodes "Trying to remove a node that's not in the graph"
# 1: mute
delete!(graph.nodes, node)
# 2: keep track
if (track) push!(graph.diff.removedNodes, node) end
# 3: invalidate caches
if (!invalidate_cache) return nothing end
invalidate_operation_caches!(graph, node)
delete!(graph.dirtyNodes, node)
return nothing
end
function remove_edge!(graph::DAG, node1::Node, node2::Node, track=true, invalidate_cache=true)
# 1: mute
pre_length1 = length(node1.parents)
pre_length2 = length(node2.children)
filter!(x -> x != node2, node1.parents)
filter!(x -> x != node1, node2.children)
#=@assert begin
removed = pre_length1 - length(node1.parents)
removed <= 1
end "removed more than one node from node1's parents"=#
#=@assert begin
removed = pre_length2 - length(node2.children)
removed <= 1
end "removed more than one node from node2's children"=#
# 2: keep track
if (track) push!(graph.diff.removedEdges, make_edge(node1, node2)) end
# 3: invalidate caches
if (!invalidate_cache) return nothing end
invalidate_operation_caches!(graph, node1)
invalidate_operation_caches!(graph, node2)
if (node1 in graph)
push!(graph.dirtyNodes, node1)
end
if (node2 in graph)
push!(graph.dirtyNodes, node2)
end
return nothing
end
# return the graph "difference" since last time this function was called
function get_snapshot_diff(graph::DAG)
return swapfield!(graph, :diff, Diff())
end
# function to invalidate the operation caches for a given NodeFusion
function invalidate_caches!(graph::DAG, operation::NodeFusion)
delete!(graph.possibleOperations, operation)
# delete the operation from all caches of nodes involved in the operation
filter!(!=(operation), operation.input[1].nodeFusions)
filter!(!=(operation), operation.input[3].nodeFusions)
operation.input[2].nodeFusion = missing
return nothing
end
# function to invalidate the operation caches for a given NodeReduction
function invalidate_caches!(graph::DAG, operation::NodeReduction)
delete!(graph.possibleOperations, operation)
for node in operation.input
node.nodeReduction = missing
end
return nothing
end
# function to invalidate the operation caches for a given NodeSplit
function invalidate_caches!(graph::DAG, operation::NodeSplit)
delete!(graph.possibleOperations, operation)
# delete the operation from all caches of nodes involved in the operation
# for node split there is only one node
operation.input.nodeSplit = missing
return nothing
end
# function to invalidate the operation caches of a ComputeTaskNode
function invalidate_operation_caches!(graph::DAG, node::ComputeTaskNode)
if !ismissing(node.nodeReduction)
invalidate_caches!(graph, node.nodeReduction)
end
if !ismissing(node.nodeSplit)
invalidate_caches!(graph, node.nodeSplit)
end
while !isempty(node.nodeFusions)
invalidate_caches!(graph, pop!(node.nodeFusions))
end
return nothing
end
# function to invalidate the operation caches of a DataTaskNode
function invalidate_operation_caches!(graph::DAG, node::DataTaskNode)
if !ismissing(node.nodeReduction)
invalidate_caches!(graph, node.nodeReduction)
end
if !ismissing(node.nodeSplit)
invalidate_caches!(graph, node.nodeSplit)
end
if !ismissing(node.nodeFusion)
invalidate_caches!(graph, node.nodeFusion)
end
return nothing
end