From 4eee23f08121fd383bae437901f5389cf4a1fa2f Mon Sep 17 00:00:00 2001 From: Anton Reinhard Date: Tue, 9 Jul 2024 23:19:25 +0200 Subject: [PATCH] Run on GPU --- examples/congruent_in_ph.jl | 152 ++++++++++------------- src/MetagraphOptimization.jl | 3 - src/QEDprocesses_patch.jl | 46 ------- src/models/physics_models/qed/compute.jl | 2 +- 4 files changed, 66 insertions(+), 137 deletions(-) diff --git a/examples/congruent_in_ph.jl b/examples/congruent_in_ph.jl index 19acc24..8b805b7 100644 --- a/examples/congruent_in_ph.jl +++ b/examples/congruent_in_ph.jl @@ -6,8 +6,13 @@ using QEDcore using QEDprocesses using Random using UUIDs + +using CUDA + +using NamedDims using CSV using JLD2 +using FlexiMaps RNG = Random.MersenneTwister(123) @@ -79,7 +84,7 @@ function congruent_input_momenta_scenario_2( # ---------- # now calculate the final_momenta from omega, cos_theta and phi - n = number_particles(processDescription, ParticleStateful{Incoming, Photon, SFourMomentum}) + n = number_particles(processDescription, Incoming(), Photon()) cos_theta = cos(theta) omega_prime = (n * omega) / (1 + n * omega * (1 - cos_theta)) @@ -108,109 +113,82 @@ end with_stacksize(f, n) = fetch(schedule(Task(f, n))) # scenario 2 -N = 1000 -M = 1000 +N = 1024 # thetas +M = 1024 # phis +K = 64 # omegas thetas = collect(LinRange(0, 2π, N)) phis = collect(LinRange(0, 2π, M)) +omegas = collect(maprange(log, 2e-2, 2e-7, K)) -for photons in 1:6 +for photons in 1:5 # temp process to generate momenta - for omega in [2e-3, 2e-6] - println("Generating $(N*M) inputs for $photons photons (Scenario 2 grid walk)...") - temp_process = parse_process("k"^photons * "e->ke", QEDModel(), PolX(), SpinUp(), PolX(), SpinUp()) + println("Generating $(K*N*M) inputs for $photons photons (Scenario 2 grid walk)...") + temp_process = parse_process("k"^photons * "e->ke", QEDModel(), PolX(), SpinUp(), PolX(), SpinUp()) - input_momenta = [ - congruent_input_momenta_scenario_2(temp_process, omega, theta, phi) for - (theta, phi) in Iterators.product(thetas, phis) - ] - results = Array{Float64}(undef, size(input_momenta)) - fill!(results, 0.0) + input_momenta = + Array{typeof(congruent_input_momenta_scenario_2(temp_process, omegas[1], thetas[1], phis[1]))}(undef, (K, N, M)) - i = 1 - for (in_pol, in_spin, out_pol, out_spin) in - Iterators.product([PolX(), PolY()], [SpinUp(), SpinDown()], [PolX(), PolY()], [SpinUp(), SpinDown()]) + Threads.@threads for k in 1:K + Threads.@threads for i in 1:N + Threads.@threads for j in 1:M + input_momenta[k, i, j] = congruent_input_momenta_scenario_2(temp_process, omegas[k], thetas[i], phis[j]) + end + end + end - print( - "[$i/16] Calculating for spin/pol config: $in_pol, $in_spin -> $out_pol, $out_spin... Preparing inputs... ", - ) - process = parse_process("k"^photons * "e->ke", QEDModel(), in_pol, in_spin, out_pol, out_spin) - inputs = build_psp.(Ref(process), input_momenta) - print("Preparing graph... ") - graph = gen_graph(process) - optimize_to_fixpoint!(ReductionOptimizer(), graph) - print("Preparing function... ") - func = get_compute_function(graph, process, mock_machine()) - func(inputs[1]) + cu_results = CuArray{Float64}(undef, size(input_momenta)) + fill!(cu_results, 0.0) - print("Calculating... ") + i = 1 + for (in_pol, in_spin, out_pol, out_spin) in + Iterators.product([PolX(), PolY()], [SpinUp(), SpinDown()], [PolX(), PolY()], [SpinUp(), SpinDown()]) + + print( + "[$i/16] Calculating for spin/pol config: $in_pol, $in_spin -> $out_pol, $out_spin... Preparing inputs... ", + ) + process = parse_process("k"^photons * "e->ke", QEDModel(), in_pol, in_spin, out_pol, out_spin) + + inputs = Array{typeof(build_psp(process, input_momenta[1, 1, 1]))}(undef, (K, N, M)) + #println("input_momenta: $input_momenta") + Threads.@threads for k in 1:K Threads.@threads for i in 1:N Threads.@threads for j in 1:M - return results[i, j] += abs2(func(inputs[i, j])) + inputs[k, i, j] = build_psp(process, input_momenta[k, i, j]) end end - println("Done.") - i += 1 end + cu_inputs = CuArray(inputs) - println("Writing results") + print("Preparing graph... ") + graph = gen_graph(process) + optimize_to_fixpoint!(ReductionOptimizer(), graph) + print("Preparing function... ") + kernel! = get_cuda_kernel(graph, process, mock_machine()) + #func = get_compute_function(graph, process, mock_machine()) - out_ph_moms = getindex.(getindex.(input_momenta, 2), 1) - out_el_moms = getindex.(getindex.(input_momenta, 2), 2) + print("Calculating... ") + ts = 32 + bs = Int64(length(cu_inputs) / 32) - @save "$(photons)_congruent_photons_omega_$(omega)_grid.jld2" out_ph_moms out_el_moms results - end -end - -exit(0) - -# scenario 1 (disabled) -n = 1000000 - -# n is the number of incoming photons -# omega is the number - -for photons in 1:6 - # temp process to generate momenta - for omega in [2e-3, 2e-6] - println("Generating $n inputs for $photons photons...") - temp_process = parse_process("k"^photons * "e->ke", QEDModel(), PolX(), SpinUp(), PolX(), SpinUp()) - - input_momenta = [congruent_input_momenta(temp_process, omega) for _ in 1:n] - results = Array{Float64}(undef, size(input_momenta)) - fill!(results, 0.0) - - i = 1 - for (in_pol, in_spin, out_pol, out_spin) in - Iterators.product([PolX(), PolY()], [SpinUp(), SpinDown()], [PolX(), PolY()], [SpinUp(), SpinDown()]) - - print( - "[$i/16] Calculating for spin/pol config: $in_pol, $in_spin -> $out_pol, $out_spin... Preparing inputs... ", - ) - process = parse_process("k"^photons * "e->ke", QEDModel(), in_pol, in_spin, out_pol, out_spin) - inputs = build_psp.(Ref(process), input_momenta) - - print("Preparing graph... ") - # prepare function - graph = gen_graph(process) - optimize_to_fixpoint!(ReductionOptimizer(), graph) - print("Preparing function... ") - func = get_compute_function(graph, process, mock_machine()) - - print("Calculating... ") - Threads.@threads for i in 1:n - results[i] += abs2(func(inputs[i])) - end - println("Done.") - i += 1 - end - - println("Writing results") - - out_ph_moms = getindex.(getindex.(input_momenta, 2), 1) - out_el_moms = getindex.(getindex.(input_momenta, 2), 2) - - @save "$(photons)_congruent_photons_omega_$(omega).jld2" out_ph_moms out_el_moms results + outputs = CuArray{ComplexF64}(undef, size(cu_inputs)) + + @cuda threads = ts blocks = bs always_inline = true kernel!(cu_inputs, outputs, length(cu_inputs)) + CUDA.device_synchronize() + cu_results += abs2.(outputs) + + println("Done.") + i += 1 end + + println("Writing results") + + out_ph_moms = getindex.(getindex.(input_momenta, 2), 1) + out_el_moms = getindex.(getindex.(input_momenta, 2), 2) + + results = NamedDimsArray{(:omegas, :thetas, :phis)}(Array(cu_results)) + println("Named results array: $(typeof(results))") + + @save "$(photons)_congruent_photons_grid.jld2" omegas thetas phis results end diff --git a/src/MetagraphOptimization.jl b/src/MetagraphOptimization.jl index 77a85eb..874de89 100644 --- a/src/MetagraphOptimization.jl +++ b/src/MetagraphOptimization.jl @@ -100,9 +100,6 @@ export ==, in, show, isempty, delete!, length export bytes_to_human_readable -# TODO: this is probably not good -import QEDprocesses.compute - import Base.length import Base.show import Base.== diff --git a/src/QEDprocesses_patch.jl b/src/QEDprocesses_patch.jl index 1eba98c..282d713 100644 --- a/src/QEDprocesses_patch.jl +++ b/src/QEDprocesses_patch.jl @@ -1,25 +1,5 @@ # patch QEDprocesses # see issue https://github.com/QEDjl-project/QEDprocesses.jl/issues/77 -@inline function QEDprocesses.number_particles( - proc_def::QEDbase.AbstractProcessDefinition, - dir::DIR, - ::PT, -) where {DIR <: QEDbase.ParticleDirection, PT <: QEDbase.AbstractParticleType} - return count(x -> x isa PT, particles(proc_def, dir)) -end - -@inline function QEDprocesses.number_particles( - proc_def::QEDbase.AbstractProcessDefinition, - ::PS, -) where { - DIR <: QEDbase.ParticleDirection, - PT <: QEDbase.AbstractParticleType, - EL <: AbstractFourMomentum, - PS <: ParticleStateful{DIR, PT, EL}, -} - return QEDprocesses.number_particles(proc_def, DIR(), PT()) -end - @inline function QEDprocesses.number_particles( proc_def::QEDbase.AbstractProcessDefinition, ::Type{PS}, @@ -43,29 +23,3 @@ end ) where {DIR <: ParticleDirection, SPECIES <: AbstractParticleType, EL <: AbstractFourMomentum} return ParticleStateful(DIR(), SPECIES(), mom) end - -@inline function QEDbase.momentum( - psp::AbstractPhaseSpacePoint{MODEL, PROC, PS_DEF, INT, OUTT}, - dir::ParticleDirection, - species::AbstractParticleType, - n::Int, -) where {MODEL, PROC, PS_DEF, INT, OUTT} - # TODO: can be done through fancy template recursion too with 0 overhead - i = 0 - c = n - for p in particles(psp, dir) - i += 1 - if particle_species(p) isa typeof(species) - c -= 1 - end - if c == 0 - break - end - end - - if c != 0 || n <= 0 - throw(InvalidInputError("could not get $n-th momentum of $dir $species, does not exist")) - end - - return momenta(psp, dir)[i] -end diff --git a/src/models/physics_models/qed/compute.jl b/src/models/physics_models/qed/compute.jl index 01b1459..3650955 100644 --- a/src/models/physics_models/qed/compute.jl +++ b/src/models/physics_models/qed/compute.jl @@ -17,7 +17,7 @@ function input_expr(instance::GenericQEDProcess, name::String, psp_symbol::Symbo return Meta.parse( "ParticleValueSP( - $type(momentum($psp_symbol, $(construction_string(particle_direction(type))), $(construction_string(particle_species(type))), $index)), + $type(momentum($psp_symbol, $(construction_string(particle_direction(type))), $(construction_string(particle_species(type))), Val($index))), 0.0im, $(construction_string(spin_or_pol(instance, type, index))), )",