Add iterator for PossibleOperations data structure

This commit is contained in:
Anton Reinhard 2023-11-20 16:56:42 +01:00
parent 992450374c
commit c73053f991
4 changed files with 51 additions and 0 deletions

View File

@ -68,6 +68,10 @@ export get_compute_function
export cost_type, graph_cost, operation_effect
export GlobalMetricEstimator, CDCost
# optimization
export AbstractOptimizer, GreedyOptimizer
export optimize_step!, optimize!
# machine info
export Machine
export get_machine_info
@ -117,6 +121,7 @@ include("node/properties.jl")
include("node/validate.jl")
include("operation/utility.jl")
include("operation/iterate.jl")
include("operation/apply.jl")
include("operation/clean.jl")
include("operation/find.jl")

View File

@ -29,6 +29,10 @@ function -(cost1::CDCost, cost2::CDCost)::CDCost
return (data = d, computeEffort = ce, computeIntensity = ce / d)::CDCost
end
function zero(type::Type{CDCost})
return (data = 0.0, computeEffort = 00.0, computeIntensity = 0.0)::CDCost
end
struct GlobalMetricEstimator <: AbstractEstimator end
function cost_type(estimator::GlobalMetricEstimator)

36
src/operation/iterate.jl Normal file
View File

@ -0,0 +1,36 @@
import Base.iterate
const _POSSIBLE_OPERATIONS_FIELDS = fieldnames(PossibleOperations)
function iterate(possibleOperations::PossibleOperations)
for fieldname in _POSSIBLE_OPERATIONS_FIELDS
iterator = iterate(getfield(possibleOperations, fieldname))
if (!isnothing(iterator))
return (result = iterator[1], state = (fieldname, iterator[2]))
end
end
return nothing
end
function iterate(possibleOperations::PossibleOperations, state)
newStateSym = state[1]
newStateIt = iterate(getfield(possibleOperations, newStateSym), state[2])
if !isnothing(newStateIt)
return (result = newStateIt[1], state = (newStateSym, newStateIt[2]))
end
# cycle to next field
index = findfirst(x -> x == newStateSym, _POSSIBLE_OPERATIONS_FIELDS) + 1
while index <= length(_POSSIBLE_OPERATIONS_FIELDS)
newStateSym = _POSSIBLE_OPERATIONS_FIELDS[index]
newStateIt = iterate(getfield(possibleOperations, newStateSym))
if !isnothing(newStateIt)
return (result = newStateIt[1], state = (newStateSym, newStateIt[2]))
end
index += 1
end
return nothing
end

View File

@ -135,6 +135,12 @@ import MetagraphOptimization.partners
@test length(operations) == (nodeFusions = 10, nodeReductions = 0, nodeSplits = 0)
@test length(graph.dirtyNodes) == 0
i = 0
for op in operations
i += 1
end
@test i == 10
@test operations == get_operations(graph)
nf = first(operations.nodeFusions)