heterogeneity (#27)

Prepare things to work with heterogeneity, make things work on GPU

Reviewed-on: Rubydragon/MetagraphOptimization.jl#27
Co-authored-by: Anton Reinhard <anton.reinhard@proton.me>
Co-committed-by: Anton Reinhard <anton.reinhard@proton.me>
This commit is contained in:
2023-12-18 14:31:52 +01:00
committed by Anton Reinhard
parent c90346e948
commit 92e0eeaaef
42 changed files with 1631 additions and 238 deletions

View File

@@ -62,21 +62,12 @@ function gen_input_assignment_code(
assignInputs = Vector{Expr}()
for (name, symbols) in inputSymbols
(type, index) = type_index_from_name(model(processDescription), name)
p = nothing
if (index > get(in_particles(processDescription), type, 0))
index -= get(in_particles(processDescription), type, 0)
@assert index <= out_particles(processDescription)[type] "Too few particles of type $type in input particles for this process"
p = "filter(x -> typeof(x) <: $type, out_particles($(processInputSymbol)))[$(index)]"
else
p = "filter(x -> typeof(x) <: $type, in_particles($(processInputSymbol)))[$(index)]"
end
p = "get_particle($(processInputSymbol), $(type), $(index))"
for symbol in symbols
device = entry_device(machine)
evalExpr = eval(gen_access_expr(device, symbol))
push!(assignInputs, Meta.parse("$(evalExpr)::ParticleValue{$type} = ParticleValue($p, one(ComplexF64))"))
push!(assignInputs, Meta.parse("$(evalExpr) = ParticleValue{$type, ComplexF64}($p, one(ComplexF64))"))
end
end
@@ -111,10 +102,12 @@ end
Execute the code of the given `graph` on the given input particles.
This is essentially shorthand for
```julia
compute_graph = get_compute_function(graph, process)
result = compute_graph(particles)
```
```julia
compute_graph = get_compute_function(graph, process)
result = compute_graph(particles)
```
If an exception occurs during the execution of the generated code, it will be printed for investigation.
See also: [`parse_dag`](@ref), [`parse_process`](@ref), [`gen_process_input`](@ref)
"""
@@ -135,6 +128,8 @@ function execute(graph::DAG, process::AbstractProcessDescription, machine::Machi
result = 0
try
result = @eval $func($input)
#functionStr = string(expr)
#println("Function:\n$functionStr")
catch e
println("Error while evaluating: $e")