Add split and fuse optimizers similar to reduce optimizer
This commit is contained in:
parent
71219f101e
commit
b7f8e4a6b3
@ -89,7 +89,7 @@ function gpu_worker(compute_func, inputs, chunk_size)
|
||||
work_start = progress
|
||||
progress = progress + chunk_size
|
||||
work_end = min(progress, nInputs)
|
||||
gpu_chunks = cpu_chunks + 1
|
||||
gpu_chunks = gpu_chunks + 1
|
||||
end
|
||||
end
|
||||
if quit
|
||||
|
@ -86,7 +86,8 @@ export cost_type, graph_cost, operation_effect
|
||||
export GlobalMetricEstimator, CDCost
|
||||
|
||||
# optimization
|
||||
export AbstractOptimizer, GreedyOptimizer, ReductionOptimizer, RandomWalkOptimizer
|
||||
export AbstractOptimizer, GreedyOptimizer, RandomWalkOptimizer
|
||||
export ReductionOptimizer, SplitOptimizer, FusionOptimizer
|
||||
export optimize_step!, optimize!
|
||||
export fixpoint_reached, optimize_to_fixpoint!
|
||||
|
||||
@ -166,6 +167,8 @@ include("optimization/interface.jl")
|
||||
include("optimization/greedy.jl")
|
||||
include("optimization/random_walk.jl")
|
||||
include("optimization/reduce.jl")
|
||||
include("optimization/fuse.jl")
|
||||
include("optimization/split.jl")
|
||||
|
||||
include("models/interface.jl")
|
||||
include("models/print.jl")
|
||||
|
@ -3,7 +3,7 @@ using UUIDs
|
||||
using Base.Threads
|
||||
|
||||
# TODO: reliably find out how many threads we're running with (nthreads() returns 1 when precompiling :/)
|
||||
rng = [Random.MersenneTwister(0) for _ in 1:64]
|
||||
rng = [Random.MersenneTwister(0) for _ in 1:128]
|
||||
|
||||
"""
|
||||
Node
|
||||
|
36
src/optimization/fuse.jl
Normal file
36
src/optimization/fuse.jl
Normal file
@ -0,0 +1,36 @@
|
||||
"""
|
||||
FusionOptimizer
|
||||
|
||||
An optimizer that simply applies an available [`NodeFusion`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeFusion`](@ref)s in the graph.
|
||||
|
||||
See also: [`SplitOptimizer`](@ref), [`ReductionOptimizer`](@ref)
|
||||
"""
|
||||
struct FusionOptimizer <: AbstractOptimizer end
|
||||
|
||||
function optimize_step!(optimizer::FusionOptimizer, graph::DAG)
|
||||
# generate all options
|
||||
operations = get_operations(graph)
|
||||
if fixpoint_reached(optimizer, graph)
|
||||
return false
|
||||
end
|
||||
|
||||
push_operation!(graph, first(operations.nodeFusions))
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function fixpoint_reached(optimizer::FusionOptimizer, graph::DAG)
|
||||
operations = get_operations(graph)
|
||||
return isempty(operations.nodeFusions)
|
||||
end
|
||||
|
||||
function optimize_to_fixpoint!(optimizer::FusionOptimizer, graph::DAG)
|
||||
while !fixpoint_reached(optimizer, graph)
|
||||
optimize_step!(optimizer, graph)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function String(::FusionOptimizer)
|
||||
return "fusion_optimizer"
|
||||
end
|
@ -2,6 +2,8 @@
|
||||
ReductionOptimizer
|
||||
|
||||
An optimizer that simply applies an available [`NodeReduction`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeReduction`](@ref)s in the graph.
|
||||
|
||||
See also: [`FusionOptimizer`](@ref), [`SplitOptimizer`](@ref)
|
||||
"""
|
||||
struct ReductionOptimizer <: AbstractOptimizer end
|
||||
|
||||
|
36
src/optimization/split.jl
Normal file
36
src/optimization/split.jl
Normal file
@ -0,0 +1,36 @@
|
||||
"""
|
||||
SplitOptimizer
|
||||
|
||||
An optimizer that simply applies an available [`NodeSplit`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeSplit`](@ref)s in the graph.
|
||||
|
||||
See also: [`FusionOptimizer`](@ref), [`ReductionOptimizer`](@ref)
|
||||
"""
|
||||
struct SplitOptimizer <: AbstractOptimizer end
|
||||
|
||||
function optimize_step!(optimizer::SplitOptimizer, graph::DAG)
|
||||
# generate all options
|
||||
operations = get_operations(graph)
|
||||
if fixpoint_reached(optimizer, graph)
|
||||
return false
|
||||
end
|
||||
|
||||
push_operation!(graph, first(operations.nodeSplits))
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function fixpoint_reached(optimizer::SplitOptimizer, graph::DAG)
|
||||
operations = get_operations(graph)
|
||||
return isempty(operations.nodeSplits)
|
||||
end
|
||||
|
||||
function optimize_to_fixpoint!(optimizer::SplitOptimizer, graph::DAG)
|
||||
while !fixpoint_reached(optimizer, graph)
|
||||
optimize_step!(optimizer, graph)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function String(::SplitOptimizer)
|
||||
return "split_optimizer"
|
||||
end
|
@ -6,7 +6,8 @@ RNG = Random.default_rng()
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"), ABCModel())
|
||||
|
||||
# create the optimizers
|
||||
FIXPOINT_OPTIMIZERS = [GreedyOptimizer(GlobalMetricEstimator()), ReductionOptimizer()]
|
||||
FIXPOINT_OPTIMIZERS =
|
||||
[GreedyOptimizer(GlobalMetricEstimator()), ReductionOptimizer(), SplitOptimizer(), FusionOptimizer()]
|
||||
NO_FIXPOINT_OPTIMIZERS = [RandomWalkOptimizer(RNG)]
|
||||
|
||||
@testset "Optimizer $optimizer" for optimizer in vcat(NO_FIXPOINT_OPTIMIZERS, FIXPOINT_OPTIMIZERS)
|
||||
@ -16,7 +17,7 @@ NO_FIXPOINT_OPTIMIZERS = [RandomWalkOptimizer(RNG)]
|
||||
@test !fixpoint_reached(optimizer, graph)
|
||||
@test operation_stack_length(graph) == 1
|
||||
|
||||
@test optimize!(optimizer, graph, 10)
|
||||
@test optimize!(optimizer, graph, 2)
|
||||
|
||||
@test !fixpoint_reached(optimizer, graph)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user