diff --git a/examples/import_bench.jl b/examples/import_bench.jl index d199d28..b9ccd80 100644 --- a/examples/import_bench.jl +++ b/examples/import_bench.jl @@ -16,7 +16,7 @@ function bench_txt(filepath::String, bench::Bool = true) println(name, ":") g = parse_abc(filepath) print(g) - #println(" Graph size in memory: ", bytes_to_human_readable(Base.summarysize(g))) + println(" Graph size in memory: ", bytes_to_human_readable(MetagraphOptimization.mem(g))) if (bench) @btime parse_abc($filepath) diff --git a/src/graph_functions.jl b/src/graph_functions.jl index 58aac5b..b95ed93 100644 --- a/src/graph_functions.jl +++ b/src/graph_functions.jl @@ -263,7 +263,7 @@ function get_exit_node(graph::DAG) end # check whether the given graph is connected -function is_valid(graph::DAG) +function is_connected(graph::DAG) nodeQueue = Deque{Node}() push!(nodeQueue, get_exit_node(graph)) seenNodes = Set{Node}() @@ -272,7 +272,7 @@ function is_valid(graph::DAG) current = pop!(nodeQueue) push!(seenNodes, current) - for child in current.chlidren + for child in current.children push!(nodeQueue, child) end end @@ -352,3 +352,38 @@ function length(diff::Diff) removedEdges = length(diff.removedEdges) ) end + +function is_valid(graph::DAG) + for node in graph.nodes + @assert is_valid(graph, node) + end + + for op in graph.operationsToApply + @assert is_valid(graph, op) + end + + for nr in graph.possibleOperations.nodeReductions + @assert is_valid(graph, nr) + end + for ns in graph.possibleOperations.nodeSplits + @assert is_valid(graph, ns) + end + for nf in graph.possibleOperations.nodeFusions + @assert is_valid(graph, nf) + end + + for node in graph.dirtyNodes + @assert node in graph "Dirty Node is not part of the graph!" + @assert ismissing(node.nodeReduction) "Dirty Node has a NodeReduction!" + @assert ismissing(node.nodeSplit) "Dirty Node has a NodeSplit!" + if (typeof(node) <: DataTaskNode) + @assert ismissing(node.nodeFusion) "Dirty DataTaskNode has a Node Fusion!" + elseif (typeof(node) <: ComputeTaskNode) + @assert isempty(node.nodeFusions) "Dirty ComputeTaskNode has Node Fusions!" + end + end + + @assert is_connected(graph) "Graph is not connected!" + + return true +end diff --git a/src/node_functions.jl b/src/node_functions.jl index 1d53f54..cccde82 100644 --- a/src/node_functions.jl +++ b/src/node_functions.jl @@ -46,6 +46,50 @@ function ==(n1::DataTaskNode, n2::DataTaskNode) return n1.id == n2.id end +function is_valid_node(graph::DAG, node::Node) + @assert node in graph "Node is not part of the given graph!" + + for parent in node.parents + @assert typeof(parent) != typeof(node) "Node's type is the same as its parent's!" + @assert parent in graph "Node's parent is not in the same graph!" + @assert node in parent.children "Node is not a child of its parent!" + end + + for child in node.children + @assert typeof(child) != typeof(node) "Node's type is the same as its child's!" + @assert child in graph "Node's child is not in the same graph!" + @assert node in child.parents "Node is not a parent of its child!" + end + + if !ismissing(node.nodeReduction) + @assert is_valid(graph, node.nodeReduction) + end + if !ismissing(node.nodeSplit) + @assert is_valid(graph, node.nodeSplit) + end + return true +end + +# call with @assert +function is_valid(graph::DAG, node::ComputeTaskNode) + @assert is_valid_node(graph, node) + + for nf in node.nodeFusions + @assert is_valid(graph, nf) + end + return true +end + +# call with @assert +function is_valid(graph::DAG, node::DataTaskNode) + @assert is_valid_node(graph, node) + + if !ismissing(node.nodeFusion) + @assert is_valid(graph, node.nodeFusion) + end + return true +end + copy(m::Missing) = missing copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng[threadid()]), copy(n.nodeReduction), copy(n.nodeSplit), copy(n.nodeFusions)) copy(n::DataTaskNode) = DataTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng[threadid()]), copy(n.nodeReduction), copy(n.nodeSplit), copy(n.nodeFusion)) diff --git a/src/operations/validate.jl b/src/operations/validate.jl index 95f9929..eaed924 100644 --- a/src/operations/validate.jl +++ b/src/operations/validate.jl @@ -59,3 +59,21 @@ function is_valid_node_split_input(graph::DAG, n1::Node) return true end + +function is_valid(graph::DAG, nr::NodeReduction) + @assert is_valid_node_reduction_input(graph, nr.input) + @assert nr in graph.possibleOperations.nodeReductions "NodeReduction is not part of the graph's possible operations!" + return true +end + +function is_valid(graph::DAG, ns::NodeSplit) + @assert is_valid_node_split_input(graph, ns.input) + @assert ns in graph.possibleOperations.nodeSplits "NodeSplit is not part of the graph's possible operations!" + return true +end + +function is_valid(graph::DAG, nf::NodeFusion) + @assert is_valid_node_fusion_input(graph, nf.input[1], nf.input[2], nf.input[3]) + @assert nf in graph.possibleOperations.nodeFusions "NodeFusion is not part of the graph's possible operations!" + return true +end diff --git a/src/utility.jl b/src/utility.jl index a163361..afa5cf8 100644 --- a/src/utility.jl +++ b/src/utility.jl @@ -1,4 +1,4 @@ -function bytes_to_human_readable(bytes::Int64) +function bytes_to_human_readable(bytes) units = ["B", "KiB", "MiB", "GiB", "TiB"] unit_index = 1 while bytes >= 1024 && unit_index < length(units) @@ -16,3 +16,38 @@ function sort_node!(node::Node) sort!(node.children, lt=lt_nodes) sort!(node.parents, lt=lt_nodes) end + +function mem(graph::DAG) + size = 0 + size += Base.summarysize(graph.nodes, exclude=Union{Node}) + for n in graph.nodes + size += mem(n) + end + + size += sizeof(graph.appliedOperations) + size += sizeof(graph.operationsToApply) + + size += sizeof(graph.possibleOperations) + for op in graph.possibleOperations.nodeFusions + size += mem(op) + end + for op in graph.possibleOperations.nodeReductions + size += mem(op) + end + for op in graph.possibleOperations.nodeSplits + size += mem(op) + end + + size += Base.summarysize(graph.dirtyNodes, exclude=Union{Node}) + size += sizeof(diff) +end + +# calculate the size of this operation in Byte +function mem(op::Operation) + return Base.summarysize(op, exclude=Union{Node}) +end + +# calculate the size of this node in Byte +function mem(node::Node) + return Base.summarysize(node, exclude=Union{Node, Operation}) +end diff --git a/test/known_graphs.jl b/test/known_graphs.jl index b4cb310..336556f 100644 --- a/test/known_graphs.jl +++ b/test/known_graphs.jl @@ -47,6 +47,8 @@ function test_random_walk(g::DAG, n::Int64) # the purpose here is to do "random" operations and reverse them again and validate that the graph stays the same and doesn't diverge reset_graph!(g) + @test is_valid(g) + properties = graph_properties(g) for i = 1:n @@ -78,6 +80,8 @@ function test_random_walk(g::DAG, n::Int64) reset_graph!(g) + @test is_valid(g) + @test properties == graph_properties(g) end end diff --git a/test/node_reduction.jl b/test/node_reduction.jl index 646035f..7c493bf 100644 --- a/test/node_reduction.jl +++ b/test/node_reduction.jl @@ -51,6 +51,8 @@ import MetagraphOptimization.make_node insert_edge!(graph, BD, B1C_2, false) insert_edge!(graph, CD, C1C, false) + @test is_valid(graph) + @test is_exit_node(d_exit) @test is_entry_node(AD) @test is_entry_node(BD) @@ -74,6 +76,8 @@ import MetagraphOptimization.make_node @test Set(nr.input) == Set([B1D_1, B1D_2]) push_operation!(graph, nr) opt = get_operations(graph) + + @test is_valid(graph) @test length(opt) == (nodeFusions = 4, nodeReductions = 0, nodeSplits = 1) #println("After 2 Node Reductions:\n", opt) @@ -89,5 +93,7 @@ import MetagraphOptimization.make_node opt = get_operations(graph) @test length(opt) == (nodeFusions = 6, nodeReductions = 1, nodeSplits = 1) #println("After reverting to the initial state:\n", opt) + + @test is_valid(graph) end println("Node Reduction Unit Tests Complete!") diff --git a/test/unit_tests_graph.jl b/test/unit_tests_graph.jl index 20dd5fa..784284e 100644 --- a/test/unit_tests_graph.jl +++ b/test/unit_tests_graph.jl @@ -107,6 +107,8 @@ import MetagraphOptimization.partners @test length(graph.dirtyNodes) == 26 @test length(graph.diff) == (addedNodes = 0, removedNodes = 0, addedEdges = 0, removedEdges = 0) + @test is_valid(graph) + @test is_entry_node(d_PB) @test is_entry_node(d_PA) @test is_entry_node(d_PBp) @@ -204,5 +206,7 @@ import MetagraphOptimization.partners operations = get_operations(graph) @test length(operations) == (nodeFusions = 10, nodeReductions = 0, nodeSplits = 0) + + @test is_valid(graph) end println("Graph Unit Tests Complete!")