Optimizer interface and sample implementation (#19)
Reviewed-on: Rubydragon/MetagraphOptimization.jl#19 Co-authored-by: Anton Reinhard <anton.reinhard@proton.me> Co-committed-by: Anton Reinhard <anton.reinhard@proton.me>
This commit is contained in:
@@ -2,7 +2,72 @@
|
||||
GreedyOptimizer
|
||||
|
||||
An implementation of the greedy optimization algorithm, simply choosing the best next option evaluated with the given estimator.
|
||||
|
||||
The fixpoint is reached when any leftover operation would increase the graph's total cost according to the given estimator.
|
||||
"""
|
||||
struct GreedyOptimizer
|
||||
estimator::AbstractEstimator
|
||||
struct GreedyOptimizer{EstimatorType <: AbstractEstimator} <: AbstractOptimizer
|
||||
estimator::EstimatorType
|
||||
end
|
||||
|
||||
function optimize_step!(optimizer::GreedyOptimizer, graph::DAG)
|
||||
# generate all options
|
||||
operations = get_operations(graph)
|
||||
if isempty(operations)
|
||||
return false
|
||||
end
|
||||
|
||||
result = nothing
|
||||
|
||||
lowestCost = reduce(
|
||||
(acc, op) -> begin
|
||||
op_cost = operation_effect(optimizer.estimator, graph, op)
|
||||
if op_cost < acc
|
||||
result = op
|
||||
return op_cost
|
||||
end
|
||||
return acc
|
||||
end,
|
||||
operations;
|
||||
init = typemax(cost_type(optimizer.estimator)),
|
||||
)
|
||||
|
||||
if lowestCost > zero(cost_type(optimizer.estimator))
|
||||
return false
|
||||
end
|
||||
|
||||
push_operation!(graph, result)
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function fixpoint_reached(optimizer::GreedyOptimizer, graph::DAG)
|
||||
# generate all options
|
||||
operations = get_operations(graph)
|
||||
if isempty(operations)
|
||||
return true
|
||||
end
|
||||
|
||||
lowestCost = reduce(
|
||||
(acc, op) -> begin
|
||||
op_cost = operation_effect(optimizer.estimator, graph, op)
|
||||
if op_cost < acc
|
||||
return op_cost
|
||||
end
|
||||
return acc
|
||||
end,
|
||||
operations;
|
||||
init = typemax(cost_type(optimizer.estimator)),
|
||||
)
|
||||
|
||||
if lowestCost > zero(cost_type(optimizer.estimator))
|
||||
return true
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
function optimize_to_fixpoint!(optimizer::GreedyOptimizer, graph::DAG)
|
||||
while optimize_step!(optimizer, graph)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
60
src/optimization/interface.jl
Normal file
60
src/optimization/interface.jl
Normal file
@@ -0,0 +1,60 @@
|
||||
|
||||
"""
|
||||
AbstractOptimizer
|
||||
|
||||
Abstract base type for optimizer implementations.
|
||||
"""
|
||||
abstract type AbstractOptimizer end
|
||||
|
||||
"""
|
||||
optimize_step!(optimizer::AbstractOptimizer, graph::DAG)
|
||||
|
||||
Interface function that must be implemented by implementations of [`AbstractOptimizer`](@ref). Returns `true` if an operations has been applied, `false` if not, usually when a fixpoint of the algorithm has been reached.
|
||||
|
||||
It should do one smallest logical step on the given [`DAG`](@ref), muting the graph and, if necessary, the optimizer's state.
|
||||
"""
|
||||
function optimize_step! end
|
||||
|
||||
"""
|
||||
optimize!(optimizer::AbstractOptimizer, graph::DAG, n::Int)
|
||||
|
||||
Function calling the given optimizer `n` times, muting the graph. Returns `true` if the requested number of operations has been applied, `false` if not, usually when a fixpoint of the algorithm has been reached.
|
||||
|
||||
If a more efficient method exists, this can be overloaded for a specific optimizer.
|
||||
"""
|
||||
function optimize!(optimizer::AbstractOptimizer, graph::DAG, n::Int)
|
||||
for i in 1:n
|
||||
if !optimize_step!(optimizer, graph)
|
||||
return false
|
||||
end
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
fixpoint_reached(optimizer::AbstractOptimizer, graph::DAG)
|
||||
|
||||
Interface function that can be implemented by optimization algorithms that can reach a fixpoint, returning as a `Bool` whether it has been reached. The default implementation returns `false`.
|
||||
|
||||
See also: [`optimize_to_fixpoint!`](@ref)
|
||||
"""
|
||||
function fixpoint_reached(optimizer::AbstractOptimizer, graph::DAG)
|
||||
return false
|
||||
end
|
||||
|
||||
"""
|
||||
optimize_to_fixpoint!(optimizer::AbstractOptimizer, graph::DAG)
|
||||
|
||||
Interface function that can be implemented by optimization algorithms that can reach a fixpoint. The algorithm will be run until that fixpoint is reached, at which point [`fixpoint_reached`](@ref) should return true.
|
||||
|
||||
A usual implementation might look like this:
|
||||
```julia
|
||||
function optimize_to_fixpoint!(optimizer::MyOptimizer, graph::DAG)
|
||||
while !fixpoint_reached(optimizer, graph)
|
||||
optimize_step!(optimizer, graph)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
```
|
||||
"""
|
||||
function optimize_to_fixpoint! end
|
49
src/optimization/random_walk.jl
Normal file
49
src/optimization/random_walk.jl
Normal file
@@ -0,0 +1,49 @@
|
||||
using Random
|
||||
|
||||
"""
|
||||
RandomWalkOptimizer
|
||||
|
||||
An optimizer that randomly pushes or pops operations. It doesn't optimize in any direction and is useful mainly for testing purposes.
|
||||
|
||||
This algorithm never reaches a fixpoint, so it does not implement [`optimize_to_fixpoint`](@ref).
|
||||
"""
|
||||
struct RandomWalkOptimizer <: AbstractOptimizer
|
||||
rng::AbstractRNG
|
||||
end
|
||||
|
||||
function optimize_step!(optimizer::RandomWalkOptimizer, graph::DAG)
|
||||
operations = get_operations(graph)
|
||||
|
||||
if sum(length(operations)) == 0 && length(graph.appliedOperations) + length(graph.operationsToApply) == 0
|
||||
# in case there are zero operations possible at all on the graph
|
||||
return false
|
||||
end
|
||||
|
||||
r = optimizer.rng
|
||||
# try until something was applied or popped
|
||||
while true
|
||||
# choose push or pop
|
||||
if rand(r, Bool)
|
||||
# push
|
||||
|
||||
# choose one of fuse/split/reduce
|
||||
option = rand(r, 1:3)
|
||||
if option == 1 && !isempty(operations.nodeFusions)
|
||||
push_operation!(graph, rand(r, collect(operations.nodeFusions)))
|
||||
return true
|
||||
elseif option == 2 && !isempty(operations.nodeReductions)
|
||||
push_operation!(graph, rand(r, collect(operations.nodeReductions)))
|
||||
return true
|
||||
elseif option == 3 && !isempty(operations.nodeSplits)
|
||||
push_operation!(graph, rand(r, collect(operations.nodeSplits)))
|
||||
return true
|
||||
end
|
||||
else
|
||||
# pop
|
||||
if (can_pop(graph))
|
||||
pop_operation!(graph)
|
||||
return true
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
30
src/optimization/reduce.jl
Normal file
30
src/optimization/reduce.jl
Normal file
@@ -0,0 +1,30 @@
|
||||
"""
|
||||
ReductionOptimizer
|
||||
|
||||
An optimizer that simply applies an available [`NodeReduction`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeReduction`](@ref)s in the graph.
|
||||
"""
|
||||
struct ReductionOptimizer <: AbstractOptimizer end
|
||||
|
||||
function optimize_step!(optimizer::ReductionOptimizer, graph::DAG)
|
||||
# generate all options
|
||||
operations = get_operations(graph)
|
||||
if fixpoint_reached(optimizer, graph)
|
||||
return false
|
||||
end
|
||||
|
||||
push_operation!(graph, first(operations.nodeReductions))
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function fixpoint_reached(optimizer::ReductionOptimizer, graph::DAG)
|
||||
operations = get_operations(graph)
|
||||
return isempty(operations.nodeReductions)
|
||||
end
|
||||
|
||||
function optimize_to_fixpoint!(optimizer::ReductionOptimizer, graph::DAG)
|
||||
while !fixpoint_reached(optimizer, graph)
|
||||
optimize_step!(optimizer, graph)
|
||||
end
|
||||
return nothing
|
||||
end
|
Reference in New Issue
Block a user