diff --git a/examples/congruent_in_ph.jl b/examples/congruent_in_ph.jl
index 19acc24..8b805b7 100644
--- a/examples/congruent_in_ph.jl
+++ b/examples/congruent_in_ph.jl
@@ -6,8 +6,13 @@ using QEDcore
 using QEDprocesses
 using Random
 using UUIDs
+
+using CUDA
+
+using NamedDims
 using CSV
 using JLD2
+using FlexiMaps
 
 RNG = Random.MersenneTwister(123)
 
@@ -79,7 +84,7 @@ function congruent_input_momenta_scenario_2(
     # ----------
 
     # now calculate the final_momenta from omega, cos_theta and phi
-    n = number_particles(processDescription, ParticleStateful{Incoming, Photon, SFourMomentum})
+    n = number_particles(processDescription, Incoming(), Photon())
 
     cos_theta = cos(theta)
     omega_prime = (n * omega) / (1 + n * omega * (1 - cos_theta))
@@ -108,109 +113,82 @@ end
 with_stacksize(f, n) = fetch(schedule(Task(f, n)))
 
 # scenario 2
-N = 1000
-M = 1000
+N = 1024    # thetas
+M = 1024    # phis
+K = 64      # omegas
 
 thetas = collect(LinRange(0, 2π, N))
 phis = collect(LinRange(0, 2π, M))
+omegas = collect(maprange(log, 2e-2, 2e-7, K))
 
-for photons in 1:6
+for photons in 1:5
     # temp process to generate momenta
-    for omega in [2e-3, 2e-6]
-        println("Generating $(N*M) inputs for $photons photons (Scenario 2 grid walk)...")
-        temp_process = parse_process("k"^photons * "e->ke", QEDModel(), PolX(), SpinUp(), PolX(), SpinUp())
+    println("Generating $(K*N*M) inputs for $photons photons (Scenario 2 grid walk)...")
+    temp_process = parse_process("k"^photons * "e->ke", QEDModel(), PolX(), SpinUp(), PolX(), SpinUp())
 
-        input_momenta = [
-            congruent_input_momenta_scenario_2(temp_process, omega, theta, phi) for
-            (theta, phi) in Iterators.product(thetas, phis)
-        ]
-        results = Array{Float64}(undef, size(input_momenta))
-        fill!(results, 0.0)
+    input_momenta =
+        Array{typeof(congruent_input_momenta_scenario_2(temp_process, omegas[1], thetas[1], phis[1]))}(undef, (K, N, M))
 
-        i = 1
-        for (in_pol, in_spin, out_pol, out_spin) in
-            Iterators.product([PolX(), PolY()], [SpinUp(), SpinDown()], [PolX(), PolY()], [SpinUp(), SpinDown()])
+    Threads.@threads for k in 1:K
+        Threads.@threads for i in 1:N
+            Threads.@threads for j in 1:M
+                input_momenta[k, i, j] = congruent_input_momenta_scenario_2(temp_process, omegas[k], thetas[i], phis[j])
+            end
+        end
+    end
 
-            print(
-                "[$i/16] Calculating for spin/pol config: $in_pol, $in_spin -> $out_pol, $out_spin... Preparing inputs... ",
-            )
-            process = parse_process("k"^photons * "e->ke", QEDModel(), in_pol, in_spin, out_pol, out_spin)
-            inputs = build_psp.(Ref(process), input_momenta)
 
-            print("Preparing graph... ")
-            graph = gen_graph(process)
-            optimize_to_fixpoint!(ReductionOptimizer(), graph)
-            print("Preparing function... ")
-            func = get_compute_function(graph, process, mock_machine())
-            func(inputs[1])
+    cu_results = CuArray{Float64}(undef, size(input_momenta))
+    fill!(cu_results, 0.0)
 
-            print("Calculating... ")
+    i = 1
+    for (in_pol, in_spin, out_pol, out_spin) in
+        Iterators.product([PolX(), PolY()], [SpinUp(), SpinDown()], [PolX(), PolY()], [SpinUp(), SpinDown()])
+
+        print(
+            "[$i/16] Calculating for spin/pol config: $in_pol, $in_spin -> $out_pol, $out_spin... Preparing inputs... ",
+        )
+        process = parse_process("k"^photons * "e->ke", QEDModel(), in_pol, in_spin, out_pol, out_spin)
+
+        inputs = Array{typeof(build_psp(process, input_momenta[1, 1, 1]))}(undef, (K, N, M))
+        #println("input_momenta: $input_momenta")
+        Threads.@threads for k in 1:K
             Threads.@threads for i in 1:N
                 Threads.@threads for j in 1:M
-                    return results[i, j] += abs2(func(inputs[i, j]))
+                    inputs[k, i, j] = build_psp(process, input_momenta[k, i, j])
                 end
             end
-            println("Done.")
-            i += 1
         end
+        cu_inputs = CuArray(inputs)
 
-        println("Writing results")
+        print("Preparing graph... ")
+        graph = gen_graph(process)
+        optimize_to_fixpoint!(ReductionOptimizer(), graph)
+        print("Preparing function... ")
+        kernel! = get_cuda_kernel(graph, process, mock_machine())
+        #func = get_compute_function(graph, process, mock_machine())
 
-        out_ph_moms = getindex.(getindex.(input_momenta, 2), 1)
-        out_el_moms = getindex.(getindex.(input_momenta, 2), 2)
+        print("Calculating... ")
+        ts = 32
+        bs = Int64(length(cu_inputs) / 32)
 
-        @save "$(photons)_congruent_photons_omega_$(omega)_grid.jld2" out_ph_moms out_el_moms results
-    end
-end
-
-exit(0)
-
-# scenario 1 (disabled)
-n = 1000000
-
-# n is the number of incoming photons
-# omega is the number
-
-for photons in 1:6
-    # temp process to generate momenta
-    for omega in [2e-3, 2e-6]
-        println("Generating $n inputs for $photons photons...")
-        temp_process = parse_process("k"^photons * "e->ke", QEDModel(), PolX(), SpinUp(), PolX(), SpinUp())
-
-        input_momenta = [congruent_input_momenta(temp_process, omega) for _ in 1:n]
-        results = Array{Float64}(undef, size(input_momenta))
-        fill!(results, 0.0)
-
-        i = 1
-        for (in_pol, in_spin, out_pol, out_spin) in
-            Iterators.product([PolX(), PolY()], [SpinUp(), SpinDown()], [PolX(), PolY()], [SpinUp(), SpinDown()])
-
-            print(
-                "[$i/16] Calculating for spin/pol config: $in_pol, $in_spin -> $out_pol, $out_spin... Preparing inputs... ",
-            )
-            process = parse_process("k"^photons * "e->ke", QEDModel(), in_pol, in_spin, out_pol, out_spin)
-            inputs = build_psp.(Ref(process), input_momenta)
-
-            print("Preparing graph... ")
-            # prepare function
-            graph = gen_graph(process)
-            optimize_to_fixpoint!(ReductionOptimizer(), graph)
-            print("Preparing function... ")
-            func = get_compute_function(graph, process, mock_machine())
-
-            print("Calculating... ")
-            Threads.@threads for i in 1:n
-                results[i] += abs2(func(inputs[i]))
-            end
-            println("Done.")
-            i += 1
-        end
-
-        println("Writing results")
-
-        out_ph_moms = getindex.(getindex.(input_momenta, 2), 1)
-        out_el_moms = getindex.(getindex.(input_momenta, 2), 2)
-
-        @save "$(photons)_congruent_photons_omega_$(omega).jld2" out_ph_moms out_el_moms results
+        outputs = CuArray{ComplexF64}(undef, size(cu_inputs))
+
+        @cuda threads = ts blocks = bs always_inline = true kernel!(cu_inputs, outputs, length(cu_inputs))
+        CUDA.device_synchronize()
+        cu_results += abs2.(outputs)
+
+        println("Done.")
+        i += 1
     end
+
+    println("Writing results")
+
+    out_ph_moms = getindex.(getindex.(input_momenta, 2), 1)
+    out_el_moms = getindex.(getindex.(input_momenta, 2), 2)
+
+    results = NamedDimsArray{(:omegas, :thetas, :phis)}(Array(cu_results))
+    println("Named results array: $(typeof(results))")
+
+    @save "$(photons)_congruent_photons_grid.jld2" omegas thetas phis results
 end
diff --git a/src/MetagraphOptimization.jl b/src/MetagraphOptimization.jl
index 77a85eb..874de89 100644
--- a/src/MetagraphOptimization.jl
+++ b/src/MetagraphOptimization.jl
@@ -100,9 +100,6 @@ export ==, in, show, isempty, delete!, length
 
 export bytes_to_human_readable
 
-# TODO: this is probably not good
-import QEDprocesses.compute
-
 import Base.length
 import Base.show
 import Base.==
diff --git a/src/QEDprocesses_patch.jl b/src/QEDprocesses_patch.jl
index 1eba98c..282d713 100644
--- a/src/QEDprocesses_patch.jl
+++ b/src/QEDprocesses_patch.jl
@@ -1,25 +1,5 @@
 # patch QEDprocesses
 # see issue https://github.com/QEDjl-project/QEDprocesses.jl/issues/77
-@inline function QEDprocesses.number_particles(
-    proc_def::QEDbase.AbstractProcessDefinition,
-    dir::DIR,
-    ::PT,
-) where {DIR <: QEDbase.ParticleDirection, PT <: QEDbase.AbstractParticleType}
-    return count(x -> x isa PT, particles(proc_def, dir))
-end
-
-@inline function QEDprocesses.number_particles(
-    proc_def::QEDbase.AbstractProcessDefinition,
-    ::PS,
-) where {
-    DIR <: QEDbase.ParticleDirection,
-    PT <: QEDbase.AbstractParticleType,
-    EL <: AbstractFourMomentum,
-    PS <: ParticleStateful{DIR, PT, EL},
-}
-    return QEDprocesses.number_particles(proc_def, DIR(), PT())
-end
-
 @inline function QEDprocesses.number_particles(
     proc_def::QEDbase.AbstractProcessDefinition,
     ::Type{PS},
@@ -43,29 +23,3 @@ end
 ) where {DIR <: ParticleDirection, SPECIES <: AbstractParticleType, EL <: AbstractFourMomentum}
     return ParticleStateful(DIR(), SPECIES(), mom)
 end
-
-@inline function QEDbase.momentum(
-    psp::AbstractPhaseSpacePoint{MODEL, PROC, PS_DEF, INT, OUTT},
-    dir::ParticleDirection,
-    species::AbstractParticleType,
-    n::Int,
-) where {MODEL, PROC, PS_DEF, INT, OUTT}
-    # TODO: can be done through fancy template recursion too with 0 overhead
-    i = 0
-    c = n
-    for p in particles(psp, dir)
-        i += 1
-        if particle_species(p) isa typeof(species)
-            c -= 1
-        end
-        if c == 0
-            break
-        end
-    end
-
-    if c != 0 || n <= 0
-        throw(InvalidInputError("could not get $n-th momentum of $dir $species, does not exist"))
-    end
-
-    return momenta(psp, dir)[i]
-end
diff --git a/src/models/physics_models/qed/compute.jl b/src/models/physics_models/qed/compute.jl
index 01b1459..3650955 100644
--- a/src/models/physics_models/qed/compute.jl
+++ b/src/models/physics_models/qed/compute.jl
@@ -17,7 +17,7 @@ function input_expr(instance::GenericQEDProcess, name::String, psp_symbol::Symbo
 
     return Meta.parse(
         "ParticleValueSP(
-    $type(momentum($psp_symbol, $(construction_string(particle_direction(type))), $(construction_string(particle_species(type))), $index)),
+    $type(momentum($psp_symbol, $(construction_string(particle_direction(type))), $(construction_string(particle_species(type))), Val($index))),
     0.0im,
     $(construction_string(spin_or_pol(instance, type, index))),
 )",