experiments (#1)
Co-authored-by: Anton Reinhard <anton.reinhard@proton.me> Reviewed-on: #1
This commit is contained in:
@@ -78,7 +78,7 @@ export gen_graph
|
||||
export execute
|
||||
export parse_dag, parse_process
|
||||
export gen_process_input
|
||||
export get_compute_function
|
||||
export get_compute_function, get_cuda_kernel
|
||||
export gen_tape, execute_tape
|
||||
|
||||
# estimator
|
||||
@@ -86,7 +86,8 @@ export cost_type, graph_cost, operation_effect
|
||||
export GlobalMetricEstimator, CDCost
|
||||
|
||||
# optimization
|
||||
export AbstractOptimizer, GreedyOptimizer, ReductionOptimizer, RandomWalkOptimizer
|
||||
export AbstractOptimizer, GreedyOptimizer, RandomWalkOptimizer
|
||||
export ReductionOptimizer, SplitOptimizer, FusionOptimizer
|
||||
export optimize_step!, optimize!
|
||||
export fixpoint_reached, optimize_to_fixpoint!
|
||||
|
||||
@@ -166,6 +167,8 @@ include("optimization/interface.jl")
|
||||
include("optimization/greedy.jl")
|
||||
include("optimization/random_walk.jl")
|
||||
include("optimization/reduce.jl")
|
||||
include("optimization/fuse.jl")
|
||||
include("optimization/split.jl")
|
||||
|
||||
include("models/interface.jl")
|
||||
include("models/print.jl")
|
||||
|
@@ -21,6 +21,38 @@ function get_compute_function(graph::DAG, process::AbstractProcessDescription, m
|
||||
return func
|
||||
end
|
||||
|
||||
"""
|
||||
get_cuda_kernel(graph::DAG, process::AbstractProcessDescription, machine::Machine)
|
||||
|
||||
Return a function of signature `compute_<id>(input::CuVector, output::CuVector, n::Int64)`, which will return the result of the DAG computation of the input on the given output variable.
|
||||
"""
|
||||
function get_cuda_kernel(graph::DAG, process::AbstractProcessDescription, machine::Machine)
|
||||
tape = gen_tape(graph, process, machine)
|
||||
|
||||
initCaches = Expr(:block, tape.initCachesCode...)
|
||||
assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...)
|
||||
code = Expr(:block, expr_from_fc.(tape.computeCode)...)
|
||||
|
||||
functionId = to_var_name(UUIDs.uuid1(rng[1]))
|
||||
resSym = eval(gen_access_expr(entry_device(tape.machine), tape.outputSymbol))
|
||||
expr = Meta.parse("function compute_$(functionId)(input_vector, output_vector, n::Int64)
|
||||
id = (blockIdx().x - 1) * blockDim().x + threadIdx().x
|
||||
if (id > n)
|
||||
return
|
||||
end
|
||||
@inline data_input = input_vector[id]
|
||||
$(initCaches)
|
||||
$(assignInputs)
|
||||
$code
|
||||
@inline output_vector[id] = $resSym
|
||||
return nothing
|
||||
end")
|
||||
|
||||
func = eval(expr)
|
||||
|
||||
return func
|
||||
end
|
||||
|
||||
"""
|
||||
execute(graph::DAG, process::AbstractProcessDescription, machine::Machine, input::AbstractProcessInput)
|
||||
|
||||
|
@@ -1,4 +1,3 @@
|
||||
|
||||
"""
|
||||
CDCost
|
||||
|
||||
@@ -34,7 +33,7 @@ function isless(cost1::CDCost, cost2::CDCost)::Bool
|
||||
end
|
||||
|
||||
function zero(type::Type{CDCost})
|
||||
return (data = 0.0, computeEffort = 00.0, computeIntensity = 0.0)::CDCost
|
||||
return (data = 0.0, computeEffort = 0.0, computeIntensity = 0.0)::CDCost
|
||||
end
|
||||
|
||||
function typemax(type::Type{CDCost})
|
||||
|
@@ -7,7 +7,8 @@ function get_properties(graph::DAG)
|
||||
# make sure the graph is fully generated
|
||||
apply_all!(graph)
|
||||
|
||||
if (graph.properties.computeEffort == 0.0)
|
||||
# TODO: tests stop working without the if condition, which means there is probably a bug in the lazy evaluation and in the tests
|
||||
if (graph.properties.computeEffort <= 0.0)
|
||||
graph.properties = GraphProperties(graph)
|
||||
end
|
||||
|
||||
|
@@ -84,9 +84,17 @@ Compute a sum over the vector. Use an algorithm that accounts for accumulated er
|
||||
Linearly many FLOP with growing data.
|
||||
"""
|
||||
function compute(::ComputeTaskABC_Sum, data...)::Float64
|
||||
return sum(data)
|
||||
s = 0.0im
|
||||
for d in data
|
||||
s += d
|
||||
end
|
||||
return s
|
||||
end
|
||||
|
||||
function compute(::ComputeTaskABC_Sum, data::AbstractArray)::Float64
|
||||
return sum(data)
|
||||
s = 0.0im
|
||||
for d in data
|
||||
s += d
|
||||
end
|
||||
return s
|
||||
end
|
||||
|
@@ -72,9 +72,9 @@ function compute(
|
||||
|
||||
# inner edge is just a "scalar", data1 and data2 are bispinor/adjointbispinnor, need to keep correct order
|
||||
if typeof(data1.v) <: BiSpinor
|
||||
return data2.v * inner * data1.v
|
||||
return (data2.v)::AdjointBiSpinor * inner * (data1.v)::BiSpinor
|
||||
else
|
||||
return data1.v * inner * data2.v
|
||||
return (data1.v)::AdjointBiSpinor * inner * (data2.v)::BiSpinor
|
||||
end
|
||||
end
|
||||
|
||||
@@ -115,10 +115,18 @@ Linearly many FLOP with growing data.
|
||||
"""
|
||||
function compute(::ComputeTaskQED_Sum, data...)::ComplexF64
|
||||
# TODO: want to use sum_kbn here but it doesn't seem to support ComplexF64, do it element-wise?
|
||||
return sum(data)
|
||||
s = 0.0im
|
||||
for d in data
|
||||
s += d
|
||||
end
|
||||
return s
|
||||
end
|
||||
|
||||
function compute(::ComputeTaskQED_Sum, data::AbstractArray)::ComplexF64
|
||||
# TODO: want to use sum_kbn here but it doesn't seem to support ComplexF64, do it element-wise?
|
||||
return sum(data)
|
||||
s = 0.0im
|
||||
for d in data
|
||||
s += d
|
||||
end
|
||||
return s
|
||||
end
|
||||
|
@@ -114,12 +114,8 @@ function gen_graph(process_description::QEDProcessDescription)
|
||||
dataOutNodes[String(particle)] = data_out
|
||||
end
|
||||
|
||||
#dataOutBackup = copy(dataOutNodes)
|
||||
|
||||
# TODO: this should be parallelizable somewhat easily
|
||||
for diagram in diagrams
|
||||
# the intermediate (virtual) particles change across
|
||||
#dataOutNodes = copy(dataOutBackup)
|
||||
|
||||
tie = diagram.tie[]
|
||||
|
||||
# handle the vertices
|
||||
|
@@ -1,3 +1,4 @@
|
||||
using Combinatorics
|
||||
|
||||
import Base.copy
|
||||
import Base.hash
|
||||
@@ -265,11 +266,12 @@ function add_vertex!(fd::FeynmanDiagram, vertex::FeynmanVertex)
|
||||
end
|
||||
|
||||
if !can_apply_vertex(get_particles(fd), vertex)
|
||||
#@assert false "Can't add vertex $vertex to diagram"
|
||||
@assert false "Can't add vertex $vertex to diagram $(get_particles(fd))"
|
||||
end
|
||||
|
||||
push!(fd.vertices, Set{FeynmanVertex}())
|
||||
push!(fd.vertices[end], vertex)
|
||||
|
||||
fd.type_ids[vertex.out.particle] += 1
|
||||
|
||||
return nothing
|
||||
@@ -437,12 +439,196 @@ function remove_duplicates(compare_set::Set{FeynmanDiagram})
|
||||
return result
|
||||
end
|
||||
|
||||
"""
|
||||
is_compton(fd::FeynmanDiagram)
|
||||
|
||||
Returns true iff the given feynman diagram is an (empty) diagram of a compton process like ke->k^ne
|
||||
"""
|
||||
function is_compton(fd::FeynmanDiagram)
|
||||
return fd.type_ids[FermionStateful{Incoming, SpinUp}] == 1 &&
|
||||
fd.type_ids[FermionStateful{Outgoing, SpinUp}] == 1 &&
|
||||
fd.type_ids[AntiFermionStateful{Incoming, SpinUp}] == 0 &&
|
||||
fd.type_ids[AntiFermionStateful{Outgoing, SpinUp}] == 0 &&
|
||||
fd.type_ids[PhotonStateful{Incoming, PolX}] >= 1 &&
|
||||
fd.type_ids[PhotonStateful{Outgoing, PolX}] >= 1
|
||||
end
|
||||
|
||||
"""
|
||||
gen_compton_diagram_from_order(order::Vector{Int}, inFerm, outFerm, n::Int, m::Int)
|
||||
|
||||
Helper function for [`gen_compton_diagrams`](@Ref). Generates a single diagram for the given order and n input and m output photons.
|
||||
"""
|
||||
function gen_compton_diagram_from_order(order::Vector{Int}, inFerm, outFerm, n::Int, m::Int)
|
||||
photons = vcat(
|
||||
[FeynmanParticle(PhotonStateful{Incoming, PolX}, i) for i in 1:n],
|
||||
[FeynmanParticle(PhotonStateful{Outgoing, PolX}, i) for i in 1:m],
|
||||
)
|
||||
|
||||
new_diagram = FeynmanDiagram(
|
||||
[],
|
||||
missing,
|
||||
[inFerm, outFerm, photons...],
|
||||
Dict{Type, Int64}(
|
||||
FermionStateful{Incoming, SpinUp} => 1,
|
||||
FermionStateful{Outgoing, SpinUp} => 1,
|
||||
PhotonStateful{Incoming, PolX} => n,
|
||||
PhotonStateful{Outgoing, PolX} => m,
|
||||
),
|
||||
)
|
||||
|
||||
left_index = 1
|
||||
right_index = length(order)
|
||||
|
||||
iterations = 1
|
||||
|
||||
while left_index <= right_index
|
||||
# left side
|
||||
v_left = FeynmanVertex(
|
||||
FeynmanParticle(FermionStateful{Incoming, SpinUp}, iterations),
|
||||
photons[order[left_index]],
|
||||
FeynmanParticle(FermionStateful{Incoming, SpinUp}, iterations + 1),
|
||||
)
|
||||
left_index += 1
|
||||
add_vertex!(new_diagram, v_left)
|
||||
|
||||
if (left_index > right_index)
|
||||
break
|
||||
end
|
||||
|
||||
# right side
|
||||
v_right = FeynmanVertex(
|
||||
FeynmanParticle(FermionStateful{Outgoing, SpinUp}, iterations),
|
||||
photons[order[right_index]],
|
||||
FeynmanParticle(FermionStateful{Outgoing, SpinUp}, iterations + 1),
|
||||
)
|
||||
right_index -= 1
|
||||
add_vertex!(new_diagram, v_right)
|
||||
|
||||
iterations += 1
|
||||
end
|
||||
|
||||
@assert possible_tie(new_diagram) !== missing
|
||||
add_tie!(new_diagram, possible_tie(new_diagram))
|
||||
return new_diagram
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, outFerm, n::Int, m::Int)
|
||||
|
||||
Helper function for [`gen_compton_diagrams`](@Ref). Generates a single diagram for the given order and n input and m output photons.
|
||||
"""
|
||||
function gen_compton_diagram_from_order_one_side(order::Vector{Int}, inFerm, outFerm, n::Int, m::Int)
|
||||
photons = vcat(
|
||||
[FeynmanParticle(PhotonStateful{Incoming, PolX}, i) for i in 1:n],
|
||||
[FeynmanParticle(PhotonStateful{Outgoing, PolX}, i) for i in 1:m],
|
||||
)
|
||||
|
||||
new_diagram = FeynmanDiagram(
|
||||
[],
|
||||
missing,
|
||||
[inFerm, outFerm, photons...],
|
||||
Dict{Type, Int64}(
|
||||
FermionStateful{Incoming, SpinUp} => 1,
|
||||
FermionStateful{Outgoing, SpinUp} => 1,
|
||||
PhotonStateful{Incoming, PolX} => n,
|
||||
PhotonStateful{Outgoing, PolX} => m,
|
||||
),
|
||||
)
|
||||
|
||||
left_index = 1
|
||||
right_index = length(order)
|
||||
|
||||
iterations = 1
|
||||
|
||||
while left_index <= right_index
|
||||
# left side
|
||||
v_left = FeynmanVertex(
|
||||
FeynmanParticle(FermionStateful{Incoming, SpinUp}, iterations),
|
||||
photons[order[left_index]],
|
||||
FeynmanParticle(FermionStateful{Incoming, SpinUp}, iterations + 1),
|
||||
)
|
||||
left_index += 1
|
||||
add_vertex!(new_diagram, v_left)
|
||||
|
||||
if (left_index > right_index)
|
||||
break
|
||||
end
|
||||
|
||||
# only once on the right side
|
||||
if (iterations == 1)
|
||||
# right side
|
||||
v_right = FeynmanVertex(
|
||||
FeynmanParticle(FermionStateful{Outgoing, SpinUp}, iterations),
|
||||
photons[order[right_index]],
|
||||
FeynmanParticle(FermionStateful{Outgoing, SpinUp}, iterations + 1),
|
||||
)
|
||||
right_index -= 1
|
||||
add_vertex!(new_diagram, v_right)
|
||||
end
|
||||
|
||||
iterations += 1
|
||||
end
|
||||
|
||||
ps = get_particles(new_diagram)
|
||||
@assert length(ps) == 2
|
||||
add_tie!(new_diagram, FeynmanTie(ps[1], ps[2]))
|
||||
return new_diagram
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
gen_compton_diagrams(n::Int, m::Int)
|
||||
|
||||
Special case diagram generation for Compton processes, i.e., processes of the form k^ne->k^me
|
||||
"""
|
||||
function gen_compton_diagrams(n::Int, m::Int)
|
||||
inFerm = FeynmanParticle(FermionStateful{Incoming, SpinUp}, 1)
|
||||
outFerm = FeynmanParticle(FermionStateful{Outgoing, SpinUp}, 1)
|
||||
|
||||
perms = [permutations([i for i in 1:(n + m)])...]
|
||||
|
||||
diagrams = [Vector{FeynmanDiagram}() for i in 1:nthreads()]
|
||||
@threads for order in perms
|
||||
push!(diagrams[threadid()], gen_compton_diagram_from_order(order, inFerm, outFerm, n, m))
|
||||
end
|
||||
|
||||
return vcat(diagrams...)
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
gen_compton_diagrams_one_side(n::Int, m::Int)
|
||||
|
||||
Special case diagram generation for Compton processes, i.e., processes of the form k^ne->k^me, but generating from one end, yielding larger diagrams
|
||||
"""
|
||||
function gen_compton_diagrams_one_side(n::Int, m::Int)
|
||||
inFerm = FeynmanParticle(FermionStateful{Incoming, SpinUp}, 1)
|
||||
outFerm = FeynmanParticle(FermionStateful{Outgoing, SpinUp}, 1)
|
||||
|
||||
perms = [permutations([i for i in 1:(n + m)])...]
|
||||
|
||||
diagrams = [Vector{FeynmanDiagram}() for i in 1:nthreads()]
|
||||
@threads for order in perms
|
||||
push!(diagrams[threadid()], gen_compton_diagram_from_order_one_side(order, inFerm, outFerm, n, m))
|
||||
end
|
||||
|
||||
return vcat(diagrams...)
|
||||
end
|
||||
|
||||
"""
|
||||
gen_diagrams(fd::FeynmanDiagram)
|
||||
|
||||
From a given feynman diagram in its initial state, e.g. when created through the [`FeynmanDiagram(pd::ProcessDescription)`](@ref) constructor, generate and return all possible [`FeynmanDiagram`](@ref)s that describe that process.
|
||||
"""
|
||||
function gen_diagrams(fd::FeynmanDiagram)
|
||||
if is_compton(fd)
|
||||
return gen_compton_diagrams_one_side(
|
||||
fd.type_ids[PhotonStateful{Incoming, PolX}],
|
||||
fd.type_ids[PhotonStateful{Outgoing, PolX}],
|
||||
)
|
||||
end
|
||||
|
||||
working = Set{FeynmanDiagram}()
|
||||
results = Set{FeynmanDiagram}()
|
||||
|
||||
|
@@ -313,7 +313,7 @@ Return the factor of a vertex in a QED feynman diagram.
|
||||
return -1im * e * gamma()
|
||||
end
|
||||
|
||||
@inline function QED_inner_edge(p::QEDParticle)
|
||||
@inline function QED_inner_edge(p::QEDParticle)::DiracMatrix
|
||||
return propagator(particle(p), p.momentum)
|
||||
end
|
||||
|
||||
|
@@ -42,10 +42,10 @@ Create a short string suitable as a filename or similar, describing the given pr
|
||||
julia> using MetagraphOptimization
|
||||
|
||||
julia> String(parse_process("ke->ke", QEDModel()))
|
||||
qed_ke-ke
|
||||
"qed_ke-ke"
|
||||
|
||||
julia> print(parse_process("kk->ep", QEDModel()))
|
||||
qed_kk-ep
|
||||
QED Process: 'kk->ep'
|
||||
```
|
||||
"""
|
||||
function String(process::QEDProcessDescription)
|
||||
|
@@ -1,32 +1,32 @@
|
||||
# TODO use correct numbers
|
||||
# compute effort numbers were measured on a home pc system using likwid
|
||||
|
||||
"""
|
||||
compute_effort(t::ComputeTaskQED_S1)
|
||||
|
||||
Return the compute effort of an S1 task.
|
||||
"""
|
||||
compute_effort(t::ComputeTaskQED_S1)::Float64 = 11.0
|
||||
compute_effort(t::ComputeTaskQED_S1)::Float64 = 475.0
|
||||
|
||||
"""
|
||||
compute_effort(t::ComputeTaskQED_S2)
|
||||
|
||||
Return the compute effort of an S2 task.
|
||||
"""
|
||||
compute_effort(t::ComputeTaskQED_S2)::Float64 = 12.0
|
||||
compute_effort(t::ComputeTaskQED_S2)::Float64 = 505.0
|
||||
|
||||
"""
|
||||
compute_effort(t::ComputeTaskQED_U)
|
||||
|
||||
Return the compute effort of a U task.
|
||||
"""
|
||||
compute_effort(t::ComputeTaskQED_U)::Float64 = 1.0
|
||||
compute_effort(t::ComputeTaskQED_U)::Float64 = (291.0 + 467.0 + 16.0 + 17.0) / 4.0 # The exact FLOPS count depends heavily on the type of particle, take an average value here
|
||||
|
||||
"""
|
||||
compute_effort(t::ComputeTaskQED_V)
|
||||
|
||||
Return the compute effort of a V task.
|
||||
"""
|
||||
compute_effort(t::ComputeTaskQED_V)::Float64 = 6.0
|
||||
compute_effort(t::ComputeTaskQED_V)::Float64 = (1150.0 + 764.0 + 828.0) / 3.0
|
||||
|
||||
"""
|
||||
compute_effort(t::ComputeTaskQED_P)
|
||||
|
@@ -3,7 +3,7 @@ using UUIDs
|
||||
using Base.Threads
|
||||
|
||||
# TODO: reliably find out how many threads we're running with (nthreads() returns 1 when precompiling :/)
|
||||
rng = [Random.MersenneTwister(0) for _ in 1:64]
|
||||
rng = [Random.MersenneTwister(0) for _ in 1:128]
|
||||
|
||||
"""
|
||||
Node
|
||||
|
@@ -197,8 +197,7 @@ function generate_operations(graph::DAG)
|
||||
|
||||
# launch thread for node reduction insertion
|
||||
# remove duplicates
|
||||
nr_task = @task nr_insertion!(graph.possibleOperations, generatedReductions)
|
||||
schedule(nr_task)
|
||||
nr_task = @spawn nr_insertion!(graph.possibleOperations, generatedReductions)
|
||||
|
||||
# --- find possible node fusions ---
|
||||
@threads for node in nodeArray
|
||||
@@ -223,8 +222,7 @@ function generate_operations(graph::DAG)
|
||||
end
|
||||
|
||||
# launch thread for node fusion insertion
|
||||
nf_task = @task nf_insertion!(graph, graph.possibleOperations, generatedFusions)
|
||||
schedule(nf_task)
|
||||
nf_task = @spawn nf_insertion!(graph, graph.possibleOperations, generatedFusions)
|
||||
|
||||
# find possible node splits
|
||||
@threads for node in nodeArray
|
||||
@@ -234,8 +232,7 @@ function generate_operations(graph::DAG)
|
||||
end
|
||||
|
||||
# launch thread for node split insertion
|
||||
ns_task = @task ns_insertion!(graph.possibleOperations, generatedSplits)
|
||||
schedule(ns_task)
|
||||
ns_task = @spawn ns_insertion!(graph.possibleOperations, generatedSplits)
|
||||
|
||||
empty!(graph.dirtyNodes)
|
||||
|
||||
|
36
src/optimization/fuse.jl
Normal file
36
src/optimization/fuse.jl
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
FusionOptimizer
|
||||
|
||||
An optimizer that simply applies an available [`NodeFusion`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeFusion`](@ref)s in the graph.
|
||||
|
||||
See also: [`SplitOptimizer`](@ref), [`ReductionOptimizer`](@ref)
|
||||
"""
|
||||
struct FusionOptimizer <: AbstractOptimizer end
|
||||
|
||||
function optimize_step!(optimizer::FusionOptimizer, graph::DAG)
|
||||
# generate all options
|
||||
operations = get_operations(graph)
|
||||
if fixpoint_reached(optimizer, graph)
|
||||
return false
|
||||
end
|
||||
|
||||
push_operation!(graph, first(operations.nodeFusions))
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function fixpoint_reached(optimizer::FusionOptimizer, graph::DAG)
|
||||
operations = get_operations(graph)
|
||||
return isempty(operations.nodeFusions)
|
||||
end
|
||||
|
||||
function optimize_to_fixpoint!(optimizer::FusionOptimizer, graph::DAG)
|
||||
while !fixpoint_reached(optimizer, graph)
|
||||
optimize_step!(optimizer, graph)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function String(::FusionOptimizer)
|
||||
return "fusion_optimizer"
|
||||
end
|
@@ -21,7 +21,7 @@ function optimize_step!(optimizer::GreedyOptimizer, graph::DAG)
|
||||
lowestCost = reduce(
|
||||
(acc, op) -> begin
|
||||
op_cost = operation_effect(optimizer.estimator, graph, op)
|
||||
if op_cost < acc
|
||||
if isless(op_cost, acc)
|
||||
result = op
|
||||
return op_cost
|
||||
end
|
||||
@@ -50,7 +50,7 @@ function fixpoint_reached(optimizer::GreedyOptimizer, graph::DAG)
|
||||
lowestCost = reduce(
|
||||
(acc, op) -> begin
|
||||
op_cost = operation_effect(optimizer.estimator, graph, op)
|
||||
if op_cost < acc
|
||||
if isless(op_cost, acc)
|
||||
return op_cost
|
||||
end
|
||||
return acc
|
||||
|
@@ -2,6 +2,8 @@
|
||||
ReductionOptimizer
|
||||
|
||||
An optimizer that simply applies an available [`NodeReduction`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeReduction`](@ref)s in the graph.
|
||||
|
||||
See also: [`FusionOptimizer`](@ref), [`SplitOptimizer`](@ref)
|
||||
"""
|
||||
struct ReductionOptimizer <: AbstractOptimizer end
|
||||
|
||||
|
36
src/optimization/split.jl
Normal file
36
src/optimization/split.jl
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
SplitOptimizer
|
||||
|
||||
An optimizer that simply applies an available [`NodeSplit`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeSplit`](@ref)s in the graph.
|
||||
|
||||
See also: [`FusionOptimizer`](@ref), [`ReductionOptimizer`](@ref)
|
||||
"""
|
||||
struct SplitOptimizer <: AbstractOptimizer end
|
||||
|
||||
function optimize_step!(optimizer::SplitOptimizer, graph::DAG)
|
||||
# generate all options
|
||||
operations = get_operations(graph)
|
||||
if fixpoint_reached(optimizer, graph)
|
||||
return false
|
||||
end
|
||||
|
||||
push_operation!(graph, first(operations.nodeSplits))
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function fixpoint_reached(optimizer::SplitOptimizer, graph::DAG)
|
||||
operations = get_operations(graph)
|
||||
return isempty(operations.nodeSplits)
|
||||
end
|
||||
|
||||
function optimize_to_fixpoint!(optimizer::SplitOptimizer, graph::DAG)
|
||||
while !fixpoint_reached(optimizer, graph)
|
||||
optimize_step!(optimizer, graph)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function String(::SplitOptimizer)
|
||||
return "split_optimizer"
|
||||
end
|
@@ -18,7 +18,7 @@ function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine)
|
||||
sizehint!(schedule, length(graph.nodes))
|
||||
|
||||
# keep an accumulated cost of things scheduled to this device so far
|
||||
deviceAccCost = PriorityQueue{AbstractDevice, Int}()
|
||||
deviceAccCost = PriorityQueue{AbstractDevice, Float64}()
|
||||
for device in machine.devices
|
||||
enqueue!(deviceAccCost, device => 0)
|
||||
end
|
||||
|
@@ -39,7 +39,7 @@ function get_function_call(node::ComputeTaskNode)
|
||||
@assert length(children(node)) <= children(task(node)) "Node $(node) has too many children for its task: node has $(length(node.children)) versus task has $(children(task(node)))\nNode's children: $(getfield.(node.children, :children))"
|
||||
@assert !ismissing(node.device) "Trying to get expression for an unscheduled ComputeTaskNode\nNode: $(node)"
|
||||
|
||||
if (length(node.children) <= 50)
|
||||
if (length(node.children) <= 800)
|
||||
#only use an SVector when there are few children
|
||||
return get_function_call(
|
||||
node.task,
|
||||
|
Reference in New Issue
Block a user