From 4f0da3dffb41c3861b29167c3f374fab9e33cbbd Mon Sep 17 00:00:00 2001 From: Rubydragon Date: Wed, 17 Jul 2024 15:45:54 +0200 Subject: [PATCH] WIP on DAG generation --- notebooks/diagram_generation.ipynb | 193 ++++++++++++++++- src/metagraph_impl/compute.jl | 40 +++- src/metagraph_impl/generation.jl | 324 ++++++++++++++++++++++++++--- 3 files changed, 509 insertions(+), 48 deletions(-) diff --git a/notebooks/diagram_generation.ipynb b/notebooks/diagram_generation.ipynb index c4d164c..0aa89a6 100644 --- a/notebooks/diagram_generation.ipynb +++ b/notebooks/diagram_generation.ipynb @@ -9,6 +9,8 @@ "name": "stderr", "output_type": "stream", "text": [ + "WARNING: Method definition (::Type{QEDcore.ParticleStateful{DIR, SPECIES, ELEMENT} where ELEMENT<:QEDbase.AbstractFourMomentum})(QEDbase.AbstractFourMomentum) where {DIR<:QEDbase.ParticleDirection, SPECIES<:QEDbase.AbstractParticleType} in module QEDcore at /home/antonr/.julia/packages/QEDcore/uVldP/src/phase_spaces/create.jl:7 overwritten in module MetagraphOptimization at /home/antonr/.julia/packages/MetagraphOptimization/mvCVq/src/QEDprocesses_patch.jl:15.\n", + "ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.\n", "WARNING: Method definition (::Type{QEDcore.ParticleStateful{DIR, SPECIES, ELEMENT} where ELEMENT<:QEDbase.AbstractFourMomentum})(QEDbase.AbstractFourMomentum) where {DIR<:QEDbase.ParticleDirection, SPECIES<:QEDbase.AbstractParticleType} in module QEDcore at /home/antonr/.julia/packages/QEDcore/uVldP/src/phase_spaces/create.jl:7 overwritten in module MetagraphOptimization at /home/antonr/.julia/packages/MetagraphOptimization/mvCVq/src/QEDprocesses_patch.jl:15.\n", "ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.\n" ] @@ -21,9 +23,40 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "14-element Vector{VirtualParticle{GenericQEDProcess{Tuple{Photon, Photon, Photon, Electron}, Tuple{Photon, Electron}, Tuple{AllPolarization, AllPolarization, AllPolarization, AllSpin}, Tuple{AllPolarization, AllSpin}}, PT, 4, 2} where PT<:AbstractParticleType}:\n", + " positron: \t0000 | 11\n", + " electron: \t0001 | 10\n", + " positron: \t0010 | 01\n", + " electron: \t0011 | 00\n", + " positron: \t0100 | 01\n", + " electron: \t0101 | 00\n", + " positron: \t1000 | 01\n", + " electron: \t1001 | 00\n", + " positron: \t1000 | 11\n", + " electron: \t1001 | 10\n", + " positron: \t1010 | 01\n", + " electron: \t1011 | 00\n", + " positron: \t1100 | 01\n", + " electron: \t1101 | 00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: both QEDcore and QEDbase export \"mul\"; uses of it in module FeynmanDiagramGenerator must be qualified\n" + ] + } + ], "source": [ "proc = GenericQEDProcess(3, 1, 1, 1, 0, 0)\n", "all_particles = Set()\n", @@ -36,27 +69,103 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "OrderedCollections.OrderedDict{VirtualParticle, Vector{Tuple{VirtualParticle, VirtualParticle}}} with 14 entries:\n", + " positron: \t0000 | 11 => [(positron: \t0000 | 01, photon: \t0000 | 10)]\n", + " electron: \t0001 | 10 => [(photon: \t0000 | 10, electron: \t0001 | 00)]\n", + " positron: \t0010 | 01 => [(positron: \t0000 | 01, photon: \t0010 | 00)]\n", + " electron: \t0011 | 00 => [(electron: \t0001 | 00, photon: \t0010 | 00)]\n", + " positron: \t0100 | 01 => [(positron: \t0000 | 01, photon: \t0100 | 00)]\n", + " electron: \t0101 | 00 => [(electron: \t0001 | 00, photon: \t0100 | 00)]\n", + " positron: \t1000 | 01 => [(positron: \t0000 | 01, photon: \t1000 | 00)]\n", + " electron: \t1001 | 00 => [(electron: \t0001 | 00, photon: \t1000 | 00)]\n", + " positron: \t1000 | 11 => [(photon: \t0000 | 10, positron: \t1000 | 01), (photon: \t10…\n", + " electron: \t1001 | 10 => [(photon: \t0000 | 10, electron: \t1001 | 00), (photon: \t10…\n", + " positron: \t1010 | 01 => [(photon: \t0010 | 00, positron: \t1000 | 01), (photon: \t10…\n", + " electron: \t1011 | 00 => [(photon: \t0010 | 00, electron: \t1001 | 00), (photon: \t10…\n", + " positron: \t1100 | 01 => [(photon: \t0100 | 00, positron: \t1000 | 01), (photon: \t10…\n", + " electron: \t1101 | 00 => [(photon: \t0100 | 00, electron: \t1001 | 00), (photon: \t10…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "pairs = FeynmanDiagramGenerator.particle_pairs(all_particles)" + "pairs = sort(FeynmanDiagramGenerator.particle_pairs(all_particles))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "12-element Vector{Tuple{VirtualParticle, VirtualParticle, VirtualParticle}}:\n", + " (photon: \t0000 | 10, electron: \t0011 | 00, positron: \t1100 | 01)\n", + " (photon: \t0000 | 10, electron: \t0101 | 00, positron: \t1010 | 01)\n", + " (photon: \t0000 | 10, electron: \t1101 | 00, positron: \t0010 | 01)\n", + " (photon: \t0000 | 10, electron: \t1011 | 00, positron: \t0100 | 01)\n", + " (photon: \t0010 | 00, electron: \t0001 | 10, positron: \t1100 | 01)\n", + " (photon: \t0010 | 00, electron: \t0101 | 00, positron: \t1000 | 11)\n", + " (photon: \t0010 | 00, electron: \t1101 | 00, positron: \t0000 | 11)\n", + " (photon: \t0010 | 00, electron: \t1001 | 10, positron: \t0100 | 01)\n", + " (photon: \t0100 | 00, electron: \t0001 | 10, positron: \t1010 | 01)\n", + " (photon: \t0100 | 00, electron: \t0011 | 00, positron: \t1000 | 11)\n", + " (photon: \t0100 | 00, electron: \t1011 | 00, positron: \t0000 | 11)\n", + " (photon: \t0100 | 00, electron: \t1001 | 10, positron: \t0010 | 01)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "triples = FeynmanDiagramGenerator.total_particle_triples(all_particles)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "s: 24, should be: 24\n", + "number of triples: 12\n" + ] + }, + { + "data": { + "text/plain": [ + "12-element Vector{Tuple{VirtualParticle, VirtualParticle, VirtualParticle}}:\n", + " (photon: \t0000 | 10, electron: \t0011 | 00, positron: \t1100 | 01)\n", + " (photon: \t0000 | 10, electron: \t0101 | 00, positron: \t1010 | 01)\n", + " (photon: \t0000 | 10, electron: \t1011 | 00, positron: \t0100 | 01)\n", + " (photon: \t0000 | 10, electron: \t1101 | 00, positron: \t0010 | 01)\n", + " (photon: \t0010 | 00, electron: \t0001 | 10, positron: \t1100 | 01)\n", + " (photon: \t0010 | 00, electron: \t0101 | 00, positron: \t1000 | 11)\n", + " (photon: \t0010 | 00, electron: \t1001 | 10, positron: \t0100 | 01)\n", + " (photon: \t0010 | 00, electron: \t1101 | 00, positron: \t0000 | 11)\n", + " (photon: \t0100 | 00, electron: \t0001 | 10, positron: \t1010 | 01)\n", + " (photon: \t0100 | 00, electron: \t0011 | 00, positron: \t1000 | 11)\n", + " (photon: \t0100 | 00, electron: \t1001 | 10, positron: \t0010 | 01)\n", + " (photon: \t0100 | 00, electron: \t1011 | 00, positron: \t0000 | 11)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "function n(vp::VirtualParticle)\n", " if !haskey(pairs, vp)\n", @@ -80,13 +189,77 @@ "sort(triples)" ] }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Graph:\n", + " Nodes: Total: 2320, FeynmanDiagramGenerator.ComputeTask_CollectTriples: 64, MetagraphOptimization.DataTask: 1173, \n", + " FeynmanDiagramGenerator.ComputeTask_CollectPairs: 80, FeynmanDiagramGenerator.ComputeTask_SpinPolCumulation: 1, FeynmanDiagramGenerator.ComputeTask_Propagator: 14, \n", + " FeynmanDiagramGenerator.ComputeTask_Triple: 768, FeynmanDiagramGenerator.ComputeTask_BaseState: 12, FeynmanDiagramGenerator.ComputeTask_PropagatePairs: 80, \n", + " FeynmanDiagramGenerator.ComputeTask_Pair: 128\n", + " Edges: 4853\n", + " Total Compute Effort: 0.0\n", + " Total Data Transfer: 0.0\n", + " Total Compute Intensity: 0.0\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "graph = generate_DAG(proc)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "compute__2e0e67fe_4441_11ef_36f2_5fda31178519 (generic function with 1 method)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "using MetagraphOptimization\n", + "using UUIDs\n", + "\n", + "function mock_machine()\n", + " return Machine(\n", + " [\n", + " MetagraphOptimization.NumaNode(\n", + " 0,\n", + " 1,\n", + " MetagraphOptimization.default_strategy(MetagraphOptimization.NumaNode),\n", + " -1.0,\n", + " UUIDs.uuid1(),\n", + " ),\n", + " ],\n", + " [-1.0;;],\n", + " )\n", + "end\n", + "\n", + "func = get_compute_function(graph, proc, mock_machine())" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "generate_DAG(proc)" + "psp = PhaseSpacePoint(proc, PerturbativeQED(), PhasespaceDefinition(SphericalCoordinateSystem(), ElectronRestFrame()), [rand(SFourMomentum) for _ in 1:number_incoming_particles(proc)], [rand(SFourMomentum) for _ in 1:number_outgoing_particles(proc)])" ] } ], diff --git a/src/metagraph_impl/compute.jl b/src/metagraph_impl/compute.jl index 8700670..c206a62 100644 --- a/src/metagraph_impl/compute.jl +++ b/src/metagraph_impl/compute.jl @@ -1,13 +1,20 @@ -struct ComputeTask_BaseState <: AbstractComputeTask end # calculate the base state of an external particle -struct ComputeTask_Propagator <: AbstractComputeTask end # calculate the propagator term of a virtual particle -struct ComputeTask_Pair <: AbstractComputeTask end # from a pair of virtual particle currents, calculate the product -struct ComputeTask_CollectPairs <: AbstractComputeTask end # for a list of virtual particle current pair products, sum -struct ComputeTask_PropagatePairs <: AbstractComputeTask end # for the result of a CollectPairs compute task and a propagator, propagate the sum -struct ComputeTask_Triple <: AbstractComputeTask end # from a triple of virtual particle currents, calculate the diagram result -struct ComputeTask_CollectTriples <: AbstractComputeTask end # sum over triples results and +struct ComputeTask_BaseState <: AbstractComputeTask end # calculate the base state of an external particle +struct ComputeTask_Propagator <: AbstractComputeTask end # calculate the propagator term of a virtual particle +struct ComputeTask_Pair <: AbstractComputeTask end # from a pair of virtual particle currents, calculate the product +struct ComputeTask_CollectPairs <: AbstractComputeTask # for a list of virtual particle current pair products, sum + children::Int +end +struct ComputeTask_PropagatePairs <: AbstractComputeTask end # for the result of a CollectPairs compute task and a propagator, propagate the sum +struct ComputeTask_Triple <: AbstractComputeTask end # from a triple of virtual particle currents, calculate the diagram result +struct ComputeTask_CollectTriples <: AbstractComputeTask # sum over triples results and + children::Int +end +struct ComputeTask_SpinPolCumulation <: AbstractComputeTask # abs2 sum over all spin/pol configs + children::Int +end # import compute so we don't have to repeat it all the time -import MetagraphOptimization: compute, compute_effort +import MetagraphOptimization: compute, compute_effort, children compute_effort(::ComputeTask_BaseState) = 0 compute_effort(::ComputeTask_Propagator) = 0 @@ -16,6 +23,16 @@ compute_effort(::ComputeTask_CollectPairs) = 0 compute_effort(::ComputeTask_PropagatePairs) = 0 compute_effort(::ComputeTask_Triple) = 0 compute_effort(::ComputeTask_CollectTriples) = 0 +compute_effort(::ComputeTask_SpinPolCumulation) = 0 + +children(::ComputeTask_BaseState) = 1 +children(::ComputeTask_Propagator) = 1 +children(::ComputeTask_Pair) = 2 +children(t::ComputeTask_CollectPairs) = t.children +children(::ComputeTask_PropagatePairs) = 2 +children(::ComputeTask_Triple) = 3 +children(t::ComputeTask_CollectTriples) = t.children +children(t::ComputeTask_SpinPolCumulation) = t.children struct BaseStateInput{PS_T<:AbstractParticleStateful,SPIN_POL_T<:AbstractSpinOrPolarization} particle::PS_T @@ -142,3 +159,10 @@ end # use a summation algorithm with more accuracy and/or parallelization @inline compute(::ComputeTask_CollectPairs, args::Vararg{N,T}) where {N,T} = sum(args) @inline compute(::ComputeTask_CollectTriples, args::Vararg{N,T}) where {N,T} = sum(args) +@inline function compute(::ComputeTask_SpinPolCumulation, args::Vararg{N,T}) where {N,T} + sum = 0.0 + for arg in args + sum += abs2(arg) + end + return sum +end diff --git a/src/metagraph_impl/generation.jl b/src/metagraph_impl/generation.jl index 77be83e..695bc65 100644 --- a/src/metagraph_impl/generation.jl +++ b/src/metagraph_impl/generation.jl @@ -22,7 +22,7 @@ function _parse_particle(name::String) throw(InvalidInputError("failed to parse particle direction from \"$name\"")) end - name = name[4:end] + name = name[5:end] local species if startswith(name, "el") @@ -35,7 +35,7 @@ function _parse_particle(name::String) throw(InvalidInputError("failed to parse particle species from name \"$name\"")) end - name = name[3:end] + name = name[4:end] local spin_pol if startswith(name, "su") @@ -54,13 +54,15 @@ function _parse_particle(name::String) ) end - name = name[3:end] + name = name[4:end] index = parse(Int, name) return (dir, species, spin_pol, index) end -function input_expr(instance::GenericQEDProcess, name::String, psp_symbol::Symbol) +function MetagraphOptimization.input_expr( + proc::AbstractProcessDefinition, name::String, psp_symbol::Symbol +) if startswith(name, "bs_") (dir, species, spin_pol, index) = _parse_particle(name[4:end]) dir_str = _construction_string(dir) @@ -81,7 +83,7 @@ function input_expr(instance::GenericQEDProcess, name::String, psp_symbol::Symbo return Meta.parse("PropagatorInput( VirtualParticle( process($psp_symbol), - $species_str, + $(_species_str(particle_species(vp))), $(vp.in_particle_contributions), $(vp.out_particle_contributions) ), @@ -92,6 +94,19 @@ function input_expr(instance::GenericQEDProcess, name::String, psp_symbol::Symbo end end +function MetagraphOptimization.input_type(p::AbstractProcessDefinition) + in_t = QEDcore._assemble_tuple_type(incoming_particles(p), Incoming(), SFourMomentum) + out_t = QEDcore._assemble_tuple_type(outgoing_particles(p), Outgoing(), SFourMomentum) + return PhaseSpacePoint{ + typeof(p), + PerturbativeQED, + PhasespaceDefinition{SphericalCoordinateSystem,ElectronRestFrame}, + Tuple{in_t...}, + Tuple{out_t...}, + SFourMomentum, + } +end + _species_str(::Photon) = "ph" _species_str(::Electron) = "el" _species_str(::Positron) = "po" @@ -101,6 +116,22 @@ _spin_pol_str(::SpinDown) = "sd" _spin_pol_str(::PolX) = "px" _spin_pol_str(::PolY) = "py" +function Base.parse(::Type{AbstractSpinOrPolarization}, s::AbstractString) + if s == "su" + return SpinUp() + end + if s == "sd" + return SpinDown() + end + if s == "px" + return PolX() + end + if s == "py" + return PolY() + end + throw(InvalidInputError("invalid string \"$s\" to parse to AbstractSpinOrPolarization")) +end + _dir_str(::Incoming) = "inc" _dir_str(::Outgoing) = "out" @@ -113,6 +144,49 @@ _spin_pols(::PolY) = (PolY(),) _is_external(p::VirtualParticle) = number_contributions(p) == 1 +function _total_index( + proc::AbstractProcessDefinition, + dir::ParticleDirection, + species::AbstractParticleType, + n::Int, +) + # find particle index of all particles given n-th particle of dir and species (inverse of _species_index) + total_index = 0 + species_count = 0 + for p in particles(proc, dir) + total_index += 1 + if species == p + species_count += 1 + end + if species_count == n + return if dir == Incoming() + total_index + else + number_incoming_particles(proc) + total_index + end + end + end + + throw("did not find $n-th $dir $species") +end + +function _species_index( + proc::AbstractProcessDefinition, + dir::ParticleDirection, + species::AbstractParticleType, + n::Int, +) + # find particle index of n-th particle of *this species and dir* + species_index = 0 + for i in 1:n + if particles(proc, dir)[i] == species + species_index += 1 + end + end + + return species_index +end + function _base_state_name(p::VirtualParticle) proc = process(p) @@ -130,13 +204,7 @@ function _base_state_name(p::VirtualParticle) species = particles(proc, dir)[index] - # find particle index of *this species* - species_index = 0 - for i in 1:index - if particles(proc, dir)[i] == species - species_index += 1 - end - end + species_index = _species_index(proc, dir, species, index) spin_pol = spin_or_pol(proc, dir, species, species_index) @@ -147,11 +215,70 @@ function _base_state_name(p::VirtualParticle) ) end +# from two or three node names like "1_su-2-px"... return a single tuple of the indices and spin/pols in sorted +function _parse_node_names(name1::String, name2::String) + split_strings_1 = split.(split(name1, "-"), "_") + split_strings_2 = split.(split(name2, "-"), "_") + + return tuple( + # TODO: could use merge sort since the sub lists are sorted already + sort([ + tuple.( + parse.(Int, getindex.(split_strings_1, 1)), + parse.(AbstractSpinOrPolarization, getindex.(split_strings_1, 2)), + )..., + tuple.( + parse.(Int, getindex.(split_strings_2, 1)), + parse.(AbstractSpinOrPolarization, getindex.(split_strings_2, 2)), + )..., + ])..., + ) +end +function _parse_node_names(name1::String, name2::String, name3::String) + split_strings_1 = split.(split(name1, "-"), "_") + split_strings_2 = split.(split(name2, "-"), "_") + split_strings_3 = split.(split(name3, "-"), "_") + + return tuple( + # TODO: could use merge sort since the sub lists are sorted already + sort([ + tuple.( + parse.(Int, getindex.(split_strings_1, 1)), + parse.(AbstractSpinOrPolarization, getindex.(split_strings_1, 2)), + )..., + tuple.( + parse.(Int, getindex.(split_strings_2, 1)), + parse.(AbstractSpinOrPolarization, getindex.(split_strings_2, 2)), + )..., + tuple.( + parse.(Int, getindex.(split_strings_3, 1)), + parse.(AbstractSpinOrPolarization, getindex.(split_strings_3, 2)), + )..., + ])..., + ) +end + +function _make_node_name( + spin_pols::NTuple{N,Tuple{Int,AbstractSpinOrPolarization}} +) where {N} + node_name = "" + first = true + for spin_pol_tuple in spin_pols + if !first + node_name *= "-" + else + first = false + end + node_name *= "$(spin_pol_tuple[1])_$(_spin_pol_str(spin_pol_tuple[2]))" + end + return node_name +end + function generate_DAG(proc::GenericQEDProcess) external_particles = _pseudo_virtual_particles(proc) # external particles that will be input to base_state tasks particles = virtual_particles(proc) # virtual particles that will be input to propagator tasks - pairs = particle_pairs(particles) # pairs to generate the pair tasks - triples = total_particle_triples(particles) # triples to generate the triple tasks + pairs = sort(particle_pairs(particles)) # pairs to generate the pair tasks + triples = sort(total_particle_triples(particles)) # triples to generate the triple tasks graph = DAG() @@ -181,7 +308,13 @@ function generate_DAG(proc::GenericQEDProcess) ) data_out = insert_node!( - graph, make_node(DataTask(0)); track=false, invalidate_cache=false + graph, + make_node( + DataTask(0), + "$(_total_index(proc, dir, species, index))_$(_spin_pol_str(spin_pol))", + ); + track=false, + invalidate_cache=false, ) insert_edge!(graph, data_in, compute_base_state) @@ -202,7 +335,10 @@ function generate_DAG(proc::GenericQEDProcess) data_node_name = "pr_$vp_index" data_in = insert_node!( - graph, make_node(DataTask(0)); track=false, invalidate_cache=false + graph, + make_node(DataTask(0), data_node_name); + track=false, + invalidate_cache=false, ) compute_vp_propagator = insert_node!( graph, make_node(ComputeTask_Propagator()); track=false, invalidate_cache=false @@ -214,15 +350,19 @@ function generate_DAG(proc::GenericQEDProcess) insert_edge!(graph, data_in, compute_vp_propagator) insert_edge!(graph, compute_vp_propagator, data_out) - propagator_task_outputs[data_node_name] = data_out + propagator_task_outputs[vp] = data_out end # -- Pair Tasks -- - pair_task_outputs = Dict() + pair_task_outputs = Dict{VirtualParticle,Vector{Node}}() for (product_particle, input_particle_vector) in pairs - # for all spins/pols of particles in product_particles do ... + pair_task_outputs[product_particle] = Vector{Node}() - pair_task_outputs[product_particle] = Vector() + # make a dictionary of vectors to collect the outputs depending on spin/pol configs of the input particles + N = number_contributions(product_particle) + pair_output_nodes_by_spin_pol = Dict{ + NTuple{N,Tuple{Int,AbstractSpinOrPolarization}},Vector{DataTaskNode} + }() for input_particles in input_particle_vector particles_data_out_nodes = (Vector(), Vector()) @@ -237,28 +377,152 @@ function generate_DAG(proc::GenericQEDProcess) ) else # grab from propagated particles - push!(particles_date_out_nodes[c], pair_task_outputs[p]) + append!(particles_data_out_nodes[c], pair_task_outputs[p]) end end - for in_nodes in Iterators.product(input_particles...) + for in_nodes in Iterators.product(particles_data_out_nodes...) # make the compute pair nodes for every combination of the found input_particle_nodes to get all spin/pol combinations + compute_pair = insert_node!( + graph, + make_node(ComputeTask_Pair()); + track=false, + invalidate_cache=false, + ) + pair_data_out = insert_node!( + graph, make_node(DataTask(0)); track=false, invalidate_cache=false + ) - #insert_node!(graph, ) + insert_edge!(graph, in_nodes[1], compute_pair) + insert_edge!(graph, in_nodes[2], compute_pair) + insert_edge!(graph, compute_pair, pair_data_out) + # get the spin/pol config of the input particles from the data_out names + index = _parse_node_names(in_nodes[1].name, in_nodes[2].name) + + if !haskey(pair_output_nodes_by_spin_pol, index) + pair_output_nodes_by_spin_pol[index] = Vector() + end + push!(pair_output_nodes_by_spin_pol[index], pair_data_out) end - # make the collect pair and propagate nodes - end - data_out_propagated = insert_node!( - graph, make_node(DataTask(0)); track=false, invalidate_caches=false - ) + propagator_node = propagator_task_outputs[product_particle] - pair_task_outputs[p] = data_out_propagated + for (index, nodes_to_sum) in pair_output_nodes_by_spin_pol + compute_pairs_sum = insert_node!( + graph, + make_node(ComputeTask_CollectPairs(length(nodes_to_sum))); + track=false, + invalidate_cache=false, + ) + data_pairs_sum = insert_node!( + graph, make_node(DataTask(0)); track=false, invalidate_cache=false + ) + compute_propagated = insert_node!( + graph, + make_node(ComputeTask_PropagatePairs()); + track=false, + invalidate_cache=false, + ) + # give this out node the correct name + data_out_propagated = insert_node!( + graph, + make_node(DataTask(0), _make_node_name(index)); + track=false, + invalidate_cache=false, + ) - # ... end do + for node in nodes_to_sum + insert_edge!(graph, node, compute_pairs_sum) + end + + insert_edge!(graph, compute_pairs_sum, data_pairs_sum) + + insert_edge!(graph, propagator_node, compute_propagated) + insert_edge!(graph, data_pairs_sum, compute_propagated) + + insert_edge!(graph, compute_propagated, data_out_propagated) + + push!(pair_task_outputs[product_particle], data_out_propagated) + end end + # -- Triples -- + triples_results = Dict() + for (ph, el, po) in triples # for each triple (each "diagram") + photons = if is_external(ph) + getindex.(Ref(base_state_task_outputs), _base_state_name(ph)) + else + pair_task_outputs[ph] + end + electrons = if is_external(el) + getindex.(Ref(base_state_task_outputs), _base_state_name(el)) + else + pair_task_outputs[el] + end + positrons = if is_external(po) + getindex.(Ref(base_state_task_outputs), _base_state_name(po)) + else + pair_task_outputs[po] + end + for (a, b, c) in Iterators.product(photons, electrons, positrons) # for each spin/pol config of each part + compute_triples = insert_node!( + graph, make_node(ComputeTask_Triple()); track=false, invalidate_cache=false + ) + data_triples = insert_node!( + graph, make_node(DataTask(0)); track=false, invalidate_cache=false + ) + + insert_edge!(graph, a, compute_triples) + insert_edge!(graph, b, compute_triples) + insert_edge!(graph, c, compute_triples) + + insert_edge!(graph, compute_triples, data_triples) + + index = _parse_node_names(a.name, b.name, c.name) + if !haskey(triples_results, index) + triples_results[index] = Vector{DataTaskNode}() + end + push!(triples_results[index], data_triples) + end + end + + # -- Collect Triples -- + collected_triples = Vector{DataTaskNode}() + for (index, results) in triples_results + compute_collect_triples = insert_node!( + graph, + make_node(ComputeTask_CollectTriples(length(results))); + track=false, + invalidate_cache=false, + ) + data_collect_triples = insert_node!( + graph, make_node(DataTask(0)); track=false, invalidate_cache=false + ) + + for triple in results + insert_edge!(graph, triple, compute_collect_triples) + end + insert_edge!(graph, compute_collect_triples, data_collect_triples) + + push!(collected_triples, data_collect_triples) + end + + # Finally, abs2 sum over spin/pol configurations + compute_total_result = insert_node!( + graph, + make_node(ComputeTask_SpinPolCumulation(length(collected_triples))); + track=false, + invalidate_cache=false, + ) + for finished_triple in collected_triples + insert_edge!(graph, finished_triple, compute_total_result) + end + + final_data_out = insert_node!( + graph, make_node(DataTask(0)); track=false, invalidate_cache=false + ) + insert_edge!(graph, compute_total_result, final_data_out) return graph end