From 4d1dc27f4f2e31439760091095822d39cabeeb22 Mon Sep 17 00:00:00 2001 From: Anton Reinhard <anton.reinhard@proton.me> Date: Tue, 21 Nov 2023 21:32:40 +0100 Subject: [PATCH] Improve operation/optimization performance --- src/graph/compare.jl | 16 ---------------- src/graph/mute.jl | 25 +++++++++++++++++++++---- src/node/compare.jl | 4 ++-- src/operation/utility.jl | 9 ++++++++- 4 files changed, 31 insertions(+), 23 deletions(-) diff --git a/src/graph/compare.jl b/src/graph/compare.jl index 1adfa69..7b4f206 100644 --- a/src/graph/compare.jl +++ b/src/graph/compare.jl @@ -19,19 +19,3 @@ function in(edge::Edge, graph::DAG) return n1 in children(n2) end - -""" - ==(n1::Node, n2::Node, g::DAG) - -Check equality of two nodes in a graph. -""" -function ==(n1::Node, n2::Node, g::DAG) - if typeof(n1) != typeof(n2) - return false - end - if !(n1 in g) || !(n2 in g) - return false - end - - return n1.task == n2.task && children(n1) == children(n2) -end diff --git a/src/graph/mute.jl b/src/graph/mute.jl index 8d324b6..4c6e0af 100644 --- a/src/graph/mute.jl +++ b/src/graph/mute.jl @@ -124,8 +124,19 @@ function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invali pre_length2 = length(node2.children) #TODO: filter is very slow - filter!(x -> x != node2, node1.parents) - filter!(x -> x != node1, node2.children) + for i in eachindex(node1.parents) + if (node1.parents[i] == node2) + splice!(node1.parents, i) + break + end + end + + for i in eachindex(node2.children) + if (node2.children[i] == node1) + splice!(node2.children, i) + break + end + end #=@assert begin removed = pre_length1 - length(node1.parents) @@ -242,8 +253,14 @@ function invalidate_caches!(graph::DAG, operation::NodeFusion) # delete the operation from all caches of nodes involved in the operation # TODO: filter is very slow - filter!(!=(operation), operation.input[1].nodeFusions) - filter!(!=(operation), operation.input[3].nodeFusions) + for n in [1, 3] + for i in eachindex(operation.input[n].nodeFusions) + if operation == operation.input[n].nodeFusions[i] + splice!(operation.input[n].nodeFusions, i) + break + end + end + end operation.input[2].nodeFusion = missing diff --git a/src/node/compare.jl b/src/node/compare.jl index d2e2c04..c84e0f4 100644 --- a/src/node/compare.jl +++ b/src/node/compare.jl @@ -21,7 +21,7 @@ end Equality comparison between two [`ComputeTaskNode`](@ref)s. """ -function ==(n1::ComputeTaskNode, n2::ComputeTaskNode) +function ==(n1::ComputeTaskNode{TaskType}, n2::ComputeTaskNode{TaskType}) where {TaskType <: AbstractComputeTask} return n1.id == n2.id end @@ -30,6 +30,6 @@ end Equality comparison between two [`DataTaskNode`](@ref)s. """ -function ==(n1::DataTaskNode, n2::DataTaskNode) +function ==(n1::DataTaskNode{TaskType}, n2::DataTaskNode{TaskType}) where {TaskType <: AbstractDataTask} return n1.id == n2.id end diff --git a/src/operation/utility.jl b/src/operation/utility.jl index e07b3e8..0ccafb3 100644 --- a/src/operation/utility.jl +++ b/src/operation/utility.jl @@ -141,7 +141,14 @@ end Equality comparison between two node fusions. Two node fusions are considered equal if they have the same inputs. """ -function ==(op1::NodeFusion, op2::NodeFusion) +function ==( + op1::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2}, + op2::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2}, +) where { + ComputeTaskType1 <: AbstractComputeTask, + DataTaskType <: AbstractDataTask, + ComputeTaskType2 <: AbstractComputeTask, +} # there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same return op1.input[2] == op2.input[2] end