metagraphoptimization.jl/src/graph_functions.jl

327 lines
8.1 KiB
Julia

using DataStructures
in(node::Node, graph::DAG) = node in graph.nodes
in(edge::Edge, graph::DAG) = edge in graph.edges
function isempty(operations::PossibleOperations)
return isempty(operations.nodeFusions) &&
isempty(operations.nodeReductions) &&
isempty(operations.nodeSplits)
end
function delete!(operations::PossibleOperations, op::NodeFusion)
delete!(operations.nodeFusions, op)
return operations
end
function delete!(operations::PossibleOperations, op::NodeReduction)
delete!(operations.nodeReductions, op)
return operations
end
function delete!(operations::PossibleOperations, op::NodeSplit)
delete!(operations.nodeSplits, op)
return operations
end
function is_parent(potential_parent, node)
return potential_parent in node.parents
end
function is_child(potential_child, node)
return potential_child in node.children
end
function ==(n1::Node, n2::Node, g::DAG)
if typeof(n1) != typeof(n2)
return false
end
if !(n1 in g) || !(n2 in g)
return false
end
return n1.task == n2.task && children(n1) == children(n2)
end
# children = prerequisite nodes, nodes that need to execute before the task, edges point into this task
function children(node::Node)
return copy(node.children)
end
# parents = subsequent nodes, nodes that need this node to execute, edges point from this task
function parents(node::Node)
return copy(node.parents)
end
# siblings = all children of any parents, no duplicates, does not include the node itself
function siblings(node::Node)
result = Set{Node}()
for parent in parents(node)
for sibling in children(parent)
if (sibling != node)
push!(result, sibling)
end
end
end
return result
end
# partners = all parents of any children, no duplicates, does not include the node itself
function partners(node::Node)
result = Set{Node}()
for child in children(node)
for partner in parents(child)
if (partner != node)
push!(result, partner)
end
end
end
return result
end
is_entry_node(node::Node) = length(children(node)) == 0
is_exit_node(node::Node) = length(parents(node)) == 0
# function to invalidate the operation caches for a given operation
function invalidate_caches!(graph::DAG, operation::Operation)
delete!(graph.possibleOperations, operation)
# delete the operation from all caches of nodes involved in the operation
# (we can iterate over tuples and vectors just fine)
for node in operation.input
filter!(!=(operation), node.operations)
end
return nothing
end
# function to invalidate the operation caches for a given Node Split specifically
function invalidate_caches!(graph::DAG, operation::NodeSplit)
delete!(graph.possibleOperations, operation)
# delete the operation from all caches of nodes involved in the operation
# for node split there is only one node
filter!(!=(operation), operation.input.operations)
return nothing
end
# for graph mutating functions we need to do a few things
# 1: mute the graph (duh)
# 2: keep track of what was changed for the diff (if track == true)
# 3: invalidate operation caches
function insert_node!(graph::DAG, node::Node, track=true)
# 1: mute
push!(graph.nodes, node)
# 2: keep track
if (track) push!(graph.diff.addedNodes, node) end
# 3: invalidate caches
push!(graph.dirtyNodes, node)
return node
end
function insert_edge!(graph::DAG, edge::Edge, track=true)
node1 = edge.edge[1]
node2 = edge.edge[2]
# 1: mute
# edge points from child to parent
push!(node1.parents, node2)
push!(node2.children, node1)
# 2: keep track
if (track) push!(graph.diff.addedEdges, edge) end
# 3: invalidate caches
while !isempty(node1.operations)
invalidate_caches!(graph, first(node1.operations))
end
while !isempty(node2.operations)
invalidate_caches!(graph, first(node2.operations))
end
push!(graph.dirtyNodes, node1)
push!(graph.dirtyNodes, node2)
return edge
end
function remove_node!(graph::DAG, node::Node, track=true)
# 1: mute
delete!(graph.nodes, node)
# 2: keep track
if (track) push!(graph.diff.removedNodes, node) end
# 3: invalidate caches
while !isempty(node.operations)
invalidate_caches!(graph, first(node.operations))
end
delete!(graph.dirtyNodes, node)
# no need to invalidate anything else, the node is gone afterwards anyways
return nothing
end
function remove_edge!(graph::DAG, edge::Edge, track=true)
node1 = edge.edge[1]
node2 = edge.edge[2]
# 1: mute
filter!(x -> x != node2, node1.parents)
filter!(x -> x != node1, node2.children)
# 2: keep track
if (track) push!(graph.diff.removedEdges, edge) end
# 3: invalidate caches
while !isempty(node1.operations)
invalidate_caches!(graph, first(node1.operations))
end
while !isempty(node2.operations)
invalidate_caches!(graph, first(node2.operations))
end
push!(graph.dirtyNodes, node1)
push!(graph.dirtyNodes, node2)
return nothing
end
# return the graph "difference" since last time this function was called
function get_snapshot_diff(graph::DAG)
return swapfield!(graph, :diff, Diff())
end
function graph_properties(graph::DAG)
# make sure the graph is fully generated
apply_all!(graph)
d = 0
ce = 0
ed = 0
for node in graph.nodes
d += data(node.task) * length(node.parents)
ce += compute_effort(node.task)
ed += length(node.parents)
end
ci = ce / d
result = (data = d,
compute_effort = ce,
compute_intensity = ci,
edges = ed)
return result
end
function get_exit_node(graph::DAG)
for node in graph.nodes
if (is_exit_node(node))
return node
end
end
error("The given graph has no exit node! It is either empty or not acyclic!")
end
function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
if !is_child(n1, n2) || !is_child(n2, n3)
# the checks are redundant but maybe a good sanity check
return false
end
if length(parents(n2)) != 1 || length(children(n2)) != 1
return false
end
return true
end
function can_reduce(n1::Node, n2::Node)
if (n1.task != n2.task)
return false
end
return Set(children(n1)) == Set(children(n2))
end
function can_split(n::Node)
return length(parents(n)) > 1
end
# check whether the given graph is connected
function is_valid(graph::DAG)
nodeQueue = Deque{Node}()
push!(nodeQueue, get_exit_node(graph))
seenNodes = Set{Node}()
while ! isempty(nodeQueue)
current = pop!(nodeQueue)
push!(seenNodes, current)
childrenNodes = children(current)
for child in childrenNodes
push!(nodeQueue, child)
end
end
return length(seenNodes) == length(graph.nodes)
end
function show_nodes(io, graph::DAG)
print(io, "[")
first = true
for n in graph.nodes
if first
first = false
else
print(io, ", ")
end
print(io, n)
end
print(io, "]")
end
function show(io::IO, graph::DAG)
println(io, "Graph:")
print(io, " Nodes: ")
nodeDict = Dict{Type, Int64}()
noEdges = 0
for node in graph.nodes
if haskey(nodeDict, typeof(node.task))
nodeDict[typeof(node.task)] = nodeDict[typeof(node.task)] + 1
else
nodeDict[typeof(node.task)] = 1
end
noEdges += length(parents(node))
end
if length(graph.nodes) <= 20
show_nodes(io, graph)
else
print("Total: ", length(graph.nodes), ", ")
first = true
i = 0
for (type, number) in zip(keys(nodeDict), values(nodeDict))
i += 1
if first
first = false
else
print(", ")
end
if (i % 3 == 0)
print("\n ")
end
print(type, ": ", number)
end
end
println(io)
println(io, " Edges: ", noEdges)
properties = graph_properties(graph)
println(io, " Total Compute Effort: ", properties.compute_effort)
println(io, " Total Data Transfer: ", properties.data)
println(io, " Total Compute Intensity: ", properties.compute_intensity)
end