From 4d1dc27f4f2e31439760091095822d39cabeeb22 Mon Sep 17 00:00:00 2001
From: Anton Reinhard <anton.reinhard@proton.me>
Date: Tue, 21 Nov 2023 21:32:40 +0100
Subject: [PATCH] Improve operation/optimization performance

---
 src/graph/compare.jl     | 16 ----------------
 src/graph/mute.jl        | 25 +++++++++++++++++++++----
 src/node/compare.jl      |  4 ++--
 src/operation/utility.jl |  9 ++++++++-
 4 files changed, 31 insertions(+), 23 deletions(-)

diff --git a/src/graph/compare.jl b/src/graph/compare.jl
index 1adfa69..7b4f206 100644
--- a/src/graph/compare.jl
+++ b/src/graph/compare.jl
@@ -19,19 +19,3 @@ function in(edge::Edge, graph::DAG)
 
     return n1 in children(n2)
 end
-
-"""
-    ==(n1::Node, n2::Node, g::DAG)
-
-Check equality of two nodes in a graph.
-"""
-function ==(n1::Node, n2::Node, g::DAG)
-    if typeof(n1) != typeof(n2)
-        return false
-    end
-    if !(n1 in g) || !(n2 in g)
-        return false
-    end
-
-    return n1.task == n2.task && children(n1) == children(n2)
-end
diff --git a/src/graph/mute.jl b/src/graph/mute.jl
index 8d324b6..4c6e0af 100644
--- a/src/graph/mute.jl
+++ b/src/graph/mute.jl
@@ -124,8 +124,19 @@ function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invali
     pre_length2 = length(node2.children)
 
     #TODO: filter is very slow
-    filter!(x -> x != node2, node1.parents)
-    filter!(x -> x != node1, node2.children)
+    for i in eachindex(node1.parents)
+        if (node1.parents[i] == node2)
+            splice!(node1.parents, i)
+            break
+        end
+    end
+
+    for i in eachindex(node2.children)
+        if (node2.children[i] == node1)
+            splice!(node2.children, i)
+            break
+        end
+    end
 
     #=@assert begin
         removed = pre_length1 - length(node1.parents)
@@ -242,8 +253,14 @@ function invalidate_caches!(graph::DAG, operation::NodeFusion)
 
     # delete the operation from all caches of nodes involved in the operation
     # TODO: filter is very slow
-    filter!(!=(operation), operation.input[1].nodeFusions)
-    filter!(!=(operation), operation.input[3].nodeFusions)
+    for n in [1, 3]
+        for i in eachindex(operation.input[n].nodeFusions)
+            if operation == operation.input[n].nodeFusions[i]
+                splice!(operation.input[n].nodeFusions, i)
+                break
+            end
+        end
+    end
 
     operation.input[2].nodeFusion = missing
 
diff --git a/src/node/compare.jl b/src/node/compare.jl
index d2e2c04..c84e0f4 100644
--- a/src/node/compare.jl
+++ b/src/node/compare.jl
@@ -21,7 +21,7 @@ end
 
 Equality comparison between two [`ComputeTaskNode`](@ref)s.
 """
-function ==(n1::ComputeTaskNode, n2::ComputeTaskNode)
+function ==(n1::ComputeTaskNode{TaskType}, n2::ComputeTaskNode{TaskType}) where {TaskType <: AbstractComputeTask}
     return n1.id == n2.id
 end
 
@@ -30,6 +30,6 @@ end
 
 Equality comparison between two [`DataTaskNode`](@ref)s.
 """
-function ==(n1::DataTaskNode, n2::DataTaskNode)
+function ==(n1::DataTaskNode{TaskType}, n2::DataTaskNode{TaskType}) where {TaskType <: AbstractDataTask}
     return n1.id == n2.id
 end
diff --git a/src/operation/utility.jl b/src/operation/utility.jl
index e07b3e8..0ccafb3 100644
--- a/src/operation/utility.jl
+++ b/src/operation/utility.jl
@@ -141,7 +141,14 @@ end
 
 Equality comparison between two node fusions. Two node fusions are considered equal if they have the same inputs.
 """
-function ==(op1::NodeFusion, op2::NodeFusion)
+function ==(
+    op1::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2},
+    op2::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2},
+) where {
+    ComputeTaskType1 <: AbstractComputeTask,
+    DataTaskType <: AbstractDataTask,
+    ComputeTaskType2 <: AbstractComputeTask,
+}
     # there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same
     return op1.input[2] == op2.input[2]
 end