Tape Machine (#30)

Adds a tape machine way of executing the code.
The tape machine is a series of FunctionCall objects, which can either be called one by one, or be used to generate expressions to make up a function.

Reviewed-on: Rubydragon/MetagraphOptimization.jl#30
Co-authored-by: Anton Reinhard <anton.reinhard@proton.me>
Co-committed-by: Anton Reinhard <anton.reinhard@proton.me>
This commit is contained in:
Anton Reinhard 2024-01-03 16:38:32 +01:00 committed by Anton Reinhard
parent 92e0eeaaef
commit 82ed774b7e
21 changed files with 398 additions and 502 deletions

1
.gitignore vendored
View File

@ -5,6 +5,7 @@
# Files generated by invoking Julia with --track-allocation
*.mem
*.pb.gz
# System-specific files and directories generated by the BinaryProvider and BinDeps packages
# They contain absolute paths specific to the host computer, and so should not be committed

View File

@ -1,8 +1,23 @@
# Code Generation
## Main
## Types
```@autodocs
Modules = [MetagraphOptimization]
Pages = ["code_gen/main.jl"]
Pages = ["code_gen/type.jl"]
Order = [:type, :constant, :function]
```
## Function Generation
Implementations for generation of a callable function. A function generated this way cannot immediately be called. One Julia World Age has to pass before this is possible, which happens when the global Julia scope advances. If the DAG and therefore the generated function becomes too large, use the tape machine instead, since compiling large functions becomes infeasible.
```@autodocs
Modules = [MetagraphOptimization]
Pages = ["code_gen/function.jl"]
Order = [:function]
```
## Tape Machine
```@autodocs
Modules = [MetagraphOptimization]
Pages = ["code_gen/tabe_machine.jl"]
Order = [:function]
```

View File

@ -7,6 +7,13 @@ Pages = ["scheduler/interface.jl"]
Order = [:type, :function]
```
## Types
```@autodocs
Modules = [MetagraphOptimization]
Pages = ["scheduler/type.jl"]
Order = [:type, :function]
```
## Greedy
```@autodocs
Modules = [MetagraphOptimization]

View File

@ -34,10 +34,3 @@ Modules = [MetagraphOptimization]
Pages = ["task/properties.jl"]
Order = [:function]
```
## Print
```@autodocs
Modules = [MetagraphOptimization]
Pages = ["task/print.jl"]
Order = [:function]
```

View File

@ -79,6 +79,7 @@ export execute
export parse_dag, parse_process
export gen_process_input
export get_compute_function
export gen_tape, execute_tape
# estimator
export cost_type, graph_cost, operation_effect
@ -120,6 +121,7 @@ include("diff/type.jl")
include("properties/type.jl")
include("operation/type.jl")
include("graph/type.jl")
include("scheduler/type.jl")
include("trie.jl")
include("utility.jl")
@ -155,7 +157,6 @@ include("properties/utility.jl")
include("task/create.jl")
include("task/compare.jl")
include("task/compute.jl")
include("task/print.jl")
include("task/properties.jl")
include("estimator/interface.jl")
@ -200,6 +201,8 @@ include("devices/cuda/impl.jl")
include("scheduler/interface.jl")
include("scheduler/greedy.jl")
include("code_gen/main.jl")
include("code_gen/type.jl")
include("code_gen/tape_machine.jl")
include("code_gen/function.jl")
end # module MetagraphOptimization

40
src/code_gen/function.jl Normal file
View File

@ -0,0 +1,40 @@
"""
get_compute_function(graph::DAG, process::AbstractProcessDescription, machine::Machine)
Return a function of signature `compute_<id>(input::AbstractProcessInput)`, which will return the result of the DAG computation on the given input.
"""
function get_compute_function(graph::DAG, process::AbstractProcessDescription, machine::Machine)
tape = gen_tape(graph, process, machine)
initCaches = Expr(:block, tape.initCachesCode...)
assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...)
code = Expr(:block, expr_from_fc.(tape.computeCode)...)
functionId = to_var_name(UUIDs.uuid1(rng[1]))
resSym = eval(gen_access_expr(entry_device(tape.machine), tape.outputSymbol))
expr = Meta.parse(
"function compute_$(functionId)(data_input::AbstractProcessInput) $(initCaches); $(assignInputs); $code; return $resSym; end",
)
func = eval(expr)
return func
end
"""
execute(graph::DAG, process::AbstractProcessDescription, machine::Machine, input::AbstractProcessInput)
Execute the code of the given `graph` on the given input particles.
This is essentially shorthand for
```julia
tape = gen_tape(graph, process, machine)
return execute_tape(tape, input)
```
See also: [`parse_dag`](@ref), [`parse_process`](@ref), [`gen_process_input`](@ref)
"""
function execute(graph::DAG, process::AbstractProcessDescription, machine::Machine, input::AbstractProcessInput)
tape = gen_tape(graph, process, machine)
return execute_tape(tape, input)
end

View File

@ -1,150 +0,0 @@
"""
gen_code(graph::DAG)
Generate the code for a given graph. The return value is a named tuple of:
- `code::Expr`: The julia expression containing the code for the whole graph.
- `inputSymbols::Dict{String, Vector{Symbol}}`: A dictionary of symbols mapping the names of the input nodes of the graph to the symbols their inputs should be provided on.
- `outputSymbol::Symbol`: The symbol of the final calculated value
See also: [`execute`](@ref)
"""
function gen_code(graph::DAG, machine::Machine)
sched = schedule_dag(GreedyScheduler(), graph, machine)
codeAcc = Vector{Expr}()
sizehint!(codeAcc, length(graph.nodes))
for node in sched
# TODO: this is kind of ugly, should init nodes be scheduled differently from the rest?
if (node isa DataTaskNode && length(node.children) == 0)
push!(codeAcc, get_init_expression(node, entry_device(machine)))
continue
end
push!(codeAcc, get_expression(node))
end
# get inSymbols
inputSyms = Dict{String, Vector{Symbol}}()
for node in get_entry_nodes(graph)
if !haskey(inputSyms, node.name)
inputSyms[node.name] = Vector{Symbol}()
end
push!(inputSyms[node.name], Symbol("$(to_var_name(node.id))_in"))
end
# get outSymbol
outSym = Symbol(to_var_name(get_exit_node(graph).id))
return (code = Expr(:block, codeAcc...), inputSymbols = inputSyms, outputSymbol = outSym)
end
function gen_cache_init_code(machine::Machine)
initializeCaches = Vector{Expr}()
for device in machine.devices
push!(initializeCaches, gen_cache_init_code(device))
end
return Expr(:block, initializeCaches...)
end
function gen_input_assignment_code(
inputSymbols::Dict{String, Vector{Symbol}},
processDescription::AbstractProcessDescription,
machine::Machine,
processInputSymbol::Symbol = :input,
)
@assert length(inputSymbols) >=
sum(values(in_particles(processDescription))) + sum(values(out_particles(processDescription))) "Number of input Symbols is smaller than the number of particles in the process description"
assignInputs = Vector{Expr}()
for (name, symbols) in inputSymbols
(type, index) = type_index_from_name(model(processDescription), name)
p = "get_particle($(processInputSymbol), $(type), $(index))"
for symbol in symbols
device = entry_device(machine)
evalExpr = eval(gen_access_expr(device, symbol))
push!(assignInputs, Meta.parse("$(evalExpr) = ParticleValue{$type, ComplexF64}($p, one(ComplexF64))"))
end
end
return Expr(:block, assignInputs...)
end
"""
get_compute_function(graph::DAG, process::AbstractProcessDescription, machine::Machine)
Return a function of signature `compute_<id>(input::AbstractProcessInput)`, which will return the result of the DAG computation on the given input.
"""
function get_compute_function(graph::DAG, process::AbstractProcessDescription, machine::Machine)
(code, inputSymbols, outputSymbol) = gen_code(graph, machine)
initCaches = gen_cache_init_code(machine)
assignInputs = gen_input_assignment_code(inputSymbols, process, machine, :input)
functionId = to_var_name(UUIDs.uuid1(rng[1]))
resSym = eval(gen_access_expr(entry_device(machine), outputSymbol))
expr = Meta.parse(
"function compute_$(functionId)(input::AbstractProcessInput) $initCaches; $assignInputs; $code; return $resSym; end",
)
func = eval(expr)
return func
end
"""
execute(graph::DAG, process::AbstractProcessDescription, machine::Machine, input::AbstractProcessInput)
Execute the code of the given `graph` on the given input particles.
This is essentially shorthand for
```julia
compute_graph = get_compute_function(graph, process)
result = compute_graph(particles)
```
If an exception occurs during the execution of the generated code, it will be printed for investigation.
See also: [`parse_dag`](@ref), [`parse_process`](@ref), [`gen_process_input`](@ref)
"""
function execute(graph::DAG, process::AbstractProcessDescription, machine::Machine, input::AbstractProcessInput)
(code, inputSymbols, outputSymbol) = gen_code(graph, machine)
initCaches = gen_cache_init_code(machine)
assignInputs = gen_input_assignment_code(inputSymbols, process, machine, :input)
functionId = to_var_name(UUIDs.uuid1(rng[1]))
resSym = eval(gen_access_expr(entry_device(machine), outputSymbol))
expr = Meta.parse(
"function compute_$(functionId)(input::AbstractProcessInput) $initCaches; $assignInputs; $code; return $resSym; end",
)
func = eval(expr)
result = 0
try
result = @eval $func($input)
#functionStr = string(expr)
#println("Function:\n$functionStr")
catch e
println("Error while evaluating: $e")
# if we find a uuid in the exception we can color it in so it's easier to spot
uuidRegex = r"[0-9a-f]{8}_[0-9a-f]{4}_[0-9a-f]{4}_[0-9a-f]{4}_[0-9a-f]{12}"
m = match(uuidRegex, string(e))
functionStr = string(expr)
if (isa(m, RegexMatch))
functionStr = replace(functionStr, m.match => "\033[31m$(m.match)\033[0m")
end
println("Function:\n$functionStr")
@assert false
end
return result
end

View File

@ -0,0 +1,182 @@
function call_fc(fc::FunctionCall{VectorT, 0}, cache::Dict{Symbol, Any}) where {VectorT <: SVector{1}}
cache[fc.return_symbol] = fc.func(cache[fc.arguments[1]])
return nothing
end
function call_fc(fc::FunctionCall{VectorT, 1}, cache::Dict{Symbol, Any}) where {VectorT <: SVector{1}}
cache[fc.return_symbol] = fc.func(fc.additional_arguments[1], cache[fc.arguments[1]])
return nothing
end
function call_fc(fc::FunctionCall{VectorT, 0}, cache::Dict{Symbol, Any}) where {VectorT <: SVector{2}}
cache[fc.return_symbol] = fc.func(cache[fc.arguments[1]], cache[fc.arguments[2]])
return nothing
end
function call_fc(fc::FunctionCall{VectorT, 1}, cache::Dict{Symbol, Any}) where {VectorT <: SVector{2}}
cache[fc.return_symbol] = fc.func(fc.additional_arguments[1], cache[fc.arguments[1]], cache[fc.arguments[2]])
return nothing
end
function call_fc(fc::FunctionCall{VectorT, 1}, cache::Dict{Symbol, Any}) where {VectorT}
cache[fc.return_symbol] = fc.func(fc.additional_arguments[1], getindex.(Ref(cache), fc.arguments)...)
return nothing
end
"""
call_fc(fc::FunctionCall, cache::Dict{Symbol, Any})
Execute the given [`FunctionCall`](@ref) on the dictionary.
Several more specialized versions of this function exist to reduce vector unrolling work for common cases.
"""
function call_fc(fc::FunctionCall{VectorT, M}, cache::Dict{Symbol, Any}) where {VectorT, M}
cache[fc.return_symbol] = fc.func(fc.additional_arguments..., getindex.(Ref(cache), fc.arguments)...)
return nothing
end
function expr_from_fc(fc::FunctionCall{VectorT, 0}) where {VectorT}
return Meta.parse(
"$(eval(gen_access_expr(fc.device, fc.return_symbol))) = $(fc.func)($(unroll_symbol_vector(eval.(gen_access_expr.(Ref(fc.device), fc.arguments)))))",
)
end
"""
expr_from_fc(fc::FunctionCall)
For a given function call, return an expression evaluating it.
"""
function expr_from_fc(fc::FunctionCall{VectorT, M}) where {VectorT, M}
func_call = Expr(
:call,
Symbol(fc.func),
fc.additional_arguments...,
eval.(gen_access_expr.(Ref(fc.device), fc.arguments))...,
)
expr = :($(eval(gen_access_expr(fc.device, fc.return_symbol))) = $func_call)
return expr
end
"""
gen_cache_init_code(machine::Machine)
For each [`AbstractDevice`](@ref) in the given [`Machine`](@ref), returning a `Vector{Expr}` doing the initialization.
"""
function gen_cache_init_code(machine::Machine)
initializeCaches = Vector{Expr}()
for device in machine.devices
push!(initializeCaches, gen_cache_init_code(device))
end
return initializeCaches
end
"""
part_from_x(type::Type, index::Int, x::AbstractProcessInput)
Return the [`ParticleValue`](@ref) of the given type of particle with the given `index` from the given process input.
Function is wrapped into a [`FunctionCall`](@ref) in [`gen_input_assignment_code`](@ref).
"""
part_from_x(type::Type, index::Int, x::AbstractProcessInput) =
ParticleValue{type, ComplexF64}(get_particle(x, type, index), one(ComplexF64))
"""
gen_input_assignment_code(
inputSymbols::Dict{String, Vector{Symbol}},
processDescription::AbstractProcessDescription,
machine::Machine,
processInputSymbol::Symbol = :input,
)
Return a `Vector{Expr}` doing the input assignments from the given `processInputSymbol` onto the `inputSymbols`.
"""
function gen_input_assignment_code(
inputSymbols::Dict{String, Vector{Symbol}},
processDescription::AbstractProcessDescription,
machine::Machine,
processInputSymbol::Symbol = :input,
)
@assert length(inputSymbols) >=
sum(values(in_particles(processDescription))) + sum(values(out_particles(processDescription))) "Number of input Symbols is smaller than the number of particles in the process description"
assignInputs = Vector{FunctionCall}()
for (name, symbols) in inputSymbols
(type, index) = type_index_from_name(model(processDescription), name)
# make a function for this, since we can't use anonymous functions in the FunctionCall
for symbol in symbols
device = entry_device(machine)
push!(
assignInputs,
FunctionCall(
# x is the process input
part_from_x,
SVector{1, Symbol}(processInputSymbol),
SVector{2, Any}(type, index),
symbol,
device,
),
)
end
end
return assignInputs
end
"""
gen_tape(graph::DAG, process::AbstractProcessDescription, machine::Machine)
Generate the code for a given graph. The return value is a [`Tape`](@ref).
See also: [`execute`](@ref), [`execute_tape`](@ref)
"""
function gen_tape(graph::DAG, process::AbstractProcessDescription, machine::Machine)
schedule = schedule_dag(GreedyScheduler(), graph, machine)
# get inSymbols
inputSyms = Dict{String, Vector{Symbol}}()
for node in get_entry_nodes(graph)
if !haskey(inputSyms, node.name)
inputSyms[node.name] = Vector{Symbol}()
end
push!(inputSyms[node.name], Symbol("$(to_var_name(node.id))_in"))
end
# get outSymbol
outSym = Symbol(to_var_name(get_exit_node(graph).id))
initCaches = gen_cache_init_code(machine)
assignInputs = gen_input_assignment_code(inputSyms, process, machine, :input)
return Tape(initCaches, assignInputs, schedule, inputSyms, outSym, Dict(), process, machine)
end
"""
execute_tape(tape::Tape, input::AbstractProcessInput)
Execute the given tape with the given input.
For implementation reasons, this disregards the set [`CacheStrategy`](@ref) of the devices and always uses a dictionary.
"""
function execute_tape(tape::Tape, input::AbstractProcessInput)
cache = Dict{Symbol, Any}()
cache[:input] = input
# simply execute all the code snippets here
# TODO: `@assert` that process input fits the tape.process
for expr in tape.initCachesCode
@eval $expr
end
for function_call in tape.inputAssignCode
call_fc(function_call, cache)
end
for function_call in tape.computeCode
call_fc(function_call, cache)
end
return cache[tape.outputSymbol]
end

19
src/code_gen/type.jl Normal file
View File

@ -0,0 +1,19 @@
"""
Tape
TODO: update docs
- `code::Vector{Expr}`: The julia expression containing the code for the whole graph.
- `inputSymbols::Dict{String, Vector{Symbol}}`: A dictionary of symbols mapping the names of the input nodes of the graph to the symbols their inputs should be provided on.
- `outputSymbol::Symbol`: The symbol of the final calculated value
"""
struct Tape
initCachesCode::Vector{Expr}
inputAssignCode::Vector{FunctionCall}
computeCode::Vector{FunctionCall}
inputSymbols::Dict{String, Vector{Symbol}}
outputSymbol::Symbol
cache::Dict{Symbol, Any}
process::AbstractProcessDescription
machine::Machine
end

View File

@ -76,91 +76,17 @@ function compute(::ComputeTaskABC_S1, data::ABCParticleValue{P})::ABCParticleVal
end
"""
compute(::ComputeTaskABC_Sum, data::StaticVector)
compute(::ComputeTaskABC_Sum, data...)
compute(::ComputeTaskABC_Sum, data::AbstractArray)
Compute a sum over the vector. Use an algorithm that accounts for accumulated errors in long sums with potentially large differences in magnitude of the summands.
Linearly many FLOP with growing data.
"""
function compute(::ComputeTaskABC_Sum, data::StaticVector)::Float64
function compute(::ComputeTaskABC_Sum, data...)::Float64
return sum(data)
end
"""
get_expression(::ComputeTaskABC_P, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
Generate and return code evaluating [`ComputeTaskABC_P`](@ref) on `inSyms`, providing the output on `outSym`.
"""
function get_expression(::ComputeTaskABC_P, device::AbstractDevice, inExprs::Vector, outExpr)
in = [eval(inExprs[1])]
out = eval(outExpr)
return Meta.parse("$out = compute(ComputeTaskABC_P(), $(in[1]))")
end
"""
get_expression(::ComputeTaskABC_U, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
Generate code evaluating [`ComputeTaskABC_U`](@ref) on `inSyms`, providing the output on `outSym`.
`inSyms` should be of type [`ABCParticleValue`](@ref), `outSym` will be of type [`ABCParticleValue`](@ref).
"""
function get_expression(::ComputeTaskABC_U, device::AbstractDevice, inExprs::Vector, outExpr)
in = [eval(inExprs[1])]
out = eval(outExpr)
return Meta.parse("$out = compute(ComputeTaskABC_U(), $(in[1]))")
end
"""
get_expression(::ComputeTaskABC_V, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
Generate code evaluating [`ComputeTaskABC_V`](@ref) on `inSyms`, providing the output on `outSym`.
`inSym[1]` and `inSym[2]` should be of type [`ABCParticleValue`](@ref), `outSym` will be of type [`ABCParticleValue`](@ref).
"""
function get_expression(::ComputeTaskABC_V, device::AbstractDevice, inExprs::Vector, outExpr)
in = [eval(inExprs[1]), eval(inExprs[2])]
out = eval(outExpr)
return Meta.parse("$out = compute(ComputeTaskABC_V(), $(in[1]), $(in[2]))")
end
"""
get_expression(::ComputeTaskABC_S2, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
Generate code evaluating [`ComputeTaskABC_S2`](@ref) on `inSyms`, providing the output on `outSym`.
`inSyms[1]` and `inSyms[2]` should be of type [`ABCParticleValue`](@ref), `outSym` will be of type `Float64`.
"""
function get_expression(::ComputeTaskABC_S2, device::AbstractDevice, inExprs::Vector, outExpr)
in = [eval(inExprs[1]), eval(inExprs[2])]
out = eval(outExpr)
return Meta.parse("$out = compute(ComputeTaskABC_S2(), $(in[1]), $(in[2]))")
end
"""
get_expression(::ComputeTaskABC_S1, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
Generate code evaluating [`ComputeTaskABC_S1`](@ref) on `inSyms`, providing the output on `outSym`.
`inSyms` should be of type [`ABCParticleValue`](@ref), `outSym` will be of type [`ABCParticleValue`](@ref).
"""
function get_expression(::ComputeTaskABC_S1, device::AbstractDevice, inExprs::Vector, outExpr)
in = [eval(inExprs[1])]
out = eval(outExpr)
return Meta.parse("$out = compute(ComputeTaskABC_S1(), $(in[1]))")
end
"""
get_expression(::ComputeTaskABC_Sum, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
Generate code evaluating [`ComputeTaskABC_Sum`](@ref) on `inSyms`, providing the output on `outSym`.
`inSyms` should be of type [`Float64`], `outSym` will be of type [`Float64`].
"""
function get_expression(::ComputeTaskABC_Sum, device::AbstractDevice, inExprs::Vector, outExpr)
in = eval.(inExprs)
out = eval(outExpr)
return Meta.parse(
"$out = compute(ComputeTaskABC_Sum(), SVector{$(length(inExprs)), Float64}($(unroll_symbol_vector(in))))",
)
function compute(::ComputeTaskABC_Sum, data::AbstractArray)::Float64
return sum(data)
end

View File

@ -43,48 +43,6 @@ this doesn't matter.
"""
compute_effort(t::ComputeTaskABC_Sum)::Float64 = 1.0
"""
show(io::IO, t::ComputeTaskABC_S1)
Print the S1 task to io.
"""
show(io::IO, t::ComputeTaskABC_S1) = print(io, "ComputeS1")
"""
show(io::IO, t::ComputeTaskABC_S2)
Print the S2 task to io.
"""
show(io::IO, t::ComputeTaskABC_S2) = print(io, "ComputeS2")
"""
show(io::IO, t::ComputeTaskABC_P)
Print the P task to io.
"""
show(io::IO, t::ComputeTaskABC_P) = print(io, "ComputeP")
"""
show(io::IO, t::ComputeTaskABC_U)
Print the U task to io.
"""
show(io::IO, t::ComputeTaskABC_U) = print(io, "ComputeU")
"""
show(io::IO, t::ComputeTaskABC_V)
Print the V task to io.
"""
show(io::IO, t::ComputeTaskABC_V) = print(io, "ComputeV")
"""
show(io::IO, t::ComputeTaskABC_Sum)
Print the sum task to io.
"""
show(io::IO, t::ComputeTaskABC_Sum) = print(io, "ComputeSum")
"""
children(::ComputeTaskABC_S1)

View File

@ -106,92 +106,19 @@ function compute(::ComputeTaskQED_S1, data::QEDParticleValue{P}) where {P <: QED
end
"""
compute(::ComputeTaskQED_Sum, data::StaticVector)
compute(::ComputeTaskQED_Sum, data...)
compute(::ComputeTaskQED_Sum, data::AbstractArray)
Compute a sum over the vector. Use an algorithm that accounts for accumulated errors in long sums with potentially large differences in magnitude of the summands.
Linearly many FLOP with growing data.
"""
function compute(::ComputeTaskQED_Sum, data::StaticVector)::ComplexF64
function compute(::ComputeTaskQED_Sum, data...)::ComplexF64
# TODO: want to use sum_kbn here but it doesn't seem to support ComplexF64, do it element-wise?
return sum(data)
end
"""
get_expression(::ComputeTaskQED_P, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
Generate and return code evaluating [`ComputeTaskQED_P`](@ref) on `inSyms`, providing the output on `outSym`.
"""
function get_expression(::ComputeTaskQED_P, device::AbstractDevice, inExprs::Vector, outExpr)
in = [eval(inExprs[1])]
out = eval(outExpr)
return Meta.parse("$out = compute(ComputeTaskQED_P(), $(in[1]))")
end
"""
get_expression(::ComputeTaskQED_U, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
Generate code evaluating [`ComputeTaskQED_U`](@ref) on `inSyms`, providing the output on `outSym`.
`inSyms` should be of type [`QEDParticleValue`](@ref), `outSym` will be of type [`QEDParticleValue`](@ref).
"""
function get_expression(::ComputeTaskQED_U, device::AbstractDevice, inExprs::Vector, outExpr)
in = [eval(inExprs[1])]
out = eval(outExpr)
return Meta.parse("$out = compute(ComputeTaskQED_U(), $(in[1]))")
end
"""
get_expression(::ComputeTaskQED_V, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
Generate code evaluating [`ComputeTaskQED_V`](@ref) on `inSyms`, providing the output on `outSym`.
`inSym[1]` and `inSym[2]` should be of type [`QEDParticleValue`](@ref), `outSym` will be of type [`QEDParticleValue`](@ref).
"""
function get_expression(::ComputeTaskQED_V, device::AbstractDevice, inExprs::Vector, outExpr)
in = [eval(inExprs[1]), eval(inExprs[2])]
out = eval(outExpr)
return Meta.parse("$out = compute(ComputeTaskQED_V(), $(in[1]), $(in[2]))")
end
"""
get_expression(::ComputeTaskQED_S2, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
Generate code evaluating [`ComputeTaskQED_S2`](@ref) on `inSyms`, providing the output on `outSym`.
`inSyms[1]` and `inSyms[2]` should be of type [`QEDParticleValue`](@ref), `outSym` will be of type `Float64`.
"""
function get_expression(::ComputeTaskQED_S2, device::AbstractDevice, inExprs::Vector, outExpr)
in = [eval(inExprs[1]), eval(inExprs[2])]
out = eval(outExpr)
return Meta.parse("$out = compute(ComputeTaskQED_S2(), $(in[1]), $(in[2]))")
end
"""
get_expression(::ComputeTaskQED_S1, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
Generate code evaluating [`ComputeTaskQED_S1`](@ref) on `inSyms`, providing the output on `outSym`.
`inSyms` should be of type [`QEDParticleValue`](@ref), `outSym` will be of type [`QEDParticleValue`](@ref).
"""
function get_expression(::ComputeTaskQED_S1, device::AbstractDevice, inExprs::Vector, outExpr)
in = [eval(inExprs[1])]
out = eval(outExpr)
return Meta.parse("$out = compute(ComputeTaskQED_S1(), $(in[1]))")
end
"""
get_expression(::ComputeTaskQED_Sum, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
Generate code evaluating [`ComputeTaskQED_Sum`](@ref) on `inSyms`, providing the output on `outSym`.
`inSyms` should be of type [`Float64`], `outSym` will be of type [`Float64`].
"""
function get_expression(::ComputeTaskQED_Sum, device::AbstractDevice, inExprs::Vector, outExpr)
in = eval.(inExprs)
out = eval(outExpr)
return Meta.parse(
"$out = compute(ComputeTaskQED_Sum(), SVector{$(length(inExprs)), ComplexF64}($(unroll_symbol_vector(in))))",
)
function compute(::ComputeTaskQED_Sum, data::AbstractArray)::ComplexF64
# TODO: want to use sum_kbn here but it doesn't seem to support ComplexF64, do it element-wise?
return sum(data)
end

View File

@ -170,6 +170,12 @@ end
String(::Type{Incoming}) = "Incoming"
String(::Type{Outgoing}) = "Outgoing"
String(::Type{PolX}) = "polx"
String(::Type{PolY}) = "poly"
String(::Type{SpinUp}) = "spinup"
String(::Type{SpinDown}) = "spindown"
String(::Incoming) = "i"
String(::Outgoing) = "o"
@ -183,6 +189,16 @@ function String(::Type{<:AntiFermionStateful})
return "p"
end
function unique_name(::Type{PhotonStateful{Dir, Pol}}) where {Dir, Pol}
return String(PhotonStateful) * String(Dir) * String(Pol)
end
function unique_name(::Type{FermionStateful{Dir, Spin}}) where {Dir, Spin}
return String(FermionStateful) * String(Dir) * String(Spin)
end
function unique_name(::Type{AntiFermionStateful{Dir, Spin}}) where {Dir, Spin}
return String(AntiFermionStateful) * String(Dir) * String(Spin)
end
@inline particle(::PhotonStateful) = Photon()
@inline particle(::FermionStateful) = Electron()
@inline particle(::AntiFermionStateful) = Positron()

View File

@ -45,48 +45,6 @@ this doesn't matter.
"""
compute_effort(t::ComputeTaskQED_Sum)::Float64 = 1.0
"""
show(io::IO, t::ComputeTaskQED_S1)
Print the S1 task to io.
"""
show(io::IO, t::ComputeTaskQED_S1) = print(io, "ComputeS1")
"""
show(io::IO, t::ComputeTaskQED_S2)
Print the S2 task to io.
"""
show(io::IO, t::ComputeTaskQED_S2) = print(io, "ComputeS2")
"""
show(io::IO, t::ComputeTaskQED_P)
Print the P task to io.
"""
show(io::IO, t::ComputeTaskQED_P) = print(io, "ComputeP")
"""
show(io::IO, t::ComputeTaskQED_U)
Print the U task to io.
"""
show(io::IO, t::ComputeTaskQED_U) = print(io, "ComputeU")
"""
show(io::IO, t::ComputeTaskQED_V)
Print the V task to io.
"""
show(io::IO, t::ComputeTaskQED_V) = print(io, "ComputeV")
"""
show(io::IO, t::ComputeTaskQED_Sum)
Print the sum task to io.
"""
show(io::IO, t::ComputeTaskQED_Sum) = print(io, "ComputeSum")
"""
children(::ComputeTaskQED_S1)

View File

@ -14,7 +14,7 @@ function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine)
enqueue!(nodeQueue, node => 0)
end
schedule = Vector{Node}()
schedule = Vector{FunctionCall}()
sizehint!(schedule, length(graph.nodes))
# keep an accumulated cost of things scheduled to this device so far
@ -35,7 +35,12 @@ function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine)
deviceAccCost[lowestDevice] = compute_effort(task(node))
end
push!(schedule, node)
if (node isa DataTaskNode && length(node.children) == 0)
push!(schedule, get_init_function_call(node, entry_device(machine)))
else
push!(schedule, get_function_call(node)...)
end
for parent in parents(node)
# reduce the priority of all parents by one
if (!haskey(nodeQueue, parent))

View File

@ -14,5 +14,7 @@ Interface functions that must be implemented for implementations of [`Scheduler`
The function assigns each [`ComputeTaskNode`](@ref) of the [`DAG`](@ref) to one of the devices in the given [`Machine`](@ref) and returns a `Vector{Node}` representing a topological ordering.
[`DataTaskNode`](@ref)s are not scheduled to devices since they do not compute. Instead, a data node transfers data from the [`AbstractDevice`](@ref) of their child to all [`AbstractDevice`](@ref)s of its parents.
Return a `Vector{FunctionCall}`. See [`FunctionCall`](@ref)
"""
function schedule_dag end

14
src/scheduler/type.jl Normal file
View File

@ -0,0 +1,14 @@
using StaticArrays
"""
FunctionCall{N}
Type representing a function call with `N` parameters. Contains the function to call, argument symbols, the return symbol and the device to execute on.
"""
struct FunctionCall{VectorType <: AbstractVector, M}
func::Function
arguments::VectorType
additional_arguments::SVector{M, Any} # additional arguments (as values) for the function call, will be prepended to the other arguments
return_symbol::Symbol
device::AbstractDevice
end

View File

@ -1,89 +1,85 @@
using StaticArrays
"""
compute(t::FusedComputeTask, data)
Compute a [`FusedComputeTask`](@ref). This simply asserts false and should not be called. Fused Compute Tasks generate their expressions directly through the other tasks instead.
"""
function compute(t::FusedComputeTask, data)
@assert false "This is not implemented and should never be called"
function compute(t::FusedComputeTask, data...)
inter = compute(t.first_task)
return compute(t.second_task, inter, data2...)
end
"""
get_expression(t::FusedComputeTask, device::AbstractDevice, inExprs::Vector{String}, outExpr::String)
get_function_call(n::Node)
get_function_call(t::AbstractTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol)
Generate code evaluating a [`FusedComputeTask`](@ref) on `inExprs`, providing the output on `outExpr`.
`inExprs` should be of the correct types and may be heterogeneous. `outExpr` will be of the type of the output of `T2` of t.
For a node or a task together with necessary information, return a vector of [`FunctionCall`](@ref)s for the computation of the node or task.
For ordinary compute or data tasks the vector will contain exactly one element, for a [`FusedComputeTask`](@ref) there can be any number of tasks greater 1.
"""
function get_expression(t::FusedComputeTask, device::AbstractDevice, inExprs::Vector, outExpr)
inExprs1 = Vector()
for sym in t.t1_inputs
push!(inExprs1, gen_access_expr(device, sym))
end
outExpr1 = gen_access_expr(device, t.t1_output)
inExprs2 = Vector()
for sym in t.t2_inputs
push!(inExprs2, gen_access_expr(device, sym))
end
expr1 = get_expression(t.first_task, device, inExprs1, outExpr1)
expr2 = get_expression(t.second_task, device, [inExprs2..., outExpr1], outExpr)
full_expr = Expr(:block, expr1, expr2)
return full_expr
function get_function_call(t::FusedComputeTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol)
# sort out the symbols to the correct tasks
return [
get_function_call(t.first_task, device, t.t1_inputs, t.t1_output)...,
get_function_call(t.second_task, device, [t.t2_inputs..., t.t1_output], outSymbol)...,
]
end
"""
get_expression(node::ComputeTaskNode)
function get_function_call(
t::CompTask,
device::AbstractDevice,
inSymbols::AbstractVector,
outSymbol::Symbol,
) where {CompTask <: AbstractComputeTask}
return [FunctionCall(compute, inSymbols, SVector{1, Any}(t), outSymbol, device)]
end
Generate and return code for a given [`ComputeTaskNode`](@ref).
"""
function get_expression(node::ComputeTaskNode)
function get_function_call(node::ComputeTaskNode)
@assert length(children(node)) <= children(task(node)) "Node $(node) has too many children for its task: node has $(length(node.children)) versus task has $(children(task(node)))\nNode's children: $(getfield.(node.children, :children))"
@assert !ismissing(node.device) "Trying to get expression for an unscheduled ComputeTaskNode\nNode: $(node)"
inExprs = Vector()
for id in getfield.(children(node), :id)
push!(inExprs, gen_access_expr(node.device, Symbol(to_var_name(id))))
if (length(node.children) <= 50)
#only use an SVector when there are few children
return get_function_call(
node.task,
node.device,
SVector{length(node.children), Symbol}(Symbol.(to_var_name.(getfield.(children(node), :id)))...),
Symbol(to_var_name(node.id)),
)
else
return get_function_call(
node.task,
node.device,
Symbol.(to_var_name.(getfield.(children(node), :id))),
Symbol(to_var_name(node.id)),
)
end
outExpr = gen_access_expr(node.device, Symbol(to_var_name(node.id)))
return get_expression(task(node), node.device, inExprs, outExpr)
end
"""
get_expression(node::DataTaskNode)
Generate and return code for a given [`DataTaskNode`](@ref).
"""
function get_expression(node::DataTaskNode)
function get_function_call(node::DataTaskNode)
@assert length(children(node)) == 1 "Trying to call get_expression on a data task node that has $(length(node.children)) children instead of 1"
# TODO: dispatch to device implementations generating the copy commands
child = children(node)[1]
inExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(child.id))))
outExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(node.id))))
dataTransportExp = Meta.parse("$outExpr = $inExpr")
return dataTransportExp
return [
FunctionCall(
unpack_identity,
SVector{1, Symbol}(Symbol(to_var_name(first(children(node)).id))),
SVector{0, Any}(),
Symbol(to_var_name(node.id)),
first(children(node)).device,
),
]
end
"""
get_init_expression(node::DataTaskNode, device::AbstractDevice)
Generate and return code for the initial input reading expression for [`DataTaskNode`](@ref)s with 0 children, i.e., entry nodes.
See also: [`get_entry_nodes`](@ref)
"""
function get_init_expression(node::DataTaskNode, device::AbstractDevice)
function get_init_function_call(node::DataTaskNode, device::AbstractDevice)
@assert isempty(children(node)) "Trying to call get_init_expression on a data task node that is not an entry node."
inExpr = eval(gen_access_expr(device, Symbol("$(to_var_name(node.id))_in")))
outExpr = eval(gen_access_expr(device, Symbol(to_var_name(node.id))))
dataTransportExp = Meta.parse("$outExpr = $inExpr")
return dataTransportExp
return FunctionCall(
unpack_identity,
SVector{1, Symbol}(Symbol("$(to_var_name(node.id))_in")),
SVector{0, Any}(),
Symbol(to_var_name(node.id)),
device,
)
end

View File

@ -1,17 +0,0 @@
"""
show(io::IO, t::FusedComputeTask)
Print a string representation of the fused compute task to io.
"""
function show(io::IO, t::FusedComputeTask)
return print(io, "ComputeFuse($(t.first_task), $(t.second_task))")
end
"""
show(io::IO, t::DataTask)
Print the data task to io.
"""
function show(io::IO, t::DataTask)
return print(io, "Data", t.data)
end

View File

@ -3,28 +3,10 @@
Fallback implementation of the compute function of a compute task, throwing an error.
"""
function compute(t::AbstractTask; data...)
function compute(t::AbstractTask, data...)
return error("Need to implement compute()")
end
"""
compute(t::FusedComputeTask; data...)
Compute a fused compute task.
"""
function compute(t::FusedComputeTask; data...)
(T1, T2) = collect(typeof(t).parameters)
return compute(T2(), compute(T1(), data))
end
"""
compute(t::AbstractDataTask; data...)
The compute function of a data task, always the identity function, regardless of the specific task.
"""
compute(t::AbstractDataTask; data...) = data
"""
compute_effort(t::AbstractTask)

View File

@ -1,3 +1,18 @@
"""
noop()
Function with no arguments, returns nothing, does nothing. Useful for noop [`FunctionCall`](@ref)s.
"""
@inline noop() = nothing
"""
unpack_identity(x::SVector)
Function taking an `SVector`, returning it unpacked.
"""
@inline unpack_identity(x::SVector{1, <:Any}) = x[1]
@inline unpack_identity(x) = x
"""
bytes_to_human_readable(bytes)
@ -104,6 +119,10 @@ function unroll_symbol_vector(vec::Vector)
return result
end
function unroll_symbol_vector(vec::SVector)
return unroll_symbol_vector(Vector(vec))
end
####################