Add split and fuse optimizers similar to reduce optimizer

This commit is contained in:
Anton Reinhard 2024-03-05 13:34:39 +01:00
parent 71219f101e
commit b7f8e4a6b3
7 changed files with 83 additions and 5 deletions

View File

@ -89,7 +89,7 @@ function gpu_worker(compute_func, inputs, chunk_size)
work_start = progress
progress = progress + chunk_size
work_end = min(progress, nInputs)
gpu_chunks = cpu_chunks + 1
gpu_chunks = gpu_chunks + 1
end
end
if quit

View File

@ -86,7 +86,8 @@ export cost_type, graph_cost, operation_effect
export GlobalMetricEstimator, CDCost
# optimization
export AbstractOptimizer, GreedyOptimizer, ReductionOptimizer, RandomWalkOptimizer
export AbstractOptimizer, GreedyOptimizer, RandomWalkOptimizer
export ReductionOptimizer, SplitOptimizer, FusionOptimizer
export optimize_step!, optimize!
export fixpoint_reached, optimize_to_fixpoint!
@ -166,6 +167,8 @@ include("optimization/interface.jl")
include("optimization/greedy.jl")
include("optimization/random_walk.jl")
include("optimization/reduce.jl")
include("optimization/fuse.jl")
include("optimization/split.jl")
include("models/interface.jl")
include("models/print.jl")

View File

@ -3,7 +3,7 @@ using UUIDs
using Base.Threads
# TODO: reliably find out how many threads we're running with (nthreads() returns 1 when precompiling :/)
rng = [Random.MersenneTwister(0) for _ in 1:64]
rng = [Random.MersenneTwister(0) for _ in 1:128]
"""
Node

36
src/optimization/fuse.jl Normal file
View File

@ -0,0 +1,36 @@
"""
FusionOptimizer
An optimizer that simply applies an available [`NodeFusion`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeFusion`](@ref)s in the graph.
See also: [`SplitOptimizer`](@ref), [`ReductionOptimizer`](@ref)
"""
struct FusionOptimizer <: AbstractOptimizer end
function optimize_step!(optimizer::FusionOptimizer, graph::DAG)
# generate all options
operations = get_operations(graph)
if fixpoint_reached(optimizer, graph)
return false
end
push_operation!(graph, first(operations.nodeFusions))
return true
end
function fixpoint_reached(optimizer::FusionOptimizer, graph::DAG)
operations = get_operations(graph)
return isempty(operations.nodeFusions)
end
function optimize_to_fixpoint!(optimizer::FusionOptimizer, graph::DAG)
while !fixpoint_reached(optimizer, graph)
optimize_step!(optimizer, graph)
end
return nothing
end
function String(::FusionOptimizer)
return "fusion_optimizer"
end

View File

@ -2,6 +2,8 @@
ReductionOptimizer
An optimizer that simply applies an available [`NodeReduction`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeReduction`](@ref)s in the graph.
See also: [`FusionOptimizer`](@ref), [`SplitOptimizer`](@ref)
"""
struct ReductionOptimizer <: AbstractOptimizer end

36
src/optimization/split.jl Normal file
View File

@ -0,0 +1,36 @@
"""
SplitOptimizer
An optimizer that simply applies an available [`NodeSplit`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeSplit`](@ref)s in the graph.
See also: [`FusionOptimizer`](@ref), [`ReductionOptimizer`](@ref)
"""
struct SplitOptimizer <: AbstractOptimizer end
function optimize_step!(optimizer::SplitOptimizer, graph::DAG)
# generate all options
operations = get_operations(graph)
if fixpoint_reached(optimizer, graph)
return false
end
push_operation!(graph, first(operations.nodeSplits))
return true
end
function fixpoint_reached(optimizer::SplitOptimizer, graph::DAG)
operations = get_operations(graph)
return isempty(operations.nodeSplits)
end
function optimize_to_fixpoint!(optimizer::SplitOptimizer, graph::DAG)
while !fixpoint_reached(optimizer, graph)
optimize_step!(optimizer, graph)
end
return nothing
end
function String(::SplitOptimizer)
return "split_optimizer"
end

View File

@ -6,7 +6,8 @@ RNG = Random.default_rng()
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"), ABCModel())
# create the optimizers
FIXPOINT_OPTIMIZERS = [GreedyOptimizer(GlobalMetricEstimator()), ReductionOptimizer()]
FIXPOINT_OPTIMIZERS =
[GreedyOptimizer(GlobalMetricEstimator()), ReductionOptimizer(), SplitOptimizer(), FusionOptimizer()]
NO_FIXPOINT_OPTIMIZERS = [RandomWalkOptimizer(RNG)]
@testset "Optimizer $optimizer" for optimizer in vcat(NO_FIXPOINT_OPTIMIZERS, FIXPOINT_OPTIMIZERS)
@ -16,7 +17,7 @@ NO_FIXPOINT_OPTIMIZERS = [RandomWalkOptimizer(RNG)]
@test !fixpoint_reached(optimizer, graph)
@test operation_stack_length(graph) == 1
@test optimize!(optimizer, graph, 10)
@test optimize!(optimizer, graph, 2)
@test !fixpoint_reached(optimizer, graph)