WIP refactor
This commit is contained in:
parent
2921882fd4
commit
1b4ba285c3
@ -90,8 +90,6 @@ function gen_input_assignment_code(
|
||||
machine::Machine,
|
||||
problemInputSymbol::Symbol = :input,
|
||||
)
|
||||
@assert length(inputSymbols) >= sum(values(in_particles(instance))) + sum(values(out_particles(instance))) "Number of input Symbols is smaller than the number of particles in the process description"
|
||||
|
||||
assignInputs = Vector{FunctionCall}()
|
||||
for (name, symbols) in inputSymbols
|
||||
(type, index) = type_index_from_name(model(instance), name)
|
||||
@ -104,8 +102,8 @@ function gen_input_assignment_code(
|
||||
FunctionCall(
|
||||
# x is the process input
|
||||
part_from_x,
|
||||
SVector{1, Symbol}(problemInputSymbol),
|
||||
SVector{2, Any}(type, index),
|
||||
SVector{1, Symbol}(problemInputSymbol),
|
||||
symbol,
|
||||
device,
|
||||
),
|
||||
@ -117,14 +115,19 @@ function gen_input_assignment_code(
|
||||
end
|
||||
|
||||
"""
|
||||
gen_tape(graph::DAG, instance::AbstractProblemInstance, machine::Machine)
|
||||
gen_tape(graph::DAG, instance::AbstractProblemInstance, machine::Machine, scheduler::AbstractScheduler = GreedyScheduler())
|
||||
|
||||
Generate the code for a given graph. The return value is a [`Tape`](@ref).
|
||||
|
||||
See also: [`execute`](@ref), [`execute_tape`](@ref)
|
||||
"""
|
||||
function gen_tape(graph::DAG, instance::AbstractProblemInstance, machine::Machine)
|
||||
schedule = schedule_dag(GreedyScheduler(), graph, machine)
|
||||
function gen_tape(
|
||||
graph::DAG,
|
||||
instance::AbstractProblemInstance,
|
||||
machine::Machine,
|
||||
scheduler::AbstractScheduler = GreedyScheduler(),
|
||||
)
|
||||
schedule = schedule_dag(scheduler, graph, machine)
|
||||
|
||||
# get inSymbols
|
||||
inputSyms = Dict{String, Vector{Symbol}}()
|
||||
@ -156,7 +159,8 @@ function execute_tape(tape::Tape, input)
|
||||
cache = Dict{Symbol, Any}()
|
||||
cache[:input] = input
|
||||
# simply execute all the code snippets here
|
||||
# TODO: `@assert` that process input fits the tape.process
|
||||
@assert typeof(input) == input_type(tape.instance)
|
||||
# TODO: `@assert` that input fits the tape.instance
|
||||
for expr in tape.initCachesCode
|
||||
@eval $expr
|
||||
end
|
||||
|
@ -2,15 +2,45 @@
|
||||
"""
|
||||
AbstractModel
|
||||
|
||||
Base type for all models. From this, [`AbstractProblemInstance`](@ref)s can be constructed.
|
||||
|
||||
See also: [`problem_instance`](@ref)
|
||||
"""
|
||||
abstract type AbstractModel end
|
||||
|
||||
"""
|
||||
problem_instance(::AbstractModel, ::Vararg)
|
||||
|
||||
Interface function that must be implemented for any implementation of [`AbstractModel`](@ref). This function should return a specific [`AbstractProblemInstance`](@ref) given some parameters.
|
||||
"""
|
||||
function problem_instance end
|
||||
|
||||
"""
|
||||
AbstractProblemInstance
|
||||
|
||||
Base type for process descriptions. An object of this type of a corresponding [`AbstractPhysicsModel`](@ref) should uniquely identify a process in that model.
|
||||
Base type for problem instances. An object of this type of a corresponding [`AbstractModel`](@ref) should uniquely identify a problem instance of that model.
|
||||
|
||||
See also: [`parse_process`](@ref)
|
||||
"""
|
||||
abstract type AbstractProblemInstance end
|
||||
|
||||
"""
|
||||
input_type(problem::AbstractProblemInstance)
|
||||
|
||||
Return the fully specified input type for a specific [`AbstractProblemInstance`](@ref).
|
||||
"""
|
||||
function input_type end
|
||||
|
||||
"""
|
||||
graph(::AbstractProblemInstance)
|
||||
|
||||
Generate the [`DAG`](@ref) for the given [`AbstractProblemInstance`](@ref). Every entry node (see [`get_entry_nodes`](@ref)) to the graph must have a name set. Implement [`input_expr`](@ref) to return a valid expression for each of those names.
|
||||
"""
|
||||
function graph end
|
||||
|
||||
"""
|
||||
input_expr(::AbstractProblemInstance, input_sym::Symbol, name::String)
|
||||
|
||||
For a given [`AbstractProblemInstance`](@ref), a `Symbol` on which the problem input is available (see [`input_type`](@ref)), and entry node name, return an `Expr` getting that specific input value from the
|
||||
"""
|
||||
function input_expr end
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
A greedy implementation of a scheduler, creating a topological ordering of nodes and naively balancing them onto the different devices.
|
||||
"""
|
||||
struct GreedyScheduler end
|
||||
struct GreedyScheduler <: AbstractScheduler end
|
||||
|
||||
function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine)
|
||||
nodeQueue = PriorityQueue{Node, Int}()
|
||||
|
@ -1,10 +1,10 @@
|
||||
|
||||
"""
|
||||
Scheduler
|
||||
AbstractScheduler
|
||||
|
||||
Abstract base type for scheduler implementations. The scheduler is used to assign each node to a device and create a topological ordering of tasks.
|
||||
"""
|
||||
abstract type Scheduler end
|
||||
abstract type AbstractScheduler end
|
||||
|
||||
"""
|
||||
schedule_dag(::Scheduler, ::DAG, ::Machine)
|
||||
|
@ -5,10 +5,11 @@ using StaticArrays
|
||||
|
||||
Type representing a function call with `N` parameters. Contains the function to call, argument symbols, the return symbol and the device to execute on.
|
||||
"""
|
||||
struct FunctionCall{VectorType <: AbstractVector, M}
|
||||
struct FunctionCall{VectorType <: AbstractVector, N}
|
||||
func::Function
|
||||
arguments::VectorType
|
||||
additional_arguments::SVector{M, Any} # additional arguments (as values) for the function call, will be prepended to the other arguments
|
||||
# TODO: this should be a tuple
|
||||
value_arguments::SVector{N, Any} # value arguments for the function call, will be prepended to the other arguments
|
||||
arguments::VectorType # symbols of the inputs to the function call
|
||||
return_symbol::Symbol
|
||||
device::AbstractDevice
|
||||
end
|
||||
|
@ -32,7 +32,7 @@ function get_function_call(
|
||||
inSymbols::AbstractVector,
|
||||
outSymbol::Symbol,
|
||||
) where {CompTask <: AbstractComputeTask}
|
||||
return [FunctionCall(compute, inSymbols, SVector{1, Any}(t), outSymbol, device)]
|
||||
return [FunctionCall(compute, SVector{1, Any}(t), inSymbols, outSymbol, device)]
|
||||
end
|
||||
|
||||
function get_function_call(node::ComputeTaskNode)
|
||||
@ -64,8 +64,8 @@ function get_function_call(node::DataTaskNode)
|
||||
return [
|
||||
FunctionCall(
|
||||
unpack_identity,
|
||||
SVector{1, Symbol}(Symbol(to_var_name(first(children(node)).id))),
|
||||
SVector{0, Any}(),
|
||||
SVector{1, Symbol}(Symbol(to_var_name(first(children(node)).id))),
|
||||
Symbol(to_var_name(node.id)),
|
||||
first(children(node)).device,
|
||||
),
|
||||
@ -77,8 +77,8 @@ function get_init_function_call(node::DataTaskNode, device::AbstractDevice)
|
||||
|
||||
return FunctionCall(
|
||||
unpack_identity,
|
||||
SVector{1, Symbol}(Symbol("$(to_var_name(node.id))_in")),
|
||||
SVector{0, Any}(),
|
||||
SVector{1, Symbol}(Symbol("$(to_var_name(node.id))_in")),
|
||||
Symbol(to_var_name(node.id)),
|
||||
device,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user