Improve operation/optimization performance

This commit is contained in:
Anton Reinhard 2023-11-21 21:32:40 +01:00
parent 968f6856de
commit 4d1dc27f4f
4 changed files with 31 additions and 23 deletions

View File

@ -19,19 +19,3 @@ function in(edge::Edge, graph::DAG)
return n1 in children(n2)
end
"""
==(n1::Node, n2::Node, g::DAG)
Check equality of two nodes in a graph.
"""
function ==(n1::Node, n2::Node, g::DAG)
if typeof(n1) != typeof(n2)
return false
end
if !(n1 in g) || !(n2 in g)
return false
end
return n1.task == n2.task && children(n1) == children(n2)
end

View File

@ -124,8 +124,19 @@ function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invali
pre_length2 = length(node2.children)
#TODO: filter is very slow
filter!(x -> x != node2, node1.parents)
filter!(x -> x != node1, node2.children)
for i in eachindex(node1.parents)
if (node1.parents[i] == node2)
splice!(node1.parents, i)
break
end
end
for i in eachindex(node2.children)
if (node2.children[i] == node1)
splice!(node2.children, i)
break
end
end
#=@assert begin
removed = pre_length1 - length(node1.parents)
@ -242,8 +253,14 @@ function invalidate_caches!(graph::DAG, operation::NodeFusion)
# delete the operation from all caches of nodes involved in the operation
# TODO: filter is very slow
filter!(!=(operation), operation.input[1].nodeFusions)
filter!(!=(operation), operation.input[3].nodeFusions)
for n in [1, 3]
for i in eachindex(operation.input[n].nodeFusions)
if operation == operation.input[n].nodeFusions[i]
splice!(operation.input[n].nodeFusions, i)
break
end
end
end
operation.input[2].nodeFusion = missing

View File

@ -21,7 +21,7 @@ end
Equality comparison between two [`ComputeTaskNode`](@ref)s.
"""
function ==(n1::ComputeTaskNode, n2::ComputeTaskNode)
function ==(n1::ComputeTaskNode{TaskType}, n2::ComputeTaskNode{TaskType}) where {TaskType <: AbstractComputeTask}
return n1.id == n2.id
end
@ -30,6 +30,6 @@ end
Equality comparison between two [`DataTaskNode`](@ref)s.
"""
function ==(n1::DataTaskNode, n2::DataTaskNode)
function ==(n1::DataTaskNode{TaskType}, n2::DataTaskNode{TaskType}) where {TaskType <: AbstractDataTask}
return n1.id == n2.id
end

View File

@ -141,7 +141,14 @@ end
Equality comparison between two node fusions. Two node fusions are considered equal if they have the same inputs.
"""
function ==(op1::NodeFusion, op2::NodeFusion)
function ==(
op1::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2},
op2::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2},
) where {
ComputeTaskType1 <: AbstractComputeTask,
DataTaskType <: AbstractDataTask,
ComputeTaskType2 <: AbstractComputeTask,
}
# there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same
return op1.input[2] == op2.input[2]
end