Run on GPU
Some checks failed
MetagraphOptimization_CI / docs (push) Failing after 5m43s
MetagraphOptimization_CI / test (push) Failing after 6m2s

This commit is contained in:
2024-07-09 23:19:25 +02:00
parent dee44dad66
commit 4eee23f081
4 changed files with 66 additions and 137 deletions

View File

@@ -6,8 +6,13 @@ using QEDcore
using QEDprocesses
using Random
using UUIDs
using CUDA
using NamedDims
using CSV
using JLD2
using FlexiMaps
RNG = Random.MersenneTwister(123)
@@ -79,7 +84,7 @@ function congruent_input_momenta_scenario_2(
# ----------
# now calculate the final_momenta from omega, cos_theta and phi
n = number_particles(processDescription, ParticleStateful{Incoming, Photon, SFourMomentum})
n = number_particles(processDescription, Incoming(), Photon())
cos_theta = cos(theta)
omega_prime = (n * omega) / (1 + n * omega * (1 - cos_theta))
@@ -108,109 +113,82 @@ end
with_stacksize(f, n) = fetch(schedule(Task(f, n)))
# scenario 2
N = 1000
M = 1000
N = 1024 # thetas
M = 1024 # phis
K = 64 # omegas
thetas = collect(LinRange(0, 2π, N))
phis = collect(LinRange(0, 2π, M))
omegas = collect(maprange(log, 2e-2, 2e-7, K))
for photons in 1:6
for photons in 1:5
# temp process to generate momenta
for omega in [2e-3, 2e-6]
println("Generating $(N*M) inputs for $photons photons (Scenario 2 grid walk)...")
temp_process = parse_process("k"^photons * "e->ke", QEDModel(), PolX(), SpinUp(), PolX(), SpinUp())
println("Generating $(K*N*M) inputs for $photons photons (Scenario 2 grid walk)...")
temp_process = parse_process("k"^photons * "e->ke", QEDModel(), PolX(), SpinUp(), PolX(), SpinUp())
input_momenta = [
congruent_input_momenta_scenario_2(temp_process, omega, theta, phi) for
(theta, phi) in Iterators.product(thetas, phis)
]
results = Array{Float64}(undef, size(input_momenta))
fill!(results, 0.0)
input_momenta =
Array{typeof(congruent_input_momenta_scenario_2(temp_process, omegas[1], thetas[1], phis[1]))}(undef, (K, N, M))
i = 1
for (in_pol, in_spin, out_pol, out_spin) in
Iterators.product([PolX(), PolY()], [SpinUp(), SpinDown()], [PolX(), PolY()], [SpinUp(), SpinDown()])
Threads.@threads for k in 1:K
Threads.@threads for i in 1:N
Threads.@threads for j in 1:M
input_momenta[k, i, j] = congruent_input_momenta_scenario_2(temp_process, omegas[k], thetas[i], phis[j])
end
end
end
print(
"[$i/16] Calculating for spin/pol config: $in_pol, $in_spin -> $out_pol, $out_spin... Preparing inputs... ",
)
process = parse_process("k"^photons * "e->ke", QEDModel(), in_pol, in_spin, out_pol, out_spin)
inputs = build_psp.(Ref(process), input_momenta)
print("Preparing graph... ")
graph = gen_graph(process)
optimize_to_fixpoint!(ReductionOptimizer(), graph)
print("Preparing function... ")
func = get_compute_function(graph, process, mock_machine())
func(inputs[1])
cu_results = CuArray{Float64}(undef, size(input_momenta))
fill!(cu_results, 0.0)
print("Calculating... ")
i = 1
for (in_pol, in_spin, out_pol, out_spin) in
Iterators.product([PolX(), PolY()], [SpinUp(), SpinDown()], [PolX(), PolY()], [SpinUp(), SpinDown()])
print(
"[$i/16] Calculating for spin/pol config: $in_pol, $in_spin -> $out_pol, $out_spin... Preparing inputs... ",
)
process = parse_process("k"^photons * "e->ke", QEDModel(), in_pol, in_spin, out_pol, out_spin)
inputs = Array{typeof(build_psp(process, input_momenta[1, 1, 1]))}(undef, (K, N, M))
#println("input_momenta: $input_momenta")
Threads.@threads for k in 1:K
Threads.@threads for i in 1:N
Threads.@threads for j in 1:M
return results[i, j] += abs2(func(inputs[i, j]))
inputs[k, i, j] = build_psp(process, input_momenta[k, i, j])
end
end
println("Done.")
i += 1
end
cu_inputs = CuArray(inputs)
println("Writing results")
print("Preparing graph... ")
graph = gen_graph(process)
optimize_to_fixpoint!(ReductionOptimizer(), graph)
print("Preparing function... ")
kernel! = get_cuda_kernel(graph, process, mock_machine())
#func = get_compute_function(graph, process, mock_machine())
out_ph_moms = getindex.(getindex.(input_momenta, 2), 1)
out_el_moms = getindex.(getindex.(input_momenta, 2), 2)
print("Calculating... ")
ts = 32
bs = Int64(length(cu_inputs) / 32)
@save "$(photons)_congruent_photons_omega_$(omega)_grid.jld2" out_ph_moms out_el_moms results
end
end
exit(0)
# scenario 1 (disabled)
n = 1000000
# n is the number of incoming photons
# omega is the number
for photons in 1:6
# temp process to generate momenta
for omega in [2e-3, 2e-6]
println("Generating $n inputs for $photons photons...")
temp_process = parse_process("k"^photons * "e->ke", QEDModel(), PolX(), SpinUp(), PolX(), SpinUp())
input_momenta = [congruent_input_momenta(temp_process, omega) for _ in 1:n]
results = Array{Float64}(undef, size(input_momenta))
fill!(results, 0.0)
i = 1
for (in_pol, in_spin, out_pol, out_spin) in
Iterators.product([PolX(), PolY()], [SpinUp(), SpinDown()], [PolX(), PolY()], [SpinUp(), SpinDown()])
print(
"[$i/16] Calculating for spin/pol config: $in_pol, $in_spin -> $out_pol, $out_spin... Preparing inputs... ",
)
process = parse_process("k"^photons * "e->ke", QEDModel(), in_pol, in_spin, out_pol, out_spin)
inputs = build_psp.(Ref(process), input_momenta)
print("Preparing graph... ")
# prepare function
graph = gen_graph(process)
optimize_to_fixpoint!(ReductionOptimizer(), graph)
print("Preparing function... ")
func = get_compute_function(graph, process, mock_machine())
print("Calculating... ")
Threads.@threads for i in 1:n
results[i] += abs2(func(inputs[i]))
end
println("Done.")
i += 1
end
println("Writing results")
out_ph_moms = getindex.(getindex.(input_momenta, 2), 1)
out_el_moms = getindex.(getindex.(input_momenta, 2), 2)
@save "$(photons)_congruent_photons_omega_$(omega).jld2" out_ph_moms out_el_moms results
outputs = CuArray{ComplexF64}(undef, size(cu_inputs))
@cuda threads = ts blocks = bs always_inline = true kernel!(cu_inputs, outputs, length(cu_inputs))
CUDA.device_synchronize()
cu_results += abs2.(outputs)
println("Done.")
i += 1
end
println("Writing results")
out_ph_moms = getindex.(getindex.(input_momenta, 2), 1)
out_el_moms = getindex.(getindex.(input_momenta, 2), 2)
results = NamedDimsArray{(:omegas, :thetas, :phis)}(Array(cu_results))
println("Named results array: $(typeof(results))")
@save "$(photons)_congruent_photons_grid.jld2" omegas thetas phis results
end