diff --git a/src/MetagraphOptimization.jl b/src/MetagraphOptimization.jl index 28a1fbb..039f233 100644 --- a/src/MetagraphOptimization.jl +++ b/src/MetagraphOptimization.jl @@ -2,6 +2,7 @@ module MetagraphOptimization import Base.show import Base.== import Base.in +import Base.copy include("tasks.jl") include("nodes.jl") @@ -16,7 +17,7 @@ include("utility.jl") export Node, Edge, ComputeTaskNode, DataTaskNode, DAG export AbstractTask, AbstractComputeTask, AbstractDataTask, DataTask, ComputeTaskP, ComputeTaskS1, ComputeTaskS2, ComputeTaskV, ComputeTaskU, ComputeTaskSum, FusedComputeTask export make_node, make_edge, insert_node, insert_edge, is_entry_node, is_exit_node, parents, children, compute, graph_properties, get_exit_node, is_valid -export NodeFusion, NodeReduction, NodeSplit, push_operation!, pop_operation!, generate_options +export NodeFusion, NodeReduction, NodeSplit, push_operation!, pop_operation!, can_pop, reset_graph!, generate_options export import_txt export ==, in, show diff --git a/src/graph_functions.jl b/src/graph_functions.jl index 6e90894..5b853ae 100644 --- a/src/graph_functions.jl +++ b/src/graph_functions.jl @@ -75,7 +75,7 @@ function insert_edge!(graph::DAG, edge::Edge, track=true) push!(edge.edge[1].parents, edge.edge[2]) push!(edge.edge[2].children, edge.edge[1]) - if (track) push!(graph.diff.addedEdges) end + if (track) push!(graph.diff.addedEdges, edge) end return edge end @@ -97,7 +97,6 @@ end # return the graph "difference" since last time this function was called function get_snapshot_diff(graph::DAG) return swapfield!(graph, :diff, Diff()) - return result end function graph_properties(graph::DAG) diff --git a/src/graph_operations.jl b/src/graph_operations.jl index 1f14a4c..9e69ff9 100644 --- a/src/graph_operations.jl +++ b/src/graph_operations.jl @@ -31,6 +31,15 @@ function pop_operation!(graph::DAG) end +can_pop(graph::DAG) = !isempty(graph.operationsToApply) || !isempty(graph.appliedOperations) + +# reset the graph to its initial state with no operations applied +function reset_graph!(graph::DAG) + while (can_pop(graph)) + pop_operation!(graph) + end +end + # implementation detail functions, don't export # applies all unapplied operations in the DAG @@ -63,7 +72,7 @@ function apply_operation!(graph::DAG, operation::NodeReduction) end function apply_operation!(graph::DAG, operation::NodeSplit) - diff = node_split!(graph, operation.input[1]) + diff = node_split!(graph, operation.input) return AppliedNodeSplit(operation, diff) end diff --git a/src/node_functions.jl b/src/node_functions.jl index d907722..ac52206 100644 --- a/src/node_functions.jl +++ b/src/node_functions.jl @@ -33,3 +33,7 @@ end function ==(e1::Edge, e2::Edge) return e1.edge[1] == e2.edge[1] && e1.edge[2] == e2.edge[2] end + +copy(id::Base.UUID) = Base.UUID(id.value) +copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task), copy(n.parents), copy(n.children), copy(n.id)) +copy(n::DataTaskNode) = DataTaskNode(copy(n.task), copy(n.parents), copy(n.children), copy(n.id)) diff --git a/src/nodes.jl b/src/nodes.jl index 3509217..1e529db 100644 --- a/src/nodes.jl +++ b/src/nodes.jl @@ -13,6 +13,7 @@ struct DataTaskNode <: Node children::Vector{Node} # need a unique identifier unique to every *constructed* node + # however, it can be copied when splitting a node id::Base.UUID end @@ -24,13 +25,8 @@ struct ComputeTaskNode <: Node id::Base.UUID end -function DataTaskNode(t::AbstractDataTask) - return DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng)) -end - -function ComputeTaskNode(t::AbstractComputeTask) - return ComputeTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng)) -end +DataTaskNode(t::AbstractDataTask) = DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng)) +ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng)) struct Edge # edge points from child to parent diff --git a/src/task_functions.jl b/src/task_functions.jl index 7d94b00..de8c8f0 100644 --- a/src/task_functions.jl +++ b/src/task_functions.jl @@ -67,3 +67,6 @@ end function ==(t1::AbstractDataTask, t2::AbstractDataTask) return data(t1) == data(t2) end + +copy(t::DataTask) = DataTask(t.data) +copy(t::AbstractComputeTask) = typeof(t)() diff --git a/test/known_graphs.jl b/test/known_graphs.jl index c17a27a..ac57dbf 100644 --- a/test/known_graphs.jl +++ b/test/known_graphs.jl @@ -1,6 +1,9 @@ using MetagraphOptimization +using Random using Test +Random.seed!(0) + function test_known_graphs() g_ABAB = import_txt(joinpath(@__DIR__, "..", "examples", "AB->AB.txt")) props = graph_properties(g_ABAB) @@ -12,6 +15,7 @@ function test_known_graphs() @test length(generate_options(g_ABAB).nodeFusions) == 10 test_node_fusion(g_ABAB) + test_random_walk(g_ABAB, 100) g_ABAB3 = import_txt(joinpath(@__DIR__, "..", "examples", "AB->ABBB.txt")) @@ -22,9 +26,10 @@ function test_known_graphs() @test props.data == 828 test_node_fusion(g_ABAB3) + test_random_walk(g_ABAB3, 1000) end -function test_node_fusion(g) +function test_node_fusion(g::DAG) props = graph_properties(g) options = generate_options(g) @@ -50,6 +55,47 @@ function test_node_fusion(g) options = generate_options(g) end - - print(g) +end + +function test_random_walk(g::DAG, n::Int64) + # the purpose here is to do "random" operations and reverse them again and validate that the graph stays the same and doesn't diverge + reset_graph!(g) + + properties = graph_properties(g) + + println("Random Walking... ") + + for i = 1:n + print("\r", i) + # choose push or pop + if rand(Bool) + # push + opt = generate_options(g) + + # choose one of fuse/split/reduce + option = rand(1:3) + if option == 1 && !isempty(opt.nodeFusions) + push_operation!(g, opt.nodeFusions[rand(1 : length(opt.nodeFusions))]) + elseif option == 2 && !isempty(opt.nodeReductions) + push_operation!(g, opt.nodeReductions[rand(1 : length(opt.nodeReductions))]) + elseif option == 3 && !isempty(opt.nodeSplits) + push_operation!(g, opt.nodeSplits[rand(1 : length(opt.nodeSplits))]) + else + i = i-1 + end + else + # pop + if (can_pop(g)) + pop_operation!(g) + else + i = i-1 + end + end + end + + println("\rDone.") + + reset_graph!(g) + + @test properties == graph_properties(g) end