Add unit tests, Fix operations and remaining failing tests
This commit is contained in:
@ -6,11 +6,11 @@ export make_node, make_edge, insert_node, insert_edge, is_entry_node, is_exit_no
|
||||
export NodeFusion, NodeReduction, NodeSplit, push_operation!, pop_operation!, can_pop, reset_graph!, get_operations
|
||||
export import_txt
|
||||
|
||||
export ==, in, show, isempty, delete!
|
||||
export ==, in, show, isempty, delete!, length
|
||||
|
||||
export bytes_to_human_readable
|
||||
|
||||
|
||||
import Base.length
|
||||
import Base.show
|
||||
import Base.==
|
||||
import Base.in
|
||||
|
@ -9,6 +9,12 @@ function isempty(operations::PossibleOperations)
|
||||
isempty(operations.nodeSplits)
|
||||
end
|
||||
|
||||
function length(operations::PossibleOperations)
|
||||
return (nodeFusions = length(operations.nodeFusions),
|
||||
nodeReductions = length(operations.nodeReductions),
|
||||
nodeSplits = length(operations.nodeSplits))
|
||||
end
|
||||
|
||||
function delete!(operations::PossibleOperations, op::NodeFusion)
|
||||
delete!(operations.nodeFusions, op)
|
||||
return operations
|
||||
@ -101,7 +107,7 @@ function invalidate_caches!(graph::DAG, operation::NodeSplit)
|
||||
|
||||
# delete the operation from all caches of nodes involved in the operation
|
||||
# for node split there is only one node
|
||||
filter!(!=(operation), operation.input.operations)
|
||||
filter!(x -> x != operation, operation.input.operations)
|
||||
|
||||
return nothing
|
||||
end
|
||||
@ -161,7 +167,6 @@ function remove_node!(graph::DAG, node::Node, track=true)
|
||||
invalidate_caches!(graph, first(node.operations))
|
||||
end
|
||||
delete!(graph.dirtyNodes, node)
|
||||
# no need to invalidate anything else, the node is gone afterwards anyways
|
||||
|
||||
return nothing
|
||||
end
|
||||
@ -184,8 +189,12 @@ function remove_edge!(graph::DAG, edge::Edge, track=true)
|
||||
while !isempty(node2.operations)
|
||||
invalidate_caches!(graph, first(node2.operations))
|
||||
end
|
||||
push!(graph.dirtyNodes, node1)
|
||||
push!(graph.dirtyNodes, node2)
|
||||
if (node1 in graph)
|
||||
push!(graph.dirtyNodes, node1)
|
||||
end
|
||||
if (node2 in graph)
|
||||
push!(graph.dirtyNodes, node2)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
@ -213,6 +222,7 @@ function graph_properties(graph::DAG)
|
||||
result = (data = d,
|
||||
compute_effort = ce,
|
||||
compute_intensity = ci,
|
||||
nodes = length(graph.nodes),
|
||||
edges = ed)
|
||||
return result
|
||||
end
|
||||
@ -256,7 +266,7 @@ function is_valid(graph::DAG)
|
||||
push!(nodeQueue, get_exit_node(graph))
|
||||
seenNodes = Set{Node}()
|
||||
|
||||
while ! isempty(nodeQueue)
|
||||
while !isempty(nodeQueue)
|
||||
current = pop!(nodeQueue)
|
||||
push!(seenNodes, current)
|
||||
|
||||
@ -324,3 +334,20 @@ function show(io::IO, graph::DAG)
|
||||
println(io, " Total Data Transfer: ", properties.data)
|
||||
println(io, " Total Compute Intensity: ", properties.compute_intensity)
|
||||
end
|
||||
|
||||
function show(io::IO, diff::Diff)
|
||||
print(io, "Nodes: ")
|
||||
print(io, length(diff.addedNodes) + length(diff.removedNodes))
|
||||
print(io, " Edges: ")
|
||||
print(io, length(diff.addedEdges) + length(diff.removedEdges))
|
||||
end
|
||||
|
||||
# return a namedtuple of the lengths of the added/removed nodes/edges
|
||||
function length(diff::Diff)
|
||||
return (
|
||||
addedNodes = length(diff.addedNodes),
|
||||
removedNodes = length(diff.removedNodes),
|
||||
addedEdges = length(diff.addedEdges),
|
||||
removedEdges = length(diff.removedEdges)
|
||||
)
|
||||
end
|
||||
|
@ -223,6 +223,10 @@ function node_split!(graph::DAG, n1::Node)
|
||||
for parent in n1_parents
|
||||
remove_edge!(graph, make_edge(n1, parent))
|
||||
end
|
||||
for child in n1_children
|
||||
remove_edge!(graph, make_edge(child, n1))
|
||||
end
|
||||
remove_node!(graph, n1)
|
||||
|
||||
for parent in n1_parents
|
||||
n_copy = copy(n1)
|
||||
@ -240,6 +244,9 @@ end
|
||||
# function to find node fusions involving the given node if it's a data node
|
||||
# pushes the found fusion everywhere it needs to be and returns nothing
|
||||
function find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
if !(node in graph)
|
||||
error("wot")
|
||||
end
|
||||
if length(parents(node)) != 1 || length(children(node)) != 1
|
||||
return nothing
|
||||
end
|
||||
@ -247,6 +254,10 @@ function find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
child_node = first(children(node))
|
||||
parent_node = first(parents(node))
|
||||
|
||||
if !(child_node in graph) || !(parent_node in graph)
|
||||
error("Parents/Children that are not in the graph!!!")
|
||||
end
|
||||
|
||||
nf = NodeFusion((child_node, node, parent_node))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(child_node.operations, nf)
|
||||
@ -259,6 +270,9 @@ end
|
||||
# function to find node fusions involving the given node if it's a compute node
|
||||
# pushes the found fusion(s) everywhere it needs to be and returns nothing
|
||||
function find_fusions!(graph::DAG, node::ComputeTaskNode)
|
||||
if !(node in graph)
|
||||
error("wot")
|
||||
end
|
||||
# for loop that always runs once for a scoped block we can break out of
|
||||
for _ in 1:1
|
||||
# assume this node as child of the chain
|
||||
@ -271,6 +285,11 @@ function find_fusions!(graph::DAG, node::ComputeTaskNode)
|
||||
end
|
||||
node3 = first(parents(node2))
|
||||
|
||||
|
||||
if !(node2 in graph) || !(node3 in graph)
|
||||
error("Parents/Children that are not in the graph!!!")
|
||||
end
|
||||
|
||||
nf = NodeFusion((node, node2, node3))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(node.operations, nf)
|
||||
@ -289,6 +308,10 @@ function find_fusions!(graph::DAG, node::ComputeTaskNode)
|
||||
end
|
||||
node1 = first(children(node2))
|
||||
|
||||
if !(node2 in graph) || !(node1 in graph)
|
||||
error("Parents/Children that are not in the graph!!!")
|
||||
end
|
||||
|
||||
nf = NodeFusion((node1, node2, node))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(node1.operations, nf)
|
||||
|
@ -34,6 +34,17 @@ function ==(e1::Edge, e2::Edge)
|
||||
return e1.edge[1] == e2.edge[1] && e1.edge[2] == e2.edge[2]
|
||||
end
|
||||
|
||||
copy(id::Base.UUID) = Base.UUID(id.value)
|
||||
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task), copy(n.parents), copy(n.children), copy(n.id), copy(n.operations))
|
||||
copy(n::DataTaskNode) = DataTaskNode(copy(n.task), copy(n.parents), copy(n.children), copy(n.id), copy(n.operations))
|
||||
function ==(n1::Node, n2::Node)
|
||||
return false
|
||||
end
|
||||
|
||||
function ==(n1::ComputeTaskNode, n2::ComputeTaskNode)
|
||||
return n1.id == n2.id
|
||||
end
|
||||
|
||||
function ==(n1::DataTaskNode, n2::DataTaskNode)
|
||||
return n1.id == n2.id
|
||||
end
|
||||
|
||||
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng), copy(n.operations))
|
||||
copy(n::DataTaskNode) = DataTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng), copy(n.operations))
|
||||
|
@ -1,9 +1,9 @@
|
||||
function bytes_to_human_readable(bytes)
|
||||
function bytes_to_human_readable(bytes::Int64)
|
||||
units = ["B", "KiB", "MiB", "GiB", "TiB"]
|
||||
unit_index = 1
|
||||
while bytes >= 1024 && unit_index < length(units)
|
||||
bytes /= 1024
|
||||
unit_index += 1
|
||||
end
|
||||
return string(round(bytes, digits=4), " ", units[unit_index])
|
||||
return string(round(bytes, sigdigits=4), " ", units[unit_index])
|
||||
end
|
||||
|
Reference in New Issue
Block a user