WIP on DAG generation
This commit is contained in:
		| @@ -9,6 +9,8 @@ | |||||||
|      "name": "stderr", |      "name": "stderr", | ||||||
|      "output_type": "stream", |      "output_type": "stream", | ||||||
|      "text": [ |      "text": [ | ||||||
|  |       "WARNING: Method definition (::Type{QEDcore.ParticleStateful{DIR, SPECIES, ELEMENT} where ELEMENT<:QEDbase.AbstractFourMomentum})(QEDbase.AbstractFourMomentum) where {DIR<:QEDbase.ParticleDirection, SPECIES<:QEDbase.AbstractParticleType} in module QEDcore at /home/antonr/.julia/packages/QEDcore/uVldP/src/phase_spaces/create.jl:7 overwritten in module MetagraphOptimization at /home/antonr/.julia/packages/MetagraphOptimization/mvCVq/src/QEDprocesses_patch.jl:15.\n", | ||||||
|  |       "ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.\n", | ||||||
|       "WARNING: Method definition (::Type{QEDcore.ParticleStateful{DIR, SPECIES, ELEMENT} where ELEMENT<:QEDbase.AbstractFourMomentum})(QEDbase.AbstractFourMomentum) where {DIR<:QEDbase.ParticleDirection, SPECIES<:QEDbase.AbstractParticleType} in module QEDcore at /home/antonr/.julia/packages/QEDcore/uVldP/src/phase_spaces/create.jl:7 overwritten in module MetagraphOptimization at /home/antonr/.julia/packages/MetagraphOptimization/mvCVq/src/QEDprocesses_patch.jl:15.\n", |       "WARNING: Method definition (::Type{QEDcore.ParticleStateful{DIR, SPECIES, ELEMENT} where ELEMENT<:QEDbase.AbstractFourMomentum})(QEDbase.AbstractFourMomentum) where {DIR<:QEDbase.ParticleDirection, SPECIES<:QEDbase.AbstractParticleType} in module QEDcore at /home/antonr/.julia/packages/QEDcore/uVldP/src/phase_spaces/create.jl:7 overwritten in module MetagraphOptimization at /home/antonr/.julia/packages/MetagraphOptimization/mvCVq/src/QEDprocesses_patch.jl:15.\n", | ||||||
|       "ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.\n" |       "ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.\n" | ||||||
|      ] |      ] | ||||||
| @@ -21,9 +23,40 @@ | |||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": null, |    "execution_count": 2, | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [], |    "outputs": [ | ||||||
|  |     { | ||||||
|  |      "data": { | ||||||
|  |       "text/plain": [ | ||||||
|  |        "14-element Vector{VirtualParticle{GenericQEDProcess{Tuple{Photon, Photon, Photon, Electron}, Tuple{Photon, Electron}, Tuple{AllPolarization, AllPolarization, AllPolarization, AllSpin}, Tuple{AllPolarization, AllSpin}}, PT, 4, 2} where PT<:AbstractParticleType}:\n", | ||||||
|  |        " positron: \t0000 | 11\n", | ||||||
|  |        " electron: \t0001 | 10\n", | ||||||
|  |        " positron: \t0010 | 01\n", | ||||||
|  |        " electron: \t0011 | 00\n", | ||||||
|  |        " positron: \t0100 | 01\n", | ||||||
|  |        " electron: \t0101 | 00\n", | ||||||
|  |        " positron: \t1000 | 01\n", | ||||||
|  |        " electron: \t1001 | 00\n", | ||||||
|  |        " positron: \t1000 | 11\n", | ||||||
|  |        " electron: \t1001 | 10\n", | ||||||
|  |        " positron: \t1010 | 01\n", | ||||||
|  |        " electron: \t1011 | 00\n", | ||||||
|  |        " positron: \t1100 | 01\n", | ||||||
|  |        " electron: \t1101 | 00" | ||||||
|  |       ] | ||||||
|  |      }, | ||||||
|  |      "metadata": {}, | ||||||
|  |      "output_type": "display_data" | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |      "name": "stderr", | ||||||
|  |      "output_type": "stream", | ||||||
|  |      "text": [ | ||||||
|  |       "WARNING: both QEDcore and QEDbase export \"mul\"; uses of it in module FeynmanDiagramGenerator must be qualified\n" | ||||||
|  |      ] | ||||||
|  |     } | ||||||
|  |    ], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "proc = GenericQEDProcess(3, 1, 1, 1, 0, 0)\n", |     "proc = GenericQEDProcess(3, 1, 1, 1, 0, 0)\n", | ||||||
|     "all_particles = Set()\n", |     "all_particles = Set()\n", | ||||||
| @@ -36,27 +69,103 @@ | |||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": null, |    "execution_count": 3, | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [], |    "outputs": [ | ||||||
|  |     { | ||||||
|  |      "data": { | ||||||
|  |       "text/plain": [ | ||||||
|  |        "OrderedCollections.OrderedDict{VirtualParticle, Vector{Tuple{VirtualParticle, VirtualParticle}}} with 14 entries:\n", | ||||||
|  |        "  positron: \t0000 | 11 => [(positron: \t0000 | 01, photon: \t0000 | 10)]\n", | ||||||
|  |        "  electron: \t0001 | 10 => [(photon: \t0000 | 10, electron: \t0001 | 00)]\n", | ||||||
|  |        "  positron: \t0010 | 01 => [(positron: \t0000 | 01, photon: \t0010 | 00)]\n", | ||||||
|  |        "  electron: \t0011 | 00 => [(electron: \t0001 | 00, photon: \t0010 | 00)]\n", | ||||||
|  |        "  positron: \t0100 | 01 => [(positron: \t0000 | 01, photon: \t0100 | 00)]\n", | ||||||
|  |        "  electron: \t0101 | 00 => [(electron: \t0001 | 00, photon: \t0100 | 00)]\n", | ||||||
|  |        "  positron: \t1000 | 01 => [(positron: \t0000 | 01, photon: \t1000 | 00)]\n", | ||||||
|  |        "  electron: \t1001 | 00 => [(electron: \t0001 | 00, photon: \t1000 | 00)]\n", | ||||||
|  |        "  positron: \t1000 | 11 => [(photon: \t0000 | 10, positron: \t1000 | 01), (photon: \t10…\n", | ||||||
|  |        "  electron: \t1001 | 10 => [(photon: \t0000 | 10, electron: \t1001 | 00), (photon: \t10…\n", | ||||||
|  |        "  positron: \t1010 | 01 => [(photon: \t0010 | 00, positron: \t1000 | 01), (photon: \t10…\n", | ||||||
|  |        "  electron: \t1011 | 00 => [(photon: \t0010 | 00, electron: \t1001 | 00), (photon: \t10…\n", | ||||||
|  |        "  positron: \t1100 | 01 => [(photon: \t0100 | 00, positron: \t1000 | 01), (photon: \t10…\n", | ||||||
|  |        "  electron: \t1101 | 00 => [(photon: \t0100 | 00, electron: \t1001 | 00), (photon: \t10…" | ||||||
|  |       ] | ||||||
|  |      }, | ||||||
|  |      "metadata": {}, | ||||||
|  |      "output_type": "display_data" | ||||||
|  |     } | ||||||
|  |    ], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "pairs = FeynmanDiagramGenerator.particle_pairs(all_particles)" |     "pairs = sort(FeynmanDiagramGenerator.particle_pairs(all_particles))" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": null, |    "execution_count": 4, | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [], |    "outputs": [ | ||||||
|  |     { | ||||||
|  |      "data": { | ||||||
|  |       "text/plain": [ | ||||||
|  |        "12-element Vector{Tuple{VirtualParticle, VirtualParticle, VirtualParticle}}:\n", | ||||||
|  |        " (photon: \t0000 | 10, electron: \t0011 | 00, positron: \t1100 | 01)\n", | ||||||
|  |        " (photon: \t0000 | 10, electron: \t0101 | 00, positron: \t1010 | 01)\n", | ||||||
|  |        " (photon: \t0000 | 10, electron: \t1101 | 00, positron: \t0010 | 01)\n", | ||||||
|  |        " (photon: \t0000 | 10, electron: \t1011 | 00, positron: \t0100 | 01)\n", | ||||||
|  |        " (photon: \t0010 | 00, electron: \t0001 | 10, positron: \t1100 | 01)\n", | ||||||
|  |        " (photon: \t0010 | 00, electron: \t0101 | 00, positron: \t1000 | 11)\n", | ||||||
|  |        " (photon: \t0010 | 00, electron: \t1101 | 00, positron: \t0000 | 11)\n", | ||||||
|  |        " (photon: \t0010 | 00, electron: \t1001 | 10, positron: \t0100 | 01)\n", | ||||||
|  |        " (photon: \t0100 | 00, electron: \t0001 | 10, positron: \t1010 | 01)\n", | ||||||
|  |        " (photon: \t0100 | 00, electron: \t0011 | 00, positron: \t1000 | 11)\n", | ||||||
|  |        " (photon: \t0100 | 00, electron: \t1011 | 00, positron: \t0000 | 11)\n", | ||||||
|  |        " (photon: \t0100 | 00, electron: \t1001 | 10, positron: \t0010 | 01)" | ||||||
|  |       ] | ||||||
|  |      }, | ||||||
|  |      "metadata": {}, | ||||||
|  |      "output_type": "display_data" | ||||||
|  |     } | ||||||
|  |    ], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "triples = FeynmanDiagramGenerator.total_particle_triples(all_particles)" |     "triples = FeynmanDiagramGenerator.total_particle_triples(all_particles)" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": null, |    "execution_count": 5, | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [], |    "outputs": [ | ||||||
|  |     { | ||||||
|  |      "name": "stdout", | ||||||
|  |      "output_type": "stream", | ||||||
|  |      "text": [ | ||||||
|  |       "s: 24, should be: 24\n", | ||||||
|  |       "number of triples: 12\n" | ||||||
|  |      ] | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |      "data": { | ||||||
|  |       "text/plain": [ | ||||||
|  |        "12-element Vector{Tuple{VirtualParticle, VirtualParticle, VirtualParticle}}:\n", | ||||||
|  |        " (photon: \t0000 | 10, electron: \t0011 | 00, positron: \t1100 | 01)\n", | ||||||
|  |        " (photon: \t0000 | 10, electron: \t0101 | 00, positron: \t1010 | 01)\n", | ||||||
|  |        " (photon: \t0000 | 10, electron: \t1011 | 00, positron: \t0100 | 01)\n", | ||||||
|  |        " (photon: \t0000 | 10, electron: \t1101 | 00, positron: \t0010 | 01)\n", | ||||||
|  |        " (photon: \t0010 | 00, electron: \t0001 | 10, positron: \t1100 | 01)\n", | ||||||
|  |        " (photon: \t0010 | 00, electron: \t0101 | 00, positron: \t1000 | 11)\n", | ||||||
|  |        " (photon: \t0010 | 00, electron: \t1001 | 10, positron: \t0100 | 01)\n", | ||||||
|  |        " (photon: \t0010 | 00, electron: \t1101 | 00, positron: \t0000 | 11)\n", | ||||||
|  |        " (photon: \t0100 | 00, electron: \t0001 | 10, positron: \t1010 | 01)\n", | ||||||
|  |        " (photon: \t0100 | 00, electron: \t0011 | 00, positron: \t1000 | 11)\n", | ||||||
|  |        " (photon: \t0100 | 00, electron: \t1001 | 10, positron: \t0010 | 01)\n", | ||||||
|  |        " (photon: \t0100 | 00, electron: \t1011 | 00, positron: \t0000 | 11)" | ||||||
|  |       ] | ||||||
|  |      }, | ||||||
|  |      "metadata": {}, | ||||||
|  |      "output_type": "display_data" | ||||||
|  |     } | ||||||
|  |    ], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "function n(vp::VirtualParticle)\n", |     "function n(vp::VirtualParticle)\n", | ||||||
|     "    if !haskey(pairs, vp)\n", |     "    if !haskey(pairs, vp)\n", | ||||||
| @@ -80,13 +189,77 @@ | |||||||
|     "sort(triples)" |     "sort(triples)" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 6, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [ | ||||||
|  |     { | ||||||
|  |      "data": { | ||||||
|  |       "text/plain": [ | ||||||
|  |        "Graph:\n", | ||||||
|  |        "  Nodes: Total: 2320, FeynmanDiagramGenerator.ComputeTask_CollectTriples: 64, MetagraphOptimization.DataTask: 1173, \n", | ||||||
|  |        "         FeynmanDiagramGenerator.ComputeTask_CollectPairs: 80, FeynmanDiagramGenerator.ComputeTask_SpinPolCumulation: 1, FeynmanDiagramGenerator.ComputeTask_Propagator: 14, \n", | ||||||
|  |        "         FeynmanDiagramGenerator.ComputeTask_Triple: 768, FeynmanDiagramGenerator.ComputeTask_BaseState: 12, FeynmanDiagramGenerator.ComputeTask_PropagatePairs: 80, \n", | ||||||
|  |        "         FeynmanDiagramGenerator.ComputeTask_Pair: 128\n", | ||||||
|  |        "  Edges: 4853\n", | ||||||
|  |        "  Total Compute Effort: 0.0\n", | ||||||
|  |        "  Total Data Transfer: 0.0\n", | ||||||
|  |        "  Total Compute Intensity: 0.0\n" | ||||||
|  |       ] | ||||||
|  |      }, | ||||||
|  |      "metadata": {}, | ||||||
|  |      "output_type": "display_data" | ||||||
|  |     } | ||||||
|  |    ], | ||||||
|  |    "source": [ | ||||||
|  |     "graph = generate_DAG(proc)" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 16, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [ | ||||||
|  |     { | ||||||
|  |      "data": { | ||||||
|  |       "text/plain": [ | ||||||
|  |        "compute__2e0e67fe_4441_11ef_36f2_5fda31178519 (generic function with 1 method)" | ||||||
|  |       ] | ||||||
|  |      }, | ||||||
|  |      "metadata": {}, | ||||||
|  |      "output_type": "display_data" | ||||||
|  |     } | ||||||
|  |    ], | ||||||
|  |    "source": [ | ||||||
|  |     "using MetagraphOptimization\n", | ||||||
|  |     "using UUIDs\n", | ||||||
|  |     "\n", | ||||||
|  |     "function mock_machine()\n", | ||||||
|  |     "    return Machine(\n", | ||||||
|  |     "        [\n", | ||||||
|  |     "            MetagraphOptimization.NumaNode(\n", | ||||||
|  |     "                0,\n", | ||||||
|  |     "                1,\n", | ||||||
|  |     "                MetagraphOptimization.default_strategy(MetagraphOptimization.NumaNode),\n", | ||||||
|  |     "                -1.0,\n", | ||||||
|  |     "                UUIDs.uuid1(),\n", | ||||||
|  |     "            ),\n", | ||||||
|  |     "        ],\n", | ||||||
|  |     "        [-1.0;;],\n", | ||||||
|  |     "    )\n", | ||||||
|  |     "end\n", | ||||||
|  |     "\n", | ||||||
|  |     "func = get_compute_function(graph, proc, mock_machine())" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": null, |    "execution_count": null, | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "generate_DAG(proc)" |     "psp = PhaseSpacePoint(proc, PerturbativeQED(), PhasespaceDefinition(SphericalCoordinateSystem(), ElectronRestFrame()), [rand(SFourMomentum) for _ in 1:number_incoming_particles(proc)], [rand(SFourMomentum) for _ in 1:number_outgoing_particles(proc)])" | ||||||
|    ] |    ] | ||||||
|   } |   } | ||||||
|  ], |  ], | ||||||
|   | |||||||
| @@ -1,13 +1,20 @@ | |||||||
| struct ComputeTask_BaseState <: AbstractComputeTask end         # calculate the base state of an external particle | struct ComputeTask_BaseState <: AbstractComputeTask end             # calculate the base state of an external particle | ||||||
| struct ComputeTask_Propagator <: AbstractComputeTask end        # calculate the propagator term of a virtual particle | struct ComputeTask_Propagator <: AbstractComputeTask end            # calculate the propagator term of a virtual particle | ||||||
| struct ComputeTask_Pair <: AbstractComputeTask end              # from a pair of virtual particle currents, calculate the product | struct ComputeTask_Pair <: AbstractComputeTask end                  # from a pair of virtual particle currents, calculate the product | ||||||
| struct ComputeTask_CollectPairs <: AbstractComputeTask end      # for a list of virtual particle current pair products, sum | struct ComputeTask_CollectPairs <: AbstractComputeTask              # for a list of virtual particle current pair products, sum | ||||||
| struct ComputeTask_PropagatePairs <: AbstractComputeTask end    # for the result of a CollectPairs compute task and a propagator, propagate the sum |     children::Int | ||||||
| struct ComputeTask_Triple <: AbstractComputeTask end            # from a triple of virtual particle currents, calculate the diagram result | end | ||||||
| struct ComputeTask_CollectTriples <: AbstractComputeTask end    # sum over triples results and  | struct ComputeTask_PropagatePairs <: AbstractComputeTask end        # for the result of a CollectPairs compute task and a propagator, propagate the sum | ||||||
|  | struct ComputeTask_Triple <: AbstractComputeTask end                # from a triple of virtual particle currents, calculate the diagram result | ||||||
|  | struct ComputeTask_CollectTriples <: AbstractComputeTask            # sum over triples results and  | ||||||
|  |     children::Int | ||||||
|  | end | ||||||
|  | struct ComputeTask_SpinPolCumulation <: AbstractComputeTask         # abs2 sum over all spin/pol configs | ||||||
|  |     children::Int | ||||||
|  | end | ||||||
|  |  | ||||||
| # import compute so we don't have to repeat it all the time | # import compute so we don't have to repeat it all the time | ||||||
| import MetagraphOptimization: compute, compute_effort | import MetagraphOptimization: compute, compute_effort, children | ||||||
|  |  | ||||||
| compute_effort(::ComputeTask_BaseState) = 0 | compute_effort(::ComputeTask_BaseState) = 0 | ||||||
| compute_effort(::ComputeTask_Propagator) = 0 | compute_effort(::ComputeTask_Propagator) = 0 | ||||||
| @@ -16,6 +23,16 @@ compute_effort(::ComputeTask_CollectPairs) = 0 | |||||||
| compute_effort(::ComputeTask_PropagatePairs) = 0 | compute_effort(::ComputeTask_PropagatePairs) = 0 | ||||||
| compute_effort(::ComputeTask_Triple) = 0 | compute_effort(::ComputeTask_Triple) = 0 | ||||||
| compute_effort(::ComputeTask_CollectTriples) = 0 | compute_effort(::ComputeTask_CollectTriples) = 0 | ||||||
|  | compute_effort(::ComputeTask_SpinPolCumulation) = 0 | ||||||
|  |  | ||||||
|  | children(::ComputeTask_BaseState) = 1 | ||||||
|  | children(::ComputeTask_Propagator) = 1 | ||||||
|  | children(::ComputeTask_Pair) = 2 | ||||||
|  | children(t::ComputeTask_CollectPairs) = t.children | ||||||
|  | children(::ComputeTask_PropagatePairs) = 2 | ||||||
|  | children(::ComputeTask_Triple) = 3 | ||||||
|  | children(t::ComputeTask_CollectTriples) = t.children | ||||||
|  | children(t::ComputeTask_SpinPolCumulation) = t.children | ||||||
|  |  | ||||||
| struct BaseStateInput{PS_T<:AbstractParticleStateful,SPIN_POL_T<:AbstractSpinOrPolarization} | struct BaseStateInput{PS_T<:AbstractParticleStateful,SPIN_POL_T<:AbstractSpinOrPolarization} | ||||||
|     particle::PS_T |     particle::PS_T | ||||||
| @@ -142,3 +159,10 @@ end | |||||||
| # use a summation algorithm with more accuracy and/or parallelization | # use a summation algorithm with more accuracy and/or parallelization | ||||||
| @inline compute(::ComputeTask_CollectPairs, args::Vararg{N,T}) where {N,T} = sum(args) | @inline compute(::ComputeTask_CollectPairs, args::Vararg{N,T}) where {N,T} = sum(args) | ||||||
| @inline compute(::ComputeTask_CollectTriples, args::Vararg{N,T}) where {N,T} = sum(args) | @inline compute(::ComputeTask_CollectTriples, args::Vararg{N,T}) where {N,T} = sum(args) | ||||||
|  | @inline function compute(::ComputeTask_SpinPolCumulation, args::Vararg{N,T}) where {N,T} | ||||||
|  |     sum = 0.0 | ||||||
|  |     for arg in args | ||||||
|  |         sum += abs2(arg) | ||||||
|  |     end | ||||||
|  |     return sum | ||||||
|  | end | ||||||
|   | |||||||
| @@ -22,7 +22,7 @@ function _parse_particle(name::String) | |||||||
|         throw(InvalidInputError("failed to parse particle direction from \"$name\"")) |         throw(InvalidInputError("failed to parse particle direction from \"$name\"")) | ||||||
|     end |     end | ||||||
|  |  | ||||||
|     name = name[4:end] |     name = name[5:end] | ||||||
|  |  | ||||||
|     local species |     local species | ||||||
|     if startswith(name, "el") |     if startswith(name, "el") | ||||||
| @@ -35,7 +35,7 @@ function _parse_particle(name::String) | |||||||
|         throw(InvalidInputError("failed to parse particle species from name \"$name\"")) |         throw(InvalidInputError("failed to parse particle species from name \"$name\"")) | ||||||
|     end |     end | ||||||
|  |  | ||||||
|     name = name[3:end] |     name = name[4:end] | ||||||
|  |  | ||||||
|     local spin_pol |     local spin_pol | ||||||
|     if startswith(name, "su") |     if startswith(name, "su") | ||||||
| @@ -54,13 +54,15 @@ function _parse_particle(name::String) | |||||||
|         ) |         ) | ||||||
|     end |     end | ||||||
|  |  | ||||||
|     name = name[3:end] |     name = name[4:end] | ||||||
|  |  | ||||||
|     index = parse(Int, name) |     index = parse(Int, name) | ||||||
|     return (dir, species, spin_pol, index) |     return (dir, species, spin_pol, index) | ||||||
| end | end | ||||||
|  |  | ||||||
| function input_expr(instance::GenericQEDProcess, name::String, psp_symbol::Symbol) | function MetagraphOptimization.input_expr( | ||||||
|  |     proc::AbstractProcessDefinition, name::String, psp_symbol::Symbol | ||||||
|  | ) | ||||||
|     if startswith(name, "bs_") |     if startswith(name, "bs_") | ||||||
|         (dir, species, spin_pol, index) = _parse_particle(name[4:end]) |         (dir, species, spin_pol, index) = _parse_particle(name[4:end]) | ||||||
|         dir_str = _construction_string(dir) |         dir_str = _construction_string(dir) | ||||||
| @@ -81,7 +83,7 @@ function input_expr(instance::GenericQEDProcess, name::String, psp_symbol::Symbo | |||||||
|         return Meta.parse("PropagatorInput( |         return Meta.parse("PropagatorInput( | ||||||
|                               VirtualParticle( |                               VirtualParticle( | ||||||
|                                 process($psp_symbol), |                                 process($psp_symbol), | ||||||
|                                 $species_str, |                                 $(_species_str(particle_species(vp))), | ||||||
|                                 $(vp.in_particle_contributions), |                                 $(vp.in_particle_contributions), | ||||||
|                                 $(vp.out_particle_contributions) |                                 $(vp.out_particle_contributions) | ||||||
|                               ), |                               ), | ||||||
| @@ -92,6 +94,19 @@ function input_expr(instance::GenericQEDProcess, name::String, psp_symbol::Symbo | |||||||
|     end |     end | ||||||
| end | end | ||||||
|  |  | ||||||
|  | function MetagraphOptimization.input_type(p::AbstractProcessDefinition) | ||||||
|  |     in_t = QEDcore._assemble_tuple_type(incoming_particles(p), Incoming(), SFourMomentum) | ||||||
|  |     out_t = QEDcore._assemble_tuple_type(outgoing_particles(p), Outgoing(), SFourMomentum) | ||||||
|  |     return PhaseSpacePoint{ | ||||||
|  |         typeof(p), | ||||||
|  |         PerturbativeQED, | ||||||
|  |         PhasespaceDefinition{SphericalCoordinateSystem,ElectronRestFrame}, | ||||||
|  |         Tuple{in_t...}, | ||||||
|  |         Tuple{out_t...}, | ||||||
|  |         SFourMomentum, | ||||||
|  |     } | ||||||
|  | end | ||||||
|  |  | ||||||
| _species_str(::Photon) = "ph" | _species_str(::Photon) = "ph" | ||||||
| _species_str(::Electron) = "el" | _species_str(::Electron) = "el" | ||||||
| _species_str(::Positron) = "po" | _species_str(::Positron) = "po" | ||||||
| @@ -101,6 +116,22 @@ _spin_pol_str(::SpinDown) = "sd" | |||||||
| _spin_pol_str(::PolX) = "px" | _spin_pol_str(::PolX) = "px" | ||||||
| _spin_pol_str(::PolY) = "py" | _spin_pol_str(::PolY) = "py" | ||||||
|  |  | ||||||
|  | function Base.parse(::Type{AbstractSpinOrPolarization}, s::AbstractString) | ||||||
|  |     if s == "su" | ||||||
|  |         return SpinUp() | ||||||
|  |     end | ||||||
|  |     if s == "sd" | ||||||
|  |         return SpinDown() | ||||||
|  |     end | ||||||
|  |     if s == "px" | ||||||
|  |         return PolX() | ||||||
|  |     end | ||||||
|  |     if s == "py" | ||||||
|  |         return PolY() | ||||||
|  |     end | ||||||
|  |     throw(InvalidInputError("invalid string \"$s\" to parse to AbstractSpinOrPolarization")) | ||||||
|  | end | ||||||
|  |  | ||||||
| _dir_str(::Incoming) = "inc" | _dir_str(::Incoming) = "inc" | ||||||
| _dir_str(::Outgoing) = "out" | _dir_str(::Outgoing) = "out" | ||||||
|  |  | ||||||
| @@ -113,6 +144,49 @@ _spin_pols(::PolY) = (PolY(),) | |||||||
|  |  | ||||||
| _is_external(p::VirtualParticle) = number_contributions(p) == 1 | _is_external(p::VirtualParticle) = number_contributions(p) == 1 | ||||||
|  |  | ||||||
|  | function _total_index( | ||||||
|  |     proc::AbstractProcessDefinition, | ||||||
|  |     dir::ParticleDirection, | ||||||
|  |     species::AbstractParticleType, | ||||||
|  |     n::Int, | ||||||
|  | ) | ||||||
|  |     # find particle index of all particles given n-th particle of dir and species (inverse of _species_index) | ||||||
|  |     total_index = 0 | ||||||
|  |     species_count = 0 | ||||||
|  |     for p in particles(proc, dir) | ||||||
|  |         total_index += 1 | ||||||
|  |         if species == p | ||||||
|  |             species_count += 1 | ||||||
|  |         end | ||||||
|  |         if species_count == n | ||||||
|  |             return if dir == Incoming() | ||||||
|  |                 total_index | ||||||
|  |             else | ||||||
|  |                 number_incoming_particles(proc) + total_index | ||||||
|  |             end | ||||||
|  |         end | ||||||
|  |     end | ||||||
|  |  | ||||||
|  |     throw("did not find $n-th $dir $species") | ||||||
|  | end | ||||||
|  |  | ||||||
|  | function _species_index( | ||||||
|  |     proc::AbstractProcessDefinition, | ||||||
|  |     dir::ParticleDirection, | ||||||
|  |     species::AbstractParticleType, | ||||||
|  |     n::Int, | ||||||
|  | ) | ||||||
|  |     # find particle index of n-th particle of *this species and dir* | ||||||
|  |     species_index = 0 | ||||||
|  |     for i in 1:n | ||||||
|  |         if particles(proc, dir)[i] == species | ||||||
|  |             species_index += 1 | ||||||
|  |         end | ||||||
|  |     end | ||||||
|  |  | ||||||
|  |     return species_index | ||||||
|  | end | ||||||
|  |  | ||||||
| function _base_state_name(p::VirtualParticle) | function _base_state_name(p::VirtualParticle) | ||||||
|     proc = process(p) |     proc = process(p) | ||||||
|  |  | ||||||
| @@ -130,13 +204,7 @@ function _base_state_name(p::VirtualParticle) | |||||||
|  |  | ||||||
|     species = particles(proc, dir)[index] |     species = particles(proc, dir)[index] | ||||||
|  |  | ||||||
|     # find particle index of *this species* |     species_index = _species_index(proc, dir, species, index) | ||||||
|     species_index = 0 |  | ||||||
|     for i in 1:index |  | ||||||
|         if particles(proc, dir)[i] == species |  | ||||||
|             species_index += 1 |  | ||||||
|         end |  | ||||||
|     end |  | ||||||
|  |  | ||||||
|     spin_pol = spin_or_pol(proc, dir, species, species_index) |     spin_pol = spin_or_pol(proc, dir, species, species_index) | ||||||
|  |  | ||||||
| @@ -147,11 +215,70 @@ function _base_state_name(p::VirtualParticle) | |||||||
|     ) |     ) | ||||||
| end | end | ||||||
|  |  | ||||||
|  | # from two or three node names like "1_su-2-px"... return a single tuple of the indices and spin/pols in sorted | ||||||
|  | function _parse_node_names(name1::String, name2::String) | ||||||
|  |     split_strings_1 = split.(split(name1, "-"), "_") | ||||||
|  |     split_strings_2 = split.(split(name2, "-"), "_") | ||||||
|  |  | ||||||
|  |     return tuple( | ||||||
|  |         # TODO: could use merge sort since the sub lists are sorted already | ||||||
|  |         sort([ | ||||||
|  |             tuple.( | ||||||
|  |                 parse.(Int, getindex.(split_strings_1, 1)), | ||||||
|  |                 parse.(AbstractSpinOrPolarization, getindex.(split_strings_1, 2)), | ||||||
|  |             )..., | ||||||
|  |             tuple.( | ||||||
|  |                 parse.(Int, getindex.(split_strings_2, 1)), | ||||||
|  |                 parse.(AbstractSpinOrPolarization, getindex.(split_strings_2, 2)), | ||||||
|  |             )..., | ||||||
|  |         ])..., | ||||||
|  |     ) | ||||||
|  | end | ||||||
|  | function _parse_node_names(name1::String, name2::String, name3::String) | ||||||
|  |     split_strings_1 = split.(split(name1, "-"), "_") | ||||||
|  |     split_strings_2 = split.(split(name2, "-"), "_") | ||||||
|  |     split_strings_3 = split.(split(name3, "-"), "_") | ||||||
|  |  | ||||||
|  |     return tuple( | ||||||
|  |         # TODO: could use merge sort since the sub lists are sorted already | ||||||
|  |         sort([ | ||||||
|  |             tuple.( | ||||||
|  |                 parse.(Int, getindex.(split_strings_1, 1)), | ||||||
|  |                 parse.(AbstractSpinOrPolarization, getindex.(split_strings_1, 2)), | ||||||
|  |             )..., | ||||||
|  |             tuple.( | ||||||
|  |                 parse.(Int, getindex.(split_strings_2, 1)), | ||||||
|  |                 parse.(AbstractSpinOrPolarization, getindex.(split_strings_2, 2)), | ||||||
|  |             )..., | ||||||
|  |             tuple.( | ||||||
|  |                 parse.(Int, getindex.(split_strings_3, 1)), | ||||||
|  |                 parse.(AbstractSpinOrPolarization, getindex.(split_strings_3, 2)), | ||||||
|  |             )..., | ||||||
|  |         ])..., | ||||||
|  |     ) | ||||||
|  | end | ||||||
|  |  | ||||||
|  | function _make_node_name( | ||||||
|  |     spin_pols::NTuple{N,Tuple{Int,AbstractSpinOrPolarization}} | ||||||
|  | ) where {N} | ||||||
|  |     node_name = "" | ||||||
|  |     first = true | ||||||
|  |     for spin_pol_tuple in spin_pols | ||||||
|  |         if !first | ||||||
|  |             node_name *= "-" | ||||||
|  |         else | ||||||
|  |             first = false | ||||||
|  |         end | ||||||
|  |         node_name *= "$(spin_pol_tuple[1])_$(_spin_pol_str(spin_pol_tuple[2]))" | ||||||
|  |     end | ||||||
|  |     return node_name | ||||||
|  | end | ||||||
|  |  | ||||||
| function generate_DAG(proc::GenericQEDProcess) | function generate_DAG(proc::GenericQEDProcess) | ||||||
|     external_particles = _pseudo_virtual_particles(proc) # external particles that will be input to base_state tasks |     external_particles = _pseudo_virtual_particles(proc) # external particles that will be input to base_state tasks | ||||||
|     particles = virtual_particles(proc)                  # virtual particles that will be input to propagator tasks |     particles = virtual_particles(proc)                  # virtual particles that will be input to propagator tasks | ||||||
|     pairs = particle_pairs(particles)                    # pairs to generate the pair tasks |     pairs = sort(particle_pairs(particles))              # pairs to generate the pair tasks | ||||||
|     triples = total_particle_triples(particles)          # triples to generate the triple tasks |     triples = sort(total_particle_triples(particles))    # triples to generate the triple tasks | ||||||
|  |  | ||||||
|     graph = DAG() |     graph = DAG() | ||||||
|  |  | ||||||
| @@ -181,7 +308,13 @@ function generate_DAG(proc::GenericQEDProcess) | |||||||
|                     ) |                     ) | ||||||
|  |  | ||||||
|                     data_out = insert_node!( |                     data_out = insert_node!( | ||||||
|                         graph, make_node(DataTask(0)); track=false, invalidate_cache=false |                         graph, | ||||||
|  |                         make_node( | ||||||
|  |                             DataTask(0), | ||||||
|  |                             "$(_total_index(proc, dir, species, index))_$(_spin_pol_str(spin_pol))", | ||||||
|  |                         ); | ||||||
|  |                         track=false, | ||||||
|  |                         invalidate_cache=false, | ||||||
|                     ) |                     ) | ||||||
|  |  | ||||||
|                     insert_edge!(graph, data_in, compute_base_state) |                     insert_edge!(graph, data_in, compute_base_state) | ||||||
| @@ -202,7 +335,10 @@ function generate_DAG(proc::GenericQEDProcess) | |||||||
|         data_node_name = "pr_$vp_index" |         data_node_name = "pr_$vp_index" | ||||||
|  |  | ||||||
|         data_in = insert_node!( |         data_in = insert_node!( | ||||||
|             graph, make_node(DataTask(0)); track=false, invalidate_cache=false |             graph, | ||||||
|  |             make_node(DataTask(0), data_node_name); | ||||||
|  |             track=false, | ||||||
|  |             invalidate_cache=false, | ||||||
|         ) |         ) | ||||||
|         compute_vp_propagator = insert_node!( |         compute_vp_propagator = insert_node!( | ||||||
|             graph, make_node(ComputeTask_Propagator()); track=false, invalidate_cache=false |             graph, make_node(ComputeTask_Propagator()); track=false, invalidate_cache=false | ||||||
| @@ -214,15 +350,19 @@ function generate_DAG(proc::GenericQEDProcess) | |||||||
|         insert_edge!(graph, data_in, compute_vp_propagator) |         insert_edge!(graph, data_in, compute_vp_propagator) | ||||||
|         insert_edge!(graph, compute_vp_propagator, data_out) |         insert_edge!(graph, compute_vp_propagator, data_out) | ||||||
|  |  | ||||||
|         propagator_task_outputs[data_node_name] = data_out |         propagator_task_outputs[vp] = data_out | ||||||
|     end |     end | ||||||
|  |  | ||||||
|     # -- Pair Tasks -- |     # -- Pair Tasks -- | ||||||
|     pair_task_outputs = Dict() |     pair_task_outputs = Dict{VirtualParticle,Vector{Node}}() | ||||||
|     for (product_particle, input_particle_vector) in pairs |     for (product_particle, input_particle_vector) in pairs | ||||||
|         # for all spins/pols of particles in product_particles do ... |         pair_task_outputs[product_particle] = Vector{Node}() | ||||||
|  |  | ||||||
|         pair_task_outputs[product_particle] = Vector() |         # make a dictionary of vectors to collect the outputs depending on spin/pol configs of the input particles | ||||||
|  |         N = number_contributions(product_particle) | ||||||
|  |         pair_output_nodes_by_spin_pol = Dict{ | ||||||
|  |             NTuple{N,Tuple{Int,AbstractSpinOrPolarization}},Vector{DataTaskNode} | ||||||
|  |         }() | ||||||
|  |  | ||||||
|         for input_particles in input_particle_vector |         for input_particles in input_particle_vector | ||||||
|             particles_data_out_nodes = (Vector(), Vector()) |             particles_data_out_nodes = (Vector(), Vector()) | ||||||
| @@ -237,28 +377,152 @@ function generate_DAG(proc::GenericQEDProcess) | |||||||
|                     ) |                     ) | ||||||
|                 else |                 else | ||||||
|                     # grab from propagated particles |                     # grab from propagated particles | ||||||
|                     push!(particles_date_out_nodes[c], pair_task_outputs[p]) |                     append!(particles_data_out_nodes[c], pair_task_outputs[p]) | ||||||
|                 end |                 end | ||||||
|             end |             end | ||||||
|  |  | ||||||
|             for in_nodes in Iterators.product(input_particles...) |             for in_nodes in Iterators.product(particles_data_out_nodes...) | ||||||
|                 # make the compute pair nodes for every combination of the found input_particle_nodes to get all spin/pol combinations |                 # make the compute pair nodes for every combination of the found input_particle_nodes to get all spin/pol combinations | ||||||
|  |                 compute_pair = insert_node!( | ||||||
|  |                     graph, | ||||||
|  |                     make_node(ComputeTask_Pair()); | ||||||
|  |                     track=false, | ||||||
|  |                     invalidate_cache=false, | ||||||
|  |                 ) | ||||||
|  |                 pair_data_out = insert_node!( | ||||||
|  |                     graph, make_node(DataTask(0)); track=false, invalidate_cache=false | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|                 #insert_node!(graph, ) |                 insert_edge!(graph, in_nodes[1], compute_pair) | ||||||
|  |                 insert_edge!(graph, in_nodes[2], compute_pair) | ||||||
|  |                 insert_edge!(graph, compute_pair, pair_data_out) | ||||||
|  |  | ||||||
|  |                 # get the spin/pol config of the input particles from the data_out names | ||||||
|  |                 index = _parse_node_names(in_nodes[1].name, in_nodes[2].name) | ||||||
|  |  | ||||||
|  |                 if !haskey(pair_output_nodes_by_spin_pol, index) | ||||||
|  |                     pair_output_nodes_by_spin_pol[index] = Vector() | ||||||
|  |                 end | ||||||
|  |                 push!(pair_output_nodes_by_spin_pol[index], pair_data_out) | ||||||
|             end |             end | ||||||
|             # make the collect pair and propagate nodes |  | ||||||
|  |  | ||||||
|         end |         end | ||||||
|  |  | ||||||
|         data_out_propagated = insert_node!( |         propagator_node = propagator_task_outputs[product_particle] | ||||||
|             graph, make_node(DataTask(0)); track=false, invalidate_caches=false |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         pair_task_outputs[p] = data_out_propagated |         for (index, nodes_to_sum) in pair_output_nodes_by_spin_pol | ||||||
|  |             compute_pairs_sum = insert_node!( | ||||||
|  |                 graph, | ||||||
|  |                 make_node(ComputeTask_CollectPairs(length(nodes_to_sum))); | ||||||
|  |                 track=false, | ||||||
|  |                 invalidate_cache=false, | ||||||
|  |             ) | ||||||
|  |             data_pairs_sum = insert_node!( | ||||||
|  |                 graph, make_node(DataTask(0)); track=false, invalidate_cache=false | ||||||
|  |             ) | ||||||
|  |             compute_propagated = insert_node!( | ||||||
|  |                 graph, | ||||||
|  |                 make_node(ComputeTask_PropagatePairs()); | ||||||
|  |                 track=false, | ||||||
|  |                 invalidate_cache=false, | ||||||
|  |             ) | ||||||
|  |             # give this out node the correct name | ||||||
|  |             data_out_propagated = insert_node!( | ||||||
|  |                 graph, | ||||||
|  |                 make_node(DataTask(0), _make_node_name(index)); | ||||||
|  |                 track=false, | ||||||
|  |                 invalidate_cache=false, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|         # ... end do |             for node in nodes_to_sum | ||||||
|  |                 insert_edge!(graph, node, compute_pairs_sum) | ||||||
|  |             end | ||||||
|  |  | ||||||
|  |             insert_edge!(graph, compute_pairs_sum, data_pairs_sum) | ||||||
|  |  | ||||||
|  |             insert_edge!(graph, propagator_node, compute_propagated) | ||||||
|  |             insert_edge!(graph, data_pairs_sum, compute_propagated) | ||||||
|  |  | ||||||
|  |             insert_edge!(graph, compute_propagated, data_out_propagated) | ||||||
|  |  | ||||||
|  |             push!(pair_task_outputs[product_particle], data_out_propagated) | ||||||
|  |         end | ||||||
|     end |     end | ||||||
|  |  | ||||||
|  |     # -- Triples -- | ||||||
|  |     triples_results = Dict() | ||||||
|  |     for (ph, el, po) in triples    # for each triple (each "diagram") | ||||||
|  |         photons = if is_external(ph) | ||||||
|  |             getindex.(Ref(base_state_task_outputs), _base_state_name(ph)) | ||||||
|  |         else | ||||||
|  |             pair_task_outputs[ph] | ||||||
|  |         end | ||||||
|  |         electrons = if is_external(el) | ||||||
|  |             getindex.(Ref(base_state_task_outputs), _base_state_name(el)) | ||||||
|  |         else | ||||||
|  |             pair_task_outputs[el] | ||||||
|  |         end | ||||||
|  |         positrons = if is_external(po) | ||||||
|  |             getindex.(Ref(base_state_task_outputs), _base_state_name(po)) | ||||||
|  |         else | ||||||
|  |             pair_task_outputs[po] | ||||||
|  |         end | ||||||
|  |         for (a, b, c) in Iterators.product(photons, electrons, positrons) # for each spin/pol config of each part | ||||||
|  |             compute_triples = insert_node!( | ||||||
|  |                 graph, make_node(ComputeTask_Triple()); track=false, invalidate_cache=false | ||||||
|  |             ) | ||||||
|  |             data_triples = insert_node!( | ||||||
|  |                 graph, make_node(DataTask(0)); track=false, invalidate_cache=false | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             insert_edge!(graph, a, compute_triples) | ||||||
|  |             insert_edge!(graph, b, compute_triples) | ||||||
|  |             insert_edge!(graph, c, compute_triples) | ||||||
|  |  | ||||||
|  |             insert_edge!(graph, compute_triples, data_triples) | ||||||
|  |  | ||||||
|  |             index = _parse_node_names(a.name, b.name, c.name) | ||||||
|  |             if !haskey(triples_results, index) | ||||||
|  |                 triples_results[index] = Vector{DataTaskNode}() | ||||||
|  |             end | ||||||
|  |             push!(triples_results[index], data_triples) | ||||||
|  |         end | ||||||
|  |     end | ||||||
|  |  | ||||||
|  |     # -- Collect Triples -- | ||||||
|  |     collected_triples = Vector{DataTaskNode}() | ||||||
|  |     for (index, results) in triples_results | ||||||
|  |         compute_collect_triples = insert_node!( | ||||||
|  |             graph, | ||||||
|  |             make_node(ComputeTask_CollectTriples(length(results))); | ||||||
|  |             track=false, | ||||||
|  |             invalidate_cache=false, | ||||||
|  |         ) | ||||||
|  |         data_collect_triples = insert_node!( | ||||||
|  |             graph, make_node(DataTask(0)); track=false, invalidate_cache=false | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         for triple in results | ||||||
|  |             insert_edge!(graph, triple, compute_collect_triples) | ||||||
|  |         end | ||||||
|  |         insert_edge!(graph, compute_collect_triples, data_collect_triples) | ||||||
|  |  | ||||||
|  |         push!(collected_triples, data_collect_triples) | ||||||
|  |     end | ||||||
|  |  | ||||||
|  |     # Finally, abs2 sum over spin/pol configurations | ||||||
|  |     compute_total_result = insert_node!( | ||||||
|  |         graph, | ||||||
|  |         make_node(ComputeTask_SpinPolCumulation(length(collected_triples))); | ||||||
|  |         track=false, | ||||||
|  |         invalidate_cache=false, | ||||||
|  |     ) | ||||||
|  |     for finished_triple in collected_triples | ||||||
|  |         insert_edge!(graph, finished_triple, compute_total_result) | ||||||
|  |     end | ||||||
|  |  | ||||||
|  |     final_data_out = insert_node!( | ||||||
|  |         graph, make_node(DataTask(0)); track=false, invalidate_cache=false | ||||||
|  |     ) | ||||||
|  |     insert_edge!(graph, compute_total_result, final_data_out) | ||||||
|     return graph |     return graph | ||||||
| end | end | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user