Add iterator for PossibleOperations data structure
This commit is contained in:
parent
992450374c
commit
c73053f991
@ -68,6 +68,10 @@ export get_compute_function
|
||||
export cost_type, graph_cost, operation_effect
|
||||
export GlobalMetricEstimator, CDCost
|
||||
|
||||
# optimization
|
||||
export AbstractOptimizer, GreedyOptimizer
|
||||
export optimize_step!, optimize!
|
||||
|
||||
# machine info
|
||||
export Machine
|
||||
export get_machine_info
|
||||
@ -117,6 +121,7 @@ include("node/properties.jl")
|
||||
include("node/validate.jl")
|
||||
|
||||
include("operation/utility.jl")
|
||||
include("operation/iterate.jl")
|
||||
include("operation/apply.jl")
|
||||
include("operation/clean.jl")
|
||||
include("operation/find.jl")
|
||||
|
@ -29,6 +29,10 @@ function -(cost1::CDCost, cost2::CDCost)::CDCost
|
||||
return (data = d, computeEffort = ce, computeIntensity = ce / d)::CDCost
|
||||
end
|
||||
|
||||
function zero(type::Type{CDCost})
|
||||
return (data = 0.0, computeEffort = 00.0, computeIntensity = 0.0)::CDCost
|
||||
end
|
||||
|
||||
struct GlobalMetricEstimator <: AbstractEstimator end
|
||||
|
||||
function cost_type(estimator::GlobalMetricEstimator)
|
||||
|
36
src/operation/iterate.jl
Normal file
36
src/operation/iterate.jl
Normal file
@ -0,0 +1,36 @@
|
||||
import Base.iterate
|
||||
|
||||
const _POSSIBLE_OPERATIONS_FIELDS = fieldnames(PossibleOperations)
|
||||
|
||||
function iterate(possibleOperations::PossibleOperations)
|
||||
for fieldname in _POSSIBLE_OPERATIONS_FIELDS
|
||||
iterator = iterate(getfield(possibleOperations, fieldname))
|
||||
if (!isnothing(iterator))
|
||||
return (result = iterator[1], state = (fieldname, iterator[2]))
|
||||
end
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function iterate(possibleOperations::PossibleOperations, state)
|
||||
newStateSym = state[1]
|
||||
newStateIt = iterate(getfield(possibleOperations, newStateSym), state[2])
|
||||
if !isnothing(newStateIt)
|
||||
return (result = newStateIt[1], state = (newStateSym, newStateIt[2]))
|
||||
end
|
||||
|
||||
# cycle to next field
|
||||
index = findfirst(x -> x == newStateSym, _POSSIBLE_OPERATIONS_FIELDS) + 1
|
||||
|
||||
while index <= length(_POSSIBLE_OPERATIONS_FIELDS)
|
||||
newStateSym = _POSSIBLE_OPERATIONS_FIELDS[index]
|
||||
newStateIt = iterate(getfield(possibleOperations, newStateSym))
|
||||
if !isnothing(newStateIt)
|
||||
return (result = newStateIt[1], state = (newStateSym, newStateIt[2]))
|
||||
end
|
||||
index += 1
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
@ -135,6 +135,12 @@ import MetagraphOptimization.partners
|
||||
@test length(operations) == (nodeFusions = 10, nodeReductions = 0, nodeSplits = 0)
|
||||
@test length(graph.dirtyNodes) == 0
|
||||
|
||||
i = 0
|
||||
for op in operations
|
||||
i += 1
|
||||
end
|
||||
@test i == 10
|
||||
|
||||
@test operations == get_operations(graph)
|
||||
nf = first(operations.nodeFusions)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user