Add unit tests, Fix operations and remaining failing tests

This commit is contained in:
2023-08-17 18:46:57 +02:00
parent ae07b4cf80
commit 78f7fb2f05
12 changed files with 428 additions and 49 deletions

View File

@ -6,11 +6,11 @@ export make_node, make_edge, insert_node, insert_edge, is_entry_node, is_exit_no
export NodeFusion, NodeReduction, NodeSplit, push_operation!, pop_operation!, can_pop, reset_graph!, get_operations
export import_txt
export ==, in, show, isempty, delete!
export ==, in, show, isempty, delete!, length
export bytes_to_human_readable
import Base.length
import Base.show
import Base.==
import Base.in

View File

@ -9,6 +9,12 @@ function isempty(operations::PossibleOperations)
isempty(operations.nodeSplits)
end
function length(operations::PossibleOperations)
return (nodeFusions = length(operations.nodeFusions),
nodeReductions = length(operations.nodeReductions),
nodeSplits = length(operations.nodeSplits))
end
function delete!(operations::PossibleOperations, op::NodeFusion)
delete!(operations.nodeFusions, op)
return operations
@ -101,7 +107,7 @@ function invalidate_caches!(graph::DAG, operation::NodeSplit)
# delete the operation from all caches of nodes involved in the operation
# for node split there is only one node
filter!(!=(operation), operation.input.operations)
filter!(x -> x != operation, operation.input.operations)
return nothing
end
@ -161,7 +167,6 @@ function remove_node!(graph::DAG, node::Node, track=true)
invalidate_caches!(graph, first(node.operations))
end
delete!(graph.dirtyNodes, node)
# no need to invalidate anything else, the node is gone afterwards anyways
return nothing
end
@ -184,8 +189,12 @@ function remove_edge!(graph::DAG, edge::Edge, track=true)
while !isempty(node2.operations)
invalidate_caches!(graph, first(node2.operations))
end
push!(graph.dirtyNodes, node1)
push!(graph.dirtyNodes, node2)
if (node1 in graph)
push!(graph.dirtyNodes, node1)
end
if (node2 in graph)
push!(graph.dirtyNodes, node2)
end
return nothing
end
@ -213,6 +222,7 @@ function graph_properties(graph::DAG)
result = (data = d,
compute_effort = ce,
compute_intensity = ci,
nodes = length(graph.nodes),
edges = ed)
return result
end
@ -256,7 +266,7 @@ function is_valid(graph::DAG)
push!(nodeQueue, get_exit_node(graph))
seenNodes = Set{Node}()
while ! isempty(nodeQueue)
while !isempty(nodeQueue)
current = pop!(nodeQueue)
push!(seenNodes, current)
@ -324,3 +334,20 @@ function show(io::IO, graph::DAG)
println(io, " Total Data Transfer: ", properties.data)
println(io, " Total Compute Intensity: ", properties.compute_intensity)
end
function show(io::IO, diff::Diff)
print(io, "Nodes: ")
print(io, length(diff.addedNodes) + length(diff.removedNodes))
print(io, " Edges: ")
print(io, length(diff.addedEdges) + length(diff.removedEdges))
end
# return a namedtuple of the lengths of the added/removed nodes/edges
function length(diff::Diff)
return (
addedNodes = length(diff.addedNodes),
removedNodes = length(diff.removedNodes),
addedEdges = length(diff.addedEdges),
removedEdges = length(diff.removedEdges)
)
end

View File

@ -223,6 +223,10 @@ function node_split!(graph::DAG, n1::Node)
for parent in n1_parents
remove_edge!(graph, make_edge(n1, parent))
end
for child in n1_children
remove_edge!(graph, make_edge(child, n1))
end
remove_node!(graph, n1)
for parent in n1_parents
n_copy = copy(n1)
@ -240,6 +244,9 @@ end
# function to find node fusions involving the given node if it's a data node
# pushes the found fusion everywhere it needs to be and returns nothing
function find_fusions!(graph::DAG, node::DataTaskNode)
if !(node in graph)
error("wot")
end
if length(parents(node)) != 1 || length(children(node)) != 1
return nothing
end
@ -247,6 +254,10 @@ function find_fusions!(graph::DAG, node::DataTaskNode)
child_node = first(children(node))
parent_node = first(parents(node))
if !(child_node in graph) || !(parent_node in graph)
error("Parents/Children that are not in the graph!!!")
end
nf = NodeFusion((child_node, node, parent_node))
push!(graph.possibleOperations.nodeFusions, nf)
push!(child_node.operations, nf)
@ -259,6 +270,9 @@ end
# function to find node fusions involving the given node if it's a compute node
# pushes the found fusion(s) everywhere it needs to be and returns nothing
function find_fusions!(graph::DAG, node::ComputeTaskNode)
if !(node in graph)
error("wot")
end
# for loop that always runs once for a scoped block we can break out of
for _ in 1:1
# assume this node as child of the chain
@ -271,6 +285,11 @@ function find_fusions!(graph::DAG, node::ComputeTaskNode)
end
node3 = first(parents(node2))
if !(node2 in graph) || !(node3 in graph)
error("Parents/Children that are not in the graph!!!")
end
nf = NodeFusion((node, node2, node3))
push!(graph.possibleOperations.nodeFusions, nf)
push!(node.operations, nf)
@ -289,6 +308,10 @@ function find_fusions!(graph::DAG, node::ComputeTaskNode)
end
node1 = first(children(node2))
if !(node2 in graph) || !(node1 in graph)
error("Parents/Children that are not in the graph!!!")
end
nf = NodeFusion((node1, node2, node))
push!(graph.possibleOperations.nodeFusions, nf)
push!(node1.operations, nf)

View File

@ -34,6 +34,17 @@ function ==(e1::Edge, e2::Edge)
return e1.edge[1] == e2.edge[1] && e1.edge[2] == e2.edge[2]
end
copy(id::Base.UUID) = Base.UUID(id.value)
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task), copy(n.parents), copy(n.children), copy(n.id), copy(n.operations))
copy(n::DataTaskNode) = DataTaskNode(copy(n.task), copy(n.parents), copy(n.children), copy(n.id), copy(n.operations))
function ==(n1::Node, n2::Node)
return false
end
function ==(n1::ComputeTaskNode, n2::ComputeTaskNode)
return n1.id == n2.id
end
function ==(n1::DataTaskNode, n2::DataTaskNode)
return n1.id == n2.id
end
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng), copy(n.operations))
copy(n::DataTaskNode) = DataTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng), copy(n.operations))

View File

@ -1,9 +1,9 @@
function bytes_to_human_readable(bytes)
function bytes_to_human_readable(bytes::Int64)
units = ["B", "KiB", "MiB", "GiB", "TiB"]
unit_index = 1
while bytes >= 1024 && unit_index < length(units)
bytes /= 1024
unit_index += 1
end
return string(round(bytes, digits=4), " ", units[unit_index])
return string(round(bytes, sigdigits=4), " ", units[unit_index])
end