Improve operation/optimization performance
This commit is contained in:
		| @@ -19,19 +19,3 @@ function in(edge::Edge, graph::DAG) | ||||
|  | ||||
|     return n1 in children(n2) | ||||
| end | ||||
|  | ||||
| """ | ||||
|     ==(n1::Node, n2::Node, g::DAG) | ||||
|  | ||||
| Check equality of two nodes in a graph. | ||||
| """ | ||||
| function ==(n1::Node, n2::Node, g::DAG) | ||||
|     if typeof(n1) != typeof(n2) | ||||
|         return false | ||||
|     end | ||||
|     if !(n1 in g) || !(n2 in g) | ||||
|         return false | ||||
|     end | ||||
|  | ||||
|     return n1.task == n2.task && children(n1) == children(n2) | ||||
| end | ||||
|   | ||||
| @@ -124,8 +124,19 @@ function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invali | ||||
|     pre_length2 = length(node2.children) | ||||
|  | ||||
|     #TODO: filter is very slow | ||||
|     filter!(x -> x != node2, node1.parents) | ||||
|     filter!(x -> x != node1, node2.children) | ||||
|     for i in eachindex(node1.parents) | ||||
|         if (node1.parents[i] == node2) | ||||
|             splice!(node1.parents, i) | ||||
|             break | ||||
|         end | ||||
|     end | ||||
|  | ||||
|     for i in eachindex(node2.children) | ||||
|         if (node2.children[i] == node1) | ||||
|             splice!(node2.children, i) | ||||
|             break | ||||
|         end | ||||
|     end | ||||
|  | ||||
|     #=@assert begin | ||||
|         removed = pre_length1 - length(node1.parents) | ||||
| @@ -242,8 +253,14 @@ function invalidate_caches!(graph::DAG, operation::NodeFusion) | ||||
|  | ||||
|     # delete the operation from all caches of nodes involved in the operation | ||||
|     # TODO: filter is very slow | ||||
|     filter!(!=(operation), operation.input[1].nodeFusions) | ||||
|     filter!(!=(operation), operation.input[3].nodeFusions) | ||||
|     for n in [1, 3] | ||||
|         for i in eachindex(operation.input[n].nodeFusions) | ||||
|             if operation == operation.input[n].nodeFusions[i] | ||||
|                 splice!(operation.input[n].nodeFusions, i) | ||||
|                 break | ||||
|             end | ||||
|         end | ||||
|     end | ||||
|  | ||||
|     operation.input[2].nodeFusion = missing | ||||
|  | ||||
|   | ||||
| @@ -21,7 +21,7 @@ end | ||||
|  | ||||
| Equality comparison between two [`ComputeTaskNode`](@ref)s. | ||||
| """ | ||||
| function ==(n1::ComputeTaskNode, n2::ComputeTaskNode) | ||||
| function ==(n1::ComputeTaskNode{TaskType}, n2::ComputeTaskNode{TaskType}) where {TaskType <: AbstractComputeTask} | ||||
|     return n1.id == n2.id | ||||
| end | ||||
|  | ||||
| @@ -30,6 +30,6 @@ end | ||||
|  | ||||
| Equality comparison between two [`DataTaskNode`](@ref)s. | ||||
| """ | ||||
| function ==(n1::DataTaskNode, n2::DataTaskNode) | ||||
| function ==(n1::DataTaskNode{TaskType}, n2::DataTaskNode{TaskType}) where {TaskType <: AbstractDataTask} | ||||
|     return n1.id == n2.id | ||||
| end | ||||
|   | ||||
| @@ -141,7 +141,14 @@ end | ||||
|  | ||||
| Equality comparison between two node fusions. Two node fusions are considered equal if they have the same inputs. | ||||
| """ | ||||
| function ==(op1::NodeFusion, op2::NodeFusion) | ||||
| function ==( | ||||
|     op1::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2}, | ||||
|     op2::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2}, | ||||
| ) where { | ||||
|     ComputeTaskType1 <: AbstractComputeTask, | ||||
|     DataTaskType <: AbstractDataTask, | ||||
|     ComputeTaskType2 <: AbstractComputeTask, | ||||
| } | ||||
|     # there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same | ||||
|     return op1.input[2] == op2.input[2] | ||||
| end | ||||
|   | ||||
		Reference in New Issue
	
	Block a user