From 1b4ba285c3220ad8a1c53f669c054fe4ef5450b2 Mon Sep 17 00:00:00 2001 From: Anton Reinhard Date: Mon, 24 Jun 2024 23:31:30 +0200 Subject: [PATCH] WIP refactor --- src/code_gen/tape_machine.jl | 18 +++++++++++------- src/models/interface.jl | 32 +++++++++++++++++++++++++++++++- src/scheduler/greedy.jl | 2 +- src/scheduler/interface.jl | 4 ++-- src/scheduler/type.jl | 7 ++++--- src/task/compute.jl | 6 +++--- 6 files changed, 52 insertions(+), 17 deletions(-) diff --git a/src/code_gen/tape_machine.jl b/src/code_gen/tape_machine.jl index ea42b9b..e986ad4 100644 --- a/src/code_gen/tape_machine.jl +++ b/src/code_gen/tape_machine.jl @@ -90,8 +90,6 @@ function gen_input_assignment_code( machine::Machine, problemInputSymbol::Symbol = :input, ) - @assert length(inputSymbols) >= sum(values(in_particles(instance))) + sum(values(out_particles(instance))) "Number of input Symbols is smaller than the number of particles in the process description" - assignInputs = Vector{FunctionCall}() for (name, symbols) in inputSymbols (type, index) = type_index_from_name(model(instance), name) @@ -104,8 +102,8 @@ function gen_input_assignment_code( FunctionCall( # x is the process input part_from_x, - SVector{1, Symbol}(problemInputSymbol), SVector{2, Any}(type, index), + SVector{1, Symbol}(problemInputSymbol), symbol, device, ), @@ -117,14 +115,19 @@ function gen_input_assignment_code( end """ - gen_tape(graph::DAG, instance::AbstractProblemInstance, machine::Machine) + gen_tape(graph::DAG, instance::AbstractProblemInstance, machine::Machine, scheduler::AbstractScheduler = GreedyScheduler()) Generate the code for a given graph. The return value is a [`Tape`](@ref). See also: [`execute`](@ref), [`execute_tape`](@ref) """ -function gen_tape(graph::DAG, instance::AbstractProblemInstance, machine::Machine) - schedule = schedule_dag(GreedyScheduler(), graph, machine) +function gen_tape( + graph::DAG, + instance::AbstractProblemInstance, + machine::Machine, + scheduler::AbstractScheduler = GreedyScheduler(), +) + schedule = schedule_dag(scheduler, graph, machine) # get inSymbols inputSyms = Dict{String, Vector{Symbol}}() @@ -156,7 +159,8 @@ function execute_tape(tape::Tape, input) cache = Dict{Symbol, Any}() cache[:input] = input # simply execute all the code snippets here - # TODO: `@assert` that process input fits the tape.process + @assert typeof(input) == input_type(tape.instance) + # TODO: `@assert` that input fits the tape.instance for expr in tape.initCachesCode @eval $expr end diff --git a/src/models/interface.jl b/src/models/interface.jl index 66f6514..3ed1b34 100644 --- a/src/models/interface.jl +++ b/src/models/interface.jl @@ -2,15 +2,45 @@ """ AbstractModel +Base type for all models. From this, [`AbstractProblemInstance`](@ref)s can be constructed. +See also: [`problem_instance`](@ref) """ abstract type AbstractModel end +""" + problem_instance(::AbstractModel, ::Vararg) + +Interface function that must be implemented for any implementation of [`AbstractModel`](@ref). This function should return a specific [`AbstractProblemInstance`](@ref) given some parameters. +""" +function problem_instance end + """ AbstractProblemInstance -Base type for process descriptions. An object of this type of a corresponding [`AbstractPhysicsModel`](@ref) should uniquely identify a process in that model. +Base type for problem instances. An object of this type of a corresponding [`AbstractModel`](@ref) should uniquely identify a problem instance of that model. See also: [`parse_process`](@ref) """ abstract type AbstractProblemInstance end + +""" + input_type(problem::AbstractProblemInstance) + +Return the fully specified input type for a specific [`AbstractProblemInstance`](@ref). +""" +function input_type end + +""" + graph(::AbstractProblemInstance) + +Generate the [`DAG`](@ref) for the given [`AbstractProblemInstance`](@ref). Every entry node (see [`get_entry_nodes`](@ref)) to the graph must have a name set. Implement [`input_expr`](@ref) to return a valid expression for each of those names. +""" +function graph end + +""" + input_expr(::AbstractProblemInstance, input_sym::Symbol, name::String) + +For a given [`AbstractProblemInstance`](@ref), a `Symbol` on which the problem input is available (see [`input_type`](@ref)), and entry node name, return an `Expr` getting that specific input value from the +""" +function input_expr end diff --git a/src/scheduler/greedy.jl b/src/scheduler/greedy.jl index ae7ae92..ee90af3 100644 --- a/src/scheduler/greedy.jl +++ b/src/scheduler/greedy.jl @@ -4,7 +4,7 @@ A greedy implementation of a scheduler, creating a topological ordering of nodes and naively balancing them onto the different devices. """ -struct GreedyScheduler end +struct GreedyScheduler <: AbstractScheduler end function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine) nodeQueue = PriorityQueue{Node, Int}() diff --git a/src/scheduler/interface.jl b/src/scheduler/interface.jl index 91c11c7..b420788 100644 --- a/src/scheduler/interface.jl +++ b/src/scheduler/interface.jl @@ -1,10 +1,10 @@ """ - Scheduler + AbstractScheduler Abstract base type for scheduler implementations. The scheduler is used to assign each node to a device and create a topological ordering of tasks. """ -abstract type Scheduler end +abstract type AbstractScheduler end """ schedule_dag(::Scheduler, ::DAG, ::Machine) diff --git a/src/scheduler/type.jl b/src/scheduler/type.jl index 55c1bd5..332a327 100644 --- a/src/scheduler/type.jl +++ b/src/scheduler/type.jl @@ -5,10 +5,11 @@ using StaticArrays Type representing a function call with `N` parameters. Contains the function to call, argument symbols, the return symbol and the device to execute on. """ -struct FunctionCall{VectorType <: AbstractVector, M} +struct FunctionCall{VectorType <: AbstractVector, N} func::Function - arguments::VectorType - additional_arguments::SVector{M, Any} # additional arguments (as values) for the function call, will be prepended to the other arguments + # TODO: this should be a tuple + value_arguments::SVector{N, Any} # value arguments for the function call, will be prepended to the other arguments + arguments::VectorType # symbols of the inputs to the function call return_symbol::Symbol device::AbstractDevice end diff --git a/src/task/compute.jl b/src/task/compute.jl index 2c06613..fdc3a0e 100644 --- a/src/task/compute.jl +++ b/src/task/compute.jl @@ -32,7 +32,7 @@ function get_function_call( inSymbols::AbstractVector, outSymbol::Symbol, ) where {CompTask <: AbstractComputeTask} - return [FunctionCall(compute, inSymbols, SVector{1, Any}(t), outSymbol, device)] + return [FunctionCall(compute, SVector{1, Any}(t), inSymbols, outSymbol, device)] end function get_function_call(node::ComputeTaskNode) @@ -64,8 +64,8 @@ function get_function_call(node::DataTaskNode) return [ FunctionCall( unpack_identity, - SVector{1, Symbol}(Symbol(to_var_name(first(children(node)).id))), SVector{0, Any}(), + SVector{1, Symbol}(Symbol(to_var_name(first(children(node)).id))), Symbol(to_var_name(node.id)), first(children(node)).device, ), @@ -77,8 +77,8 @@ function get_init_function_call(node::DataTaskNode, device::AbstractDevice) return FunctionCall( unpack_identity, - SVector{1, Symbol}(Symbol("$(to_var_name(node.id))_in")), SVector{0, Any}(), + SVector{1, Symbol}(Symbol("$(to_var_name(node.id))_in")), Symbol(to_var_name(node.id)), device, )