WIP refactor
Some checks failed
MetagraphOptimization_CI / test (push) Failing after 7m59s
MetagraphOptimization_CI / docs (push) Failing after 8m6s

This commit is contained in:
Anton Reinhard 2024-06-24 23:31:30 +02:00
parent 2921882fd4
commit 1b4ba285c3
6 changed files with 52 additions and 17 deletions

View File

@ -90,8 +90,6 @@ function gen_input_assignment_code(
machine::Machine,
problemInputSymbol::Symbol = :input,
)
@assert length(inputSymbols) >= sum(values(in_particles(instance))) + sum(values(out_particles(instance))) "Number of input Symbols is smaller than the number of particles in the process description"
assignInputs = Vector{FunctionCall}()
for (name, symbols) in inputSymbols
(type, index) = type_index_from_name(model(instance), name)
@ -104,8 +102,8 @@ function gen_input_assignment_code(
FunctionCall(
# x is the process input
part_from_x,
SVector{1, Symbol}(problemInputSymbol),
SVector{2, Any}(type, index),
SVector{1, Symbol}(problemInputSymbol),
symbol,
device,
),
@ -117,14 +115,19 @@ function gen_input_assignment_code(
end
"""
gen_tape(graph::DAG, instance::AbstractProblemInstance, machine::Machine)
gen_tape(graph::DAG, instance::AbstractProblemInstance, machine::Machine, scheduler::AbstractScheduler = GreedyScheduler())
Generate the code for a given graph. The return value is a [`Tape`](@ref).
See also: [`execute`](@ref), [`execute_tape`](@ref)
"""
function gen_tape(graph::DAG, instance::AbstractProblemInstance, machine::Machine)
schedule = schedule_dag(GreedyScheduler(), graph, machine)
function gen_tape(
graph::DAG,
instance::AbstractProblemInstance,
machine::Machine,
scheduler::AbstractScheduler = GreedyScheduler(),
)
schedule = schedule_dag(scheduler, graph, machine)
# get inSymbols
inputSyms = Dict{String, Vector{Symbol}}()
@ -156,7 +159,8 @@ function execute_tape(tape::Tape, input)
cache = Dict{Symbol, Any}()
cache[:input] = input
# simply execute all the code snippets here
# TODO: `@assert` that process input fits the tape.process
@assert typeof(input) == input_type(tape.instance)
# TODO: `@assert` that input fits the tape.instance
for expr in tape.initCachesCode
@eval $expr
end

View File

@ -2,15 +2,45 @@
"""
AbstractModel
Base type for all models. From this, [`AbstractProblemInstance`](@ref)s can be constructed.
See also: [`problem_instance`](@ref)
"""
abstract type AbstractModel end
"""
problem_instance(::AbstractModel, ::Vararg)
Interface function that must be implemented for any implementation of [`AbstractModel`](@ref). This function should return a specific [`AbstractProblemInstance`](@ref) given some parameters.
"""
function problem_instance end
"""
AbstractProblemInstance
Base type for process descriptions. An object of this type of a corresponding [`AbstractPhysicsModel`](@ref) should uniquely identify a process in that model.
Base type for problem instances. An object of this type of a corresponding [`AbstractModel`](@ref) should uniquely identify a problem instance of that model.
See also: [`parse_process`](@ref)
"""
abstract type AbstractProblemInstance end
"""
input_type(problem::AbstractProblemInstance)
Return the fully specified input type for a specific [`AbstractProblemInstance`](@ref).
"""
function input_type end
"""
graph(::AbstractProblemInstance)
Generate the [`DAG`](@ref) for the given [`AbstractProblemInstance`](@ref). Every entry node (see [`get_entry_nodes`](@ref)) to the graph must have a name set. Implement [`input_expr`](@ref) to return a valid expression for each of those names.
"""
function graph end
"""
input_expr(::AbstractProblemInstance, input_sym::Symbol, name::String)
For a given [`AbstractProblemInstance`](@ref), a `Symbol` on which the problem input is available (see [`input_type`](@ref)), and entry node name, return an `Expr` getting that specific input value from the
"""
function input_expr end

View File

@ -4,7 +4,7 @@
A greedy implementation of a scheduler, creating a topological ordering of nodes and naively balancing them onto the different devices.
"""
struct GreedyScheduler end
struct GreedyScheduler <: AbstractScheduler end
function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine)
nodeQueue = PriorityQueue{Node, Int}()

View File

@ -1,10 +1,10 @@
"""
Scheduler
AbstractScheduler
Abstract base type for scheduler implementations. The scheduler is used to assign each node to a device and create a topological ordering of tasks.
"""
abstract type Scheduler end
abstract type AbstractScheduler end
"""
schedule_dag(::Scheduler, ::DAG, ::Machine)

View File

@ -5,10 +5,11 @@ using StaticArrays
Type representing a function call with `N` parameters. Contains the function to call, argument symbols, the return symbol and the device to execute on.
"""
struct FunctionCall{VectorType <: AbstractVector, M}
struct FunctionCall{VectorType <: AbstractVector, N}
func::Function
arguments::VectorType
additional_arguments::SVector{M, Any} # additional arguments (as values) for the function call, will be prepended to the other arguments
# TODO: this should be a tuple
value_arguments::SVector{N, Any} # value arguments for the function call, will be prepended to the other arguments
arguments::VectorType # symbols of the inputs to the function call
return_symbol::Symbol
device::AbstractDevice
end

View File

@ -32,7 +32,7 @@ function get_function_call(
inSymbols::AbstractVector,
outSymbol::Symbol,
) where {CompTask <: AbstractComputeTask}
return [FunctionCall(compute, inSymbols, SVector{1, Any}(t), outSymbol, device)]
return [FunctionCall(compute, SVector{1, Any}(t), inSymbols, outSymbol, device)]
end
function get_function_call(node::ComputeTaskNode)
@ -64,8 +64,8 @@ function get_function_call(node::DataTaskNode)
return [
FunctionCall(
unpack_identity,
SVector{1, Symbol}(Symbol(to_var_name(first(children(node)).id))),
SVector{0, Any}(),
SVector{1, Symbol}(Symbol(to_var_name(first(children(node)).id))),
Symbol(to_var_name(node.id)),
first(children(node)).device,
),
@ -77,8 +77,8 @@ function get_init_function_call(node::DataTaskNode, device::AbstractDevice)
return FunctionCall(
unpack_identity,
SVector{1, Symbol}(Symbol("$(to_var_name(node.id))_in")),
SVector{0, Any}(),
SVector{1, Symbol}(Symbol("$(to_var_name(node.id))_in")),
Symbol(to_var_name(node.id)),
device,
)