Tape Machine (#30)

Adds a tape machine way of executing the code.
The tape machine is a series of FunctionCall objects, which can either be called one by one, or be used to generate expressions to make up a function.

Reviewed-on: Rubydragon/MetagraphOptimization.jl#30
Co-authored-by: Anton Reinhard <anton.reinhard@proton.me>
Co-committed-by: Anton Reinhard <anton.reinhard@proton.me>
This commit is contained in:
2024-01-03 16:38:32 +01:00
committed by Anton Reinhard
parent 92e0eeaaef
commit 82ed774b7e
21 changed files with 398 additions and 502 deletions

40
src/code_gen/function.jl Normal file
View File

@@ -0,0 +1,40 @@
"""
get_compute_function(graph::DAG, process::AbstractProcessDescription, machine::Machine)
Return a function of signature `compute_<id>(input::AbstractProcessInput)`, which will return the result of the DAG computation on the given input.
"""
function get_compute_function(graph::DAG, process::AbstractProcessDescription, machine::Machine)
tape = gen_tape(graph, process, machine)
initCaches = Expr(:block, tape.initCachesCode...)
assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...)
code = Expr(:block, expr_from_fc.(tape.computeCode)...)
functionId = to_var_name(UUIDs.uuid1(rng[1]))
resSym = eval(gen_access_expr(entry_device(tape.machine), tape.outputSymbol))
expr = Meta.parse(
"function compute_$(functionId)(data_input::AbstractProcessInput) $(initCaches); $(assignInputs); $code; return $resSym; end",
)
func = eval(expr)
return func
end
"""
execute(graph::DAG, process::AbstractProcessDescription, machine::Machine, input::AbstractProcessInput)
Execute the code of the given `graph` on the given input particles.
This is essentially shorthand for
```julia
tape = gen_tape(graph, process, machine)
return execute_tape(tape, input)
```
See also: [`parse_dag`](@ref), [`parse_process`](@ref), [`gen_process_input`](@ref)
"""
function execute(graph::DAG, process::AbstractProcessDescription, machine::Machine, input::AbstractProcessInput)
tape = gen_tape(graph, process, machine)
return execute_tape(tape, input)
end

View File

@@ -1,150 +0,0 @@
"""
gen_code(graph::DAG)
Generate the code for a given graph. The return value is a named tuple of:
- `code::Expr`: The julia expression containing the code for the whole graph.
- `inputSymbols::Dict{String, Vector{Symbol}}`: A dictionary of symbols mapping the names of the input nodes of the graph to the symbols their inputs should be provided on.
- `outputSymbol::Symbol`: The symbol of the final calculated value
See also: [`execute`](@ref)
"""
function gen_code(graph::DAG, machine::Machine)
sched = schedule_dag(GreedyScheduler(), graph, machine)
codeAcc = Vector{Expr}()
sizehint!(codeAcc, length(graph.nodes))
for node in sched
# TODO: this is kind of ugly, should init nodes be scheduled differently from the rest?
if (node isa DataTaskNode && length(node.children) == 0)
push!(codeAcc, get_init_expression(node, entry_device(machine)))
continue
end
push!(codeAcc, get_expression(node))
end
# get inSymbols
inputSyms = Dict{String, Vector{Symbol}}()
for node in get_entry_nodes(graph)
if !haskey(inputSyms, node.name)
inputSyms[node.name] = Vector{Symbol}()
end
push!(inputSyms[node.name], Symbol("$(to_var_name(node.id))_in"))
end
# get outSymbol
outSym = Symbol(to_var_name(get_exit_node(graph).id))
return (code = Expr(:block, codeAcc...), inputSymbols = inputSyms, outputSymbol = outSym)
end
function gen_cache_init_code(machine::Machine)
initializeCaches = Vector{Expr}()
for device in machine.devices
push!(initializeCaches, gen_cache_init_code(device))
end
return Expr(:block, initializeCaches...)
end
function gen_input_assignment_code(
inputSymbols::Dict{String, Vector{Symbol}},
processDescription::AbstractProcessDescription,
machine::Machine,
processInputSymbol::Symbol = :input,
)
@assert length(inputSymbols) >=
sum(values(in_particles(processDescription))) + sum(values(out_particles(processDescription))) "Number of input Symbols is smaller than the number of particles in the process description"
assignInputs = Vector{Expr}()
for (name, symbols) in inputSymbols
(type, index) = type_index_from_name(model(processDescription), name)
p = "get_particle($(processInputSymbol), $(type), $(index))"
for symbol in symbols
device = entry_device(machine)
evalExpr = eval(gen_access_expr(device, symbol))
push!(assignInputs, Meta.parse("$(evalExpr) = ParticleValue{$type, ComplexF64}($p, one(ComplexF64))"))
end
end
return Expr(:block, assignInputs...)
end
"""
get_compute_function(graph::DAG, process::AbstractProcessDescription, machine::Machine)
Return a function of signature `compute_<id>(input::AbstractProcessInput)`, which will return the result of the DAG computation on the given input.
"""
function get_compute_function(graph::DAG, process::AbstractProcessDescription, machine::Machine)
(code, inputSymbols, outputSymbol) = gen_code(graph, machine)
initCaches = gen_cache_init_code(machine)
assignInputs = gen_input_assignment_code(inputSymbols, process, machine, :input)
functionId = to_var_name(UUIDs.uuid1(rng[1]))
resSym = eval(gen_access_expr(entry_device(machine), outputSymbol))
expr = Meta.parse(
"function compute_$(functionId)(input::AbstractProcessInput) $initCaches; $assignInputs; $code; return $resSym; end",
)
func = eval(expr)
return func
end
"""
execute(graph::DAG, process::AbstractProcessDescription, machine::Machine, input::AbstractProcessInput)
Execute the code of the given `graph` on the given input particles.
This is essentially shorthand for
```julia
compute_graph = get_compute_function(graph, process)
result = compute_graph(particles)
```
If an exception occurs during the execution of the generated code, it will be printed for investigation.
See also: [`parse_dag`](@ref), [`parse_process`](@ref), [`gen_process_input`](@ref)
"""
function execute(graph::DAG, process::AbstractProcessDescription, machine::Machine, input::AbstractProcessInput)
(code, inputSymbols, outputSymbol) = gen_code(graph, machine)
initCaches = gen_cache_init_code(machine)
assignInputs = gen_input_assignment_code(inputSymbols, process, machine, :input)
functionId = to_var_name(UUIDs.uuid1(rng[1]))
resSym = eval(gen_access_expr(entry_device(machine), outputSymbol))
expr = Meta.parse(
"function compute_$(functionId)(input::AbstractProcessInput) $initCaches; $assignInputs; $code; return $resSym; end",
)
func = eval(expr)
result = 0
try
result = @eval $func($input)
#functionStr = string(expr)
#println("Function:\n$functionStr")
catch e
println("Error while evaluating: $e")
# if we find a uuid in the exception we can color it in so it's easier to spot
uuidRegex = r"[0-9a-f]{8}_[0-9a-f]{4}_[0-9a-f]{4}_[0-9a-f]{4}_[0-9a-f]{12}"
m = match(uuidRegex, string(e))
functionStr = string(expr)
if (isa(m, RegexMatch))
functionStr = replace(functionStr, m.match => "\033[31m$(m.match)\033[0m")
end
println("Function:\n$functionStr")
@assert false
end
return result
end

View File

@@ -0,0 +1,182 @@
function call_fc(fc::FunctionCall{VectorT, 0}, cache::Dict{Symbol, Any}) where {VectorT <: SVector{1}}
cache[fc.return_symbol] = fc.func(cache[fc.arguments[1]])
return nothing
end
function call_fc(fc::FunctionCall{VectorT, 1}, cache::Dict{Symbol, Any}) where {VectorT <: SVector{1}}
cache[fc.return_symbol] = fc.func(fc.additional_arguments[1], cache[fc.arguments[1]])
return nothing
end
function call_fc(fc::FunctionCall{VectorT, 0}, cache::Dict{Symbol, Any}) where {VectorT <: SVector{2}}
cache[fc.return_symbol] = fc.func(cache[fc.arguments[1]], cache[fc.arguments[2]])
return nothing
end
function call_fc(fc::FunctionCall{VectorT, 1}, cache::Dict{Symbol, Any}) where {VectorT <: SVector{2}}
cache[fc.return_symbol] = fc.func(fc.additional_arguments[1], cache[fc.arguments[1]], cache[fc.arguments[2]])
return nothing
end
function call_fc(fc::FunctionCall{VectorT, 1}, cache::Dict{Symbol, Any}) where {VectorT}
cache[fc.return_symbol] = fc.func(fc.additional_arguments[1], getindex.(Ref(cache), fc.arguments)...)
return nothing
end
"""
call_fc(fc::FunctionCall, cache::Dict{Symbol, Any})
Execute the given [`FunctionCall`](@ref) on the dictionary.
Several more specialized versions of this function exist to reduce vector unrolling work for common cases.
"""
function call_fc(fc::FunctionCall{VectorT, M}, cache::Dict{Symbol, Any}) where {VectorT, M}
cache[fc.return_symbol] = fc.func(fc.additional_arguments..., getindex.(Ref(cache), fc.arguments)...)
return nothing
end
function expr_from_fc(fc::FunctionCall{VectorT, 0}) where {VectorT}
return Meta.parse(
"$(eval(gen_access_expr(fc.device, fc.return_symbol))) = $(fc.func)($(unroll_symbol_vector(eval.(gen_access_expr.(Ref(fc.device), fc.arguments)))))",
)
end
"""
expr_from_fc(fc::FunctionCall)
For a given function call, return an expression evaluating it.
"""
function expr_from_fc(fc::FunctionCall{VectorT, M}) where {VectorT, M}
func_call = Expr(
:call,
Symbol(fc.func),
fc.additional_arguments...,
eval.(gen_access_expr.(Ref(fc.device), fc.arguments))...,
)
expr = :($(eval(gen_access_expr(fc.device, fc.return_symbol))) = $func_call)
return expr
end
"""
gen_cache_init_code(machine::Machine)
For each [`AbstractDevice`](@ref) in the given [`Machine`](@ref), returning a `Vector{Expr}` doing the initialization.
"""
function gen_cache_init_code(machine::Machine)
initializeCaches = Vector{Expr}()
for device in machine.devices
push!(initializeCaches, gen_cache_init_code(device))
end
return initializeCaches
end
"""
part_from_x(type::Type, index::Int, x::AbstractProcessInput)
Return the [`ParticleValue`](@ref) of the given type of particle with the given `index` from the given process input.
Function is wrapped into a [`FunctionCall`](@ref) in [`gen_input_assignment_code`](@ref).
"""
part_from_x(type::Type, index::Int, x::AbstractProcessInput) =
ParticleValue{type, ComplexF64}(get_particle(x, type, index), one(ComplexF64))
"""
gen_input_assignment_code(
inputSymbols::Dict{String, Vector{Symbol}},
processDescription::AbstractProcessDescription,
machine::Machine,
processInputSymbol::Symbol = :input,
)
Return a `Vector{Expr}` doing the input assignments from the given `processInputSymbol` onto the `inputSymbols`.
"""
function gen_input_assignment_code(
inputSymbols::Dict{String, Vector{Symbol}},
processDescription::AbstractProcessDescription,
machine::Machine,
processInputSymbol::Symbol = :input,
)
@assert length(inputSymbols) >=
sum(values(in_particles(processDescription))) + sum(values(out_particles(processDescription))) "Number of input Symbols is smaller than the number of particles in the process description"
assignInputs = Vector{FunctionCall}()
for (name, symbols) in inputSymbols
(type, index) = type_index_from_name(model(processDescription), name)
# make a function for this, since we can't use anonymous functions in the FunctionCall
for symbol in symbols
device = entry_device(machine)
push!(
assignInputs,
FunctionCall(
# x is the process input
part_from_x,
SVector{1, Symbol}(processInputSymbol),
SVector{2, Any}(type, index),
symbol,
device,
),
)
end
end
return assignInputs
end
"""
gen_tape(graph::DAG, process::AbstractProcessDescription, machine::Machine)
Generate the code for a given graph. The return value is a [`Tape`](@ref).
See also: [`execute`](@ref), [`execute_tape`](@ref)
"""
function gen_tape(graph::DAG, process::AbstractProcessDescription, machine::Machine)
schedule = schedule_dag(GreedyScheduler(), graph, machine)
# get inSymbols
inputSyms = Dict{String, Vector{Symbol}}()
for node in get_entry_nodes(graph)
if !haskey(inputSyms, node.name)
inputSyms[node.name] = Vector{Symbol}()
end
push!(inputSyms[node.name], Symbol("$(to_var_name(node.id))_in"))
end
# get outSymbol
outSym = Symbol(to_var_name(get_exit_node(graph).id))
initCaches = gen_cache_init_code(machine)
assignInputs = gen_input_assignment_code(inputSyms, process, machine, :input)
return Tape(initCaches, assignInputs, schedule, inputSyms, outSym, Dict(), process, machine)
end
"""
execute_tape(tape::Tape, input::AbstractProcessInput)
Execute the given tape with the given input.
For implementation reasons, this disregards the set [`CacheStrategy`](@ref) of the devices and always uses a dictionary.
"""
function execute_tape(tape::Tape, input::AbstractProcessInput)
cache = Dict{Symbol, Any}()
cache[:input] = input
# simply execute all the code snippets here
# TODO: `@assert` that process input fits the tape.process
for expr in tape.initCachesCode
@eval $expr
end
for function_call in tape.inputAssignCode
call_fc(function_call, cache)
end
for function_call in tape.computeCode
call_fc(function_call, cache)
end
return cache[tape.outputSymbol]
end

19
src/code_gen/type.jl Normal file
View File

@@ -0,0 +1,19 @@
"""
Tape
TODO: update docs
- `code::Vector{Expr}`: The julia expression containing the code for the whole graph.
- `inputSymbols::Dict{String, Vector{Symbol}}`: A dictionary of symbols mapping the names of the input nodes of the graph to the symbols their inputs should be provided on.
- `outputSymbol::Symbol`: The symbol of the final calculated value
"""
struct Tape
initCachesCode::Vector{Expr}
inputAssignCode::Vector{FunctionCall}
computeCode::Vector{FunctionCall}
inputSymbols::Dict{String, Vector{Symbol}}
outputSymbol::Symbol
cache::Dict{Symbol, Any}
process::AbstractProcessDescription
machine::Machine
end