327 lines
8.1 KiB
Julia
327 lines
8.1 KiB
Julia
using DataStructures
|
|
|
|
in(node::Node, graph::DAG) = node in graph.nodes
|
|
in(edge::Edge, graph::DAG) = edge in graph.edges
|
|
|
|
function isempty(operations::PossibleOperations)
|
|
return isempty(operations.nodeFusions) &&
|
|
isempty(operations.nodeReductions) &&
|
|
isempty(operations.nodeSplits)
|
|
end
|
|
|
|
function delete!(operations::PossibleOperations, op::NodeFusion)
|
|
delete!(operations.nodeFusions, op)
|
|
return operations
|
|
end
|
|
function delete!(operations::PossibleOperations, op::NodeReduction)
|
|
delete!(operations.nodeReductions, op)
|
|
return operations
|
|
end
|
|
function delete!(operations::PossibleOperations, op::NodeSplit)
|
|
delete!(operations.nodeSplits, op)
|
|
return operations
|
|
end
|
|
|
|
function is_parent(potential_parent, node)
|
|
return potential_parent in node.parents
|
|
end
|
|
|
|
function is_child(potential_child, node)
|
|
return potential_child in node.children
|
|
end
|
|
|
|
function ==(n1::Node, n2::Node, g::DAG)
|
|
if typeof(n1) != typeof(n2)
|
|
return false
|
|
end
|
|
if !(n1 in g) || !(n2 in g)
|
|
return false
|
|
end
|
|
|
|
return n1.task == n2.task && children(n1) == children(n2)
|
|
end
|
|
|
|
# children = prerequisite nodes, nodes that need to execute before the task, edges point into this task
|
|
function children(node::Node)
|
|
return copy(node.children)
|
|
end
|
|
|
|
# parents = subsequent nodes, nodes that need this node to execute, edges point from this task
|
|
function parents(node::Node)
|
|
return copy(node.parents)
|
|
end
|
|
|
|
# siblings = all children of any parents, no duplicates, does not include the node itself
|
|
function siblings(node::Node)
|
|
result = Set{Node}()
|
|
for parent in parents(node)
|
|
for sibling in children(parent)
|
|
if (sibling != node)
|
|
push!(result, sibling)
|
|
end
|
|
end
|
|
end
|
|
|
|
return result
|
|
end
|
|
|
|
# partners = all parents of any children, no duplicates, does not include the node itself
|
|
function partners(node::Node)
|
|
result = Set{Node}()
|
|
for child in children(node)
|
|
for partner in parents(child)
|
|
if (partner != node)
|
|
push!(result, partner)
|
|
end
|
|
end
|
|
end
|
|
|
|
return result
|
|
end
|
|
|
|
is_entry_node(node::Node) = length(children(node)) == 0
|
|
is_exit_node(node::Node) = length(parents(node)) == 0
|
|
|
|
# function to invalidate the operation caches for a given operation
|
|
function invalidate_caches!(graph::DAG, operation::Operation)
|
|
delete!(graph.possibleOperations, operation)
|
|
|
|
# delete the operation from all caches of nodes involved in the operation
|
|
# (we can iterate over tuples and vectors just fine)
|
|
for node in operation.input
|
|
filter!(!=(operation), node.operations)
|
|
end
|
|
|
|
return nothing
|
|
end
|
|
|
|
# function to invalidate the operation caches for a given Node Split specifically
|
|
function invalidate_caches!(graph::DAG, operation::NodeSplit)
|
|
delete!(graph.possibleOperations, operation)
|
|
|
|
# delete the operation from all caches of nodes involved in the operation
|
|
# for node split there is only one node
|
|
filter!(!=(operation), operation.input.operations)
|
|
|
|
return nothing
|
|
end
|
|
|
|
# for graph mutating functions we need to do a few things
|
|
# 1: mute the graph (duh)
|
|
# 2: keep track of what was changed for the diff (if track == true)
|
|
# 3: invalidate operation caches
|
|
|
|
function insert_node!(graph::DAG, node::Node, track=true)
|
|
# 1: mute
|
|
push!(graph.nodes, node)
|
|
|
|
# 2: keep track
|
|
if (track) push!(graph.diff.addedNodes, node) end
|
|
|
|
# 3: invalidate caches
|
|
push!(graph.dirtyNodes, node)
|
|
|
|
return node
|
|
end
|
|
|
|
function insert_edge!(graph::DAG, edge::Edge, track=true)
|
|
node1 = edge.edge[1]
|
|
node2 = edge.edge[2]
|
|
|
|
# 1: mute
|
|
# edge points from child to parent
|
|
push!(node1.parents, node2)
|
|
push!(node2.children, node1)
|
|
|
|
# 2: keep track
|
|
if (track) push!(graph.diff.addedEdges, edge) end
|
|
|
|
# 3: invalidate caches
|
|
while !isempty(node1.operations)
|
|
invalidate_caches!(graph, first(node1.operations))
|
|
end
|
|
while !isempty(node2.operations)
|
|
invalidate_caches!(graph, first(node2.operations))
|
|
end
|
|
push!(graph.dirtyNodes, node1)
|
|
push!(graph.dirtyNodes, node2)
|
|
|
|
return edge
|
|
end
|
|
|
|
function remove_node!(graph::DAG, node::Node, track=true)
|
|
# 1: mute
|
|
delete!(graph.nodes, node)
|
|
|
|
# 2: keep track
|
|
if (track) push!(graph.diff.removedNodes, node) end
|
|
|
|
# 3: invalidate caches
|
|
while !isempty(node.operations)
|
|
invalidate_caches!(graph, first(node.operations))
|
|
end
|
|
delete!(graph.dirtyNodes, node)
|
|
# no need to invalidate anything else, the node is gone afterwards anyways
|
|
|
|
return nothing
|
|
end
|
|
|
|
function remove_edge!(graph::DAG, edge::Edge, track=true)
|
|
node1 = edge.edge[1]
|
|
node2 = edge.edge[2]
|
|
|
|
# 1: mute
|
|
filter!(x -> x != node2, node1.parents)
|
|
filter!(x -> x != node1, node2.children)
|
|
|
|
# 2: keep track
|
|
if (track) push!(graph.diff.removedEdges, edge) end
|
|
|
|
# 3: invalidate caches
|
|
while !isempty(node1.operations)
|
|
invalidate_caches!(graph, first(node1.operations))
|
|
end
|
|
while !isempty(node2.operations)
|
|
invalidate_caches!(graph, first(node2.operations))
|
|
end
|
|
push!(graph.dirtyNodes, node1)
|
|
push!(graph.dirtyNodes, node2)
|
|
|
|
return nothing
|
|
end
|
|
|
|
# return the graph "difference" since last time this function was called
|
|
function get_snapshot_diff(graph::DAG)
|
|
return swapfield!(graph, :diff, Diff())
|
|
end
|
|
|
|
function graph_properties(graph::DAG)
|
|
# make sure the graph is fully generated
|
|
apply_all!(graph)
|
|
|
|
d = 0
|
|
ce = 0
|
|
ed = 0
|
|
for node in graph.nodes
|
|
d += data(node.task) * length(node.parents)
|
|
ce += compute_effort(node.task)
|
|
ed += length(node.parents)
|
|
end
|
|
|
|
ci = ce / d
|
|
|
|
result = (data = d,
|
|
compute_effort = ce,
|
|
compute_intensity = ci,
|
|
edges = ed)
|
|
return result
|
|
end
|
|
|
|
function get_exit_node(graph::DAG)
|
|
for node in graph.nodes
|
|
if (is_exit_node(node))
|
|
return node
|
|
end
|
|
end
|
|
error("The given graph has no exit node! It is either empty or not acyclic!")
|
|
end
|
|
|
|
function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
|
if !is_child(n1, n2) || !is_child(n2, n3)
|
|
# the checks are redundant but maybe a good sanity check
|
|
return false
|
|
end
|
|
|
|
if length(parents(n2)) != 1 || length(children(n2)) != 1
|
|
return false
|
|
end
|
|
|
|
return true
|
|
end
|
|
|
|
function can_reduce(n1::Node, n2::Node)
|
|
if (n1.task != n2.task)
|
|
return false
|
|
end
|
|
return Set(children(n1)) == Set(children(n2))
|
|
end
|
|
|
|
function can_split(n::Node)
|
|
return length(parents(n)) > 1
|
|
end
|
|
|
|
# check whether the given graph is connected
|
|
function is_valid(graph::DAG)
|
|
nodeQueue = Deque{Node}()
|
|
push!(nodeQueue, get_exit_node(graph))
|
|
seenNodes = Set{Node}()
|
|
|
|
while ! isempty(nodeQueue)
|
|
current = pop!(nodeQueue)
|
|
push!(seenNodes, current)
|
|
|
|
childrenNodes = children(current)
|
|
for child in childrenNodes
|
|
push!(nodeQueue, child)
|
|
end
|
|
end
|
|
|
|
return length(seenNodes) == length(graph.nodes)
|
|
end
|
|
|
|
function show_nodes(io, graph::DAG)
|
|
print(io, "[")
|
|
first = true
|
|
for n in graph.nodes
|
|
if first
|
|
first = false
|
|
else
|
|
print(io, ", ")
|
|
end
|
|
print(io, n)
|
|
end
|
|
print(io, "]")
|
|
end
|
|
|
|
function show(io::IO, graph::DAG)
|
|
println(io, "Graph:")
|
|
print(io, " Nodes: ")
|
|
|
|
nodeDict = Dict{Type, Int64}()
|
|
noEdges = 0
|
|
for node in graph.nodes
|
|
if haskey(nodeDict, typeof(node.task))
|
|
nodeDict[typeof(node.task)] = nodeDict[typeof(node.task)] + 1
|
|
else
|
|
nodeDict[typeof(node.task)] = 1
|
|
end
|
|
noEdges += length(parents(node))
|
|
end
|
|
|
|
if length(graph.nodes) <= 20
|
|
show_nodes(io, graph)
|
|
else
|
|
print("Total: ", length(graph.nodes), ", ")
|
|
first = true
|
|
i = 0
|
|
for (type, number) in zip(keys(nodeDict), values(nodeDict))
|
|
i += 1
|
|
if first
|
|
first = false
|
|
else
|
|
print(", ")
|
|
end
|
|
if (i % 3 == 0)
|
|
print("\n ")
|
|
end
|
|
print(type, ": ", number)
|
|
end
|
|
end
|
|
println(io)
|
|
println(io, " Edges: ", noEdges)
|
|
properties = graph_properties(graph)
|
|
println(io, " Total Compute Effort: ", properties.compute_effort)
|
|
println(io, " Total Data Transfer: ", properties.data)
|
|
println(io, " Total Compute Intensity: ", properties.compute_intensity)
|
|
end
|