From 6ee444b46ff190d2665ebe751b3cf9c85db12762 Mon Sep 17 00:00:00 2001
From: Anton Reinhard <anton.reinhard@proton.me>
Date: Fri, 18 Aug 2023 11:47:12 +0200
Subject: [PATCH] Remove double edge insertions

---
 examples/Project.toml           |  1 +
 examples/profiling_utilities.jl | 36 ++++++++++++++++++++++++++++++
 src/graph_functions.jl          | 16 ++++++--------
 src/graph_operations.jl         | 39 +++++++++++++++++++++++----------
 4 files changed, 71 insertions(+), 21 deletions(-)
 create mode 100644 examples/profiling_utilities.jl

diff --git a/examples/Project.toml b/examples/Project.toml
index a1e4baf..74721e2 100644
--- a/examples/Project.toml
+++ b/examples/Project.toml
@@ -2,3 +2,4 @@
 BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
 MetagraphOptimization = "3e869610-d48d-4942-ba70-c1b702a33ca4"
 Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
+ProfileView = "c46f51b8-102a-5cf2-8d2c-8597cb0e0da7"
diff --git a/examples/profiling_utilities.jl b/examples/profiling_utilities.jl
new file mode 100644
index 0000000..e279f76
--- /dev/null
+++ b/examples/profiling_utilities.jl
@@ -0,0 +1,36 @@
+
+function test_random_walk(g::DAG, n::Int64)
+    # the purpose here is to do "random" operations and reverse them again and validate that the graph stays the same and doesn't diverge
+    reset_graph!(g)
+
+    properties = graph_properties(g)
+
+    for i = 1:n
+        # choose push or pop
+        if rand(Bool)
+            # push
+            opt = get_operations(g)
+
+            # choose one of fuse/split/reduce
+            option = rand(1:3)
+            if option == 1 && !isempty(opt.nodeFusions)
+                push_operation!(g, rand(collect(opt.nodeFusions)))
+            elseif option == 2 && !isempty(opt.nodeReductions)
+                push_operation!(g, rand(collect(opt.nodeReductions)))
+            elseif option == 3 && !isempty(opt.nodeSplits)
+                push_operation!(g, rand(collect(opt.nodeSplits)))
+            else
+                i = i - 1
+            end
+        else
+            # pop
+            if (can_pop(g))
+                pop_operation!(g)
+            else
+                i = i - 1
+            end
+        end
+    end
+
+    reset_graph!(g)
+end
\ No newline at end of file
diff --git a/src/graph_functions.jl b/src/graph_functions.jl
index 426d97b..5762ba6 100644
--- a/src/graph_functions.jl
+++ b/src/graph_functions.jl
@@ -135,13 +135,12 @@ function insert_edge!(graph::DAG, edge::Edge, track=true)
    node2 = edge.edge[2]
 
    # 1: mute
-   if (node2 in node1.parents) || (node1 in node2.children)
-      #=if !(node2 in node1.parents && node1 in node2.children)
+   #=if (node2 in node1.parents) || (node1 in node2.children)
+      if !(node2 in node1.parents && node1 in node2.children)
          error("One-sided edge")
-      end=#
-
-      return edge
-   end
+      end
+      error("Edge to insert already exists")
+   end=#
 
    # edge points from child to parent
    push!(node1.parents, node2)
@@ -276,7 +275,7 @@ function can_reduce(n1::Node, n2::Node)
    if (n1.task != n2.task)
       return false
    end
-   return Set(children(n1)) == Set(children(n2))
+   return Set(n1.children) == Set(n2.children)
 end
 
 function can_split(n::Node)
@@ -293,8 +292,7 @@ function is_valid(graph::DAG)
       current = pop!(nodeQueue)
       push!(seenNodes, current)
 
-      childrenNodes = children(current)
-      for child in childrenNodes
+      for child in current.chlidren
          push!(nodeQueue, child)
       end
    end
diff --git a/src/graph_operations.jl b/src/graph_operations.jl
index d660108..496a5ab 100644
--- a/src/graph_operations.jl
+++ b/src/graph_operations.jl
@@ -155,15 +155,22 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com
    new_node = ComputeTaskNode(FusedComputeTask{typeof(n1.task), typeof(n3.task)}())
    insert_node!(graph, new_node)
    
-   # "repoint" children of n1 to the new node
+   # use a set for combined children of n1 and n3 to not get duplicates
+   n1and3_children = Set{Node}()
+
+   # remove edges from n1 children to n1
    for child in n1_children
       remove_edge!(graph, make_edge(child, n1))
-      insert_edge!(graph, make_edge(child, new_node))
+      push!(n1and3_children, child)
    end
 
-   # "repoint" children of n3 to the new node
+   # remove edges from n3 children to n3
    for child in n3_children
       remove_edge!(graph, make_edge(child, n3))
+      push!(n1and3_children, child)
+   end
+
+   for child in n1and3_children
       insert_edge!(graph, make_edge(child, new_node))
    end
 
@@ -190,9 +197,9 @@ function node_reduction!(graph::DAG, n1::Node, n2::Node)
 
    # save n2 parents and children
    n2_children = children(n2)
-   n2_parents = parents(n2)
+   n2_parents = Set(n2.parents)
 
-   #=if Set(n2_children) != Set(children(n1))
+   #=if Set(n2_children) != Set(n1.children)
       error("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction")
    end=#
 
@@ -200,12 +207,22 @@ function node_reduction!(graph::DAG, n1::Node, n2::Node)
    for child in n2_children
       remove_edge!(graph, make_edge(child, n2))
    end
+
+
    for parent in n2_parents
       remove_edge!(graph, make_edge(n2, parent))
+   end
 
-      # add parents of n2 to n1
+   for parent in n1.parents
+      # delete parents in n1 that already exist in n2
+      delete!(n2_parents, parent)
+   end
+
+   for parent in n2_parents
+      # now add parents of n2 to n1 without duplicates
       insert_edge!(graph, make_edge(n1, parent))
    end
+
    remove_node!(graph, n2)
 
    return get_snapshot_diff(graph)
@@ -355,12 +372,10 @@ function find_reductions!(graph::DAG, node::Node)
 end
 
 function find_splits!(graph::DAG, node::Node)
-   for node in graph.nodes
-      if (can_split(node))
-         ns = NodeSplit(node)
-         push!(graph.possibleOperations.nodeSplits, ns)
-         push!(node.operations, ns)
-      end
+   if (can_split(node))
+      ns = NodeSplit(node)
+      push!(graph.possibleOperations.nodeSplits, ns)
+      push!(node.operations, ns)
    end
 
    return nothing