Optimizer interface and sample implementation (#19)

Reviewed-on: Rubydragon/MetagraphOptimization.jl#19
Co-authored-by: Anton Reinhard <anton.reinhard@proton.me>
Co-committed-by: Anton Reinhard <anton.reinhard@proton.me>
This commit is contained in:
2023-11-22 13:51:54 +01:00
committed by Anton Reinhard
parent 16274919e4
commit b7560685d4
53 changed files with 639 additions and 331 deletions

View File

@ -61,7 +61,7 @@ function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
return false
end
if length(n2.parents) != 1 || length(n2.children) != 1 || length(n1.parents) != 1
if length(parents(n2)) != 1 || length(children(n2)) != 1 || length(parents(n1)) != 1
return false
end
@ -74,12 +74,15 @@ end
Return whether the given two nodes can be reduced. See [`NodeReduction`](@ref) for the requirements.
"""
function can_reduce(n1::Node, n2::Node)
if (n1.task != n2.task)
return false
end
return false
end
n1_length = length(n1.children)
n2_length = length(n2.children)
function can_reduce(
n1::NodeType,
n2::NodeType,
) where {TaskType <: AbstractTask, NodeType <: Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}}}
n1_length = length(children(n1))
n2_length = length(children(n2))
if (n1_length != n2_length)
return false
@ -88,19 +91,19 @@ function can_reduce(n1::Node, n2::Node)
# this seems to be the most common case so do this first
# doing it manually is a lot faster than using the sets for a general solution
if (n1_length == 2)
if (n1.children[1] != n2.children[1])
if (n1.children[1] != n2.children[2])
if (children(n1)[1] != children(n2)[1])
if (children(n1)[1] != children(n2)[2])
return false
end
# 1_1 == 2_2
if (n1.children[2] != n2.children[1])
if (children(n1)[2] != children(n2)[1])
return false
end
return true
end
# 1_1 == 2_1
if (n1.children[2] != n2.children[2])
if (children(n1)[2] != children(n2)[2])
return false
end
return true
@ -108,11 +111,11 @@ function can_reduce(n1::Node, n2::Node)
# this is simple
if (n1_length == 1)
return n1.children[1] == n2.children[1]
return children(n1)[1] == children(n2)[1]
end
# this takes a long time
return Set(n1.children) == Set(n2.children)
return Set(children(n1)) == Set(children(n2))
end
"""
@ -138,7 +141,14 @@ end
Equality comparison between two node fusions. Two node fusions are considered equal if they have the same inputs.
"""
function ==(op1::NodeFusion, op2::NodeFusion)
function ==(
op1::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2},
op2::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2},
) where {
ComputeTaskType1 <: AbstractComputeTask,
DataTaskType <: AbstractDataTask,
ComputeTaskType2 <: AbstractComputeTask,
}
# there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same
return op1.input[2] == op2.input[2]
end