From 86ad9ed5e8ecd43152382d66c60127147c0bd4f0 Mon Sep 17 00:00:00 2001
From: Anton Reinhard <anton.reinhard@proton.me>
Date: Wed, 6 Mar 2024 23:41:13 +0100
Subject: [PATCH] Add kernel generating function

---
 Project.toml                |  3 ++-
 examples/full_node_bench.jl |  6 +++---
 src/code_gen/function.jl    | 32 ++++++++++++++++++++++++++++++++
 src/models/qed/compute.jl   |  4 ++--
 src/models/qed/particle.jl  |  2 +-
 5 files changed, 40 insertions(+), 7 deletions(-)

diff --git a/Project.toml b/Project.toml
index 0335c8c..c495944 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
+authors = ["Anton Reinhard <anton.reinhard@proton.me>"]
 name = "MetagraphOptimization"
 uuid = "3e869610-d48d-4942-ba70-c1b702a33ca4"
-authors = ["Anton Reinhard <anton.reinhard@proton.me>"]
 version = "0.1.0"
 
 [deps]
@@ -20,6 +20,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
 UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
 
 [extras]
+CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
 Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
 
 [targets]
diff --git a/examples/full_node_bench.jl b/examples/full_node_bench.jl
index bada3f4..351096a 100644
--- a/examples/full_node_bench.jl
+++ b/examples/full_node_bench.jl
@@ -36,7 +36,7 @@ if isfile(results_filename)
     df = CSV.read(results_filename, DataFrame)
 end
 
-nInputs = 1_073_741_824 # 2^30
+nInputs = 16_777_216 # 2^30
 
 lck = ReentrantLock()
 
@@ -151,7 +151,7 @@ function bench(compute_function, inputs, chunk_size)
 
     bench = @benchmark begin
         full_compute($compute_function, $inputs, $chunk_size)
-    end gcsample = true seconds = 600
+    end gcsample = true seconds = 30
 
     time = median(bench.times) / 1e9
     s = std(bench.times) / 1e9
@@ -212,7 +212,7 @@ machine = Machine(
 )
 
 optimizer = ReductionOptimizer()
-processes = ["ke->ke", "ke->kke", "ke->kkke", "ke->kkkke", "ke->kkkkke"]
+processes = [#="ke->ke", "ke->kke", "ke->kkke", =#"ke->kkkke", "ke->kkkkke"]
 
 for proc in processes
     process = parse_process(proc, QEDModel())
diff --git a/src/code_gen/function.jl b/src/code_gen/function.jl
index 8e21b3e..97df44e 100644
--- a/src/code_gen/function.jl
+++ b/src/code_gen/function.jl
@@ -21,6 +21,38 @@ function get_compute_function(graph::DAG, process::AbstractProcessDescription, m
     return func
 end
 
+"""
+    get_cuda_kernel(graph::DAG, process::AbstractProcessDescription, machine::Machine)
+
+Return a function of signature `compute_<id>(input::CuVector, output::CuVector, n::Int64)`, which will return the result of the DAG computation of the input on the given output variable.
+"""
+function get_cuda_kernel(graph::DAG, process::AbstractProcessDescription, machine::Machine)
+    tape = gen_tape(graph, process, machine)
+
+    initCaches = Expr(:block, tape.initCachesCode...)
+    assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...)
+    code = Expr(:block, expr_from_fc.(tape.computeCode)...)
+
+    functionId = to_var_name(UUIDs.uuid1(rng[1]))
+    resSym = eval(gen_access_expr(entry_device(tape.machine), tape.outputSymbol))
+    expr = Meta.parse("function compute_$(functionId)(input_vector, output_vector, n::Int64)
+                          id = (blockIdx().x - 1) * blockDim().x + threadIdx().x
+                          if (id > n)
+                              return
+                          end
+                          @inline data_input = input_vector[id]
+                          $(initCaches)
+                          $(assignInputs)
+                          $code
+                          @inline output_vector[id] = $resSym
+                          return nothing
+                      end")
+
+    func = eval(expr)
+
+    return func
+end
+
 """
     execute(graph::DAG, process::AbstractProcessDescription, machine::Machine, input::AbstractProcessInput)
 
diff --git a/src/models/qed/compute.jl b/src/models/qed/compute.jl
index 84859d0..b30994f 100644
--- a/src/models/qed/compute.jl
+++ b/src/models/qed/compute.jl
@@ -72,9 +72,9 @@ function compute(
 
     # inner edge is just a "scalar", data1 and data2 are bispinor/adjointbispinnor, need to keep correct order
     if typeof(data1.v) <: BiSpinor
-        return data2.v * inner * data1.v
+        return (data2.v)::AdjointBiSpinor * inner * (data1.v)::BiSpinor
     else
-        return data1.v * inner * data2.v
+        return (data1.v)::AdjointBiSpinor * inner * (data2.v)::BiSpinor
     end
 end
 
diff --git a/src/models/qed/particle.jl b/src/models/qed/particle.jl
index d51fedd..152b7b2 100644
--- a/src/models/qed/particle.jl
+++ b/src/models/qed/particle.jl
@@ -313,7 +313,7 @@ Return the factor of a vertex in a QED feynman diagram.
     return -1im * e * gamma()
 end
 
-@inline function QED_inner_edge(p::QEDParticle)
+@inline function QED_inner_edge(p::QEDParticle)::DiracMatrix
     return propagator(particle(p), p.momentum)
 end