diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml
index e6b5e1f..d43955e 100644
--- a/.JuliaFormatter.toml
+++ b/.JuliaFormatter.toml
@@ -1,5 +1,5 @@
indent = 4
-margin = 80
+margin = 120
always_for_in = true
for_in_replacement = "in"
whitespace_typedefs = true
diff --git a/.gitea/workflows/julia-package-ci.yml b/.gitea/workflows/julia-package-ci.yml
index bcf8481..99ae2f2 100644
--- a/.gitea/workflows/julia-package-ci.yml
+++ b/.gitea/workflows/julia-package-ci.yml
@@ -108,7 +108,7 @@ jobs:
- name: Format check
run: |
- julia --project=./ -e 'using JuliaFormatter; format(".", verbose=true)'
+ julia --project=./ -e 'using JuliaFormatter; format(".", verbose=true, ignore=[".julia/*"])'
julia --project=./ -e '
out = Cmd(`git diff --name-only`) |> read |> String
if out == ""
diff --git a/.gitignore b/.gitignore
index 75c67bc..423c8ef 100644
--- a/.gitignore
+++ b/.gitignore
@@ -26,3 +26,5 @@ Manifest.toml
# vscode workspace directory
.vscode
+.julia
+**/.ipynb_checkpoints/
diff --git a/Project.toml b/Project.toml
index a6000e3..a15ad52 100644
--- a/Project.toml
+++ b/Project.toml
@@ -5,9 +5,15 @@ version = "0.1.0"
[deps]
AccurateArithmetic = "22286c92-06ac-501d-9306-4abd417d9753"
+CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
+ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
+KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
+NumaAllocators = "21436f30-1b4a-4f08-87af-e26101bb5379"
+QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
+Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[extras]
diff --git a/README.md b/README.md
index 1cb9a8b..f09438a 100644
--- a/README.md
+++ b/README.md
@@ -42,7 +42,7 @@ Problems:
- Lots of testing required because mistakes will propagate and multiply.
## Other TODOs
-- Reduce memory footprint of the graph, are the UUIDs too large?
+- Reduce memory footprint of the graph
- Memory layout of Nodes? They should lie linearly in memory, right now probably on heap?
- Add scaling functions
@@ -53,7 +53,7 @@ For graphs AB->AB^n:
- Number of ComputeTaskS2 should always be (n+1)!
- Number of ComputeTaskU should always be (n+3)
-Times are from my home machine: AMD Ryzen 7900X3D, 64GB DDR5 RAM @ 6000MHz
+Times are from my home machine: AMD Ryzen 7900X3D, 64GB DDR5 RAM @ 6000MHz (not necessarily up to date, check Jupyter Notebooks in `notebooks/` instead)
```
$ julia --project examples/import_bench.jl
diff --git a/docs/make.jl b/docs/make.jl
index 5c81178..8b40fa1 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -27,6 +27,7 @@ makedocs(
"Diff" => "lib/internals/diff.md",
"Utility" => "lib/internals/utility.md",
"Code Generation" => "lib/internals/code_gen.md",
+ "Devices" => "lib/internals/devices.md",
],
"Contribution" => "contribution.md",
],
diff --git a/docs/src/flowchart.drawio b/docs/src/flowchart.drawio
new file mode 100644
index 0000000..321e7a3
--- /dev/null
+++ b/docs/src/flowchart.drawio
@@ -0,0 +1,75 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/docs/src/lib/internals/devices.md b/docs/src/lib/internals/devices.md
new file mode 100644
index 0000000..8bd8191
--- /dev/null
+++ b/docs/src/lib/internals/devices.md
@@ -0,0 +1,59 @@
+# Devices
+
+## Interface
+```@autodocs
+Modules = [MetagraphOptimization]
+Pages = ["devices/interface.jl"]
+Order = [:type, :constant, :function]
+```
+
+## Detect
+```@autodocs
+Modules = [MetagraphOptimization]
+Pages = ["devices/detect.jl"]
+Order = [:function]
+```
+
+## Measure
+```@autodocs
+Modules = [MetagraphOptimization]
+Pages = ["devices/measure.jl"]
+Order = [:function]
+```
+
+## Implementations
+
+### General
+```@autodocs
+Modules = [MetagraphOptimization]
+Pages = ["devices/impl.jl"]
+Order = [:type, :function]
+```
+
+### NUMA
+```@autodocs
+Modules = [MetagraphOptimization]
+Pages = ["devices/numa/impl.jl"]
+Order = [:type, :function]
+```
+
+### CUDA
+```@autodocs
+Modules = [MetagraphOptimization]
+Pages = ["devices/cuda/impl.jl"]
+Order = [:type, :function]
+```
+
+### ROCm
+```@autodocs
+Modules = [MetagraphOptimization]
+Pages = ["devices/rocm/impl.jl"]
+Order = [:type, :function]
+```
+
+### oneAPI
+```@autodocs
+Modules = [MetagraphOptimization]
+Pages = ["devices/oneapi/impl.jl"]
+Order = [:type, :function]
+```
diff --git a/docs/src/lib/internals/models.md b/docs/src/lib/internals/models.md
index 192b91c..a258ce5 100644
--- a/docs/src/lib/internals/models.md
+++ b/docs/src/lib/internals/models.md
@@ -1,5 +1,21 @@
# Models
+## Interface
+
+The interface that has to be implemented for a model to be usable is defined in `src/models/interface.jl`.
+
+```@autodocs
+Modules = [MetagraphOptimization]
+Pages = ["models/interface.jl"]
+Order = [:type, :constant, :function]
+```
+
+```@autodocs
+Modules = [MetagraphOptimization]
+Pages = ["models/print.jl"]
+Order = [:function]
+```
+
## ABC-Model
### Types
@@ -44,6 +60,13 @@ Pages = ["models/abc/compute.jl"]
Order = [:function]
```
+### Print
+```@autodocs
+Modules = [MetagraphOptimization]
+Pages = ["models/abc/print.jl"]
+Order = [:function]
+```
+
## QED-Model
*To be added*
diff --git a/docs/src/lib/internals/scheduler.md b/docs/src/lib/internals/scheduler.md
new file mode 100644
index 0000000..ec973b0
--- /dev/null
+++ b/docs/src/lib/internals/scheduler.md
@@ -0,0 +1,15 @@
+# Scheduler
+
+## Interface
+```@autodocs
+Modules = [MetagraphOptimization]
+Pages = ["scheduler/interface.jl"]
+Order = [:type, :function]
+```
+
+## Greedy
+```@autodocs
+Modules = [MetagraphOptimization]
+Pages = ["scheduler/greedy.jl"]
+Order = [:type, :function]
+```
diff --git a/docs/src/lib/internals/task.md b/docs/src/lib/internals/task.md
index e7debbb..b73939f 100644
--- a/docs/src/lib/internals/task.md
+++ b/docs/src/lib/internals/task.md
@@ -21,6 +21,13 @@ Pages = ["task/compare.jl"]
Order = [:function]
```
+## Compute
+```@autodocs
+Modules = [MetagraphOptimization]
+Pages = ["task/compute.jl"]
+Order = [:function]
+```
+
## Properties
```@autodocs
Modules = [MetagraphOptimization]
diff --git a/docs/src/manual.md b/docs/src/manual.md
index 827c6b2..63b4a1b 100644
--- a/docs/src/manual.md
+++ b/docs/src/manual.md
@@ -1,3 +1,7 @@
# Manual
-This will become a manual.
+## Jupyter Notebooks
+
+In the `notebooks` directory are notebooks containing some examples of the usage of this repository.
+
+- `abc_model_showcase`: A simple showcase of the intended usage of the ABC Model implementation.
diff --git a/examples/import_bench.jl b/examples/import_bench.jl
index 64a5000..5143504 100644
--- a/examples/import_bench.jl
+++ b/examples/import_bench.jl
@@ -13,16 +13,15 @@ function bench_txt(filepath::String, bench::Bool = true)
return
end
+ model = ABCModel()
+
println(name, ":")
- g = parse_abc(filepath)
+ g = parse_dag(filepath, model)
print(g)
- println(
- " Graph size in memory: ",
- bytes_to_human_readable(MetagraphOptimization.mem(g)),
- )
+ println(" Graph size in memory: ", bytes_to_human_readable(MetagraphOptimization.mem(g)))
if (bench)
- @btime parse_abc($filepath)
+ @btime parse_dag($filepath, $model)
end
println(" Get Operations: ")
diff --git a/examples/plot_chain.jl b/examples/plot_chain.jl
index 4e9cb1c..3b57044 100644
--- a/examples/plot_chain.jl
+++ b/examples/plot_chain.jl
@@ -12,7 +12,7 @@ function gen_plot(filepath)
return
end
- g = parse_abc(filepath)
+ g = parse_dag(filepath, ABCModel())
Random.seed!(1)
@@ -48,23 +48,10 @@ function gen_plot(filepath)
println("\rDone.")
- plot(
- [x[1], x[2]],
- [y[1], y[2]],
- linestyle = :solid,
- linewidth = 1,
- color = :red,
- legend = false,
- )
+ plot([x[1], x[2]], [y[1], y[2]], linestyle = :solid, linewidth = 1, color = :red, legend = false)
# Create lines connecting the reference point to each data point
for i in 3:length(x)
- plot!(
- [x[i - 1], x[i]],
- [y[i - 1], y[i]],
- linestyle = :solid,
- linewidth = 1,
- color = :red,
- )
+ plot!([x[i - 1], x[i]], [y[i - 1], y[i]], linestyle = :solid, linewidth = 1, color = :red)
end
return gui()
diff --git a/examples/plot_star.jl b/examples/plot_star.jl
index 3f82fb5..8f8ad68 100644
--- a/examples/plot_star.jl
+++ b/examples/plot_star.jl
@@ -12,7 +12,7 @@ function gen_plot(filepath)
return
end
- g = parse_abc(filepath)
+ g = parse_dag(filepath, ABCModel())
Random.seed!(1)
@@ -60,14 +60,7 @@ function gen_plot(filepath)
push!(y, props.computeEffort)
pop_operation!(g)
- push!(
- names,
- "NF: (" *
- string(props.data) *
- ", " *
- string(props.computeEffort) *
- ")",
- )
+ push!(names, "NF: (" * string(props.data) * ", " * string(props.computeEffort) * ")")
end
for op in opt.nodeReductions
push_operation!(g, op)
@@ -76,14 +69,7 @@ function gen_plot(filepath)
push!(y, props.computeEffort)
pop_operation!(g)
- push!(
- names,
- "NR: (" *
- string(props.data) *
- ", " *
- string(props.computeEffort) *
- ")",
- )
+ push!(names, "NR: (" * string(props.data) * ", " * string(props.computeEffort) * ")")
end
for op in opt.nodeSplits
push_operation!(g, op)
@@ -92,33 +78,13 @@ function gen_plot(filepath)
push!(y, props.computeEffort)
pop_operation!(g)
- push!(
- names,
- "NS: (" *
- string(props.data) *
- ", " *
- string(props.computeEffort) *
- ")",
- )
+ push!(names, "NS: (" * string(props.data) * ", " * string(props.computeEffort) * ")")
end
- plot(
- [x0, x[1]],
- [y0, y[1]],
- linestyle = :solid,
- linewidth = 1,
- color = :red,
- legend = false,
- )
+ plot([x0, x[1]], [y0, y[1]], linestyle = :solid, linewidth = 1, color = :red, legend = false)
# Create lines connecting the reference point to each data point
for i in 2:length(x)
- plot!(
- [x0, x[i]],
- [y0, y[i]],
- linestyle = :solid,
- linewidth = 1,
- color = :red,
- )
+ plot!([x0, x[i]], [y0, y[i]], linestyle = :solid, linewidth = 1, color = :red)
end
#scatter!(x, y, label=names)
diff --git a/examples/profiling_utilities.jl b/examples/profiling_utilities.jl
index eb99b88..2400567 100644
--- a/examples/profiling_utilities.jl
+++ b/examples/profiling_utilities.jl
@@ -1,6 +1,6 @@
-function test_random_walk(g::DAG, n::Int64)
- # the purpose here is to do "random" operations and reverse them again and validate that the graph stays the same and doesn't diverge
+function random_walk!(g::DAG, n::Int64)
+ # the purpose here is to do "random" operations on the graph to simulate an optimizer
reset_graph!(g)
properties = get_properties(g)
@@ -32,7 +32,7 @@ function test_random_walk(g::DAG, n::Int64)
end
end
- return reset_graph!(g)
+ return nothing
end
function reduce_all!(g::DAG)
diff --git a/input/AB->ABBBBBBBBB.txt b/input/AB->ABBBBBBBBB.txt
index 80c4c5e..b93004a 100644
Binary files a/input/AB->ABBBBBBBBB.txt and b/input/AB->ABBBBBBBBB.txt differ
diff --git a/notebooks/abc_model_large.ipynb b/notebooks/abc_model_large.ipynb
new file mode 100644
index 0000000..c4d0b99
--- /dev/null
+++ b/notebooks/abc_model_large.ipynb
@@ -0,0 +1,678 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "using MetagraphOptimization"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Graph:\n",
+ " Nodes: Total: 438436, ComputeTaskP: 10, ComputeTaskU: 10, \n",
+ " ComputeTaskV: 109600, ComputeTaskSum: 1, ComputeTaskS2: 40320, \n",
+ " ComputeTaskS1: 69272, DataTask: 219223\n",
+ " Edges: 628665\n",
+ " Total Compute Effort: 1.903443e6\n",
+ " Total Data Transfer: 1.8040896e7\n",
+ " Total Compute Intensity: 0.10550712115407128\n"
+ ]
+ }
+ ],
+ "source": [
+ "model = ABCModel()\n",
+ "process_str = \"AB->ABBBBBBB\"\n",
+ "process = parse_process(process_str, model)\n",
+ "graph = parse_dag(\"../input/$process_str.txt\", model)\n",
+ "print(graph)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "351.606942 seconds (1.13 G allocations: 25.949 GiB, 1.33% gc time, 0.72% compilation time)\n",
+ "Graph:\n",
+ " Nodes: Total: 277188, ComputeTaskP: 10, ComputeTaskU: 10, \n",
+ " ComputeTaskV: 69288, ComputeTaskSum: 1, ComputeTaskS2: 40320, \n",
+ " ComputeTaskS1: 28960, DataTask: 138599\n",
+ " Edges: 427105\n",
+ " Total Compute Effort: 1.218139e6\n",
+ " Total Data Transfer: 1.2235968e7\n",
+ " Total Compute Intensity: 0.0995539543745129\n"
+ ]
+ }
+ ],
+ "source": [
+ "include(\"../examples/profiling_utilities.jl\")\n",
+ "@time reduce_all!(graph)\n",
+ "print(graph)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Found 1 NUMA nodes\n",
+ "CUDA is non-functional\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Get machine and set dictionary caching strategy\n",
+ "machine = get_machine_info()\n",
+ "MetagraphOptimization.set_cache_strategy(machine.devices[1], MetagraphOptimization.Dictionary())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2315.896312 seconds (87.18 M allocations: 132.726 GiB, 0.11% gc time, 0.04% compilation time)\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "compute__8fd7c454_6214_11ee_3616_0f2435e477fe (generic function with 1 method)"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "@time compute_AB_AB7 = get_compute_function(graph, process, machine)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " 1.910169 seconds (4.34 M allocations: 278.284 MiB, 6.25% gc time, 99.23% compilation time)\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "1000-element Vector{ABCProcessInput}:\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [8.411745173347825, 0.0, 0.0, 8.352092962924948]\n",
+ " B: [8.411745173347825, 0.0, 0.0, -8.352092962924948]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-2.003428483168789, 1.2386385417950023, -0.8321671195319228, 0.8871291535745444]\n",
+ " B: [-2.444326994820653, 1.1775023368116424, -0.9536682034633904, 1.6366855721594777]\n",
+ " B: [-4.289211829680359, -3.7216649121036443, 1.128125248220305, 1.50793959634144]\n",
+ " B: [-1.2727607454602508, 0.07512513775641204, 0.6370236198332677, -0.45659285653208986]\n",
+ " B: [-1.8777156401619268, -1.042329795325101, -0.5508846238377632, -1.0657817573524957]\n",
+ " B: [-1.1322368113474306, 0.0498922458527246, -0.2963537951915457, -0.4377732162313449]\n",
+ " B: [-1.4340705015357569, 0.7798902829682378, 0.144450581630926, -0.6538068364381232]\n",
+ " B: [-2.369739340520482, 1.4429461622447262, 0.7234742923401235, -1.4177996555214083]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [8.262146117199348, 0.0, 0.0, 8.201405883258813]\n",
+ " B: [8.262146117199348, 0.0, 0.0, -8.201405883258813]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-2.022253637967156, 0.040616190652067494, 1.5789161216660899, -0.7712872241073523]\n",
+ " B: [-1.085155894223277, -0.4013306445746292, 0.044561160964560184, -0.12046298778597243]\n",
+ " B: [-2.3099664718736963, -0.6028883246226666, 0.7721426580907682, 1.8374619682515352]\n",
+ " B: [-3.8528592267292674, -1.1057919702708323, -3.154341441424319, -1.6345881470237529]\n",
+ " B: [-1.445065980497648, -0.3803292238069696, -0.9038074225417192, 0.3559459403736899]\n",
+ " B: [-1.637993216461692, 0.18276067729419151, -0.6165325663294264, 1.1267244146927589]\n",
+ " B: [-3.0791604558286254, 1.8666082398498536, 2.1149851082876507, -0.7237684597886623]\n",
+ " B: [-1.091837350817336, 0.4003550554789843, 0.16407638128639515, -0.0700255046122441]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [9.522164300929319, 0.0, 0.0, 9.4695096480173]\n",
+ " B: [9.522164300929319, 0.0, 0.0, -9.4695096480173]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-2.2614545815907876, 0.09596466269330481, -1.680314037563078, -1.1320390202111377]\n",
+ " B: [-2.5164555101345942, 2.0544568173259474, 0.7608284478099104, 0.7299969816600982]\n",
+ " B: [-3.527555187469315, 3.1461533872404055, -0.4998113855480195, 1.1382236350884531]\n",
+ " B: [-1.5843416170605953, -0.649775322646379, 0.6368565466386346, -0.8260412390634552]\n",
+ " B: [-1.0715042390215452, 0.33101538188959895, -0.19275377509309963, -0.037364868271978664]\n",
+ " B: [-1.8269658913133924, -1.2104472444295427, -0.7036857693244948, 0.6143681099517287]\n",
+ " B: [-1.7510547915269752, 0.35168054121444203, 0.408535633181173, -1.3325210378384098]\n",
+ " B: [-4.504996783741433, -4.119048223287777, 1.270344339898973, 0.8453774386847008]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [7.225275339000687, 0.0, 0.0, 7.1557392157883655]\n",
+ " B: [7.225275339000687, 0.0, 0.0, -7.1557392157883655]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-1.5721586195862234, -0.6346644373772993, 0.7957285133297657, -0.6600756851617959]\n",
+ " B: [-1.0093393293662618, -0.11321130994303012, 0.07324286826550051, -0.024177745030521003]\n",
+ " B: [-2.7355755394886443, 0.2329840388558535, -2.4939308642531, -0.4576033371958622]\n",
+ " B: [-1.618399027736879, -0.47727357006920945, 1.0132042772011558, -0.6040218911217943]\n",
+ " B: [-1.7201610947708947, 0.01110230391313025, 0.8839000043421623, -1.0851505486038107]\n",
+ " B: [-1.792300907703241, 0.8101193095744785, -0.625916307414256, 1.0790171565463333]\n",
+ " B: [-1.5563810656498285, -1.1865287585293671, 0.12019738267353275, -0.004910793671790455]\n",
+ " B: [-2.4462350936994026, 1.3574724235754438, 0.2335741258552372, 1.7569228442392408]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [7.94532861335446, 0.0, 0.0, 7.882147345374172]\n",
+ " B: [7.94532861335446, 0.0, 0.0, -7.882147345374172]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-2.118671714766621, -0.6322452591326608, -1.2236882164873555, -1.2615953852509143]\n",
+ " B: [-2.560753710001491, -1.7412395645571277, -1.5891033163317627, 0.01717533495153369]\n",
+ " B: [-1.5550581087132076, -0.639122838128628, -0.9624327134008909, 0.2888788525193626]\n",
+ " B: [-2.181477133464949, 0.4918918998013713, 1.8559068969600523, -0.2692479016749415]\n",
+ " B: [-1.2628370388798702, -0.4013500667990802, 0.24813196852393224, 0.6100049482124643]\n",
+ " B: [-1.901139724448186, 1.3625293914322611, -0.8176066997802711, 0.2989401174693193]\n",
+ " B: [-2.2302691928842697, -0.1867565668705846, 1.9609184768063308, 0.3066290670808993]\n",
+ " B: [-2.0804506035503256, 1.7462930042544484, 0.5278736037099664, 0.009214966692276028]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [5.597768901835826, 0.0, 0.0, 5.507723366179557]\n",
+ " B: [5.597768901835826, 0.0, 0.0, -5.507723366179557]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-1.0009073340208385, 0.03522831505376105, -0.010844681575969111, -0.021374049609080487]\n",
+ " B: [-1.3943823799403026, -0.886019044587247, 0.21582726795187737, -0.3356948979730148]\n",
+ " B: [-1.0593061926863385, 0.3261714964515558, -0.10930051701751846, -0.06160488410736567]\n",
+ " B: [-1.0190344437384602, 0.02512063114228613, 0.04379726771854621, -0.18942531709556668]\n",
+ " B: [-1.0919277601624486, -0.39612686480944176, 0.07078221355247243, -0.17429750036714983]\n",
+ " B: [-1.8292258091360047, 1.1565638126055895, 0.329244535677723, 0.9486966026643375]\n",
+ " B: [-1.7379569022732355, 0.6562121276078657, 0.7749535141539342, -0.9946491284065995]\n",
+ " B: [-2.0627969817140217, -0.9171504734643696, -1.3144596004610647, 0.8283491748944392]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [6.860362769879496, 0.0, 0.0, 6.787089017712134]\n",
+ " B: [6.860362769879496, 0.0, 0.0, -6.787089017712134]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-2.1483538194490985, 1.8204047500578164, 0.1342978924269131, -0.532461036694855]\n",
+ " B: [-1.2136825716769264, 0.12932805245115084, -0.43609629710270903, -0.5158678699965871]\n",
+ " B: [-3.3642987422516573, -1.7653207470663739, 0.533955101409256, 2.630026736893018]\n",
+ " B: [-1.053677321951765, 0.11000921943972916, 0.04739423847128557, -0.30965732123337875]\n",
+ " B: [-1.2932387925896982, -0.6843810329952256, 0.045636429012288295, -0.4494513240410521]\n",
+ " B: [-1.1237194151971648, -0.45140047643622017, 0.19994785657222267, -0.13785422959193222]\n",
+ " B: [-1.7619597212239484, 1.3299261857304887, 0.561749934748497, 0.1422512233127988]\n",
+ " B: [-1.7617951554187332, -0.488565951181366, -1.0868851555377534, -0.8269861786480115]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [9.57507915889135, 0.0, 0.0, 9.522717096450755]\n",
+ " B: [9.57507915889135, 0.0, 0.0, -9.522717096450755]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-3.4305207411483516, 2.6682294806816835, -1.883054168339437, -0.3211401453721668]\n",
+ " B: [-2.185574270107571, 1.4558232366821502, 1.2235951792097912, 0.40016050668089054]\n",
+ " B: [-3.0259648593433583, -0.9184166853584697, -0.10930222461665634, -2.7020412923806107]\n",
+ " B: [-3.246659025038245, -2.493839704051011, -1.0189869044243565, 1.5110340975546257]\n",
+ " B: [-1.4247322676315595, 0.05954103854817788, 0.9940897925990366, -0.19519831815252583]\n",
+ " B: [-1.4889906300188005, 0.5912092032645169, -0.19371449043911573, -0.9110650198822441]\n",
+ " B: [-1.1268952499657272, 0.36236812621338876, -0.3636229828302436, 0.07975319340034331]\n",
+ " B: [-3.220821274529085, -1.7249146959804351, 1.350995798840981, 2.1384969781516885]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [8.472852690841874, 0.0, 0.0, 8.413633740584764]\n",
+ " B: [8.472852690841874, 0.0, 0.0, -8.413633740584764]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-1.1530011327357317, 0.34211475449117323, -0.45923141786607913, -0.03841369149190832]\n",
+ " B: [-2.62915067223017, 1.042431210232047, 0.6288618003426715, -2.1048285595963105]\n",
+ " B: [-1.1265473249385953, -0.4344882737979479, -0.1553035746380426, 0.2370856700921221]\n",
+ " B: [-1.4826889242092416, -0.5889894099544346, -0.45026884678673923, -0.8054290077639529]\n",
+ " B: [-4.118520088756618, -2.101194203160593, -3.0008966741533745, 1.5943054265577095]\n",
+ " B: [-3.9992129109551517, 1.0607252636964415, 3.6847882851419875, 0.539352496783755]\n",
+ " B: [-1.3172538577755006, 0.4084669000294691, -0.6351790575407871, 0.4060296568803221]\n",
+ " B: [-1.1193304700827373, 0.2709337584638445, 0.3872294855003629, 0.17189800853826395]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [5.913538688235051, 0.0, 0.0, 5.828373685450576]\n",
+ " B: [5.913538688235051, 0.0, 0.0, -5.828373685450576]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-1.6813734506828508, -1.1942921586618185, -0.384476919421686, 0.5028522833318558]\n",
+ " B: [-1.412586238014363, 0.010275442474480664, 0.8780055986304257, -0.4737092609218783]\n",
+ " B: [-1.5338446207986793, 1.1234162145644635, 0.1670274754582306, -0.25043392751132176]\n",
+ " B: [-1.4260274101869397, 0.9023875675844153, -0.4646063309051003, -0.058239245843783906]\n",
+ " B: [-1.1055189977833793, -0.3699146930280028, 0.2809292901965394, -0.08008812803177658]\n",
+ " B: [-1.1926016738662872, 0.4242726765633766, 0.34415633034138016, -0.3519202590308968]\n",
+ " B: [-1.4188061371181722, 0.47356120240959365, 0.33662773751584696, 0.8218469496393668]\n",
+ " B: [-2.0563188480194308, -1.3697062519065082, -1.1576631818156364, -0.1103084116315648]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [6.062750568659298, 0.0, 0.0, 5.979711068085032]\n",
+ " B: [6.062750568659298, 0.0, 0.0, -5.979711068085032]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-1.1157392140073992, -0.0424317149721654, 0.4662958482482185, -0.16013033799016252]\n",
+ " B: [-2.395340693850968, -1.171776361305547, -1.746409249879336, 0.5609384374776449]\n",
+ " B: [-1.0289722654275464, 0.23139962589771268, 0.07055331234631396, 0.01613586906426155]\n",
+ " B: [-1.212565238145815, -0.6377842504248107, 0.04163119753237706, 0.24862129848767983]\n",
+ " B: [-1.8156755638105053, -0.3987185167288875, 1.2510245302740972, 0.7567290942527487]\n",
+ " B: [-2.003891077687212, 1.2159250459117166, 0.38048599808923245, -1.1799729400359336]\n",
+ " B: [-1.4663599649673638, 0.593985649692284, -0.7733488095969958, -0.44645740391848543]\n",
+ " B: [-1.086957119421786, 0.20940052192969777, 0.3097671729860923, 0.20413598266224653]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [7.088363151833832, 0.0, 0.0, 7.017470496715726]\n",
+ " B: [7.088363151833832, 0.0, 0.0, -7.017470496715726]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-3.1474601133746627, 0.14412280671945385, 2.7364508363525357, 1.1821889028802701]\n",
+ " B: [-1.256451004773104, 0.1153142495225348, -0.7455659837621855, -0.09748392231091944]\n",
+ " B: [-1.4964417911663928, -0.0996845872039782, -0.8492275192498467, 0.7128910421459969]\n",
+ " B: [-3.2499484244824526, -0.8927423628721523, -1.0242747556675866, -2.777775559729678]\n",
+ " B: [-1.0489067674373789, -0.31603136975662793, 0.016268502528308637, -0.008057042333727152]\n",
+ " B: [-1.6957667777105587, 1.0857339287179024, 0.6252297389508089, 0.5530773670555896]\n",
+ " B: [-1.243679438145053, 0.06348629097723194, -0.7145975145476898, 0.17904867473682565]\n",
+ " B: [-1.0380719865780628, -0.10019895610436466, -0.044283304604344965, 0.2561105375556422]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [9.842517855137334, 0.0, 0.0, 9.791586068084028]\n",
+ " B: [9.842517855137334, 0.0, 0.0, -9.791586068084028]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-1.0081083393933719, 0.09315850477843095, -0.05390640772287413, 0.06854207575149836]\n",
+ " B: [-1.2533776879399583, -0.09567218890986252, -0.022562148977002077, -0.749195175056841]\n",
+ " B: [-4.199102452438099, 3.1204551726062775, 2.23725963921713, 1.3747327844190023]\n",
+ " B: [-5.1018332572388285, -4.999892707918183, 0.09407944148737099, -0.14465321518774693]\n",
+ " B: [-3.7582268429742243, 2.1814891293707577, -1.5410280493623207, -2.4475715991095703]\n",
+ " B: [-1.1792132348986593, 0.6125282131702711, -0.12369433042852651, -0.007263198361168502]\n",
+ " B: [-1.3600169327450258, -0.07835376476887727, -0.6694537001487819, 0.6287594836317273]\n",
+ " B: [-1.8251569626465018, -0.8337123583288142, 0.07930555593500455, 1.2766488439130985]\n",
+ "\n",
+ " ⋮\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [9.861596443743153, 0.0, 0.0, 9.810763702141012]\n",
+ " B: [9.861596443743153, 0.0, 0.0, -9.810763702141012]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-1.8179384769334697, 0.9572508915748105, -0.9794338269553214, 0.6551949443563104]\n",
+ " B: [-2.1028582035167607, -0.7676665378472812, 0.6218562087985972, -1.5639678917247444]\n",
+ " B: [-3.1263866679666865, 2.3808322573838474, -1.6099851834448586, 0.7168535896041835]\n",
+ " B: [-5.177179415841987, -1.3605325795287053, 4.805481256903438, -0.9270855911989424]\n",
+ " B: [-1.2605754590213083, -0.023284320526100116, -0.14250915308265208, 0.7537900699744495]\n",
+ " B: [-2.712925004518324, -1.4343063146086636, -1.452340398698398, 1.4810249296764189]\n",
+ " B: [-2.3798188172675734, 0.6412170781802653, -1.487389994435021, -1.4283029321979925]\n",
+ " B: [-1.1455108424201939, -0.39351047462817185, 0.24432109091421514, 0.3124928815103169]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [5.611571819338176, 0.0, 0.0, 5.521751378284825]\n",
+ " B: [5.611571819338176, 0.0, 0.0, -5.521751378284825]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-1.0759150984150232, -0.3903007964405737, 0.045679777762273936, -0.05632002484775736]\n",
+ " B: [-1.021003529021616, -0.07269336486556076, 0.11388411952175649, 0.15554513267817288]\n",
+ " B: [-1.6939705353811365, -0.1440535362616654, -0.25084793375093056, -1.3363607550219565]\n",
+ " B: [-1.185801144621379, -0.31618880274591826, 0.5459120200606805, -0.09016131075324207]\n",
+ " B: [-1.197431131926246, 0.16472462054297168, -0.17198607315407527, -0.6141074056988615]\n",
+ " B: [-1.0089442324730478, -0.12314856400749492, -0.027052115631495212, -0.04550910308256443]\n",
+ " B: [-2.703474424566498, 0.16902217864171518, -0.14049660772763695, 2.502092358533033]\n",
+ " B: [-1.3366035422714058, 0.7126382651365266, -0.11509318708057305, -0.5151788918068239]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [8.775111706253933, 0.0, 0.0, 8.717946171962454]\n",
+ " B: [8.775111706253933, 0.0, 0.0, -8.717946171962454]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-2.2750151423103953, 1.8467170131598, 0.8729070809034145, 0.05799482008261441]\n",
+ " B: [-1.5756212156561644, 1.0377655822554295, 0.3001332912880399, 0.5617337616455574]\n",
+ " B: [-1.6945981163898138, -0.5153714693329569, 0.050834292767083435, 1.2662823142365867]\n",
+ " B: [-2.630307241578496, -0.5126707368632603, 1.3344949978186418, -1.9684532002212756]\n",
+ " B: [-3.0848917600353407, -2.827901193400985, -0.46541663267058264, -0.5503811129833626]\n",
+ " B: [-2.812675339815945, 2.346626876124383, -1.1757879806725677, 0.14834923648401968]\n",
+ " B: [-1.695817659938434, -0.3817827622891304, -0.19598317768122073, 1.3006267920675472]\n",
+ " B: [-1.7812969367832734, -0.9933833096532803, -0.7211818717528079, -0.8161526113116866]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [6.832501783927461, 0.0, 0.0, 6.758925996589395]\n",
+ " B: [6.832501783927461, 0.0, 0.0, -6.758925996589395]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-1.0114752465345387, -0.11558780230223581, -0.03776248532804595, -0.09108034372406744]\n",
+ " B: [-1.031154612454516, -0.04425244057817861, -0.0789748074180023, -0.23470095032271823]\n",
+ " B: [-2.2555952063288855, 1.7491237654517413, -0.4233804231771479, -0.9214254203222908]\n",
+ " B: [-2.089561973736715, 0.9235335217807571, 1.3477207222453012, -0.8348676128969853]\n",
+ " B: [-1.3199981586264844, -0.6902187266500668, -0.06216816149242132, -0.5119847340063199]\n",
+ " B: [-1.0105028642371863, -0.09317036739551621, -0.041275823376393385, -0.1035935696630954]\n",
+ " B: [-1.2426376312622325, -0.48126859609618416, 0.05225488689293943, -0.5565952280036419]\n",
+ " B: [-3.704077874674367, -1.2481593542103167, -0.7564139083462295, 3.254247858939119]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [8.775903429741401, 0.0, 0.0, 8.718743086485969]\n",
+ " B: [8.775903429741401, 0.0, 0.0, -8.718743086485969]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-1.7137666526922533, 1.1358800766324049, 0.08268488211087159, 0.7999598750311686]\n",
+ " B: [-1.1669696745288112, -0.04351472671445914, 0.5992401461010018, 0.028912577361687116]\n",
+ " B: [-3.5481649603318184, 0.4490928742123019, 1.0371640968528058, -3.21124287656006]\n",
+ " B: [-1.276578701414564, -0.08287623449031867, -0.6317118623642547, -0.47299559576203803]\n",
+ " B: [-4.955351547203613, -2.6459981607514886, 0.5026315754882429, 4.037519558961317]\n",
+ " B: [-2.3130557250521284, 1.4242375193555785, -1.5228161303749386, 0.05296516521446809]\n",
+ " B: [-1.4353464814836179, 0.25997106791735547, -0.029309860840599063, -0.9958792586507745]\n",
+ " B: [-1.1425731167759967, -0.4967924161613736, -0.03788284697312998, -0.23923944559576807]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [8.907102929629284, 0.0, 0.0, 8.850789942090511]\n",
+ " B: [8.907102929629284, 0.0, 0.0, -8.850789942090511]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-2.946046511363992, -0.9439001466724447, 2.1873638734369836, 1.4155146927582347]\n",
+ " B: [-3.7848309582649415, -2.22832689875391, -0.18756115269295068, -2.885190709282662]\n",
+ " B: [-1.0159875652570234, 0.04172671107403079, -0.15271016054388648, 0.08467125371989566]\n",
+ " B: [-2.0867601165869685, -1.8155383548303043, -0.021995043965926685, -0.24063350631004576]\n",
+ " B: [-4.34790862339958, 3.6266859724946396, -1.8990793068549607, 1.0700261868843775]\n",
+ " B: [-1.1578951917200673, 0.35622580432348594, 0.23734793715600985, 0.3968506117802061]\n",
+ " B: [-1.4421363377447174, 1.0156020669389267, -0.20020339434090184, -0.0907097523285523]\n",
+ " B: [-1.0326405549212787, -0.052475154574424254, 0.03683724780563263, 0.24947122277854633]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [6.294285658794556, 0.0, 0.0, 6.214340830249562]\n",
+ " B: [6.294285658794556, 0.0, 0.0, -6.214340830249562]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-1.06844272609547, -0.2848922847204133, 0.15179083391454987, -0.19330232226393051]\n",
+ " B: [-2.114647837734541, -1.6956804594706658, -0.38950327120442063, 0.6668511518515798]\n",
+ " B: [-1.494217345848325, 0.7529614584695401, -0.5432224448027106, -0.6088053006963738]\n",
+ " B: [-1.3783311635115514, 0.9215501628423943, 0.0395584401371469, -0.2213079833313275]\n",
+ " B: [-1.7816982863175768, 0.5393674002906785, 0.38766524831377364, 1.316528482874748]\n",
+ " B: [-1.659172767477475, 0.17135237894801714, -1.2297516401309854, -0.45956886117628726]\n",
+ " B: [-1.55277617510909, -0.23319042207457166, 1.041131562383322, 0.522284545863997]\n",
+ " B: [-1.5392850154950812, -0.17146823428497893, 0.5423312713893238, -1.022679713122405]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [6.965556009635571, 0.0, 0.0, 6.8934005050751415]\n",
+ " B: [6.965556009635571, 0.0, 0.0, -6.8934005050751415]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-1.0775179795487104, -0.05690318568456522, -0.2919638065794134, 0.269377354945329]\n",
+ " B: [-3.216279237662679, -2.600571207682032, 0.23217633942174215, 1.5898351096286563]\n",
+ " B: [-1.9852997763312183, 1.2696870590322706, -0.6412445999499571, -0.9581833525279955]\n",
+ " B: [-1.9885313318262752, 0.8019078287339996, 1.2060162608136897, 0.9255946577864792]\n",
+ " B: [-1.4288503016026572, 0.2805632486843285, 0.07929023042776773, -0.9780646743628009]\n",
+ " B: [-1.3652585458391595, -0.12810083240879516, 0.7809145290728301, -0.4875382774777694]\n",
+ " B: [-1.8158888731893035, 0.7439741257624499, -1.2924797037897653, -0.2710186621991885]\n",
+ " B: [-1.0534859732711408, -0.3105570364376559, -0.07270924941689365, -0.0900021557927108]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [6.43062328219917, 0.0, 0.0, 6.352394493225528]\n",
+ " B: [6.43062328219917, 0.0, 0.0, -6.352394493225528]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-2.125364788369443, -1.214725294501684, 0.4075454777366224, 1.369497946736289]\n",
+ " B: [-1.1032249572940587, -0.2977536437640783, 0.35819035202044425, 0.012155070594697458]\n",
+ " B: [-2.225917349319406, 1.3039585629995813, -0.8668848261688078, 1.2259326287114942]\n",
+ " B: [-2.717025897056506, -0.9721840017189309, 0.6274004665152297, -2.2457641565164295]\n",
+ " B: [-1.000557419196324, 0.013685057618434337, 0.015873673340379625, 0.025997976872664537]\n",
+ " B: [-1.1652637249339481, 0.20750251779397902, -0.05219673300317853, -0.5586212982154317]\n",
+ " B: [-1.4667402310584912, 0.9160649085291783, -0.533306342231441, -0.16654228923208916]\n",
+ " B: [-1.057152197170161, 0.043451893043520345, 0.0433779317907512, 0.3373441210488047]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [8.156196486154876, 0.0, 0.0, 8.094661272762755]\n",
+ " B: [8.156196486154876, 0.0, 0.0, -8.094661272762755]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-1.4617318374080812, 0.10404421660552193, -0.19476289320497314, -1.0430254938944576]\n",
+ " B: [-2.745518719911882, 2.0283487429720055, -0.01415841484271091, -1.556751090431481]\n",
+ " B: [-1.193795120882441, -0.223211890483827, 0.20666745479885903, 0.5767250694363129]\n",
+ " B: [-1.0771186742980503, 0.29121400254582763, -0.18584437613704033, 0.20209134345899718]\n",
+ " B: [-2.9756813564276348, 0.7747616688600099, 0.31071107817153876, 2.6754219325851647]\n",
+ " B: [-1.8605025819101852, -0.3441559100391822, 0.5570133470539003, 1.4257498722017754]\n",
+ " B: [-3.3546424693401353, -1.4228183303706836, -0.7768040014609222, -2.7614832317390525]\n",
+ " B: [-1.6434022121313414, -1.2081825000896715, 0.0971778056213486, 0.48127159838273986]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [9.631814348202784, 0.0, 0.0, 9.579762399884718]\n",
+ " B: [9.631814348202784, 0.0, 0.0, -9.579762399884718]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-2.4271747113709625, -0.9216752449526319, 0.35006248470601437, 1.9796838331313595]\n",
+ " B: [-1.926574191117535, -0.6155920425308834, -0.36855158619622796, 1.4821957628346814]\n",
+ " B: [-2.809711053334662, 0.053095841327541846, -2.415282611989454, -1.0286238083410733]\n",
+ " B: [-2.069340346061984, 0.0706218659128716, 1.6880494984307581, 0.6539655271821153]\n",
+ " B: [-1.600891859223819, 0.522182956459051, 1.0136801062226364, -0.5124766796267364]\n",
+ " B: [-2.3653602811566903, 0.7359929823506941, 2.003935313635875, 0.19361520696286152]\n",
+ " B: [-4.134587420071929, 0.11270979705086029, -1.0448676862999513, -3.871738776569513]\n",
+ " B: [-1.9299888340679847, 0.04266384438249662, -1.2270255185096508, 1.1033789344263045]\n",
+ "\n",
+ " Input for ABC Process: 'AB->ABBBBBBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [7.383091586636561, 0.0, 0.0, 7.31505580133628]\n",
+ " B: [7.383091586636561, 0.0, 0.0, -7.31505580133628]\n",
+ " 8 Outgoing Particles:\n",
+ " A: [-1.0026822379766207, 0.02425303574920085, -0.0683120173174935, 0.010813366763733786]\n",
+ " B: [-3.2851307251831745, -2.830568076855887, -0.9156122597784988, 0.9703723169846757]\n",
+ " B: [-2.028220232462834, 1.6810294384373135, 0.4923274291375999, -0.21314558638988076]\n",
+ " B: [-1.5191535227395792, -0.17123543395193966, -1.1293131485074372, -0.05619309939470401]\n",
+ " B: [-1.1059696544762567, 0.2375361941082015, -0.40208228112542477, -0.07124094550113935]\n",
+ " B: [-1.371740281577803, -0.2278482692103191, -0.6986437390927988, -0.5845113276468179]\n",
+ " B: [-1.2867512190171768, 0.6015837296464805, -0.16735271525316733, -0.5155761675681034]\n",
+ " B: [-3.166535299839676, 0.6852493820769491, 2.888988731937221, 0.4594814427522358]\n"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "@time inputs = [gen_process_input(process) for _ in 1:1000]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Internal error: stack overflow in type inference of materialize(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(MetagraphOptimization.compute__8fd7c454_6214_11ee_3616_0f2435e477fe), Tuple{Array{MetagraphOptimization.ABCProcessInput, 1}}}).\n",
+ "This might be caused by recursion over very long tuples or argument lists.\n"
+ ]
+ },
+ {
+ "ename": "LoadError",
+ "evalue": "StackOverflowError:",
+ "output_type": "error",
+ "traceback": [
+ "StackOverflowError:",
+ "",
+ "Stacktrace:",
+ " [1] get",
+ " @ ./iddict.jl:102 [inlined]",
+ " [2] in",
+ " @ ./iddict.jl:189 [inlined]",
+ " [3] haskey",
+ " @ ./abstractdict.jl:17 [inlined]",
+ " [4] findall(sig::Type, table::Core.Compiler.CachedMethodTable{Core.Compiler.InternalMethodTable}; limit::Int64)",
+ " @ Core.Compiler ./compiler/methodtable.jl:120",
+ " [5] findall",
+ " @ ./compiler/methodtable.jl:114 [inlined]",
+ " [6] find_matching_methods(argtypes::Vector{Any}, atype::Any, method_table::Core.Compiler.CachedMethodTable{Core.Compiler.InternalMethodTable}, union_split::Int64, max_methods::Int64)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:336",
+ " [7] abstract_call_gf_by_type(interp::Core.Compiler.NativeInterpreter, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, atype::Any, sv::Core.Compiler.InferenceState, max_methods::Int64)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:80",
+ " [8] abstract_call_known(interp::Core.Compiler.NativeInterpreter, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:1949",
+ " [9] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:2020",
+ " [10] abstract_apply(interp::Core.Compiler.NativeInterpreter, argtypes::Vector{Any}, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:1566",
+ " [11] abstract_call_known(interp::Core.Compiler.NativeInterpreter, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:1855",
+ " [12] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Nothing)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:2020",
+ " [13] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:1999",
+ " [14] abstract_eval_statement_expr(interp::Core.Compiler.NativeInterpreter, e::Expr, vtypes::Vector{Core.Compiler.VarState}, sv::Core.Compiler.InferenceState, mi::Nothing)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:2183",
+ " [15] abstract_eval_statement(interp::Core.Compiler.NativeInterpreter, e::Any, vtypes::Vector{Core.Compiler.VarState}, sv::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:2396",
+ " [16] abstract_eval_basic_statement(interp::Core.Compiler.NativeInterpreter, stmt::Any, pc_vartable::Vector{Core.Compiler.VarState}, frame::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:2682",
+ " [17] typeinf_local(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:2867",
+ " [18] typeinf_nocycle(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:2955",
+ " [19] _typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/typeinfer.jl:246",
+ " [20] typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/typeinfer.jl:216",
+ " [21] typeinf_edge(interp::Core.Compiler.NativeInterpreter, method::Method, atype::Any, sparams::Core.SimpleVector, caller::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/typeinfer.jl:932",
+ " [22] abstract_call_method(interp::Core.Compiler.NativeInterpreter, method::Method, sig::Any, sparams::Core.SimpleVector, hardlimit::Bool, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:611",
+ " [23] abstract_call_gf_by_type(interp::Core.Compiler.NativeInterpreter, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, atype::Any, sv::Core.Compiler.InferenceState, max_methods::Int64)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:152",
+ "--- the last 16 lines are repeated 413 more times ---",
+ " [6632] abstract_call_known(interp::Core.Compiler.NativeInterpreter, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:1949",
+ " [6633] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:2020",
+ " [6634] abstract_apply(interp::Core.Compiler.NativeInterpreter, argtypes::Vector{Any}, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:1566",
+ " [6635] abstract_call_known(interp::Core.Compiler.NativeInterpreter, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:1855",
+ " [6636] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Nothing)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:2020",
+ " [6637] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:1999",
+ " [6638] abstract_eval_statement_expr(interp::Core.Compiler.NativeInterpreter, e::Expr, vtypes::Vector{Core.Compiler.VarState}, sv::Core.Compiler.InferenceState, mi::Nothing)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:2183",
+ " [6639] abstract_eval_statement(interp::Core.Compiler.NativeInterpreter, e::Any, vtypes::Vector{Core.Compiler.VarState}, sv::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:2396",
+ " [6640] abstract_eval_basic_statement(interp::Core.Compiler.NativeInterpreter, stmt::Any, pc_vartable::Vector{Core.Compiler.VarState}, frame::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:2658",
+ " [6641] typeinf_local(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:2867",
+ " [6642] typeinf_nocycle(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:2955",
+ " [6643] _typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/typeinfer.jl:246",
+ " [6644] typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/typeinfer.jl:216",
+ " [6645] typeinf_edge(interp::Core.Compiler.NativeInterpreter, method::Method, atype::Any, sparams::Core.SimpleVector, caller::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/typeinfer.jl:932",
+ " [6646] abstract_call_method(interp::Core.Compiler.NativeInterpreter, method::Method, sig::Any, sparams::Core.SimpleVector, hardlimit::Bool, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:611",
+ " [6647] abstract_call_gf_by_type(interp::Core.Compiler.NativeInterpreter, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, atype::Any, sv::Core.Compiler.InferenceState, max_methods::Int64)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:152",
+ " [6648] abstract_call_known(interp::Core.Compiler.NativeInterpreter, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:1949",
+ " [6649] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Nothing)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:2020",
+ " [6650] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:1999",
+ " [6651] abstract_eval_statement_expr(interp::Core.Compiler.NativeInterpreter, e::Expr, vtypes::Vector{Core.Compiler.VarState}, sv::Core.Compiler.InferenceState, mi::Nothing)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:2183",
+ " [6652] abstract_eval_statement(interp::Core.Compiler.NativeInterpreter, e::Any, vtypes::Vector{Core.Compiler.VarState}, sv::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:2396",
+ " [6653] abstract_eval_basic_statement(interp::Core.Compiler.NativeInterpreter, stmt::Any, pc_vartable::Vector{Core.Compiler.VarState}, frame::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:2682",
+ " [6654] typeinf_local(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:2867",
+ " [6655] typeinf_nocycle(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/abstractinterpretation.jl:2955",
+ " [6656] _typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/typeinfer.jl:246",
+ " [6657] typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
+ " @ Core.Compiler ./compiler/typeinfer.jl:216",
+ " [6658] typeinf",
+ " @ ./compiler/typeinfer.jl:12 [inlined]",
+ " [6659] typeinf_type(interp::Core.Compiler.NativeInterpreter, method::Method, atype::Any, sparams::Core.SimpleVector)",
+ " @ Core.Compiler ./compiler/typeinfer.jl:1079",
+ " [6660] return_type(interp::Core.Compiler.NativeInterpreter, t::DataType)",
+ " @ Core.Compiler ./compiler/typeinfer.jl:1140",
+ " [6661] return_type(f::Any, t::DataType)",
+ " @ Core.Compiler ./compiler/typeinfer.jl:1112",
+ " [6662] combine_eltypes(f::Function, args::Tuple{Vector{ABCProcessInput}})",
+ " @ Base.Broadcast ./broadcast.jl:730",
+ " [6663] copy(bc::Base.Broadcast.Broadcasted{Style}) where Style",
+ " @ Base.Broadcast ./broadcast.jl:895",
+ " [6664] materialize(bc::Base.Broadcast.Broadcasted)",
+ " @ Base.Broadcast ./broadcast.jl:873",
+ " [6665] var\"##core#293\"()",
+ " @ Main ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:489",
+ " [6666] var\"##sample#294\"(::Tuple{}, __params::BenchmarkTools.Parameters)",
+ " @ Main ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:495",
+ " [6667] _run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; verbose::Bool, pad::String, kwargs::Base.Pairs{Symbol, Integer, NTuple{4, Symbol}, NamedTuple{(:samples, :evals, :gctrial, :gcsample), Tuple{Int64, Int64, Bool, Bool}}})",
+ " @ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:99",
+ " [6668] #invokelatest#2",
+ " @ ./essentials.jl:821 [inlined]",
+ " [6669] invokelatest",
+ " @ ./essentials.jl:816 [inlined]",
+ " [6670] #run_result#45",
+ " @ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:34 [inlined]",
+ " [6671] run_result",
+ " @ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:34 [inlined]",
+ " [6672] run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; progressid::Nothing, nleaves::Float64, ndone::Float64, kwargs::Base.Pairs{Symbol, Integer, NTuple{5, Symbol}, NamedTuple{(:verbose, :samples, :evals, :gctrial, :gcsample), Tuple{Bool, Int64, Int64, Bool, Bool}}})",
+ " @ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:117",
+ " [6673] run (repeats 2 times)",
+ " @ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:117 [inlined]",
+ " [6674] #warmup#54",
+ " @ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:169 [inlined]",
+ " [6675] warmup(item::BenchmarkTools.Benchmark)",
+ " @ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:168"
+ ]
+ }
+ ],
+ "source": [
+ "using BenchmarkTools\n",
+ "@benchmark compute_AB_AB7.(inputs)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Julia 1.9.3",
+ "language": "julia",
+ "name": "julia-1.9"
+ },
+ "language_info": {
+ "file_extension": ".jl",
+ "mimetype": "application/julia",
+ "name": "julia",
+ "version": "1.9.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/notebooks/abc_model_showcase.ipynb b/notebooks/abc_model_showcase.ipynb
new file mode 100644
index 0000000..cd72498
--- /dev/null
+++ b/notebooks/abc_model_showcase.ipynb
@@ -0,0 +1,409 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "20768e45-df62-4638-ba33-b0ccf239f1aa",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "using Revise\n",
+ "using MetagraphOptimization\n",
+ "using BenchmarkTools"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "ff5f4a49",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Found 1 NUMA nodes\n",
+ "CUDA is non-functional\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "Machine(MetagraphOptimization.AbstractDevice[MetagraphOptimization.NumaNode(0x0000, 0x0001, MetagraphOptimization.LocalVariables(), -1.0, UUID(\"a89974f6-6212-11ee-0866-0f591a3b69ea\"))], [-1.0;;])"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Get our machine's info\n",
+ "machine = get_machine_info()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "9df482a4-ca44-44c5-9ea7-7a2977d529be",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "ABCModel()"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Create a model identifier\n",
+ "model = ABCModel()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "30b16872-07f7-4d47-8ff8-8c3a849c9d4e",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "ABC Process: 'AB->ABBB'"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Create a process in our model\n",
+ "process_str = \"AB->ABBB\"\n",
+ "process = parse_process(process_str, model)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "043bd9e2-f89a-4362-885a-8c89d4cdd76f",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Total: 280, ComputeTaskP"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "Graph:\n",
+ " Nodes: \n",
+ " Edges: 385\n",
+ " Total Compute Effort: 1075.0\n",
+ " Total Data Transfer: 10944.0\n",
+ " Total Compute Intensity: 0.09822733918128655\n"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ ": 6, ComputeTaskU: 6, \n",
+ " ComputeTaskV: 64, ComputeTaskSum: 1, ComputeTaskS2: 24, \n",
+ " ComputeTaskS1: 36, DataTask: 143"
+ ]
+ }
+ ],
+ "source": [
+ "# Read the graph (of the same process) from a file\n",
+ "graph = parse_dag(\"../input/$process_str.txt\", model)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "02f01ad3-fd10-48d5-a0e0-c03dc83c80a4",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Input for ABC Process: 'AB->ABBB':\n",
+ " 2 Incoming particles:\n",
+ " A: [5.77986599979293, 0.0, 0.0, 5.692701553354288]\n",
+ " B: [5.77986599979293, 0.0, 0.0, -5.692701553354288]\n",
+ " 4 Outgoing Particles:\n",
+ " A: [-3.8835293143673746, -1.4292027910861678, 2.8576090179942106, 1.968057422378813]\n",
+ " B: [-1.1554024905063585, -0.1464656500147254, -0.2082400426692148, 0.5197487980391896]\n",
+ " B: [-2.849749730594798, -1.0177034035100576, -2.464951858896686, -0.09677625137882176]\n",
+ " B: [-3.6710504641173287, 2.5933718446109513, -0.1844171164283155, -2.391029969039186]\n"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Generate some random input data for our process\n",
+ "input_data = gen_process_input(process)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "083fb1be-ce2a-47f9-afb9-60a6fdfaed0b",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "compute__af4450a2_6212_11ee_2601_cde7cf2aedc1 (generic function with 1 method)"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Get the function computing the result of the process from a ProcessInput\n",
+ "AB_AB3_compute = get_compute_function(graph, process, machine)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "a40c9500-8f79-4f04-b3c5-59b72a6b7ba9",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "-1.8924431710735022e-13"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Actually compute a result using the generated function and the input data\n",
+ "result = AB_AB3_compute(input_data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "80c70010",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "include(\"../examples/profiling_utilities.jl\")\n",
+ "\n",
+ "# We can also mute the graph by applying some operations to it\n",
+ "reduce_all!(graph)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "5b192b44",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# The result should be the same as before (we can use execute to save having to generate the function ourselves)\n",
+ "@assert result ≈ execute(graph, process, machine, input_data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "9b2f4a3f",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "1000-element Vector{Float64}:\n",
+ " -2.1491995259940396e-11\n",
+ " -1.04995646459455e-11\n",
+ " 5.821760691187782e-15\n",
+ " -6.556969485683705e-14\n",
+ " -1.3588086164373753e-14\n",
+ " -1.8789662441593694e-13\n",
+ " -2.131973301835892e-13\n",
+ " -5.3359759072004825e-12\n",
+ " -9.053914191490223e-13\n",
+ " -5.61107901706923e-13\n",
+ " -5.063492275603428e-11\n",
+ " 2.9168508985811397e-15\n",
+ " -1.6420151378194157e-13\n",
+ " ⋮\n",
+ " 1.0931677247833436e-13\n",
+ " -7.704755306462797e-16\n",
+ " -1.8385907037491397e-12\n",
+ " -6.036215596560059e-14\n",
+ " -9.98872401400362e-12\n",
+ " 3.4861755637292935e-13\n",
+ " -1.1051119822969222e-10\n",
+ " -2.496572513216201e-12\n",
+ " -3.8682427847201926e-11\n",
+ " 7.904149696653438e-15\n",
+ " -7.606811743178716e-11\n",
+ " -5.100594937480292e-13"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Now we can generate a function and use it on lots of inputs\n",
+ "inputs = [gen_process_input(process) for _ in 1:1000]\n",
+ "AB_AB3_reduced_compute = get_compute_function(graph, process, machine)\n",
+ "\n",
+ "results = AB_AB3_reduced_compute.(inputs)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "d43e4ff0",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "BenchmarkTools.Trial: 879 samples with 1 evaluation.\n",
+ " Range \u001b[90m(\u001b[39m\u001b[36m\u001b[1mmin\u001b[22m\u001b[39m … \u001b[35mmax\u001b[39m\u001b[90m): \u001b[39m\u001b[36m\u001b[1m4.567 ms\u001b[22m\u001b[39m … \u001b[35m14.334 ms\u001b[39m \u001b[90m┊\u001b[39m GC \u001b[90m(\u001b[39mmin … max\u001b[90m): \u001b[39m0.00% … 54.51%\n",
+ " Time \u001b[90m(\u001b[39m\u001b[34m\u001b[1mmedian\u001b[22m\u001b[39m\u001b[90m): \u001b[39m\u001b[34m\u001b[1m4.998 ms \u001b[22m\u001b[39m\u001b[90m┊\u001b[39m GC \u001b[90m(\u001b[39mmedian\u001b[90m): \u001b[39m0.00%\n",
+ " Time \u001b[90m(\u001b[39m\u001b[32m\u001b[1mmean\u001b[22m\u001b[39m ± \u001b[32mσ\u001b[39m\u001b[90m): \u001b[39m\u001b[32m\u001b[1m5.686 ms\u001b[22m\u001b[39m ± \u001b[32m 1.414 ms\u001b[39m \u001b[90m┊\u001b[39m GC \u001b[90m(\u001b[39mmean ± σ\u001b[90m): \u001b[39m9.09% ± 14.49%\n",
+ "\n",
+ " \u001b[39m \u001b[39m \u001b[39m▃\u001b[39m▇\u001b[39m█\u001b[34m▅\u001b[39m\u001b[39m▄\u001b[39m▁\u001b[39m \u001b[39m▁\u001b[39m \u001b[39m \u001b[32m \u001b[39m\u001b[39m \u001b[39m▁\u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m▁\u001b[39m▁\u001b[39m \u001b[39m▁\u001b[39m▁\u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \n",
+ " \u001b[39m▆\u001b[39m█\u001b[39m█\u001b[39m█\u001b[39m█\u001b[34m█\u001b[39m\u001b[39m█\u001b[39m█\u001b[39m▇\u001b[39m█\u001b[39m▇\u001b[39m▇\u001b[32m█\u001b[39m\u001b[39m▆\u001b[39m█\u001b[39m█\u001b[39m▆\u001b[39m▆\u001b[39m▇\u001b[39m▅\u001b[39m▅\u001b[39m▄\u001b[39m▁\u001b[39m▄\u001b[39m▅\u001b[39m▅\u001b[39m▆\u001b[39m▅\u001b[39m▅\u001b[39m▄\u001b[39m▁\u001b[39m▄\u001b[39m▄\u001b[39m▁\u001b[39m▅\u001b[39m▄\u001b[39m▄\u001b[39m▆\u001b[39m▇\u001b[39m█\u001b[39m█\u001b[39m█\u001b[39m█\u001b[39m█\u001b[39m█\u001b[39m▄\u001b[39m▅\u001b[39m▆\u001b[39m▅\u001b[39m▅\u001b[39m▅\u001b[39m▁\u001b[39m▅\u001b[39m▄\u001b[39m▄\u001b[39m▅\u001b[39m▁\u001b[39m▄\u001b[39m \u001b[39m▇\n",
+ " 4.57 ms\u001b[90m \u001b[39m\u001b[90mHistogram: \u001b[39m\u001b[90m\u001b[1mlog(\u001b[22m\u001b[39m\u001b[90mfrequency\u001b[39m\u001b[90m\u001b[1m)\u001b[22m\u001b[39m\u001b[90m by time\u001b[39m 10 ms \u001b[0m\u001b[1m<\u001b[22m\n",
+ "\n",
+ " Memory estimate\u001b[90m: \u001b[39m\u001b[33m6.17 MiB\u001b[39m, allocs estimate\u001b[90m: \u001b[39m\u001b[33m143006\u001b[39m."
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "@benchmark results = AB_AB3_compute.($inputs)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "e18d9546",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "BenchmarkTools.Trial: 1089 samples with 1 evaluation.\n",
+ " Range \u001b[90m(\u001b[39m\u001b[36m\u001b[1mmin\u001b[22m\u001b[39m … \u001b[35mmax\u001b[39m\u001b[90m): \u001b[39m\u001b[36m\u001b[1m3.637 ms\u001b[22m\u001b[39m … \u001b[35m10.921 ms\u001b[39m \u001b[90m┊\u001b[39m GC \u001b[90m(\u001b[39mmin … max\u001b[90m): \u001b[39m 0.00% … 59.52%\n",
+ " Time \u001b[90m(\u001b[39m\u001b[34m\u001b[1mmedian\u001b[22m\u001b[39m\u001b[90m): \u001b[39m\u001b[34m\u001b[1m4.098 ms \u001b[22m\u001b[39m\u001b[90m┊\u001b[39m GC \u001b[90m(\u001b[39mmedian\u001b[90m): \u001b[39m 0.00%\n",
+ " Time \u001b[90m(\u001b[39m\u001b[32m\u001b[1mmean\u001b[22m\u001b[39m ± \u001b[32mσ\u001b[39m\u001b[90m): \u001b[39m\u001b[32m\u001b[1m4.587 ms\u001b[22m\u001b[39m ± \u001b[32m 1.334 ms\u001b[39m \u001b[90m┊\u001b[39m GC \u001b[90m(\u001b[39mmean ± σ\u001b[90m): \u001b[39m10.21% ± 15.77%\n",
+ "\n",
+ " \u001b[39m \u001b[39m▂\u001b[39m▆\u001b[39m▆\u001b[39m▇\u001b[34m█\u001b[39m\u001b[39m▆\u001b[39m▂\u001b[39m \u001b[39m \u001b[39m \u001b[32m \u001b[39m\u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m▁\u001b[39m▁\u001b[39m \u001b[39m▁\u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \n",
+ " \u001b[39m▆\u001b[39m█\u001b[39m█\u001b[39m█\u001b[39m█\u001b[34m█\u001b[39m\u001b[39m█\u001b[39m█\u001b[39m▇\u001b[39m█\u001b[39m▇\u001b[32m▆\u001b[39m\u001b[39m▅\u001b[39m▇\u001b[39m▅\u001b[39m▅\u001b[39m▅\u001b[39m▄\u001b[39m▆\u001b[39m▄\u001b[39m▅\u001b[39m▅\u001b[39m▅\u001b[39m▅\u001b[39m▆\u001b[39m▄\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▄\u001b[39m▆\u001b[39m▆\u001b[39m▆\u001b[39m█\u001b[39m█\u001b[39m▇\u001b[39m█\u001b[39m█\u001b[39m▆\u001b[39m▆\u001b[39m▆\u001b[39m█\u001b[39m█\u001b[39m▇\u001b[39m▆\u001b[39m▄\u001b[39m▄\u001b[39m \u001b[39m█\n",
+ " 3.64 ms\u001b[90m \u001b[39m\u001b[90mHistogram: \u001b[39m\u001b[90m\u001b[1mlog(\u001b[22m\u001b[39m\u001b[90mfrequency\u001b[39m\u001b[90m\u001b[1m)\u001b[22m\u001b[39m\u001b[90m by time\u001b[39m 8.78 ms \u001b[0m\u001b[1m<\u001b[22m\n",
+ "\n",
+ " Memory estimate\u001b[90m: \u001b[39m\u001b[33m5.26 MiB\u001b[39m, allocs estimate\u001b[90m: \u001b[39m\u001b[33m123006\u001b[39m."
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "@benchmark results = AB_AB3_reduced_compute.($inputs)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "13efed12-3547-400b-a7a2-5dfae9a973a2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Set a different caching strategy\n",
+ "MetagraphOptimization.set_cache_strategy(machine.devices[1], MetagraphOptimization.Dictionary())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "ef62716b-a219-4f6e-9150-f984d3734839",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "BenchmarkTools.Trial: 331 samples with 1 evaluation.\n",
+ " Range \u001b[90m(\u001b[39m\u001b[36m\u001b[1mmin\u001b[22m\u001b[39m … \u001b[35mmax\u001b[39m\u001b[90m): \u001b[39m\u001b[36m\u001b[1m12.148 ms\u001b[22m\u001b[39m … \u001b[35m24.164 ms\u001b[39m \u001b[90m┊\u001b[39m GC \u001b[90m(\u001b[39mmin … max\u001b[90m): \u001b[39m 0.00% … 13.35%\n",
+ " Time \u001b[90m(\u001b[39m\u001b[34m\u001b[1mmedian\u001b[22m\u001b[39m\u001b[90m): \u001b[39m\u001b[34m\u001b[1m15.412 ms \u001b[22m\u001b[39m\u001b[90m┊\u001b[39m GC \u001b[90m(\u001b[39mmedian\u001b[90m): \u001b[39m17.47%\n",
+ " Time \u001b[90m(\u001b[39m\u001b[32m\u001b[1mmean\u001b[22m\u001b[39m ± \u001b[32mσ\u001b[39m\u001b[90m): \u001b[39m\u001b[32m\u001b[1m15.117 ms\u001b[22m\u001b[39m ± \u001b[32m 2.194 ms\u001b[39m \u001b[90m┊\u001b[39m GC \u001b[90m(\u001b[39mmean ± σ\u001b[90m): \u001b[39m12.31% ± 8.95%\n",
+ "\n",
+ " \u001b[39m \u001b[39m▄\u001b[39m█\u001b[39m▄\u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[32m▄\u001b[39m\u001b[39m▄\u001b[34m▂\u001b[39m\u001b[39m \u001b[39m▂\u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \n",
+ " \u001b[39m▅\u001b[39m█\u001b[39m█\u001b[39m█\u001b[39m▅\u001b[39m▃\u001b[39m▃\u001b[39m▂\u001b[39m▃\u001b[39m▂\u001b[39m▅\u001b[39m▂\u001b[39m▃\u001b[39m▁\u001b[39m▂\u001b[39m▂\u001b[39m▂\u001b[39m▃\u001b[39m▂\u001b[39m▃\u001b[32m█\u001b[39m\u001b[39m█\u001b[34m█\u001b[39m\u001b[39m▇\u001b[39m█\u001b[39m▄\u001b[39m▆\u001b[39m▄\u001b[39m▆\u001b[39m▄\u001b[39m▄\u001b[39m▆\u001b[39m▅\u001b[39m▄\u001b[39m▃\u001b[39m▄\u001b[39m▂\u001b[39m▂\u001b[39m▃\u001b[39m▃\u001b[39m▄\u001b[39m▃\u001b[39m▂\u001b[39m▂\u001b[39m▁\u001b[39m▂\u001b[39m▂\u001b[39m▃\u001b[39m▂\u001b[39m▂\u001b[39m▁\u001b[39m▂\u001b[39m▁\u001b[39m▃\u001b[39m▃\u001b[39m▂\u001b[39m▂\u001b[39m▁\u001b[39m▂\u001b[39m \u001b[39m▃\n",
+ " 12.1 ms\u001b[90m Histogram: frequency by time\u001b[39m 21 ms \u001b[0m\u001b[1m<\u001b[22m\n",
+ "\n",
+ " Memory estimate\u001b[90m: \u001b[39m\u001b[33m27.46 MiB\u001b[39m, allocs estimate\u001b[90m: \u001b[39m\u001b[33m118013\u001b[39m."
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# ... and bench again\n",
+ "AB_AB3_reduced_dict_compute = get_compute_function(graph, process, machine)\n",
+ "@benchmark results = AB_AB3_reduced_dict_compute.($inputs)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5461ffd4-6a0e-4f1f-b1f1-3a2854a8ae88",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Julia 1.9.3",
+ "language": "julia",
+ "name": "julia-1.9"
+ },
+ "language_info": {
+ "file_extension": ".jl",
+ "mimetype": "application/julia",
+ "name": "julia",
+ "version": "1.9.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/notebooks/profiling.ipynb b/notebooks/profiling.ipynb
new file mode 100644
index 0000000..f782184
--- /dev/null
+++ b/notebooks/profiling.ipynb
@@ -0,0 +1,70 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "using Revise; using MetagraphOptimization; using BenchmarkTools; using ProfileView\n",
+ "using Base.Threads\n",
+ "nthreads()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = ABCModel()\n",
+ "process_str = \"AB->ABBBBB\"\n",
+ "process = parse_process(process_str, model)\n",
+ "graph = parse_dag(\"../input/$process_str.txt\", model)\n",
+ "print(graph)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "include(\"../examples/profiling_utilities.jl\")\n",
+ "@ProfileView.profview reduce_all!(graph)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@ProfileView.profview comp_func = get_compute_function(graph, process)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Julia 1.9.3",
+ "language": "julia",
+ "name": "julia-1.9"
+ },
+ "language_info": {
+ "file_extension": ".jl",
+ "mimetype": "application/julia",
+ "name": "julia",
+ "version": "1.9.3"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/scripts/bench_threads.fish b/scripts/bench_threads.fish
index 28df8c3..5e1ce95 100755
--- a/scripts/bench_threads.fish
+++ b/scripts/bench_threads.fish
@@ -6,20 +6,20 @@ julia --project=./examples -t 4 -e 'import Pkg; Pkg.instantiate()'
#for i in $(seq $minthreads $maxthreads)
# printf "(AB->AB, $i) "
-# julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("input/AB->AB.txt"))'
+# julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_dag("input/AB->AB.txt"), ABCModel())'
#end
#for i in $(seq $minthreads $maxthreads)
# printf "(AB->ABBB, $i) "
-# julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("input/AB->ABBB.txt"))'
+# julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_dag("input/AB->ABBB.txt"), ABCModel())'
#end
#for i in $(seq $minthreads $maxthreads)
# printf "(AB->ABBBBB, $i) "
-# julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("input/AB->ABBBBB.txt"))'
+# julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_dag("input/AB->ABBBBB.txt"), ABCModel())'
#end
for i in $(seq $minthreads $maxthreads)
printf "(AB->ABBBBBBB, $i) "
- julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("input/AB->ABBBBBBB.txt"))'
+ julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_dag("input/AB->ABBBBBBB.txt"), ABCModel())'
end
diff --git a/src/MetagraphOptimization.jl b/src/MetagraphOptimization.jl
index 093536a..ac0b1ef 100644
--- a/src/MetagraphOptimization.jl
+++ b/src/MetagraphOptimization.jl
@@ -29,7 +29,7 @@ export children
export compute
export get_properties
export get_exit_node
-export is_valid
+export is_valid, is_scheduled
export Operation
export AppliedOperation
@@ -42,7 +42,6 @@ export can_pop
export reset_graph!
export get_operations
-export parse_abc
export ComputeTaskP
export ComputeTaskS1
export ComputeTaskS2
@@ -51,9 +50,15 @@ export ComputeTaskU
export ComputeTaskSum
export execute
-export gen_particles
+export parse_dag, parse_process
+export gen_process_input
+export get_compute_function
export ParticleValue
-export Particle
+export ParticleA, ParticleB, ParticleC
+export ABCProcessDescription, ABCProcessInput, ABCModel
+
+export Machine
+export get_machine_info
export ==, in, show, isempty, delete!, length
@@ -72,6 +77,7 @@ import Base.insert!
import Base.collect
+include("devices/interface.jl")
include("task/type.jl")
include("node/type.jl")
include("diff/type.jl")
@@ -111,15 +117,34 @@ include("properties/utility.jl")
include("task/create.jl")
include("task/compare.jl")
+include("task/compute.jl")
include("task/print.jl")
include("task/properties.jl")
+include("models/interface.jl")
+include("models/print.jl")
+
include("models/abc/types.jl")
include("models/abc/particle.jl")
include("models/abc/compute.jl")
include("models/abc/create.jl")
include("models/abc/properties.jl")
include("models/abc/parse.jl")
+include("models/abc/print.jl")
+
+include("devices/measure.jl")
+include("devices/detect.jl")
+include("devices/impl.jl")
+
+include("devices/numa/impl.jl")
+include("devices/cuda/impl.jl")
+# can currently not use AMDGPU because of incompatability with the newest rocm drivers
+# include("devices/rocm/impl.jl")
+# oneapi seems also broken for now
+# include("devices/oneapi/impl.jl")
+
+include("scheduler/interface.jl")
+include("scheduler/greedy.jl")
include("code_gen/main.jl")
diff --git a/src/code_gen/main.jl b/src/code_gen/main.jl
index 3bbb5a3..13406bb 100644
--- a/src/code_gen/main.jl
+++ b/src/code_gen/main.jl
@@ -1,126 +1,157 @@
-using DataStructures
-
"""
gen_code(graph::DAG)
-Generate the code for a given graph. The return value is a tuple of:
+Generate the code for a given graph. The return value is a named tuple of:
- `code::Expr`: The julia expression containing the code for the whole graph.
-- `inputSymbols::Dict{String, Symbol}`: A dictionary of symbols mapping the names of the input nodes of the graph to the symbols their inputs should be provided on.
+- `inputSymbols::Dict{String, Vector{Symbol}}`: A dictionary of symbols mapping the names of the input nodes of the graph to the symbols their inputs should be provided on.
- `outputSymbol::Symbol`: The symbol of the final calculated value
See also: [`execute`](@ref)
"""
-function gen_code(graph::DAG)
- code = Vector{Expr}()
- sizehint!(code, length(graph.nodes))
+function gen_code(graph::DAG, machine::Machine)
+ sched = schedule_dag(GreedyScheduler(), graph, machine)
- nodeQueue = PriorityQueue{Node, Int}()
- inputSyms = Dict{String, Symbol}()
+ codeAcc = Vector{Expr}()
+ sizehint!(codeAcc, length(graph.nodes))
- # use a priority equal to the number of unseen children -> 0 are nodes that can be added
+ for node in sched
+ # TODO: this is kind of ugly, should init nodes be scheduled differently from the rest?
+ if (node isa DataTaskNode && length(node.children) == 0)
+ push!(codeAcc, get_init_expression(node, entry_device(machine)))
+ continue
+ end
+ push!(codeAcc, get_expression(node))
+ end
+
+ # get inSymbols
+ inputSyms = Dict{String, Vector{Symbol}}()
for node in get_entry_nodes(graph)
- enqueue!(nodeQueue, node => 0)
- push!(inputSyms, node.name => Symbol("data_$(to_var_name(node.id))_in"))
+ if !haskey(inputSyms, node.name)
+ inputSyms[node.name] = Vector{Symbol}()
+ end
+
+ push!(inputSyms[node.name], Symbol("$(to_var_name(node.id))_in"))
end
- node = nothing
- while !isempty(nodeQueue)
- @assert peek(nodeQueue)[2] == 0
- node = dequeue!(nodeQueue)
+ # get outSymbol
+ outSym = Symbol(to_var_name(get_exit_node(graph).id))
- push!(code, get_expression(node))
- for parent in node.parents
- # reduce the priority of all parents by one
- if (!haskey(nodeQueue, parent))
- enqueue!(nodeQueue, parent => length(parent.children) - 1)
- else
- nodeQueue[parent] = nodeQueue[parent] - 1
- end
+ return (code = Expr(:block, codeAcc...), inputSymbols = inputSyms, outputSymbol = outSym)
+end
+
+function gen_cache_init_code(machine::Machine)
+ initializeCaches = Vector{Expr}()
+
+ for device in machine.devices
+ push!(initializeCaches, gen_cache_init_code(device))
+ end
+
+ return Expr(:block, initializeCaches...)
+end
+
+function gen_input_assignment_code(
+ inputSymbols::Dict{String, Vector{Symbol}},
+ processDescription::AbstractProcessDescription,
+ machine::Machine,
+ processInputSymbol::Symbol = :input,
+)
+ @assert length(inputSymbols) >=
+ sum(values(in_particles(processDescription))) + sum(values(out_particles(processDescription))) "Number of input Symbols is smaller than the number of particles in the process description"
+
+ assignInputs = Vector{Expr}()
+ for (name, symbols) in inputSymbols
+ type = type_from_name(name)
+ index = parse(Int, name[2:end])
+
+ p = nothing
+
+ if (index > in_particles(processDescription)[type])
+ index -= in_particles(processDescription)[type]
+ @assert index <= out_particles(processDescription)[type] "Too few particles of type $type in input particles for this process"
+
+ p = "filter(x -> typeof(x) <: $type, out_particles($(processInputSymbol)))[$(index)]"
+ else
+ p = "filter(x -> typeof(x) <: $type, in_particles($(processInputSymbol)))[$(index)]"
+ end
+
+ for symbol in symbols
+ # TODO: how to get the "default" cpu device?
+ device = entry_device(machine)
+ evalExpr = eval(gen_access_expr(device, symbol))
+ push!(assignInputs, Meta.parse("$(evalExpr) = ParticleValue($p, 1.0)"))
end
end
- # node is now the last node we looked at -> the output node
- outSym = Symbol("data_$(to_var_name(node.id))")
+ return Expr(:block, assignInputs...)
+end
- return (
- code = Expr(:block, code...),
- inputSymbols = inputSyms,
- outputSymbol = outSym,
+"""
+ get_compute_function(graph::DAG, process::AbstractProcessDescription, machine::Machine)
+
+Return a function of signature `compute_(input::AbstractProcessInput)`, which will return the result of the DAG computation on the given input.
+"""
+function get_compute_function(graph::DAG, process::AbstractProcessDescription, machine::Machine)
+ (code, inputSymbols, outputSymbol) = gen_code(graph, machine)
+
+ initCaches = gen_cache_init_code(machine)
+ assignInputs = gen_input_assignment_code(inputSymbols, process, machine, :input)
+
+ functionId = to_var_name(UUIDs.uuid1(rng[1]))
+ resSym = eval(gen_access_expr(entry_device(machine), outputSymbol))
+ expr = Meta.parse(
+ "function compute_$(functionId)(input::AbstractProcessInput) $initCaches; $assignInputs; $code; return $resSym; end",
)
+ func = eval(expr)
+
+ return func
end
"""
- execute(generated_code, input::Dict{ParticleType, Vector{Particle}})
+ execute(graph::DAG, process::AbstractProcessDescription, machine::Machine, input::AbstractProcessInput)
-Execute the given `generated_code` (as returned by [`gen_code`](@ref)) on the given input particles.
+Execute the code of the given `graph` on the given input particles.
+
+This is essentially shorthand for
+ ```julia
+ compute_graph = get_compute_function(graph, process)
+ result = compute_graph(particles)
+ ```
+
+See also: [`parse_dag`](@ref), [`parse_process`](@ref), [`gen_process_input`](@ref)
"""
-function execute(generated_code, input::Dict{ParticleType, Vector{Particle}})
- (code, inputSymbols, outputSymbol) = generated_code
+function execute(graph::DAG, process::AbstractProcessDescription, machine::Machine, input::AbstractProcessInput)
+ (code, inputSymbols, outputSymbol) = gen_code(graph, machine)
- assignInputs = Vector{Expr}()
- for (name, symbol) in inputSymbols
- type = nothing
- if startswith(name, "A")
- type = A
- elseif startswith(name, "B")
- type = B
- else
- type = C
+ initCaches = gen_cache_init_code(machine)
+ assignInputs = gen_input_assignment_code(inputSymbols, process, machine, :input)
+
+
+ functionId = to_var_name(UUIDs.uuid1(rng[1]))
+ resSym = eval(gen_access_expr(entry_device(machine), outputSymbol))
+ expr = Meta.parse(
+ "function compute_$(functionId)(input::AbstractProcessInput) $initCaches; $assignInputs; $code; return $resSym; end",
+ )
+ func = eval(expr)
+
+ result = 0
+ try
+ result = @eval $func($input)
+ catch e
+ println("Error while evaluating: $e")
+
+ # if we find a uuid in the exception we can color it in so it's easier to spot
+ uuidRegex = r"[0-9a-f]{8}_[0-9a-f]{4}_[0-9a-f]{4}_[0-9a-f]{4}_[0-9a-f]{12}"
+ m = match(uuidRegex, string(e))
+
+ functionStr = string(expr)
+ if (isa(m, RegexMatch))
+ functionStr = replace(functionStr, m.match => "\033[31m$(m.match)\033[0m")
end
- index = parse(Int, name[2:end])
- push!(
- assignInputs,
- Meta.parse(
- "$(symbol) = ParticleValue(Particle($(input[type][index]).P0, $(input[type][index]).P1, $(input[type][index]).P2, $(input[type][index]).P3, $(type)), 1.0)",
- ),
- )
+ println("Function:\n$functionStr")
+ @assert false
end
- assignInputs = Expr(:block, assignInputs...)
- eval(assignInputs)
- eval(code)
-
- eval(Meta.parse("result = $outputSymbol"))
- return result
-end
-
-"""
- execute(graph::DAG, input::Dict{ParticleType, Vector{Particle}})
-
-Execute the given `generated_code` (as returned by [`gen_code`](@ref)) on the given input particles.
-The input particles should be sorted correctly into the dictionary to their according [`ParticleType`](@ref)s.
-
-See also: [`gen_particles`](@ref)
-"""
-function execute(graph::DAG, input::Dict{ParticleType, Vector{Particle}})
- (code, inputSymbols, outputSymbol) = gen_code(graph)
-
- assignInputs = Vector{Expr}()
- for (name, symbol) in inputSymbols
- type = nothing
- if startswith(name, "A")
- type = A
- elseif startswith(name, "B")
- type = B
- else
- type = C
- end
- index = parse(Int, name[2:end])
-
- push!(
- assignInputs,
- Meta.parse(
- "$(symbol) = ParticleValue(Particle($(input[type][index]).P0, $(input[type][index]).P1, $(input[type][index]).P2, $(input[type][index]).P3, $(type)), 1.0)",
- ),
- )
- end
-
- assignInputs = Expr(:block, assignInputs...)
- eval(assignInputs)
- eval(code)
-
- eval(Meta.parse("result = $outputSymbol"))
return result
end
diff --git a/src/devices/cuda/impl.jl b/src/devices/cuda/impl.jl
new file mode 100644
index 0000000..30fe231
--- /dev/null
+++ b/src/devices/cuda/impl.jl
@@ -0,0 +1,53 @@
+using CUDA
+
+"""
+ CUDAGPU <: AbstractGPU
+
+Representation of a specific CUDA GPU that code can run on. Implements the [`AbstractDevice`](@ref) interface.
+"""
+mutable struct CUDAGPU <: AbstractGPU
+ device::Any # TODO: what's the cuda device type?
+ cacheStrategy::CacheStrategy
+ FLOPS::Float64
+end
+
+push!(DEVICE_TYPES, CUDAGPU)
+
+CACHE_STRATEGIES[CUDAGPU] = [LocalVariables()]
+
+default_strategy(::Type{T}) where {T <: CUDAGPU} = LocalVariables()
+
+function measure_device!(device::CUDAGPU; verbose::Bool)
+ if verbose
+ println("Measuring CUDA GPU $(device.device)")
+ end
+
+ # TODO implement
+ return nothing
+end
+
+"""
+ get_devices(deviceType::Type{T}; verbose::Bool) where {T <: CUDAGPU}
+
+Return a Vector of [`CUDAGPU`](@ref)s available on the current machine. If `verbose` is true, print some additional information.
+"""
+function get_devices(deviceType::Type{T}; verbose::Bool = false) where {T <: CUDAGPU}
+ devices = Vector{AbstractDevice}()
+
+ if !CUDA.functional()
+ if verbose
+ println("CUDA is non-functional")
+ end
+ return devices
+ end
+
+ CUDADevices = CUDA.devices()
+ if verbose
+ println("Found $(length(CUDADevices)) CUDA devices")
+ end
+ for device in CUDADevices
+ push!(devices, CUDAGPU(device, default_strategy(CUDAGPU), -1))
+ end
+
+ return devices
+end
diff --git a/src/devices/detect.jl b/src/devices/detect.jl
new file mode 100644
index 0000000..6de0f48
--- /dev/null
+++ b/src/devices/detect.jl
@@ -0,0 +1,23 @@
+
+"""
+ get_machine_info(verbose::Bool)
+
+Return the [`Machine`](@ref) currently running on. The parameter `verbose` defaults to true when interactive.
+"""
+function get_machine_info(; verbose::Bool = Base.is_interactive)
+ devices = Vector{AbstractDevice}()
+
+ for device in device_types()
+ devs = get_devices(device, verbose = verbose)
+ for dev in devs
+ push!(devices, dev)
+ end
+ end
+
+ noDevices = length(devices)
+ @assert noDevices > 0 "No devices were found, but at least one NUMA node should always be available!"
+
+ transferRates = Matrix{Float64}(undef, noDevices, noDevices)
+ fill!(transferRates, -1)
+ return Machine(devices, transferRates)
+end
diff --git a/src/devices/impl.jl b/src/devices/impl.jl
new file mode 100644
index 0000000..aaa60aa
--- /dev/null
+++ b/src/devices/impl.jl
@@ -0,0 +1,52 @@
+"""
+ device_types()
+
+Return a vector of available and implemented device types.
+
+See also: [`DEVICE_TYPES`](@ref)
+"""
+function device_types()
+ return DEVICE_TYPES
+end
+
+"""
+ entry_device(machine::Machine)
+
+Return the "entry" device, i.e., the device that starts CPU threads and GPU kernels, and takes input values and returns the output value.
+"""
+function entry_device(machine::Machine)
+ return machine.devices[1]
+end
+
+"""
+ strategies(t::Type{T}) where {T <: AbstractDevice}
+
+Return a vector of available [`CacheStrategy`](@ref)s for the given [`AbstractDevice`](@ref).
+The caching strategies are used in code generation.
+"""
+function strategies(t::Type{T}) where {T <: AbstractDevice}
+ if !haskey(CACHE_STRATEGIES, t)
+ error("Trying to get strategies for $T, but it has no strategies defined!")
+ end
+
+ return CACHE_STRATEGIES[t]
+end
+
+"""
+ cache_strategy(device::AbstractDevice)
+
+Returns the cache strategy set for this device.
+"""
+function cache_strategy(device::AbstractDevice)
+ return device.cacheStrategy
+end
+
+"""
+ set_cache_strategy(device::AbstractDevice, cacheStrategy::CacheStrategy)
+
+Sets the device's cache strategy. After this call, [`cache_strategy`](@ref) should return `cacheStrategy` on the given device.
+"""
+function set_cache_strategy(device::AbstractDevice, cacheStrategy::CacheStrategy)
+ device.cacheStrategy = cacheStrategy
+ return nothing
+end
diff --git a/src/devices/interface.jl b/src/devices/interface.jl
new file mode 100644
index 0000000..bb65297
--- /dev/null
+++ b/src/devices/interface.jl
@@ -0,0 +1,108 @@
+"""
+ AbstractDevice
+
+Abstract base type for every device, like GPUs, CPUs or any other compute devices.
+Every implementation needs to implement various functions and needs a member `cacheStrategy`.
+"""
+abstract type AbstractDevice end
+
+abstract type AbstractCPU <: AbstractDevice end
+
+abstract type AbstractGPU <: AbstractDevice end
+
+"""
+ Machine
+
+A representation of a machine to execute on. Contains information about its architecture (CPUs, GPUs, maybe more). This representation can be used to make a more accurate cost prediction of a [`DAG`](@ref) state.
+
+See also: [`Scheduler`](@ref)
+"""
+struct Machine
+ devices::Vector{AbstractDevice}
+
+ transferRates::Matrix{Float64}
+end
+
+"""
+ CacheStrategy
+
+Abstract base type for caching strategies.
+
+See also: [`strategies`](@ref)
+"""
+abstract type CacheStrategy end
+
+"""
+ LocalVariables <: CacheStrategy
+
+A caching strategy relying solely on local variables for every input and output.
+
+Implements the [`CacheStrategy`](@ref) interface.
+"""
+struct LocalVariables <: CacheStrategy end
+
+"""
+ Dictionary <: CacheStrategy
+
+A caching strategy relying on a dictionary of Symbols to store every input and output.
+
+Implements the [`CacheStrategy`](@ref) interface.
+"""
+struct Dictionary <: CacheStrategy end
+
+"""
+ DEVICE_TYPES::Vector{Type}
+
+Global vector of available and implemented device types. Each implementation of a [`AbstractDevice`](@ref) should add its concrete type to this vector.
+
+See also: [`device_types`](@ref), [`get_devices`](@ref)
+"""
+DEVICE_TYPES = Vector{Type}()
+
+"""
+ CACHE_STRATEGIES::Dict{Type{AbstractDevice}, Symbol}
+
+Global dictionary of available caching strategies per device. Each implementation of [`AbstractDevice`](@ref) should add its available strategies to the dictionary.
+
+See also: [`strategies`](@ref)
+"""
+CACHE_STRATEGIES = Dict{Type, Vector{CacheStrategy}}()
+
+"""
+ default_strategy(deviceType::Type{T}) where {T <: AbstractDevice}
+
+Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref). Returns the default [`CacheStrategy`](@ref) to use on the given device type.
+See also: [`cache_strategy`](@ref), [`set_cache_strategy`](@ref)
+"""
+function default_strategy end
+
+"""
+ get_devices(t::Type{T}; verbose::Bool) where {T <: AbstractDevice}
+
+Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref). Returns a `Vector{Type}` of the devices for the given [`AbstractDevice`](@ref) Type available on the current machine.
+"""
+function get_devices end
+
+"""
+ measure_device!(device::AbstractDevice; verbose::Bool)
+
+Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref). Measures the compute speed of the given device and writes into it.
+"""
+function measure_device! end
+
+"""
+ gen_cache_init_code(device::AbstractDevice)
+
+Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref) and at least one [`CacheStrategy`](@ref). Returns an `Expr` initializing this device's variable cache.
+
+The strategy is a symbol
+"""
+function gen_cache_init_code end
+
+"""
+ gen_access_expr(device::AbstractDevice, symbol::Symbol)
+
+Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref) and at least one [`CacheStrategy`](@ref).
+Return an `Expr` or `QuoteNode` accessing the variable identified by [`symbol`].
+"""
+function gen_access_expr end
diff --git a/src/devices/measure.jl b/src/devices/measure.jl
new file mode 100644
index 0000000..b9a8e26
--- /dev/null
+++ b/src/devices/measure.jl
@@ -0,0 +1,22 @@
+"""
+ measure_devices(machine::Machine; verbose::Bool)
+
+Measure FLOPS, RAM, cache sizes and what other properties can be extracted for the devices in the given machine.
+"""
+function measure_devices!(machine::Machine; verbose::Bool = Base.is_interactive())
+ for device in machine.devices
+ measure_device!(device; verbose = verbose)
+ end
+
+ return nothing
+end
+
+"""
+ measure_transfer_rates(machine::Machine; verbose::Bool)
+
+Measure the transfer rates between devices in the machine.
+"""
+function measure_transfer_rates!(machine::Machine; verbose::Bool = Base.is_interactive())
+ # TODO implement
+ return nothing
+end
diff --git a/src/devices/numa/impl.jl b/src/devices/numa/impl.jl
new file mode 100644
index 0000000..01ac8cd
--- /dev/null
+++ b/src/devices/numa/impl.jl
@@ -0,0 +1,96 @@
+using NumaAllocators
+
+"""
+ NumaNode <: AbstractCPU
+
+Representation of a specific CPU that code can run on. Implements the [`AbstractDevice`](@ref) interface.
+"""
+mutable struct NumaNode <: AbstractCPU
+ numaId::UInt16
+ threads::UInt16
+ cacheStrategy::CacheStrategy
+ FLOPS::Float64
+ id::UUID
+end
+
+push!(DEVICE_TYPES, NumaNode)
+
+CACHE_STRATEGIES[NumaNode] = [LocalVariables()]
+
+default_strategy(::Type{T}) where {T <: NumaNode} = LocalVariables()
+
+function measure_device!(device::NumaNode; verbose::Bool)
+ if verbose
+ println("Measuring Numa Node $(device.numaId)")
+ end
+
+ # TODO implement
+ return nothing
+end
+
+"""
+ get_devices(deviceType::Type{T}; verbose::Bool) where {T <: NumaNode}
+
+Return a Vector of [`NumaNode`](@ref)s available on the current machine. If `verbose` is true, print some additional information.
+"""
+function get_devices(deviceType::Type{T}; verbose::Bool = false) where {T <: NumaNode}
+ devices = Vector{AbstractDevice}()
+ noNumaNodes = highest_numa_node()
+
+ if (verbose)
+ println("Found $(noNumaNodes + 1) NUMA nodes")
+ end
+ for i in 0:noNumaNodes
+ push!(devices, NumaNode(i, 1, default_strategy(NumaNode), -1, UUIDs.uuid1(rng[1])))
+ end
+
+ return devices
+end
+
+"""
+ gen_cache_init_code(device::NumaNode)
+
+Generate code for initializing the [`LocalVariables`](@ref) strategy on a [`NumaNode`](@ref).
+"""
+function gen_cache_init_code(device::NumaNode)
+ if typeof(device.cacheStrategy) <: LocalVariables
+ # don't need to initialize anything
+ return Expr(:block)
+ elseif typeof(device.cacheStrategy) <: Dictionary
+ return Meta.parse("cache_$(to_var_name(device.id)) = Dict{Symbol, Any}()")
+ # TODO: sizehint?
+ end
+
+ return error("Unimplemented cache strategy \"$(device.cacheStrategy)\" for device \"$(device)\"")
+end
+
+"""
+ gen_access_expr(device::NumaNode, symbol::Symbol)
+
+Generate code to access the variable designated by `symbol` on a [`NumaNode`](@ref), using the [`CacheStrategy`](@ref) set in the device.
+"""
+function gen_access_expr(device::NumaNode, symbol::Symbol)
+ return _gen_access_expr(device, device.cacheStrategy, symbol)
+end
+
+"""
+ _gen_access_expr(device::NumaNode, ::LocalVariables, symbol::Symbol)
+
+Internal function for dispatch, used in [`gen_access_expr`](@ref).
+"""
+function _gen_access_expr(device::NumaNode, ::LocalVariables, symbol::Symbol)
+ s = Symbol("data_$symbol")
+ quoteNode = Meta.parse(":($s)")
+ return quoteNode
+end
+
+"""
+ _gen_access_expr(device::NumaNode, ::Dictionary, symbol::Symbol)
+
+Internal function for dispatch, used in [`gen_access_expr`](@ref).
+"""
+function _gen_access_expr(device::NumaNode, ::Dictionary, symbol::Symbol)
+ accessStr = ":(cache_$(to_var_name(device.id))[:$symbol])"
+ quoteNode = Meta.parse(accessStr)
+ return quoteNode
+end
diff --git a/src/devices/oneapi/impl.jl b/src/devices/oneapi/impl.jl
new file mode 100644
index 0000000..9c18e5a
--- /dev/null
+++ b/src/devices/oneapi/impl.jl
@@ -0,0 +1,53 @@
+using oneAPI
+
+"""
+ oneAPIGPU <: AbstractGPU
+
+Representation of a specific Intel GPU that code can run on. Implements the [`AbstractDevice`](@ref) interface.
+"""
+mutable struct oneAPIGPU <: AbstractGPU
+ device::Any
+ cacheStrategy::CacheStrategy
+ FLOPS::Float64
+end
+
+push!(DEVICE_TYPES, oneAPIGPU)
+
+CACHE_STRATEGIES[oneAPIGPU] = [LocalVariables()]
+
+default_strategy(::Type{T}) where {T <: oneAPIGPU} = LocalVariables()
+
+function measure_device!(device::oneAPIGPU; verbose::Bool)
+ if verbose
+ println("Measuring oneAPI GPU $(device.device)")
+ end
+
+ # TODO implement
+ return nothing
+end
+
+"""
+ get_devices(deviceType::Type{T}; verbose::Bool = false) where {T <: oneAPIGPU}
+
+Return a Vector of [`oneAPIGPU`](@ref)s available on the current machine. If `verbose` is true, print some additional information.
+"""
+function get_devices(deviceType::Type{T}; verbose::Bool = false) where {T <: oneAPIGPU}
+ devices = Vector{AbstractDevice}()
+
+ if !oneAPI.functional()
+ if verbose
+ println("oneAPI is non-functional")
+ end
+ return devices
+ end
+
+ oneAPIDevices = oneAPI.devices()
+ if verbose
+ println("Found $(length(oneAPIDevices)) oneAPI devices")
+ end
+ for device in oneAPIDevices
+ push!(devices, oneAPIGPU(device, default_strategy(oneAPIGPU), -1))
+ end
+
+ return devices
+end
diff --git a/src/devices/rocm/impl.jl b/src/devices/rocm/impl.jl
new file mode 100644
index 0000000..c0fe5c2
--- /dev/null
+++ b/src/devices/rocm/impl.jl
@@ -0,0 +1,53 @@
+using AMDGPU
+
+"""
+ ROCmGPU <: AbstractGPU
+
+Representation of a specific AMD GPU that code can run on. Implements the [`AbstractDevice`](@ref) interface.
+"""
+mutable struct ROCmGPU <: AbstractGPU
+ device::Any
+ cacheStrategy::CacheStrategy
+ FLOPS::Float64
+end
+
+push!(DEVICE_TYPES, ROCmGPU)
+
+CACHE_STRATEGIES[ROCmGPU] = [LocalVariables()]
+
+default_strategy(::Type{T}) where {T <: ROCmGPU} = LocalVariables()
+
+function measure_device!(device::ROCmGPU; verbose::Bool)
+ if verbose
+ println("Measuring ROCm GPU $(device.device)")
+ end
+
+ # TODO implement
+ return nothing
+end
+
+"""
+ get_devices(deviceType::Type{T}; verbose::Bool = false) where {T <: ROCmGPU}
+
+Return a Vector of [`ROCmGPU`](@ref)s available on the current machine. If `verbose` is true, print some additional information.
+"""
+function get_devices(deviceType::Type{T}; verbose::Bool = false) where {T <: ROCmGPU}
+ devices = Vector{AbstractDevice}()
+
+ if !AMDGPU.functional()
+ if verbose
+ println("AMDGPU is non-functional")
+ end
+ return devices
+ end
+
+ AMDDevices = AMDGPU.devices()
+ if verbose
+ println("Found $(length(AMDDevices)) AMD devices")
+ end
+ for device in AMDDevices
+ push!(devices, ROCmGPU(device, default_strategy(ROCmGPU), -1))
+ end
+
+ return devices
+end
diff --git a/src/diff/print.jl b/src/diff/print.jl
index 5f6c5ff..5214e6f 100644
--- a/src/diff/print.jl
+++ b/src/diff/print.jl
@@ -6,6 +6,6 @@ Pretty-print a [`Diff`](@ref). Called via print, println and co.
function show(io::IO, diff::Diff)
print(io, "Nodes: ")
print(io, length(diff.addedNodes) + length(diff.removedNodes))
- print(io, " Edges: ")
+ print(io, ", Edges: ")
return print(io, length(diff.addedEdges) + length(diff.removedEdges))
end
diff --git a/src/diff/type.jl b/src/diff/type.jl
index b3b7de3..be6d8b9 100644
--- a/src/diff/type.jl
+++ b/src/diff/type.jl
@@ -4,8 +4,8 @@
A named tuple representing a difference of added and removed nodes and edges on a [`DAG`](@ref).
"""
const Diff = NamedTuple{
- (:addedNodes, :removedNodes, :addedEdges, :removedEdges),
- Tuple{Vector{Node}, Vector{Node}, Vector{Edge}, Vector{Edge}},
+ (:addedNodes, :removedNodes, :addedEdges, :removedEdges, :updatedChildren),
+ Tuple{Vector{Node}, Vector{Node}, Vector{Edge}, Vector{Edge}, Vector{Tuple{Node, AbstractTask}}},
}
function Diff()
@@ -14,5 +14,8 @@ function Diff()
removedNodes = Vector{Node}(),
addedEdges = Vector{Edge}(),
removedEdges = Vector{Edge}(),
+
+ # children were updated in the task, updatedChildren[x][2] is the task before the update
+ updatedChildren = Vector{Tuple{Node, AbstractTask}}(),
)::Diff
end
diff --git a/src/graph/interface.jl b/src/graph/interface.jl
index 450419c..0fa74cc 100644
--- a/src/graph/interface.jl
+++ b/src/graph/interface.jl
@@ -38,8 +38,7 @@ end
Return `true` if [`pop_operation!`](@ref) is possible, `false` otherwise.
"""
-can_pop(graph::DAG) =
- !isempty(graph.operationsToApply) || !isempty(graph.appliedOperations)
+can_pop(graph::DAG) = !isempty(graph.operationsToApply) || !isempty(graph.appliedOperations)
"""
reset_graph!(graph::DAG)
diff --git a/src/graph/mute.jl b/src/graph/mute.jl
index 77fd83d..d23611f 100644
--- a/src/graph/mute.jl
+++ b/src/graph/mute.jl
@@ -15,12 +15,7 @@ Insert the node into the graph.
See also: [`remove_node!`](@ref), [`insert_edge!`](@ref), [`remove_edge!`](@ref)
"""
-function insert_node!(
- graph::DAG,
- node::Node,
- track = true,
- invalidate_cache = true,
-)
+function insert_node!(graph::DAG, node::Node; track = true, invalidate_cache = true)
# 1: mute
push!(graph.nodes, node)
@@ -50,14 +45,8 @@ Insert the edge between node1 (child) and node2 (parent) into the graph.
See also: [`insert_node!`](@ref), [`remove_node!`](@ref), [`remove_edge!`](@ref)
"""
-function insert_edge!(
- graph::DAG,
- node1::Node,
- node2::Node,
- track = true,
- invalidate_cache = true,
-)
- # @assert (node2 ∉ node1.parents) && (node1 ∉ node2.children) "Edge to insert already exists"
+function insert_edge!(graph::DAG, node1::Node, node2::Node; track = true, invalidate_cache = true)
+ @assert (node2 ∉ node1.parents) && (node1 ∉ node2.children) "Edge to insert already exists"
# 1: mute
# edge points from child to parent
@@ -95,13 +84,8 @@ Remove the node from the graph.
See also: [`insert_node!`](@ref), [`insert_edge!`](@ref), [`remove_edge!`](@ref)
"""
-function remove_node!(
- graph::DAG,
- node::Node,
- track = true,
- invalidate_cache = true,
-)
- # @assert node in graph.nodes "Trying to remove a node that's not in the graph"
+function remove_node!(graph::DAG, node::Node; track = true, invalidate_cache = true)
+ @assert node in graph.nodes "Trying to remove a node that's not in the graph"
# 1: mute
delete!(graph.nodes, node)
@@ -134,13 +118,7 @@ Remove the edge between node1 (child) and node2 (parent) into the graph.
See also: [`insert_node!`](@ref), [`remove_node!`](@ref), [`insert_edge!`](@ref)
"""
-function remove_edge!(
- graph::DAG,
- node1::Node,
- node2::Node,
- track = true,
- invalidate_cache = true,
-)
+function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invalidate_cache = true)
# 1: mute
pre_length1 = length(node1.parents)
pre_length2 = length(node2.children)
@@ -149,15 +127,15 @@ function remove_edge!(
filter!(x -> x != node2, node1.parents)
filter!(x -> x != node1, node2.children)
- #=@assert begin
- removed = pre_length1 - length(node1.parents)
- removed <= 1
- end "removed more than one node from node1's parents"=#
+ @assert begin
+ removed = pre_length1 - length(node1.parents)
+ removed <= 1
+ end "removed more than one node from node1's parents"
- #=@assert begin
- removed = pre_length2 - length(node2.children)
- removed <= 1
- end "removed more than one node from node2's children"=#
+ @assert begin
+ removed = pre_length2 - length(node2.children)
+ removed <= 1
+ end "removed more than one node from node2's children"
# 2: keep track
if (track)
@@ -181,6 +159,66 @@ function remove_edge!(
return nothing
end
+function replace_children!(task::FusedComputeTask, before, after)
+ replacedIn1 = length(findall(x -> x == before, task.t1_inputs))
+ replacedIn2 = length(findall(x -> x == before, task.t2_inputs))
+
+ @assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
+
+ replace!(task.t1_inputs, before => after)
+ replace!(task.t2_inputs, before => after)
+
+ # recursively descend down the tree, but only in the tasks where we're replacing things
+ if replacedIn1 > 0
+ replace_children!(task.first_task, before, after)
+ end
+ if replacedIn2 > 0
+ replace_children!(task.second_task, before, after)
+ end
+
+ return nothing
+end
+
+function replace_children!(task::AbstractTask, before, after)
+ return nothing
+end
+
+function update_child!(graph::DAG, n::Node, child_before::Symbol, child_after::Symbol; track = true)
+ # only need to update fused compute tasks
+ if !(typeof(n.task) <: FusedComputeTask)
+ return nothing
+ end
+
+ taskBefore = copy(n.task)
+
+ if !((child_before in n.task.t1_inputs) || (child_before in n.task.t2_inputs))
+ println("------------------ Nothing to replace!! ------------------")
+ child_ids = Vector{String}()
+ for child in n.children
+ push!(child_ids, "$(child.id)")
+ end
+ println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
+ @assert false
+ end
+
+ replace_children!(n.task, child_before, child_after)
+
+ if !((child_after in n.task.t1_inputs) || (child_after in n.task.t2_inputs))
+ println("------------------ Did not replace anything!! ------------------")
+ child_ids = Vector{String}()
+ for child in n.children
+ push!(child_ids, "$(child.id)")
+ end
+ println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
+ @assert false
+ end
+
+ # keep track
+ if (track)
+ push!(graph.diff.updatedChildren, (n, taskBefore))
+ end
+end
+
"""
get_snapshot_diff(graph::DAG)
diff --git a/src/graph/print.jl b/src/graph/print.jl
index c7e66e4..5b130e7 100644
--- a/src/graph/print.jl
+++ b/src/graph/print.jl
@@ -62,9 +62,5 @@ function show(io::IO, graph::DAG)
properties = get_properties(graph)
println(io, " Total Compute Effort: ", properties.computeEffort)
println(io, " Total Data Transfer: ", properties.data)
- return println(
- io,
- " Total Compute Intensity: ",
- properties.computeIntensity,
- )
+ return println(io, " Total Compute Intensity: ", properties.computeIntensity)
end
diff --git a/src/graph/properties.jl b/src/graph/properties.jl
index 2fd89db..7458c13 100644
--- a/src/graph/properties.jl
+++ b/src/graph/properties.jl
@@ -34,6 +34,7 @@ end
Return a vector of the graph's entry nodes.
"""
function get_entry_nodes(graph::DAG)
+ apply_all!(graph)
result = Vector{Node}()
for node in graph.nodes
if (is_entry_node(node))
diff --git a/src/graph/type.jl b/src/graph/type.jl
index 64ef860..6aa8585 100644
--- a/src/graph/type.jl
+++ b/src/graph/type.jl
@@ -17,7 +17,7 @@ end
The representation of the graph as a set of [`Node`](@ref)s.
-A DAG can be loaded using the appropriate parse function, e.g. [`parse_abc`](@ref).
+A DAG can be loaded using the appropriate parse_dag function, e.g. [`parse_dag`](@ref).
[`Operation`](@ref)s can be applied on it using [`push_operation!`](@ref) and reverted using [`pop_operation!`](@ref) like a stack.
To get the set of possible operations, use [`get_operations`](@ref).
@@ -52,11 +52,7 @@ end
Construct and return an empty [`PossibleOperations`](@ref) object.
"""
function PossibleOperations()
- return PossibleOperations(
- Set{NodeFusion}(),
- Set{NodeReduction}(),
- Set{NodeSplit}(),
- )
+ return PossibleOperations(Set{NodeFusion}(), Set{NodeReduction}(), Set{NodeSplit}())
end
"""
diff --git a/src/graph/validate.jl b/src/graph/validate.jl
index c033d2f..d731906 100644
--- a/src/graph/validate.jl
+++ b/src/graph/validate.jl
@@ -59,3 +59,19 @@ function is_valid(graph::DAG)
return true
end
+
+"""
+ is_scheduled(graph::DAG)
+
+Validate that the entire graph has been scheduled, i.e., every [`ComputeTaskNode`](@ref) has its `.device` set.
+"""
+function is_scheduled(graph::DAG)
+ for node in graph.nodes
+ if (node isa DataTaskNode)
+ continue
+ end
+ @assert !ismissing(node.device)
+ end
+
+ return true
+end
diff --git a/src/models/abc/compute.jl b/src/models/abc/compute.jl
index 03fc82a..8d14d01 100644
--- a/src/models/abc/compute.jl
+++ b/src/models/abc/compute.jl
@@ -45,6 +45,12 @@ For valid inputs, both input particles should have the same momenta at this poin
12 FLOP.
"""
function compute(::ComputeTaskS2, data1::ParticleValue, data2::ParticleValue)
+ #=
+ @assert isapprox(abs(data1.p.momentum.E), abs(data2.p.momentum.E), rtol = 0.001, atol = sqrt(eps())) "E: $(data1.p.momentum.E) vs. $(data2.p.momentum.E)"
+ @assert isapprox(data1.p.momentum.px, -data2.p.momentum.px, rtol = 0.001, atol = sqrt(eps())) "px: $(data1.p.momentum.px) vs. $(data2.p.momentum.px)"
+ @assert isapprox(data1.p.momentum.py, -data2.p.momentum.py, rtol = 0.001, atol = sqrt(eps())) "py: $(data1.p.momentum.py) vs. $(data2.p.momentum.py)"
+ @assert isapprox(data1.p.momentum.pz, -data2.p.momentum.pz, rtol = 0.001, atol = sqrt(eps())) "pz: $(data1.p.momentum.pz) vs. $(data2.p.momentum.pz)"
+ =#
return data1.v * inner_edge(data1.p) * data2.v
end
@@ -71,186 +77,78 @@ function compute(::ComputeTaskSum, data::Vector{Float64})
end
"""
- compute(t::FusedComputeTask, data)
+ get_expression(::ComputeTaskP, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
-Compute a [`FusedComputeTask`](@ref). This simply asserts false and should not be called. Fused Compute Tasks generate their expressions directly through the other tasks instead.
+Generate and return code evaluating [`ComputeTaskP`](@ref) on `inSyms`, providing the output on `outSym`.
"""
-function compute(t::FusedComputeTask, data)
- @assert false "This is not implemented and should never be called"
+function get_expression(::ComputeTaskP, device::AbstractDevice, inExprs::Vector, outExpr)
+ in = [eval(inExprs[1])]
+ out = eval(outExpr)
+
+ return Meta.parse("$out = compute(ComputeTaskP(), $(in[1]))")
end
"""
- get_expression(::ComputeTaskP, inSymbol::Symbol, outSymbol::Symbol)
+ get_expression(::ComputeTaskU, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
-Generate and return code evaluating [`ComputeTaskP`](@ref) on `inSymbol`, providing the output on `outSymbol`.
+Generate code evaluating [`ComputeTaskU`](@ref) on `inSyms`, providing the output on `outSym`.
+`inSyms` should be of type [`ParticleValue`](@ref), `outSym` will be of type [`ParticleValue`](@ref).
"""
-function get_expression(::ComputeTaskP, inSymbol::Symbol, outSymbol::Symbol)
- return Meta.parse("$outSymbol = compute(ComputeTaskP(), $inSymbol)")
+function get_expression(::ComputeTaskU, device::AbstractDevice, inExprs::Vector, outExpr)
+ in = [eval(inExprs[1])]
+ out = eval(outExpr)
+
+ return Meta.parse("$out = compute(ComputeTaskU(), $(in[1]))")
end
"""
- get_expression(::ComputeTaskU, inSymbol::Symbol, outSymbol::Symbol)
+ get_expression(::ComputeTaskV, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
-Generate code evaluating [`ComputeTaskU`](@ref) on `inSymbol`, providing the output on `outSymbol`.
-`inSymbol` should be of type [`ParticleValue`](@ref), `outSymbol` will be of type [`ParticleValue`](@ref).
+Generate code evaluating [`ComputeTaskV`](@ref) on `inSyms`, providing the output on `outSym`.
+`inSym[1]` and `inSym[2]` should be of type [`ParticleValue`](@ref), `outSym` will be of type [`ParticleValue`](@ref).
"""
-function get_expression(::ComputeTaskU, inSymbol::Symbol, outSymbol::Symbol)
- return Meta.parse("$outSymbol = compute(ComputeTaskU(), $inSymbol)")
+function get_expression(::ComputeTaskV, device::AbstractDevice, inExprs::Vector, outExpr)
+ in = [eval(inExprs[1]), eval(inExprs[2])]
+ out = eval(outExpr)
+
+ return Meta.parse("$out = compute(ComputeTaskV(), $(in[1]), $(in[2]))")
end
"""
- get_expression(::ComputeTaskV, inSymbol1::Symbol, inSymbol2::Symbol, outSymbol::Symbol)
+ get_expression(::ComputeTaskS2, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
-Generate code evaluating [`ComputeTaskV`](@ref) on `inSymbol1` and `inSymbol2`, providing the output on `outSymbol`.
-`inSymbol1` and `inSymbol2` should be of type [`ParticleValue`](@ref), `outSymbol` will be of type [`ParticleValue`](@ref).
+Generate code evaluating [`ComputeTaskS2`](@ref) on `inSyms`, providing the output on `outSym`.
+`inSyms[1]` and `inSyms[2]` should be of type [`ParticleValue`](@ref), `outSym` will be of type `Float64`.
"""
-function get_expression(
- ::ComputeTaskV,
- inSymbol1::Symbol,
- inSymbol2::Symbol,
- outSymbol::Symbol,
-)
- return Meta.parse(
- "$outSymbol = compute(ComputeTaskV(), $inSymbol1, $inSymbol2)",
- )
+function get_expression(::ComputeTaskS2, device::AbstractDevice, inExprs::Vector, outExpr)
+ in = [eval(inExprs[1]), eval(inExprs[2])]
+ out = eval(outExpr)
+
+ return Meta.parse("$out = compute(ComputeTaskS2(), $(in[1]), $(in[2]))")
end
"""
- get_expression(::ComputeTaskS2, inSymbol1::Symbol, inSymbol2::Symbol, outSymbol::Symbol)
+ get_expression(::ComputeTaskS1, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
-Generate code evaluating [`ComputeTaskS2`](@ref) on `inSymbol1` and `inSymbol2`, providing the output on `outSymbol`.
-`inSymbol1` and `inSymbol2` should be of type [`ParticleValue`](@ref), `outSymbol` will be of type `Float64`.
+Generate code evaluating [`ComputeTaskS1`](@ref) on `inSyms`, providing the output on `outSym`.
+`inSyms` should be of type [`ParticleValue`](@ref), `outSym` will be of type [`ParticleValue`](@ref).
"""
-function get_expression(
- ::ComputeTaskS2,
- inSymbol1::Symbol,
- inSymbol2::Symbol,
- outSymbol::Symbol,
-)
- return Meta.parse(
- "$outSymbol = compute(ComputeTaskS2(), $inSymbol1, $inSymbol2)",
- )
+function get_expression(::ComputeTaskS1, device::AbstractDevice, inExprs::Vector, outExpr)
+ in = [eval(inExprs[1])]
+ out = eval(outExpr)
+
+ return Meta.parse("$out = compute(ComputeTaskS1(), $(in[1]))")
end
"""
- get_expression(::ComputeTaskS1, inSymbol::Symbol, outSymbol::Symbol)
+ get_expression(::ComputeTaskSum, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
-Generate code evaluating [`ComputeTaskS1`](@ref) on `inSymbol`, providing the output on `outSymbol`.
-`inSymbol` should be of type [`ParticleValue`](@ref), `outSymbol` will be of type [`ParticleValue`](@ref).
+Generate code evaluating [`ComputeTaskSum`](@ref) on `inSyms`, providing the output on `outSym`.
+`inSyms` should be of type [`Float64`], `outSym` will be of type [`Float64`].
"""
-function get_expression(::ComputeTaskS1, inSymbol::Symbol, outSymbol::Symbol)
- return Meta.parse("$outSymbol = compute(ComputeTaskS1(), $inSymbol)")
-end
-
-"""
- get_expression(::ComputeTaskSum, inSymbols::Vector{Symbol}, outSymbol::Symbol)
-
-Generate code evaluating [`ComputeTaskSum`](@ref) on `inSymbols`, providing the output on `outSymbol`.
-`inSymbols` should be of type [`Float64`], `outSymbol` will be of type [`Float64`].
-"""
-function get_expression(
- ::ComputeTaskSum,
- inSymbols::Vector{Symbol},
- outSymbol::Symbol,
-)
- return quote
- $outSymbol = compute(ComputeTaskSum(), [$(inSymbols...)])
- end
-end
-
-"""
- get_expression(t::FusedComputeTask, inSymbols::Vector{Symbol}, outSymbol::Symbol)
-
-Generate code evaluating a [`FusedComputeTask`](@ref) on `inSymbols`, providing the output on `outSymbol`.
-`inSymbols` should be of the correct types and may be heterogeneous. `outSymbol` will be of the type of the output of `T2` of t.
-"""
-function get_expression(
- t::FusedComputeTask,
- inSymbols::Vector{Symbol},
- outSymbol::Symbol,
-)
- (T1, T2) = get_types(t)
- c1 = children(T1())
- c2 = children(T2())
-
- expr1 = nothing
- expr2 = nothing
-
- # TODO need to figure out how to know which inputs belong to which subtask
- # since we order the vectors with the child nodes we can't just split
- if (c1 == 1)
- expr1 = get_expression(T1(), inSymbols[begin], :intermediate)
- elseif (c1 == 2)
- expr1 =
- get_expression(T1(), inSymbols[begin], inSymbols[2], :intermediate)
- else
- expr1 = get_expression(T1(), inSymbols[begin:c1], :intermediate)
- end
-
- if (c2 == 1)
- expr2 = get_expression(T2(), :intermediate, outSymbol)
- elseif c2 == 2
- expr2 =
- get_expression(T2(), :intermediate, inSymbols[c1 + 1], outSymbol)
- else
- expr2 = get_expression(
- T2(),
- :intermediate * inSymbols[(c1 + 1):end],
- outSymbol,
- )
- end
-
- return Expr(:block, expr1, expr2)
-end
-
-"""
- get_expression(node::ComputeTaskNode)
-
-Generate and return code for a given [`ComputeTaskNode`](@ref).
-"""
-function get_expression(node::ComputeTaskNode)
- t = typeof(node.task)
- @assert length(node.children) == children(node.task) || t <: ComputeTaskSum
-
- if (t <: ComputeTaskU || t <: ComputeTaskP || t <: ComputeTaskS1) # single input
- symbolIn = Symbol("data_$(to_var_name(node.children[1].id))")
- symbolOut = Symbol("data_$(to_var_name(node.id))")
- return get_expression(t(), symbolIn, symbolOut)
- elseif (t <: ComputeTaskS2 || t <: ComputeTaskV) # double input
- symbolIn1 = Symbol("data_$(to_var_name(node.children[1].id))")
- symbolIn2 = Symbol("data_$(to_var_name(node.children[2].id))")
- symbolOut = Symbol("data_$(to_var_name(node.id))")
- return get_expression(t(), symbolIn1, symbolIn2, symbolOut)
- elseif (t <: ComputeTaskSum || t <: FusedComputeTask) # vector input
- inSymbols = Vector{Symbol}()
- for child in node.children
- push!(inSymbols, Symbol("data_$(to_var_name(child.id))"))
- end
- outSymbol = Symbol("data_$(to_var_name(node.id))")
- return get_expression(t(), inSymbols, outSymbol)
- else
- error("Unknown compute task")
- end
-end
-
-"""
- get_expression(node::DataTaskNode)
-
-Generate and return code for a given [`DataTaskNode`](@ref).
-"""
-function get_expression(node::DataTaskNode)
- # TODO: do things to transport data from/to gpu, between numa nodes, etc.
- @assert length(node.children) <= 1
-
- inSymbol = nothing
- if (length(node.children) == 1)
- inSymbol = Symbol("data_$(to_var_name(node.children[1].id))")
- else
- inSymbol = Symbol("data_$(to_var_name(node.id))_in")
- end
- outSymbol = Symbol("data_$(to_var_name(node.id))")
-
- dataTransportExp = Meta.parse("$outSymbol = $inSymbol")
-
- return dataTransportExp
+function get_expression(::ComputeTaskSum, device::AbstractDevice, inExprs::Vector, outExpr)
+ in = eval.(inExprs)
+ out = eval(outExpr)
+
+ return Meta.parse("$out = compute(ComputeTaskSum(), [$(unroll_symbol_vector(in))])")
end
diff --git a/src/models/abc/create.jl b/src/models/abc/create.jl
index 865b18c..d33ec4a 100644
--- a/src/models/abc/create.jl
+++ b/src/models/abc/create.jl
@@ -1,74 +1,198 @@
+using QEDbase
+using Random
+using Roots
+using ForwardDiff
+
+ComputeTaskSum() = ComputeTaskSum(0)
"""
- Particle(rng)
+ gen_process_input(processDescription::ABCProcessDescription)
-Return a randomly generated particle.
+Return a ProcessInput of randomly generated [`ABCParticle`](@ref)s from a [`ABCProcessDescription`](@ref). The process description can be created manually or parsed from a string using [`parse_process`](@ref).
+
+Note: This uses RAMBO to create a valid process with conservation of momentum and energy.
"""
-function Particle(rng, type::ParticleType)
+function gen_process_input(processDescription::ABCProcessDescription)
+ inParticleTypes = keys(processDescription.inParticles)
+ outParticleTypes = keys(processDescription.outParticles)
- p1 = rand(rng, Float64)
- p2 = rand(rng, Float64)
- p3 = rand(rng, Float64)
- m = mass(type)
-
- # keep the momenta of the particles on-shell
- p4 = sqrt(p1^2 + p2^2 + p3^2 + m^2)
-
- return Particle(p1, p2, p3, p4, type)
-end
-
-"""
- gen_particles(n::Int)
-
-Return a Vector of `n` randomly generated [`Particle`](@ref)s.
-
-Note: This does not take into account the preservation of momenta required for an actual valid process!
-"""
-function gen_particles(ns::Dict{ParticleType, Int})
- particles = Dict{ParticleType, Vector{Particle}}()
- rng = MersenneTwister(0)
-
-
- if ns == Dict((A => 2), (B => 2))
- rho = 1.0
-
- omega = rand(rng, Float64)
- theta = rand(rng, Float64) * π
- phi = rand(rng, Float64) * π
-
- particles[A] = Vector{Particle}()
- particles[B] = Vector{Particle}()
-
- push!(particles[A], Particle(omega, 0, 0, omega, A))
- push!(particles[B], Particle(omega, 0, 0, -omega, B))
- push!(
- particles[A],
- Particle(
- omega,
- rho * cos(theta) * cos(phi),
- rho * cos(theta) * sin(phi),
- rho * sin(theta),
- A,
- ),
- )
- push!(
- particles[B],
- Particle(
- omega,
- -rho * cos(theta) * cos(phi),
- -rho * cos(theta) * sin(phi),
- -rho * sin(theta),
- B,
- ),
- )
- return particles
- end
-
- for (type, n) in ns
- particles[type] = Vector{Particle}()
- for i in 1:n
- push!(particles[type], Particle(rng, type))
+ massSum = 0
+ inputMasses = Vector{Float64}()
+ for (particle, n) in processDescription.inParticles
+ for _ in 1:n
+ massSum += mass(particle)
+ push!(inputMasses, mass(particle))
end
end
- return particles
+ outputMasses = Vector{Float64}()
+ for (particle, n) in processDescription.outParticles
+ for _ in 1:n
+ massSum += mass(particle)
+ push!(outputMasses, mass(particle))
+ end
+ end
+
+ # add some extra random mass to allow for some momentum
+ massSum += rand(rng[threadid()]) * (length(inputMasses) + length(outputMasses))
+
+
+ inputParticles = Vector{ABCParticle}()
+ initialMomenta = generate_initial_moms(massSum, inputMasses)
+ index = 1
+ for (particle, n) in processDescription.inParticles
+ for _ in 1:n
+ mom = initialMomenta[index]
+ push!(inputParticles, particle(mom))
+ index += 1
+ end
+ end
+
+ outputParticles = Vector{ABCParticle}()
+ final_momenta = generate_physical_massive_moms(rng[threadid()], massSum, outputMasses)
+ index = 1
+ for (particle, n) in processDescription.outParticles
+ for _ in 1:n
+ mom = final_momenta[index]
+ push!(outputParticles, particle(SFourMomentum(-mom.E, mom.px, mom.py, mom.pz)))
+ index += 1
+ end
+ end
+
+ processInput = ABCProcessInput(processDescription, inputParticles, outputParticles)
+
+ return return processInput
+end
+
+####################
+# CODE FROM HERE BORROWED FROM SOURCE: https://codebase.helmholtz.cloud/qedsandbox/QEDphasespaces.jl/
+# use qedphasespaces directly once released
+#
+# quick and dirty implementation of the RAMBO algorithm
+#
+# reference:
+# * https://cds.cern.ch/record/164736/files/198601282.pdf
+# * https://www.sciencedirect.com/science/article/pii/0010465586901190
+####################
+
+function generate_initial_moms(ss, masses)
+ E1 = (ss^2 + masses[1]^2 - masses[2]^2) / (2 * ss)
+ E2 = (ss^2 + masses[2]^2 - masses[1]^2) / (2 * ss)
+
+ rho1 = sqrt(E1^2 - masses[1]^2)
+ rho2 = sqrt(E2^2 - masses[2]^2)
+
+ return [SFourMomentum(E1, 0, 0, rho1), SFourMomentum(E2, 0, 0, -rho2)]
+end
+
+
+Random.rand(rng::AbstractRNG, ::Random.SamplerType{SFourMomentum}) = SFourMomentum(rand(rng, 4))
+Random.rand(rng::AbstractRNG, ::Random.SamplerType{NTuple{N, Float64}}) where {N} = Tuple(rand(rng, N))
+
+
+function _transform_uni_to_mom(u1, u2, u3, u4)
+ cth = 2 * u1 - 1
+ sth = sqrt(1 - cth^2)
+ phi = 2 * pi * u2
+ q0 = -log(u3 * u4)
+ qx = q0 * sth * cos(phi)
+ qy = q0 * sth * sin(phi)
+ qz = q0 * cth
+
+ return SFourMomentum(q0, qx, qy, qz)
+end
+
+function _transform_uni_to_mom!(uni_mom, dest)
+ u1, u2, u3, u4 = Tuple(uni_mom)
+ cth = 2 * u1 - 1
+ sth = sqrt(1 - cth^2)
+ phi = 2 * pi * u2
+ q0 = -log(u3 * u4)
+ qx = q0 * sth * cos(phi)
+ qy = q0 * sth * sin(phi)
+ qz = q0 * cth
+
+ return dest = SFourMomentum(q0, qx, qy, qz)
+end
+
+_transform_uni_to_mom(u1234::Tuple) = _transform_uni_to_mom(u1234...)
+_transform_uni_to_mom(u1234::SFourMomentum) = _transform_uni_to_mom(Tuple(u1234))
+
+function generate_massless_moms(rng, n::Int)
+ a = Vector{SFourMomentum}(undef, n)
+ rand!(rng, a)
+ return map(_transform_uni_to_mom, a)
+end
+
+function generate_physical_massless_moms(rng, ss, n)
+ r_moms = generate_massless_moms(rng, n)
+ Q = sum(r_moms)
+ M = sqrt(Q * Q)
+ fac = -1 / M
+ Qx = getX(Q)
+ Qy = getY(Q)
+ Qz = getZ(Q)
+ bx = fac * Qx
+ by = fac * Qy
+ bz = fac * Qz
+ gamma = getT(Q) / M
+ a = 1 / (1 + gamma)
+ x = ss / M
+
+ i = 1
+ while i <= n
+ mom = r_moms[i]
+ mom0 = getT(mom)
+ mom1 = getX(mom)
+ mom2 = getY(mom)
+ mom3 = getZ(mom)
+
+ bq = bx * mom1 + by * mom2 + bz * mom3
+
+ p0 = x * (gamma * mom0 + bq)
+ px = x * (mom1 + bx * mom0 + a * bq * bx)
+ py = x * (mom2 + by * mom0 + a * bq * by)
+ pz = x * (mom3 + bz * mom0 + a * bq * bz)
+
+ r_moms[i] = SFourMomentum(p0, px, py, pz)
+ i += 1
+ end
+ return r_moms
+end
+
+function _to_be_solved(xi, masses, p0s, ss)
+ sum = 0.0
+ for (i, E) in enumerate(p0s)
+ sum += sqrt(masses[i]^2 + xi^2 * E^2)
+ end
+ return sum - ss
+end
+
+function _build_massive_momenta(xi, masses, massless_moms)
+ vec = SFourMomentum[]
+ i = 1
+ while i <= length(massless_moms)
+ massless_mom = massless_moms[i]
+ k0 = sqrt(getT(massless_mom)^2 * xi^2 + masses[i]^2)
+
+ kx = xi * getX(massless_mom)
+ ky = xi * getY(massless_mom)
+ kz = xi * getZ(massless_mom)
+
+ push!(vec, SFourMomentum(k0, kx, ky, kz))
+
+ i += 1
+ end
+ return vec
+end
+
+first_derivative(func) = x -> ForwardDiff.derivative(func, float(x))
+
+
+function generate_physical_massive_moms(rng, ss, masses; x0 = 0.1)
+ n = length(masses)
+ massless_moms = generate_physical_massless_moms(rng, ss, n)
+ energies = getT.(massless_moms)
+ f = x -> _to_be_solved(x, masses, energies, ss)
+ xi = find_zero((f, first_derivative(f)), x0, Roots.Newton())
+ return _build_massive_momenta(xi, masses, massless_moms)
end
diff --git a/src/models/abc/parse.jl b/src/models/abc/parse.jl
index 94e62f9..5ef890b 100644
--- a/src/models/abc/parse.jl
+++ b/src/models/abc/parse.jl
@@ -32,13 +32,13 @@ function parse_edges(input::AbstractString)
end
"""
- parse_abc(filename::String; verbose::Bool = false)
+ parse_dag(filename::String, model::ABCModel; verbose::Bool = false)
Read an abc-model process from the given file. If `verbose` is set to true, print some progress information to stdout.
Returns a valid [`DAG`](@ref).
"""
-function parse_abc(filename::String, verbose::Bool = false)
+function parse_dag(filename::AbstractString, model::ABCModel, verbose::Bool = false)
file = open(filename, "r")
if (verbose)
@@ -63,10 +63,9 @@ function parse_abc(filename::String, verbose::Bool = false)
end
sizehint!(graph.nodes, estimate_no_nodes)
- sum_node = insert_node!(graph, make_node(ComputeTaskSum()), false, false)
- global_data_out =
- insert_node!(graph, make_node(DataTask(FLOAT_SIZE)), false, false)
- insert_edge!(graph, sum_node, global_data_out, false, false)
+ sum_node = insert_node!(graph, make_node(ComputeTaskSum(0)), track = false, invalidate_cache = false)
+ global_data_out = insert_node!(graph, make_node(DataTask(FLOAT_SIZE)), track = false, invalidate_cache = false)
+ insert_edge!(graph, sum_node, global_data_out, track = false, invalidate_cache = false)
# remember the data out nodes for connection
dataOutNodes = Dict()
@@ -81,10 +80,7 @@ function parse_abc(filename::String, verbose::Bool = false)
noNodes += 1
if (noNodes % 100 == 0)
if (verbose)
- percent = string(
- round(100.0 * noNodes / nodesToRead, digits = 2),
- "%",
- )
+ percent = string(round(100.0 * noNodes / nodesToRead, digits = 2), "%")
print("\rReading Nodes... $percent")
end
end
@@ -93,30 +89,20 @@ function parse_abc(filename::String, verbose::Bool = false)
data_in = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE), string(node)),
- false,
- false,
+ track = false,
+ invalidate_cache = false,
) # read particle data node
- compute_P =
- insert_node!(graph, make_node(ComputeTaskP()), false, false) # compute P node
- data_Pu = insert_node!(
- graph,
- make_node(DataTask(PARTICLE_VALUE_SIZE)),
- false,
- false,
- ) # transfer data from P to u (one ParticleValue object)
- compute_u =
- insert_node!(graph, make_node(ComputeTaskU()), false, false) # compute U node
- data_out = insert_node!(
- graph,
- make_node(DataTask(PARTICLE_VALUE_SIZE)),
- false,
- false,
- ) # transfer data out from u (one ParticleValue object)
+ compute_P = insert_node!(graph, make_node(ComputeTaskP()), track = false, invalidate_cache = false) # compute P node
+ data_Pu =
+ insert_node!(graph, make_node(DataTask(PARTICLE_VALUE_SIZE)), track = false, invalidate_cache = false) # transfer data from P to u (one ParticleValue object)
+ compute_u = insert_node!(graph, make_node(ComputeTaskU()), track = false, invalidate_cache = false) # compute U node
+ data_out =
+ insert_node!(graph, make_node(DataTask(PARTICLE_VALUE_SIZE)), track = false, invalidate_cache = false) # transfer data out from u (one ParticleValue object)
- insert_edge!(graph, data_in, compute_P, false, false)
- insert_edge!(graph, compute_P, data_Pu, false, false)
- insert_edge!(graph, data_Pu, compute_u, false, false)
- insert_edge!(graph, compute_u, data_out, false, false)
+ insert_edge!(graph, data_in, compute_P, track = false, invalidate_cache = false)
+ insert_edge!(graph, compute_P, data_Pu, track = false, invalidate_cache = false)
+ insert_edge!(graph, data_Pu, compute_u, track = false, invalidate_cache = false)
+ insert_edge!(graph, compute_u, data_out, track = false, invalidate_cache = false)
# remember the data_out node for future edges
dataOutNodes[node] = data_out
@@ -126,63 +112,48 @@ function parse_abc(filename::String, verbose::Bool = false)
in1 = capt.captures[1]
in2 = capt.captures[2]
- compute_v =
- insert_node!(graph, make_node(ComputeTaskV()), false, false)
- data_out = insert_node!(
- graph,
- make_node(DataTask(PARTICLE_VALUE_SIZE)),
- false,
- false,
- )
+ compute_v = insert_node!(graph, make_node(ComputeTaskV()), track = false, invalidate_cache = false)
+ data_out =
+ insert_node!(graph, make_node(DataTask(PARTICLE_VALUE_SIZE)), track = false, invalidate_cache = false)
if (occursin(regex_c, in1))
# put an S node after this input
- compute_S = insert_node!(
- graph,
- make_node(ComputeTaskS1()),
- false,
- false,
- )
+ compute_S = insert_node!(graph, make_node(ComputeTaskS1()), track = false, invalidate_cache = false)
data_S_v = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE)),
- false,
- false,
+ track = false,
+ invalidate_cache = false,
)
- insert_edge!(graph, dataOutNodes[in1], compute_S, false, false)
- insert_edge!(graph, compute_S, data_S_v, false, false)
+ insert_edge!(graph, dataOutNodes[in1], compute_S, track = false, invalidate_cache = false)
+ insert_edge!(graph, compute_S, data_S_v, track = false, invalidate_cache = false)
- insert_edge!(graph, data_S_v, compute_v, false, false)
+ insert_edge!(graph, data_S_v, compute_v, track = false, invalidate_cache = false)
else
- insert_edge!(graph, dataOutNodes[in1], compute_v, false, false)
+ insert_edge!(graph, dataOutNodes[in1], compute_v, track = false, invalidate_cache = false)
end
if (occursin(regex_c, in2))
# i think the current generator only puts the combined particles in the first space, so this case might never be entered
# put an S node after this input
- compute_S = insert_node!(
- graph,
- make_node(ComputeTaskS1()),
- false,
- false,
- )
+ compute_S = insert_node!(graph, make_node(ComputeTaskS1()), track = false, invalidate_cache = false)
data_S_v = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE)),
- false,
- false,
+ track = false,
+ invalidate_cache = false,
)
- insert_edge!(graph, dataOutNodes[in2], compute_S, false, false)
- insert_edge!(graph, compute_S, data_S_v, false, false)
+ insert_edge!(graph, dataOutNodes[in2], compute_S, track = false, invalidate_cache = false)
+ insert_edge!(graph, compute_S, data_S_v, track = false, invalidate_cache = false)
- insert_edge!(graph, data_S_v, compute_v, false, false)
+ insert_edge!(graph, data_S_v, compute_v, track = false, invalidate_cache = false)
else
- insert_edge!(graph, dataOutNodes[in2], compute_v, false, false)
+ insert_edge!(graph, dataOutNodes[in2], compute_v, track = false, invalidate_cache = false)
end
- insert_edge!(graph, compute_v, data_out, false, false)
+ insert_edge!(graph, compute_v, data_out, track = false, invalidate_cache = false)
dataOutNodes[node] = data_out
elseif occursin(regex_m, node)
@@ -193,43 +164,31 @@ function parse_abc(filename::String, verbose::Bool = false)
in3 = capt.captures[3]
# in2 + in3 with a v
- compute_v =
- insert_node!(graph, make_node(ComputeTaskV()), false, false)
- data_v = insert_node!(
- graph,
- make_node(DataTask(PARTICLE_VALUE_SIZE)),
- false,
- false,
- )
+ compute_v = insert_node!(graph, make_node(ComputeTaskV()), track = false, invalidate_cache = false)
+ data_v =
+ insert_node!(graph, make_node(DataTask(PARTICLE_VALUE_SIZE)), track = false, invalidate_cache = false)
- insert_edge!(graph, dataOutNodes[in2], compute_v, false, false)
- insert_edge!(graph, dataOutNodes[in3], compute_v, false, false)
- insert_edge!(graph, compute_v, data_v, false, false)
+ insert_edge!(graph, dataOutNodes[in2], compute_v, track = false, invalidate_cache = false)
+ insert_edge!(graph, dataOutNodes[in3], compute_v, track = false, invalidate_cache = false)
+ insert_edge!(graph, compute_v, data_v, track = false, invalidate_cache = false)
# combine with the v of the combined other input
- compute_S2 =
- insert_node!(graph, make_node(ComputeTaskS2()), false, false)
- data_out = insert_node!(
- graph,
- make_node(DataTask(FLOAT_SIZE)),
- false,
- false,
- ) # output of a S2 task is only a float
+ compute_S2 = insert_node!(graph, make_node(ComputeTaskS2()), track = false, invalidate_cache = false)
+ data_out = insert_node!(graph, make_node(DataTask(FLOAT_SIZE)), track = false, invalidate_cache = false) # output of a S2 task is only a float
- insert_edge!(graph, data_v, compute_S2, false, false)
- insert_edge!(graph, dataOutNodes[in1], compute_S2, false, false)
- insert_edge!(graph, compute_S2, data_out, false, false)
+ insert_edge!(graph, data_v, compute_S2, track = false, invalidate_cache = false)
+ insert_edge!(graph, dataOutNodes[in1], compute_S2, track = false, invalidate_cache = false)
+ insert_edge!(graph, compute_S2, data_out, track = false, invalidate_cache = false)
- insert_edge!(graph, data_out, sum_node, false, false)
+ insert_edge!(graph, data_out, sum_node, track = false, invalidate_cache = false)
+ add_child!(sum_node.task)
elseif occursin(regex_plus, node)
if (verbose)
println("\rReading Nodes Complete ")
println("Added ", length(graph.nodes), " nodes")
end
else
- @assert false (
- "Unknown node '$node' while reading from file $filename"
- )
+ @assert false ("Unknown node '$node' while reading from file $filename")
end
end
@@ -244,6 +203,46 @@ function parse_abc(filename::String, verbose::Bool = false)
if (verbose)
println("Done")
end
+
# don't actually need to read the edges
return graph
end
+
+"""
+ parse_process(string::AbstractString, model::ABCModel)
+
+Parse a string representation of a process, such as "AB->ABBB" into the corresponding [`ABCProcessDescription`](@ref).
+"""
+function parse_process(str::AbstractString, model::ABCModel)
+ inParticles = Dict{Type, Int}()
+ outParticles = Dict{Type, Int}()
+
+ if !(contains(str, "->"))
+ throw("Did not find -> while parsing process \"$str\"")
+ end
+
+ (inStr, outStr) = split(str, "->")
+
+ if (isempty(inStr) || isempty(outStr))
+ throw("Process (\"$str\") input or output part is empty!")
+ end
+
+ for t in types(model)
+ inCount = count(x -> x == String(t)[1], inStr)
+ outCount = count(x -> x == String(t)[1], outStr)
+ if inCount != 0
+ inParticles[t] = inCount
+ end
+ if outCount != 0
+ outParticles[t] = outCount
+ end
+ end
+
+ if length(inStr) != sum(values(inParticles))
+ throw("Encountered unknown characters in the input part of process \"$str\"")
+ elseif length(outStr) != sum(values(outParticles))
+ throw("Encountered unknown characters in the output part of process \"$str\"")
+ end
+
+ return ABCProcessDescription(inParticles, outParticles)
+end
diff --git a/src/models/abc/particle.jl b/src/models/abc/particle.jl
index b6807b5..368165b 100644
--- a/src/models/abc/particle.jl
+++ b/src/models/abc/particle.jl
@@ -1,99 +1,140 @@
-"""
- ParticleType
+using QEDbase
-A Particle Type in the ABC Model as an enum, with types `A`, `B` and `C`.
"""
-@enum ParticleType A = 1 B = 2 C = 3
+ ABCModel <: AbstractPhysicsModel
+
+Singleton definition for identification of the ABC-Model.
+"""
+struct ABCModel <: AbstractPhysicsModel end
+
+"""
+ ABCParticle
+
+Base type for all particles in the [`ABCModel`](@ref).
+"""
+abstract type ABCParticle <: AbstractParticle end
+
+"""
+ ParticleA <: ABCParticle
+
+An 'A' particle in the ABC Model.
+"""
+struct ParticleA <: ABCParticle
+ momentum::SFourMomentum
+end
+
+"""
+ ParticleB <: ABCParticle
+
+A 'B' particle in the ABC Model.
+"""
+struct ParticleB <: ABCParticle
+ momentum::SFourMomentum
+end
+
+"""
+ ParticleC <: ABCParticle
+
+A 'C' particle in the ABC Model.
+"""
+struct ParticleC <: ABCParticle
+ momentum::SFourMomentum
+end
+
+"""
+ ABCProcessDescription <: AbstractProcessDescription
+
+A description of a process in the ABC-Model. Contains the input and output particles.
+
+See also: [`in_particles`](@ref), [`out_particles`](@ref), [`parse_process`](@ref)
+"""
+struct ABCProcessDescription <: AbstractProcessDescription
+ inParticles::Dict{Type, Int}
+ outParticles::Dict{Type, Int}
+end
+
+"""
+ ABCProcessInput <: AbstractProcessInput
+
+Input for a ABC Process. Contains the [`ABCProcessDescription`](@ref) of the process it is an input for, and the values of the in and out particles.
+
+See also: [`gen_process_input`](@ref)
+"""
+struct ABCProcessInput <: AbstractProcessInput
+ process::ABCProcessDescription
+ inParticles::Vector{ABCParticle}
+ outParticles::Vector{ABCParticle}
+end
"""
PARTICLE_MASSES
-A constant dictionary containing the masses of the different [`ParticleType`](@ref)s.
+A constant dictionary containing the masses of the different [`ABCParticle`](@ref)s.
"""
-const PARTICLE_MASSES =
- Dict{ParticleType, Float64}(A => 1.0, B => 1.0, C => 0.0)
+const PARTICLE_MASSES = Dict{Type, Float64}(ParticleA => 1.0, ParticleB => 1.0, ParticleC => 0.0)
"""
- Particle
-
-A struct describing a particle of the ABC-Model. It has the 4 momentum parts P0...P3 and a [`ParticleType`](@ref).
-
-`sizeof(Particle())` = 40 Byte
-"""
-struct Particle
- P0::Float64
- P1::Float64
- P2::Float64
- P3::Float64
-
- type::ParticleType
-end
-
-"""
- ParticleValue
-
-A struct describing a particle during a calculation of a Feynman Diagram, together with the value that's being calculated.
-
-`sizeof(ParticleValue())` = 48 Byte
-"""
-struct ParticleValue
- p::Particle
- v::Float64
-end
-
-"""
- mass(t::ParticleType)
+ mass(t::Type{T}) where {T <: ABCParticle}
Return the mass (at rest) of the given particle type.
"""
-mass(t::ParticleType) = PARTICLE_MASSES[t]
+mass(t::Type{T}) where {T <: ABCParticle} = PARTICLE_MASSES[t]
"""
- remaining_type(t1::ParticleType, t2::ParticleType)
+ interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: ABCParticle, T2 <: ABCParticle}
For 2 given (non-equal) particle types, return the third of ABC.
"""
-function remaining_type(t1::ParticleType, t2::ParticleType)
+function interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: ABCParticle, T2 <: ABCParticle}
@assert t1 != t2
- if t1 != A && t2 != A
- return A
- elseif t1 != B && t2 != B
- return B
+ if t1 != Type{ParticleA} && t2 != Type{ParticleA}
+ return ParticleA
+ elseif t1 != Type{ParticleB} && t2 != Type{ParticleB}
+ return ParticleB
else
- return C
+ return ParticleC
end
end
"""
- square(p::Particle)
+ types(::ABCModel)
+
+Return a Vector of the possible types of particle in the [`ABCModel`](@ref).
+"""
+function types(::ABCModel)
+ return [ParticleA, ParticleB, ParticleC]
+end
+
+"""
+ square(p::ABCParticle)
Return the square of the particle's momentum as a `Float` value.
Takes 7 effective FLOP.
"""
-function square(p::Particle)
- return p.P0 * p.P0 - p.P1 * p.P1 - p.P2 * p.P2 - p.P3 * p.P3
+function square(p::ABCParticle)
+ return getMass2(p.momentum)
end
"""
- inner_edge(p::Particle)
+ inner_edge(p::ABCParticle)
Return the factor of the inner edge with the given (virtual) particle.
-Takes 10 effective FLOP. (3 here + 10 in square(p))
+Takes 10 effective FLOP. (3 here + 7 in square(p))
"""
-function inner_edge(p::Particle)
- return 1.0 / (square(p) - mass(p.type) * mass(p.type))
+function inner_edge(p::ABCParticle)
+ return 1.0 / (square(p) - mass(typeof(p)) * mass(typeof(p)))
end
"""
- outer_edge(p::Particle)
+ outer_edge(p::ABCParticle)
Return the factor of the outer edge with the given (real) particle.
Takes 0 effective FLOP.
"""
-function outer_edge(p::Particle)
+function outer_edge(p::ABCParticle)
return 1.0
end
@@ -111,20 +152,58 @@ function vertex()
end
"""
- preserve_momentum(p1::Particle, p2::Particle)
+ preserve_momentum(p1::ABCParticle, p2::ABCParticle)
Calculate and return a new particle from two given interacting ones at a vertex.
Takes 4 effective FLOP.
"""
-function preserve_momentum(p1::Particle, p2::Particle)
- p3 = Particle(
- p1.P0 + p2.P0,
- p1.P1 + p2.P1,
- p1.P2 + p2.P2,
- p1.P3 + p2.P3,
- remaining_type(p1.type, p2.type),
- )
+function preserve_momentum(p1::ABCParticle, p2::ABCParticle)
+ t3 = interaction_result(typeof(p1), typeof(p2))
+ p3 = t3(p1.momentum + p2.momentum)
return p3
end
+
+"""
+ type_from_name(name::String)
+
+For a name of a particle, return the particle's [`Type`].
+"""
+function type_from_name(name::String)
+ if startswith(name, "A")
+ return ParticleA
+ elseif startswith(name, "B")
+ return ParticleB
+ elseif startswith(name, "C")
+ return ParticleC
+ else
+ throw("Invalid name for a particle in the ABC model")
+ end
+end
+
+function String(::Type{ParticleA})
+ return "A"
+end
+function String(::Type{ParticleB})
+ return "B"
+end
+function String(::Type{ParticleC})
+ return "C"
+end
+
+function in_particles(process::ABCProcessDescription)
+ return process.inParticles
+end
+
+function in_particles(input::ABCProcessInput)
+ return input.inParticles
+end
+
+function out_particles(process::ABCProcessDescription)
+ return process.outParticles
+end
+
+function out_particles(input::ABCProcessInput)
+ return input.outParticles
+end
diff --git a/src/models/abc/print.jl b/src/models/abc/print.jl
new file mode 100644
index 0000000..26bfb39
--- /dev/null
+++ b/src/models/abc/print.jl
@@ -0,0 +1,58 @@
+
+"""
+ show(io::IO, process::ABCProcessDescription)
+
+Pretty print an [`ABCProcessDescription`](@ref) (no newlines).
+
+```jldoctest
+julia> using MetagraphOptimization
+
+julia> print(parse_process("AB->ABBB", ABCModel()))
+ABC Process: 'AB->ABBB'
+```
+"""
+function show(io::IO, process::ABCProcessDescription)
+ # types() gives the types in order (ABC) instead of random like keys() would
+ print(io, "ABC Process: \'")
+ for type in types(ABCModel())
+ for _ in 1:get(process.inParticles, type, 0)
+ print(io, String(type))
+ end
+ end
+ print(io, "->")
+ for type in types(ABCModel())
+ for _ in 1:get(process.outParticles, type, 0)
+ print(io, String(type))
+ end
+ end
+ print(io, "'")
+ return nothing
+end
+
+"""
+ show(io::IO, processInput::ABCProcessInput)
+
+Pretty print an [`ABCProcessInput`](@ref) (with newlines).
+"""
+function show(io::IO, processInput::ABCProcessInput)
+ println(io, "Input for $(processInput.process):")
+ println(io, " $(length(processInput.inParticles)) Incoming particles:")
+ for particle in processInput.inParticles
+ println(io, " $particle")
+ end
+ println(io, " $(length(processInput.outParticles)) Outgoing Particles:")
+ for particle in processInput.outParticles
+ println(io, " $particle")
+ end
+ return nothing
+end
+
+"""
+ show(io::IO, particle::T) where {T <: ABCParticle}
+
+Pretty print an [`ABCParticle`](@ref) (no newlines).
+"""
+function show(io::IO, particle::T) where {T <: ABCParticle}
+ print(io, "$(String(typeof(particle))): $(particle.momentum)")
+ return nothing
+end
diff --git a/src/models/abc/properties.jl b/src/models/abc/properties.jl
index 8e08e97..e21df0d 100644
--- a/src/models/abc/properties.jl
+++ b/src/models/abc/properties.jl
@@ -57,42 +57,42 @@ end
Print the S1 task to io.
"""
-show(io::IO, t::ComputeTaskS1) = print("ComputeS1")
+show(io::IO, t::ComputeTaskS1) = print(io, "ComputeS1")
"""
show(io::IO, t::ComputeTaskS2)
Print the S2 task to io.
"""
-show(io::IO, t::ComputeTaskS2) = print("ComputeS2")
+show(io::IO, t::ComputeTaskS2) = print(io, "ComputeS2")
"""
show(io::IO, t::ComputeTaskP)
Print the P task to io.
"""
-show(io::IO, t::ComputeTaskP) = print("ComputeP")
+show(io::IO, t::ComputeTaskP) = print(io, "ComputeP")
"""
show(io::IO, t::ComputeTaskU)
Print the U task to io.
"""
-show(io::IO, t::ComputeTaskU) = print("ComputeU")
+show(io::IO, t::ComputeTaskU) = print(io, "ComputeU")
"""
show(io::IO, t::ComputeTaskV)
Print the V task to io.
"""
-show(io::IO, t::ComputeTaskV) = print("ComputeV")
+show(io::IO, t::ComputeTaskV) = print(io, "ComputeV")
"""
show(io::IO, t::ComputeTaskSum)
Print the sum task to io.
"""
-show(io::IO, t::ComputeTaskSum) = print("ComputeSum")
+show(io::IO, t::ComputeTaskSum) = print(io, "ComputeSum")
"""
copy(t::DataTask)
@@ -147,19 +147,20 @@ children(::ComputeTaskV) = 2
"""
children(::ComputeTaskSum)
-Return the number of children of a ComputeTaskSum, since this is variable and the task doesn't know
-how many children it will sum over, return a wildcard -1.
-
-TODO: this is kind of bad because it means we can't fuse with a sum task
+Return the number of children of a ComputeTaskSum.
"""
-children(::ComputeTaskSum) = -1
+children(t::ComputeTaskSum) = t.children_number
"""
children(t::FusedComputeTask)
-Return the number of children of a FusedComputeTask. It's the sum of the children of both tasks minus one.
+Return the number of children of a FusedComputeTask.
"""
function children(t::FusedComputeTask)
- (T1, T2) = get_types(t)
- return children(T1()) + children(T2()) - 1 # one of the inputs is the output of T1 and thus not a child of the node
+ return length(union(Set(t.t1_inputs), Set(t.t2_inputs)))
+end
+
+function add_child!(t::ComputeTaskSum)
+ t.children_number += 1
+ return nothing
end
diff --git a/src/models/abc/types.jl b/src/models/abc/types.jl
index a160128..e9e6ee9 100644
--- a/src/models/abc/types.jl
+++ b/src/models/abc/types.jl
@@ -47,19 +47,13 @@ struct ComputeTaskU <: AbstractComputeTask end
Task that sums all its inputs, n children.
"""
-struct ComputeTaskSum <: AbstractComputeTask end
+mutable struct ComputeTaskSum <: AbstractComputeTask
+ children_number::Int
+end
"""
ABC_TASKS
Constant vector of all tasks of the ABC-Model.
"""
-ABC_TASKS = [
- DataTask,
- ComputeTaskS1,
- ComputeTaskS2,
- ComputeTaskP,
- ComputeTaskV,
- ComputeTaskU,
- ComputeTaskSum,
-]
+ABC_TASKS = [DataTask, ComputeTaskS1, ComputeTaskS2, ComputeTaskP, ComputeTaskV, ComputeTaskU, ComputeTaskSum]
diff --git a/src/models/interface.jl b/src/models/interface.jl
new file mode 100644
index 0000000..dfb2c9f
--- /dev/null
+++ b/src/models/interface.jl
@@ -0,0 +1,109 @@
+
+"""
+ AbstractPhysicsModel
+
+Base type for a model, e.g. ABC-Model or QED. This is used to dispatch many functions.
+"""
+abstract type AbstractPhysicsModel end
+
+"""
+ AbstractParticle
+
+Base type for particles belonging to a certain [`AbstractPhysicsModel`](@ref).
+"""
+abstract type AbstractParticle end
+
+"""
+ ParticleValue{ParticleType <: AbstractParticle}
+
+A struct describing a particle during a calculation of a Feynman Diagram, together with the value that's being calculated.
+
+`sizeof(ParticleValue())` = 48 Byte
+"""
+struct ParticleValue{ParticleType <: AbstractParticle}
+ p::ParticleType
+ v::Float64
+end
+
+"""
+ AbstractProcessDescription
+
+Base type for process descriptions. An object of this type of a corresponding [`AbstractPhysicsModel`](@ref) should uniquely identify a process in that model.
+
+See also: [`parse_process`](@ref)
+"""
+abstract type AbstractProcessDescription end
+
+"""
+ AbstractProcessInput
+
+Base type for process inputs. An object of this type contains the input values (e.g. momenta) of the particles in a process.
+
+See also: [`gen_process_input`](@ref)
+"""
+abstract type AbstractProcessInput end
+
+"""
+ mass(t::Type{T}) where {T <: AbstractParticle}
+
+Interface function that must be implemented for every subtype of [`AbstractParticle`](@ref), returning the particles mass at rest.
+"""
+function mass end
+
+"""
+ interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: AbstractParticle, T2 <: AbstractParticle}
+
+Interface function that must be implemented for every subtype of [`AbstractParticle`](@ref), returning the result particle type when the two given particles interact.
+"""
+function interaction_result end
+
+"""
+ types(::AbstractPhysicsModel)
+
+Interface function that must be implemented for every subtype of [`AbstractPhysicsModel`](@ref), returning a `Vector` of the available particle types in the model.
+"""
+function types end
+
+"""
+ in_particles(::AbstractProcessDescription)
+
+Interface function that must be implemented for every subtype of [`AbstractProcessDescription`](@ref).
+Returns a `<: Dict{Type{AbstractParticle}, Int}` object, representing the number of incoming particles for the process per particle type.
+
+
+ in_particles(::AbstractProcessInput)
+
+Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
+Returns a `<: Vector{AbstractParticle}` object with the values of all incoming particles for the corresponding `ProcessDescription`.
+"""
+function in_particles end
+
+"""
+ out_particles(::AbstractProcessDescription)
+
+Interface function that must be implemented for every subtype of [`AbstractProcessDescription`](@ref).
+Returns a `<: Dict{Type{AbstractParticle}, Int}` object, representing the number of outgoing particles for the process per particle type.
+
+
+ out_particles(::AbstractProcessInput)
+
+Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
+Returns a `<: Vector{AbstractParticle}` object with the values of all outgoing particles for the corresponding `ProcessDescription`.
+"""
+function out_particles end
+
+"""
+ parse_process(::AbstractString, ::AbstractPhysicsModel)
+
+Interface function that must be implemented for every subtype of [`AbstractPhysicsModel`](@ref).
+Returns a `ProcessDescription` object.
+"""
+function parse_process end
+
+"""
+ gen_process_input(::AbstractProcessDescription)
+
+Interface function that must be implemented for every specific [`AbstractProcessDescription`](@ref).
+Returns a randomly generated and valid corresponding `ProcessInput`.
+"""
+function gen_process_input end
diff --git a/src/models/print.jl b/src/models/print.jl
new file mode 100644
index 0000000..00b8489
--- /dev/null
+++ b/src/models/print.jl
@@ -0,0 +1,10 @@
+
+"""
+ show(io::IO, particleValue::ParticleValue)
+
+Pretty print a [`ParticleValue`](@ref), no newlines.
+"""
+function show(io::IO, particleValue::ParticleValue)
+ print(io, "($(particleValue.p), value: $(particleValue.v))")
+ return nothing
+end
diff --git a/src/node/create.jl b/src/node/create.jl
index 1d28169..84a8501 100644
--- a/src/node/create.jl
+++ b/src/node/create.jl
@@ -1,44 +1,20 @@
-DataTaskNode(t::AbstractDataTask, name = "") = DataTaskNode(
- t,
- Vector{Node}(),
- Vector{Node}(),
- UUIDs.uuid1(rng[threadid()]),
- missing,
- missing,
- missing,
- name,
-)
+DataTaskNode(t::AbstractDataTask, name = "") =
+ DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, missing, name)
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
- t,
- Vector{Node}(),
- Vector{Node}(),
- UUIDs.uuid1(rng[threadid()]),
- missing,
- missing,
- Vector{NodeFusion}(),
+ t, # task
+ Vector{Node}(), # parents
+ Vector{Node}(), # children
+ UUIDs.uuid1(rng[threadid()]), # id
+ missing, # node reduction
+ missing, # node split
+ Vector{NodeFusion}(), # node fusions
+ missing, # device
)
copy(m::Missing) = missing
-copy(n::ComputeTaskNode) = ComputeTaskNode(
- copy(n.task),
- copy(n.parents),
- copy(n.children),
- UUIDs.uuid1(rng[threadid()]),
- copy(n.nodeReduction),
- copy(n.nodeSplit),
- copy(n.nodeFusions),
-)
-copy(n::DataTaskNode) = DataTaskNode(
- copy(n.task),
- copy(n.parents),
- copy(n.children),
- UUIDs.uuid1(rng[threadid()]),
- copy(n.nodeReduction),
- copy(n.nodeSplit),
- copy(n.nodeFusion),
- n.name,
-)
+copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task))
+copy(n::DataTaskNode) = DataTaskNode(copy(n.task), n.name)
"""
make_node(t::AbstractTask)
diff --git a/src/node/print.jl b/src/node/print.jl
index c39c1b5..61200a9 100644
--- a/src/node/print.jl
+++ b/src/node/print.jl
@@ -22,5 +22,6 @@ end
Return the uuid as a string usable as a variable name in code generation.
"""
function to_var_name(id::UUID)
- return replace(string(id), "-" => "_")
+ str = "_" * replace(string(id), "-" => "_")
+ return str
end
diff --git a/src/node/type.jl b/src/node/type.jl
index 06bf308..980962a 100644
--- a/src/node/type.jl
+++ b/src/node/type.jl
@@ -24,13 +24,14 @@ abstract type Operation end
Any node that transfers data and does no computation.
# Fields
-`.task`: The node's data task type. Usually [`DataTask`](@ref).\\
-`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
-`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
-`.id`: The node's id. Improves the speed of comparisons.\\
-`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
-`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
-`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.
+`.task`: The node's data task type. Usually [`DataTask`](@ref).\\
+`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
+`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
+`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
+`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
+`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
+`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.\\
+`.name`: The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\
"""
mutable struct DataTaskNode <: Node
task::AbstractDataTask
@@ -60,16 +61,17 @@ end
"""
ComputeTaskNode <: Node
-Any node that transfers data and does no computation.
+Any node that computes a result from inputs using an [`AbstractComputeTask`](@ref).
# Fields
-`.task`: The node's data task type. Usually [`DataTask`](@ref).\\
-`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
-`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
-`.id`: The node's id. Improves the speed of comparisons.\\
-`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
-`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
-`.nodeFusion`: A vector of this node's [`NodeFusion`](@ref)s. For a ComputeTaskNode there can be any number of these, unlike the DataTaskNodes.
+`.task`: The node's compute task type. A concrete subtype of [`AbstractComputeTask`](@ref).\\
+`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
+`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
+`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
+`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
+`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
+`.nodeFusions`: A vector of this node's [`NodeFusion`](@ref)s. For a `ComputeTaskNode` there can be any number of these, unlike the [`DataTaskNode`](@ref)s.\\
+`.device`: The Device this node has been scheduled on by a [`Scheduler`](@ref).
"""
mutable struct ComputeTaskNode <: Node
task::AbstractComputeTask
@@ -82,6 +84,9 @@ mutable struct ComputeTaskNode <: Node
# for ComputeTasks there can be multiple fusions, unlike the DataTasks
nodeFusions::Vector{Operation}
+
+ # the device this node is assigned to execute on
+ device::Union{AbstractDevice, Missing}
end
"""
@@ -95,8 +100,5 @@ The child is the prerequisite node of the parent.
"""
struct Edge
# edge points from child to parent
- edge::Union{
- Tuple{DataTaskNode, ComputeTaskNode},
- Tuple{ComputeTaskNode, DataTaskNode},
- }
+ edge::Union{Tuple{DataTaskNode, ComputeTaskNode}, Tuple{ComputeTaskNode, DataTaskNode}}
end
diff --git a/src/node/validate.jl b/src/node/validate.jl
index a16df17..d7ad4dd 100644
--- a/src/node/validate.jl
+++ b/src/node/validate.jl
@@ -22,12 +22,24 @@ function is_valid_node(graph::DAG, node::Node)
@assert node in child.parents "Node is not a parent of its child!"
end
- if !ismissing(node.nodeReduction)
+ #=if !ismissing(node.nodeReduction)
@assert is_valid(graph, node.nodeReduction)
end
if !ismissing(node.nodeSplit)
@assert is_valid(graph, node.nodeSplit)
+ end=#
+
+ if !(typeof(node.task) <: FusedComputeTask)
+ # the remaining checks are only necessary for fused compute tasks
+ return true
end
+
+ # every child must be in some input of the task
+ for child in node.children
+ str = Symbol(to_var_name(child.id))
+ @assert (str in node.task.t1_inputs) || (str in node.task.t2_inputs) "$str was not in any of the tasks' inputs\nt1_inputs: $(node.task.t1_inputs)\nt2_inputs: $(node.task.t2_inputs)"
+ end
+
return true
end
@@ -41,9 +53,9 @@ This also calls [`is_valid_node(graph::DAG, node::Node)`](@ref).
function is_valid(graph::DAG, node::ComputeTaskNode)
@assert is_valid_node(graph, node)
- for nf in node.nodeFusions
+ #=for nf in node.nodeFusions
@assert is_valid(graph, nf)
- end
+ end=#
return true
end
@@ -57,8 +69,8 @@ This also calls [`is_valid_node(graph::DAG, node::Node)`](@ref).
function is_valid(graph::DAG, node::DataTaskNode)
@assert is_valid_node(graph, node)
- if !ismissing(node.nodeFusion)
+ #=if !ismissing(node.nodeFusion)
@assert is_valid(graph, node.nodeFusion)
- end
+ end=#
return true
end
diff --git a/src/operation/apply.jl b/src/operation/apply.jl
index 2974289..dfe9a0b 100644
--- a/src/operation/apply.jl
+++ b/src/operation/apply.jl
@@ -34,12 +34,7 @@ Apply the given [`NodeFusion`](@ref) to the graph. Generic wrapper around [`node
Return an [`AppliedNodeFusion`](@ref) object generated from the graph's [`Diff`](@ref).
"""
function apply_operation!(graph::DAG, operation::NodeFusion)
- diff = node_fusion!(
- graph,
- operation.input[1],
- operation.input[2],
- operation.input[3],
- )
+ diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3])
graph.properties += GraphProperties(diff)
@@ -124,17 +119,24 @@ function revert_diff!(graph::DAG, diff::Diff)
# add removed nodes, remove added nodes, same for edges
# note the order
for edge in diff.addedEdges
- remove_edge!(graph, edge.edge[1], edge.edge[2], false)
+ remove_edge!(graph, edge.edge[1], edge.edge[2], track = false)
end
for node in diff.addedNodes
- remove_node!(graph, node, false)
+ remove_node!(graph, node, track = false)
end
for node in diff.removedNodes
- insert_node!(graph, node, false)
+ insert_node!(graph, node, track = false)
end
for edge in diff.removedEdges
- insert_edge!(graph, edge.edge[1], edge.edge[2], false)
+ insert_edge!(graph, edge.edge[1], edge.edge[2], track = false)
+ end
+
+ for (node, task) in diff.updatedChildren
+ # node must be fused compute task at this point
+ @assert typeof(node.task) <: FusedComputeTask
+
+ node.task = task
end
graph.properties -= GraphProperties(diff)
@@ -149,21 +151,24 @@ Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference
For details see [`NodeFusion`](@ref).
"""
-function node_fusion!(
- graph::DAG,
- n1::ComputeTaskNode,
- n2::DataTaskNode,
- n3::ComputeTaskNode,
-)
- # @assert is_valid_node_fusion_input(graph, n1, n2, n3)
+function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
+ @assert is_valid_node_fusion_input(graph, n1, n2, n3)
# clear snapshot
get_snapshot_diff(graph)
# save children and parents
- n1_children = children(n1)
- n3_parents = parents(n3)
- n3_children = children(n3)
+ n1Children = children(n1)
+ n3Parents = parents(n3)
+
+ n1Task = copy(n1.task)
+ n3Task = copy(n3.task)
+
+ # assemble the input node vectors of n1 and n3 to save into the FusedComputeTask
+ n1Inputs = Vector{Symbol}()
+ for child in n1Children
+ push!(n1Inputs, Symbol(to_var_name(child.id)))
+ end
# remove the edges and nodes that will be replaced by the fused node
remove_edge!(graph, n1, n2)
@@ -172,29 +177,38 @@ function node_fusion!(
remove_node!(graph, n2)
# get n3's children now so it automatically excludes n2
- n3_children = children(n3)
+ n3Children = children(n3)
+
+ n3Inputs = Vector{Symbol}()
+ for child in n3Children
+ push!(n3Inputs, Symbol(to_var_name(child.id)))
+ end
+
remove_node!(graph, n3)
# create new node with the fused compute task
- new_node =
- ComputeTaskNode(FusedComputeTask{typeof(n1.task), typeof(n3.task)}())
- insert_node!(graph, new_node)
+ newNode = ComputeTaskNode(FusedComputeTask(n1Task, n3Task, n1Inputs, Symbol(to_var_name(n2.id)), n3Inputs))
+ insert_node!(graph, newNode)
- for child in n1_children
+ for child in n1Children
remove_edge!(graph, child, n1)
- insert_edge!(graph, child, new_node)
+ insert_edge!(graph, child, newNode)
end
- for child in n3_children
+ for child in n3Children
remove_edge!(graph, child, n3)
- if !(child in n1_children)
- insert_edge!(graph, child, new_node)
+ if !(child in n1Children)
+ insert_edge!(graph, child, newNode)
end
end
- for parent in n3_parents
+ for parent in n3Parents
remove_edge!(graph, n3, parent)
- insert_edge!(graph, new_node, parent)
+ insert_edge!(graph, newNode, parent)
+
+ # important! update the parent node's child names in case they are fused compute tasks
+ # needed for compute generation so the fused compute task can correctly match inputs to its component tasks
+ update_child!(graph, parent, Symbol(to_var_name(n3.id)), Symbol(to_var_name(newNode.id)))
end
return get_snapshot_diff(graph)
@@ -208,21 +222,26 @@ Reduce the given nodes together into one node, return the applied difference to
For details see [`NodeReduction`](@ref).
"""
function node_reduction!(graph::DAG, nodes::Vector{Node})
- # @assert is_valid_node_reduction_input(graph, nodes)
+ @assert is_valid_node_reduction_input(graph, nodes)
# clear snapshot
get_snapshot_diff(graph)
n1 = nodes[1]
- n1_children = children(n1)
+ n1Children = children(n1)
- n1_parents = Set(n1.parents)
- new_parents = Set{Node}()
+ n1Parents = Set(n1.parents)
+
+ # set of the new parents of n1
+ newParents = Set{Node}()
+
+ # names of the previous children that n1 now replaces per parent
+ newParentsChildNames = Dict{Node, Symbol}()
# remove all of the nodes' parents and children and the nodes themselves (except for first node)
for i in 2:length(nodes)
n = nodes[i]
- for child in n1_children
+ for child in n1Children
remove_edge!(graph, child, n)
end
@@ -230,17 +249,23 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
remove_edge!(graph, n, parent)
# collect all parents
- push!(new_parents, parent)
+ push!(newParents, parent)
+ newParentsChildNames[parent] = Symbol(to_var_name(n.id))
end
remove_node!(graph, n)
end
- setdiff!(new_parents, n1_parents)
-
- for parent in new_parents
+ for parent in newParents
# now add parents of all input nodes to n1 without duplicates
- insert_edge!(graph, n1, parent)
+ if !(parent in n1Parents)
+ # don't double insert edges
+ insert_edge!(graph, n1, parent)
+ end
+
+ # this has to be done for all parents, even the ones of n1 because they can be duplicate
+ prevChild = newParentsChildNames[parent]
+ update_child!(graph, parent, prevChild, Symbol(to_var_name(n1.id)))
end
return get_snapshot_diff(graph)
@@ -254,30 +279,33 @@ Split the given node into one node per parent, return the applied difference to
For details see [`NodeSplit`](@ref).
"""
function node_split!(graph::DAG, n1::Node)
- # @assert is_valid_node_split_input(graph, n1)
+ @assert is_valid_node_split_input(graph, n1)
# clear snapshot
get_snapshot_diff(graph)
- n1_parents = parents(n1)
- n1_children = children(n1)
+ n1Parents = parents(n1)
+ n1Children = children(n1)
- for parent in n1_parents
+ for parent in n1Parents
remove_edge!(graph, n1, parent)
end
- for child in n1_children
+ for child in n1Children
remove_edge!(graph, child, n1)
end
remove_node!(graph, n1)
- for parent in n1_parents
- n_copy = copy(n1)
- insert_node!(graph, n_copy)
- insert_edge!(graph, n_copy, parent)
+ for parent in n1Parents
+ nCopy = copy(n1)
- for child in n1_children
- insert_edge!(graph, child, n_copy)
+ insert_node!(graph, nCopy)
+ insert_edge!(graph, nCopy, parent)
+
+ for child in n1Children
+ insert_edge!(graph, child, nCopy)
end
+
+ update_child!(graph, parent, Symbol(to_var_name(n1.id)), Symbol(to_var_name(nCopy.id)))
end
return get_snapshot_diff(graph)
diff --git a/src/operation/find.jl b/src/operation/find.jl
index 89acc3a..f6d6218 100644
--- a/src/operation/find.jl
+++ b/src/operation/find.jl
@@ -7,10 +7,7 @@ using Base.Threads
Insert the given node fusion into its input nodes' operation caches. For the compute nodes, locking via the given `locks` is employed to have safe multi-threading. For a large set of nodes, contention on the locks should be very small.
"""
-function insert_operation!(
- nf::NodeFusion,
- locks::Dict{ComputeTaskNode, SpinLock},
-)
+function insert_operation!(nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock})
n1 = nf.input[1]
n2 = nf.input[2]
n3 = nf.input[3]
@@ -52,10 +49,7 @@ end
Insert the node reductions into the graph and the nodes' caches. Employs multithreading for speedup.
"""
-function nr_insertion!(
- operations::PossibleOperations,
- nodeReductions::Vector{Vector{NodeReduction}},
-)
+function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}})
total_len = 0
for vec in nodeReductions
total_len += length(vec)
@@ -83,11 +77,7 @@ end
Insert the node fusions into the graph and the nodes' caches. Employs multithreading for speedup.
"""
-function nf_insertion!(
- graph::DAG,
- operations::PossibleOperations,
- nodeFusions::Vector{Vector{NodeFusion}},
-)
+function nf_insertion!(graph::DAG, operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}})
total_len = 0
for vec in nodeFusions
total_len += length(vec)
@@ -122,10 +112,7 @@ end
Insert the node splits into the graph and the nodes' caches. Employs multithreading for speedup.
"""
-function ns_insertion!(
- operations::PossibleOperations,
- nodeSplits::Vector{Vector{NodeSplit}},
-)
+function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}})
total_len = 0
for vec in nodeSplits
total_len += length(vec)
@@ -231,16 +218,12 @@ function generate_operations(graph::DAG)
continue
end
- push!(
- generatedFusions[threadid()],
- NodeFusion((child_node, node, parent_node)),
- )
+ push!(generatedFusions[threadid()], NodeFusion((child_node, node, parent_node)))
end
end
# launch thread for node fusion insertion
- nf_task =
- @task nf_insertion!(graph, graph.possibleOperations, generatedFusions)
+ nf_task = @task nf_insertion!(graph, graph.possibleOperations, generatedFusions)
schedule(nf_task)
# find possible node splits
diff --git a/src/operation/utility.jl b/src/operation/utility.jl
index 2c1bae5..b7f874a 100644
--- a/src/operation/utility.jl
+++ b/src/operation/utility.jl
@@ -4,9 +4,7 @@
Return whether `operations` is empty, i.e. all of its fields are empty.
"""
function isempty(operations::PossibleOperations)
- return isempty(operations.nodeFusions) &&
- isempty(operations.nodeReductions) &&
- isempty(operations.nodeSplits)
+ return isempty(operations.nodeFusions) && isempty(operations.nodeReductions) && isempty(operations.nodeSplits)
end
"""
@@ -63,9 +61,7 @@ function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
return false
end
- if length(n2.parents) != 1 ||
- length(n2.children) != 1 ||
- length(n1.parents) != 1
+ if length(n2.parents) != 1 || length(n2.children) != 1 || length(n1.parents) != 1
return false
end
diff --git a/src/operation/validate.jl b/src/operation/validate.jl
index 5d41e87..0fe3218 100644
--- a/src/operation/validate.jl
+++ b/src/operation/validate.jl
@@ -9,24 +9,12 @@ Assert for a gven node fusion input whether the nodes can be fused. For the requ
Intended for use with `@assert` or `@test`.
"""
-function is_valid_node_fusion_input(
- graph::DAG,
- n1::ComputeTaskNode,
- n2::DataTaskNode,
- n3::ComputeTaskNode,
-)
+function is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
if !(n1 in graph) || !(n2 in graph) || !(n3 in graph)
- throw(
- AssertionError(
- "[Node Fusion] The given nodes are not part of the given graph",
- ),
- )
+ throw(AssertionError("[Node Fusion] The given nodes are not part of the given graph"))
end
- if !is_child(n1, n2) ||
- !is_child(n2, n3) ||
- !is_parent(n3, n2) ||
- !is_parent(n2, n1)
+ if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
throw(
AssertionError(
"[Node Fusion] The given nodes are not connected by edges which is required for node fusion",
@@ -35,27 +23,19 @@ function is_valid_node_fusion_input(
end
if length(n2.parents) > 1
- throw(
- AssertionError(
- "[Node Fusion] The given data node has more than one parent",
- ),
- )
+ throw(AssertionError("[Node Fusion] The given data node has more than one parent"))
end
if length(n2.children) > 1
- throw(
- AssertionError(
- "[Node Fusion] The given data node has more than one child",
- ),
- )
+ throw(AssertionError("[Node Fusion] The given data node has more than one child"))
end
if length(n1.parents) > 1
- throw(
- AssertionError(
- "[Node Fusion] The given n1 has more than one parent",
- ),
- )
+ throw(AssertionError("[Node Fusion] The given n1 has more than one parent"))
end
+ @assert is_valid(graph, n1)
+ @assert is_valid(graph, n2)
+ @assert is_valid(graph, n3)
+
return true
end
@@ -69,22 +49,21 @@ Intended for use with `@assert` or `@test`.
function is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node})
for n in nodes
if n ∉ graph
- throw(
- AssertionError(
- "[Node Reduction] The given nodes are not part of the given graph",
- ),
- )
+ throw(AssertionError("[Node Reduction] The given nodes are not part of the given graph"))
end
+ @assert is_valid(graph, n)
end
t = typeof(nodes[1].task)
for n in nodes
if typeof(n.task) != t
- throw(
- AssertionError(
- "[Node Reduction] The given nodes are not of the same type",
- ),
- )
+ throw(AssertionError("[Node Reduction] The given nodes are not of the same type"))
+ end
+
+ if (typeof(n) <: DataTaskNode)
+ if (n.name != nodes[1].name)
+ throw(AssertionError("[Node Reduction] The given nodes do not have the same name"))
+ end
end
end
@@ -111,11 +90,7 @@ Intended for use with `@assert` or `@test`.
"""
function is_valid_node_split_input(graph::DAG, n1::Node)
if n1 ∉ graph
- throw(
- AssertionError(
- "[Node Split] The given node is not part of the given graph",
- ),
- )
+ throw(AssertionError("[Node Split] The given node is not part of the given graph"))
end
if length(n1.parents) <= 1
@@ -126,6 +101,8 @@ function is_valid_node_split_input(graph::DAG, n1::Node)
)
end
+ @assert is_valid(graph, n1)
+
return true
end
@@ -163,12 +140,7 @@ Assert for a given [`NodeFusion`](@ref) whether it is a valid operation in the g
Intended for use with `@assert` or `@test`.
"""
function is_valid(graph::DAG, nf::NodeFusion)
- @assert is_valid_node_fusion_input(
- graph,
- nf.input[1],
- nf.input[2],
- nf.input[3],
- )
+ @assert is_valid_node_fusion_input(graph, nf.input[1], nf.input[2], nf.input[3])
@assert nf in graph.possibleOperations.nodeFusions "NodeFusion is not part of the graph's possible operations!"
return true
end
diff --git a/src/properties/utility.jl b/src/properties/utility.jl
index bf936db..3aa9def 100644
--- a/src/properties/utility.jl
+++ b/src/properties/utility.jl
@@ -11,8 +11,7 @@ function -(prop1::GraphProperties, prop2::GraphProperties)
computeIntensity = if (prop1.data - prop2.data == 0)
0.0
else
- (prop1.computeEffort - prop2.computeEffort) /
- (prop1.data - prop2.data)
+ (prop1.computeEffort - prop2.computeEffort) / (prop1.data - prop2.data)
end,
cost = prop1.cost - prop2.cost,
noNodes = prop1.noNodes - prop2.noNodes,
@@ -33,8 +32,7 @@ function +(prop1::GraphProperties, prop2::GraphProperties)
computeIntensity = if (prop1.data + prop2.data == 0)
0.0
else
- (prop1.computeEffort + prop2.computeEffort) /
- (prop1.data + prop2.data)
+ (prop1.computeEffort + prop2.computeEffort) / (prop1.data + prop2.data)
end,
cost = prop1.cost + prop2.cost,
noNodes = prop1.noNodes + prop2.noNodes,
diff --git a/src/scheduler/greedy.jl b/src/scheduler/greedy.jl
new file mode 100644
index 0000000..7ab77e9
--- /dev/null
+++ b/src/scheduler/greedy.jl
@@ -0,0 +1,50 @@
+
+"""
+ GreedyScheduler
+
+A greedy implementation of a scheduler, creating a topological ordering of nodes and naively balancing them onto the different devices.
+"""
+struct GreedyScheduler end
+
+function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine)
+ nodeQueue = PriorityQueue{Node, Int}()
+
+ # use a priority equal to the number of unseen children -> 0 are nodes that can be added
+ for node in get_entry_nodes(graph)
+ enqueue!(nodeQueue, node => 0)
+ end
+
+ schedule = Vector{Node}()
+ sizehint!(schedule, length(graph.nodes))
+
+ # keep an accumulated cost of things scheduled to this device so far
+ deviceAccCost = PriorityQueue{AbstractDevice, Int}()
+ for device in machine.devices
+ enqueue!(deviceAccCost, device => 0)
+ end
+
+ node = nothing
+ while !isempty(nodeQueue)
+ @assert peek(nodeQueue)[2] == 0
+ node = dequeue!(nodeQueue)
+
+ # assign the device with lowest accumulated cost to the node (if it's a compute node)
+ if (isa(node, ComputeTaskNode))
+ lowestDevice = peek(deviceAccCost)[1]
+ node.device = lowestDevice
+ deviceAccCost[lowestDevice] = compute_effort(node.task)
+ end
+
+ push!(schedule, node)
+ for parent in node.parents
+ # reduce the priority of all parents by one
+ if (!haskey(nodeQueue, parent))
+ enqueue!(nodeQueue, parent => length(parent.children) - 1)
+ else
+ nodeQueue[parent] = nodeQueue[parent] - 1
+ end
+ end
+ end
+
+ return schedule
+end
diff --git a/src/scheduler/interface.jl b/src/scheduler/interface.jl
new file mode 100644
index 0000000..8d5f6cd
--- /dev/null
+++ b/src/scheduler/interface.jl
@@ -0,0 +1,18 @@
+
+"""
+ Scheduler
+
+Abstract base type for scheduler implementations. The scheduler is used to assign each node to a device and create a topological ordering of tasks.
+"""
+abstract type Scheduler end
+
+"""
+ schedule_dag(::Scheduler, ::DAG, ::Machine)
+
+Interface functions that must be implemented for implementations of [`Scheduler`](@ref).
+
+The function assigns each [`ComputeTaskNode`](@ref) of the [`DAG`](@ref) to one of the devices in the given [`Machine`](@ref) and returns a `Vector{Node}` representing a topological ordering.
+
+[`DataTaskNode`](@ref)s are not scheduled to devices since they do not compute. Instead, a data node transfers data from the [`AbstractDevice`](@ref) of their child to all [`AbstractDevice`](@ref)s of its parents.
+"""
+function schedule_dag end
diff --git a/src/task/compute.jl b/src/task/compute.jl
new file mode 100644
index 0000000..beb4e52
--- /dev/null
+++ b/src/task/compute.jl
@@ -0,0 +1,89 @@
+
+"""
+ compute(t::FusedComputeTask, data)
+
+Compute a [`FusedComputeTask`](@ref). This simply asserts false and should not be called. Fused Compute Tasks generate their expressions directly through the other tasks instead.
+"""
+function compute(t::FusedComputeTask, data)
+ @assert false "This is not implemented and should never be called"
+end
+
+"""
+ get_expression(t::FusedComputeTask, device::AbstractDevice, inExprs::Vector{String}, outExpr::String)
+
+Generate code evaluating a [`FusedComputeTask`](@ref) on `inExprs`, providing the output on `outExpr`.
+`inExprs` should be of the correct types and may be heterogeneous. `outExpr` will be of the type of the output of `T2` of t.
+"""
+function get_expression(t::FusedComputeTask, device::AbstractDevice, inExprs::Vector, outExpr)
+ inExprs1 = Vector()
+ for sym in t.t1_inputs
+ push!(inExprs1, gen_access_expr(device, sym))
+ end
+
+ outExpr1 = gen_access_expr(device, t.t1_output)
+
+ inExprs2 = Vector()
+ for sym in t.t2_inputs
+ push!(inExprs2, gen_access_expr(device, sym))
+ end
+
+ expr1 = get_expression(t.first_task, device, inExprs1, outExpr1)
+ expr2 = get_expression(t.second_task, device, [inExprs2..., outExpr1], outExpr)
+
+ full_expr = Expr(:block, expr1, expr2)
+
+ return full_expr
+end
+
+"""
+ get_expression(node::ComputeTaskNode)
+
+Generate and return code for a given [`ComputeTaskNode`](@ref).
+"""
+function get_expression(node::ComputeTaskNode)
+ @assert length(node.children) <= children(node.task) "Node $(node) has too many children for its task: node has $(length(node.children)) versus task has $(children(node.task))\nNode's children: $(getfield.(node.children, :children))"
+ @assert !ismissing(node.device) "Trying to get expression for an unscheduled ComputeTaskNode\nNode: $(node)"
+
+ inExprs = Vector()
+ for id in getfield.(node.children, :id)
+ push!(inExprs, gen_access_expr(node.device, Symbol(to_var_name(id))))
+ end
+ outExpr = gen_access_expr(node.device, Symbol(to_var_name(node.id)))
+
+ return get_expression(node.task, node.device, inExprs, outExpr)
+end
+
+"""
+ get_expression(node::DataTaskNode)
+
+Generate and return code for a given [`DataTaskNode`](@ref).
+"""
+function get_expression(node::DataTaskNode)
+ @assert length(node.children) == 1 "Trying to call get_expression on a data task node that has $(length(node.children)) children instead of 1"
+
+ # TODO: dispatch to device implementations generating the copy commands
+
+ child = node.children[1]
+ inExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(child.id))))
+ outExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(node.id))))
+ dataTransportExp = Meta.parse("$outExpr = $inExpr")
+
+ return dataTransportExp
+end
+
+"""
+ get_init_expression(node::DataTaskNode, device::AbstractDevice)
+
+Generate and return code for the initial input reading expression for [`DataTaskNode`](@ref)s with 0 children, i.e., entry nodes.
+
+See also: [`get_entry_nodes`](@ref)
+"""
+function get_init_expression(node::DataTaskNode, device::AbstractDevice)
+ @assert isempty(node.children) "Trying to call get_init_expression on a data task node that is not an entry node."
+
+ inExpr = eval(gen_access_expr(device, Symbol("$(to_var_name(node.id))_in")))
+ outExpr = eval(gen_access_expr(device, Symbol(to_var_name(node.id))))
+ dataTransportExp = Meta.parse("$outExpr = $inExpr")
+
+ return dataTransportExp
+end
diff --git a/src/task/create.jl b/src/task/create.jl
index 9124984..81dc564 100644
--- a/src/task/create.jl
+++ b/src/task/create.jl
@@ -3,8 +3,7 @@
Fallback implementation of the copy of an abstract data task, throwing an error.
"""
-copy(t::AbstractDataTask) =
- error("Need to implement copying for your data tasks!")
+copy(t::AbstractDataTask) = error("Need to implement copying for your data tasks!")
"""
copy(t::AbstractComputeTask)
@@ -12,3 +11,21 @@ copy(t::AbstractDataTask) =
Return a copy of the given compute task.
"""
copy(t::AbstractComputeTask) = typeof(t)()
+
+"""
+ copy(t::FusedComputeTask)
+
+Return a copy of th egiven [`FusedComputeTask`](@ref).
+"""
+function copy(t::FusedComputeTask{T1, T2}) where {T1, T2}
+ return FusedComputeTask{T1, T2}(
+ copy(t.first_task),
+ copy(t.second_task),
+ copy(t.t1_inputs),
+ t.t1_output,
+ copy(t.t2_inputs),
+ )
+end
+
+FusedComputeTask{T1, T2}(t1_inputs::Vector{String}, t1_output::String, t2_inputs::Vector{String}) where {T1, T2} =
+ FusedComputeTask{T1, T2}(T1(), T2(), t1_inputs, t1_output, t2_inputs)
diff --git a/src/task/print.jl b/src/task/print.jl
index 5909b5f..5dcd9ac 100644
--- a/src/task/print.jl
+++ b/src/task/print.jl
@@ -4,6 +4,5 @@
Print a string representation of the fused compute task to io.
"""
function show(io::IO, t::FusedComputeTask)
- (T1, T2) = get_types(t)
- return print(io, "ComputeFuse(", T1(), ", ", T2(), ")")
+ return print(io, "ComputeFuse($(t.first_task), $(t.second_task))")
end
diff --git a/src/task/properties.jl b/src/task/properties.jl
index 4b1d889..9affe0a 100644
--- a/src/task/properties.jl
+++ b/src/task/properties.jl
@@ -71,8 +71,7 @@ data(t::AbstractComputeTask) = 0
Return the compute effort of a fused compute task.
"""
function compute_effort(t::FusedComputeTask)
- (T1, T2) = collect(typeof(t).parameters)
- return compute_effort(T1()) + compute_effort(T2())
+ return compute_effort(t.first_task) + compute_effort(t.second_task)
end
"""
@@ -81,30 +80,3 @@ end
Return a tuple of a the fused compute task's components' types.
"""
get_types(::FusedComputeTask{T1, T2}) where {T1, T2} = (T1, T2)
-
-"""
- get_expression(t::AbstractTask)
-
-Return an expression evaluating the given task on the :dataIn symbol
-"""
-function get_expression(t::AbstractTask)
- return quote
- dataOut = compute($t, dataIn)
- end
-end
-
-"""
- get_expression()
-"""
-function get_expression(
- t::FusedComputeTask,
- inSymbol::Symbol,
- outSymbol::Symbol,
-)
- #TODO
- computeExp = quote
- $outSymbol = compute($t, $inSymbol)
- end
-
- return computeExp
-end
diff --git a/src/task/type.jl b/src/task/type.jl
index 5d5e329..8f9dfe1 100644
--- a/src/task/type.jl
+++ b/src/task/type.jl
@@ -26,5 +26,13 @@ A fused compute task made up of the computation of first `T1` and then `T2`.
Also see: [`get_types`](@ref).
"""
-struct FusedComputeTask{T1 <: AbstractComputeTask, T2 <: AbstractComputeTask} <:
- AbstractComputeTask end
+struct FusedComputeTask{T1 <: AbstractComputeTask, T2 <: AbstractComputeTask} <: AbstractComputeTask
+ first_task::T1
+ second_task::T2
+ # the names of the inputs for T1
+ t1_inputs::Vector{Symbol}
+ # output name of T1
+ t1_output::Symbol
+ # t2_inputs doesn't include the output of t1, that's implicit
+ t2_inputs::Vector{Symbol}
+end
diff --git a/src/utility.jl b/src/utility.jl
index 2d4b39e..3760690 100644
--- a/src/utility.jl
+++ b/src/utility.jl
@@ -87,3 +87,19 @@ Return the memory footprint of the node in Byte. Used in [`mem(graph::DAG)`](@re
function mem(node::Node)
return Base.summarysize(node, exclude = Union{Node, Operation})
end
+
+"""
+ unroll_symbol_vector(vec::Vector{Symbol})
+
+Return the given vector as single String without quotation marks or brackets.
+"""
+function unroll_symbol_vector(vec::Vector)
+ result = ""
+ for s in vec
+ if (result != "")
+ result *= ", "
+ end
+ result *= "$s"
+ end
+ return result
+end
diff --git a/test/Project.toml b/test/Project.toml
index 7a21f89..fbcc5de 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -1,3 +1,4 @@
[deps]
+QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
diff --git a/test/known_graphs.jl b/test/known_graphs.jl
index de81c12..6afbf9f 100644
--- a/test/known_graphs.jl
+++ b/test/known_graphs.jl
@@ -2,7 +2,7 @@ using Random
function test_known_graph(name::String, n, fusion_test = true)
@testset "Test $name Graph ($n)" begin
- graph = parse_abc(joinpath(@__DIR__, "..", "input", "$name.txt"))
+ graph = parse_dag(joinpath(@__DIR__, "..", "input", "$name.txt"), ABCModel())
props = get_properties(graph)
if (fusion_test)
diff --git a/test/node_reduction.jl b/test/node_reduction.jl
index 47c6255..5063e8e 100644
--- a/test/node_reduction.jl
+++ b/test/node_reduction.jl
@@ -5,51 +5,51 @@ import MetagraphOptimization.make_node
@testset "Unit Tests Node Reduction" begin
graph = MetagraphOptimization.DAG()
- d_exit = insert_node!(graph, make_node(DataTask(10)), false)
+ d_exit = insert_node!(graph, make_node(DataTask(10)), track = false)
- s0 = insert_node!(graph, make_node(ComputeTaskS2()), false)
+ s0 = insert_node!(graph, make_node(ComputeTaskS2()), track = false)
- ED = insert_node!(graph, make_node(DataTask(3)), false)
- FD = insert_node!(graph, make_node(DataTask(3)), false)
+ ED = insert_node!(graph, make_node(DataTask(3)), track = false)
+ FD = insert_node!(graph, make_node(DataTask(3)), track = false)
- EC = insert_node!(graph, make_node(ComputeTaskV()), false)
- FC = insert_node!(graph, make_node(ComputeTaskV()), false)
+ EC = insert_node!(graph, make_node(ComputeTaskV()), track = false)
+ FC = insert_node!(graph, make_node(ComputeTaskV()), track = false)
- A1D = insert_node!(graph, make_node(DataTask(4)), false)
- B1D_1 = insert_node!(graph, make_node(DataTask(4)), false)
- B1D_2 = insert_node!(graph, make_node(DataTask(4)), false)
- C1D = insert_node!(graph, make_node(DataTask(4)), false)
+ A1D = insert_node!(graph, make_node(DataTask(4)), track = false)
+ B1D_1 = insert_node!(graph, make_node(DataTask(4)), track = false)
+ B1D_2 = insert_node!(graph, make_node(DataTask(4)), track = false)
+ C1D = insert_node!(graph, make_node(DataTask(4)), track = false)
- A1C = insert_node!(graph, make_node(ComputeTaskU()), false)
- B1C_1 = insert_node!(graph, make_node(ComputeTaskU()), false)
- B1C_2 = insert_node!(graph, make_node(ComputeTaskU()), false)
- C1C = insert_node!(graph, make_node(ComputeTaskU()), false)
+ A1C = insert_node!(graph, make_node(ComputeTaskU()), track = false)
+ B1C_1 = insert_node!(graph, make_node(ComputeTaskU()), track = false)
+ B1C_2 = insert_node!(graph, make_node(ComputeTaskU()), track = false)
+ C1C = insert_node!(graph, make_node(ComputeTaskU()), track = false)
- AD = insert_node!(graph, make_node(DataTask(5)), false)
- BD = insert_node!(graph, make_node(DataTask(5)), false)
- CD = insert_node!(graph, make_node(DataTask(5)), false)
+ AD = insert_node!(graph, make_node(DataTask(5)), track = false)
+ BD = insert_node!(graph, make_node(DataTask(5)), track = false)
+ CD = insert_node!(graph, make_node(DataTask(5)), track = false)
- insert_edge!(graph, s0, d_exit, false)
- insert_edge!(graph, ED, s0, false)
- insert_edge!(graph, FD, s0, false)
- insert_edge!(graph, EC, ED, false)
- insert_edge!(graph, FC, FD, false)
+ insert_edge!(graph, s0, d_exit, track = false)
+ insert_edge!(graph, ED, s0, track = false)
+ insert_edge!(graph, FD, s0, track = false)
+ insert_edge!(graph, EC, ED, track = false)
+ insert_edge!(graph, FC, FD, track = false)
- insert_edge!(graph, A1D, EC, false)
- insert_edge!(graph, B1D_1, EC, false)
+ insert_edge!(graph, A1D, EC, track = false)
+ insert_edge!(graph, B1D_1, EC, track = false)
- insert_edge!(graph, B1D_2, FC, false)
- insert_edge!(graph, C1D, FC, false)
+ insert_edge!(graph, B1D_2, FC, track = false)
+ insert_edge!(graph, C1D, FC, track = false)
- insert_edge!(graph, A1C, A1D, false)
- insert_edge!(graph, B1C_1, B1D_1, false)
- insert_edge!(graph, B1C_2, B1D_2, false)
- insert_edge!(graph, C1C, C1D, false)
+ insert_edge!(graph, A1C, A1D, track = false)
+ insert_edge!(graph, B1C_1, B1D_1, track = false)
+ insert_edge!(graph, B1C_2, B1D_2, track = false)
+ insert_edge!(graph, C1C, C1D, track = false)
- insert_edge!(graph, AD, A1C, false)
- insert_edge!(graph, BD, B1C_1, false)
- insert_edge!(graph, BD, B1C_2, false)
- insert_edge!(graph, CD, C1C, false)
+ insert_edge!(graph, AD, A1C, track = false)
+ insert_edge!(graph, BD, B1C_1, track = false)
+ insert_edge!(graph, BD, B1C_2, track = false)
+ insert_edge!(graph, CD, C1C, track = false)
@test is_valid(graph)
diff --git a/test/unit_tests_execution.jl b/test/unit_tests_execution.jl
index d0c0f68..4877076 100644
--- a/test/unit_tests_execution.jl
+++ b/test/unit_tests_execution.jl
@@ -1,31 +1,177 @@
-import MetagraphOptimization.A
-import MetagraphOptimization.B
-import MetagraphOptimization.ParticleType
+import MetagraphOptimization.ABCParticle
-@testset "Unit Tests Graph" begin
- particles = Dict{ParticleType, Vector{Particle}}(
- (
- A => [
- Particle(0.823648, 0.0, 0.0, 0.823648, A),
- Particle(0.823648, -0.835061, -0.474802, 0.277915, A),
- ]
- ),
- (
- B => [
- Particle(0.823648, 0.0, 0.0, -0.823648, B),
- Particle(0.823648, 0.835061, 0.474802, -0.277915, B),
- ]
- ),
+using QEDbase
+
+include("../examples/profiling_utilities.jl")
+
+@testset "Unit Tests Execution" begin
+ machine = get_machine_info()
+
+ process_2_2 = ABCProcessDescription(
+ Dict{Type, Int64}(ParticleA => 1, ParticleB => 1),
+ Dict{Type, Int64}(ParticleA => 1, ParticleB => 1),
)
- expected_result = 5.5320567694746876e-5
+ particles_2_2 = ABCProcessInput(
+ process_2_2,
+ ABCParticle[
+ ParticleA(SFourMomentum(0.823648, 0.0, 0.0, 0.823648)),
+ ParticleB(SFourMomentum(0.823648, 0.0, 0.0, -0.823648)),
+ ],
+ ABCParticle[
+ ParticleA(SFourMomentum(0.823648, -0.835061, -0.474802, 0.277915)),
+ ParticleB(SFourMomentum(0.823648, 0.835061, 0.474802, -0.277915)),
+ ],
+ )
+ expected_result = 0.00013916495566048735
- for _ in 1:10 # test in a loop because graph layout should not change the result
- graph = parse_abc(joinpath(@__DIR__, "..", "input", "AB->AB.txt"))
- @test isapprox(execute(graph, particles), expected_result; rtol = 0.001)
+ @testset "AB->AB no optimization" begin
+ for _ in 1:10 # test in a loop because graph layout should not change the result
+ graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
+ @test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001)
- code = MetagraphOptimization.gen_code(graph)
- @test isapprox(execute(code, particles), expected_result; rtol = 0.001)
+ # graph should be fully scheduled after being executed
+ @test is_scheduled(graph)
+
+ func = get_compute_function(graph, process_2_2, machine)
+ @test isapprox(func(particles_2_2), expected_result; rtol = 0.001)
+ end
end
+
+ @testset "AB->AB after random walk" begin
+ for i in 1:1000
+ graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
+ random_walk!(graph, 50)
+
+ @test is_valid(graph)
+
+ @test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001)
+
+ # graph should be fully scheduled after being executed
+ @test is_scheduled(graph)
+ end
+ end
+
+ process_2_4 = ABCProcessDescription(
+ Dict{Type, Int64}(ParticleA => 1, ParticleB => 1),
+ Dict{Type, Int64}(ParticleA => 1, ParticleB => 3),
+ )
+ particles_2_4 = gen_process_input(process_2_4)
+ graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"), ABCModel())
+ expected_result = execute(graph, process_2_4, machine, particles_2_4)
+
+ @testset "AB->ABBB no optimization" begin
+ for _ in 1:5 # test in a loop because graph layout should not change the result
+ graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"), ABCModel())
+ @test isapprox(execute(graph, process_2_4, machine, particles_2_4), expected_result; rtol = 0.001)
+
+ func = get_compute_function(graph, process_2_4, machine)
+ @test isapprox(func(particles_2_4), expected_result; rtol = 0.001)
+ end
+ end
+
+ @testset "AB->ABBB after random walk" begin
+ for i in 1:200
+ graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"), ABCModel())
+ random_walk!(graph, 100)
+ @test is_valid(graph)
+
+ @test isapprox(execute(graph, process_2_4, machine, particles_2_4), expected_result; rtol = 0.001)
+ end
+ end
+
+ @testset "AB->AB large sum fusion" for _ in 1:20
+ graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
+
+ # push a fusion with the sum node
+ ops = get_operations(graph)
+ for fusion in ops.nodeFusions
+ if isa(fusion.input[3].task, ComputeTaskSum)
+ push_operation!(graph, fusion)
+ break
+ end
+ end
+
+ # push two more fusions with the fused node
+ for _ in 1:15
+ ops = get_operations(graph)
+ for fusion in ops.nodeFusions
+ if isa(fusion.input[3].task, FusedComputeTask)
+ push_operation!(graph, fusion)
+ break
+ end
+ end
+ end
+
+ # try execute
+ @test is_valid(graph)
+ expected_result = 0.00013916495566048735
+ @test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001)
+ end
+
+
+ @testset "AB->AB large sum fusion" for _ in 1:20
+ graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
+
+ # push a fusion with the sum node
+ ops = get_operations(graph)
+ for fusion in ops.nodeFusions
+ if isa(fusion.input[3].task, ComputeTaskSum)
+ push_operation!(graph, fusion)
+ break
+ end
+ end
+
+ # push two more fusions with the fused node
+ for _ in 1:15
+ ops = get_operations(graph)
+ for fusion in ops.nodeFusions
+ if isa(fusion.input[3].task, FusedComputeTask)
+ push_operation!(graph, fusion)
+ break
+ end
+ end
+ end
+
+ # try execute
+ @test is_valid(graph)
+ expected_result = 0.00013916495566048735
+ @test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001)
+ end
+
+ @testset "AB->AB fusion edge case" for _ in 1:20
+ graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
+
+ # push two fusions with ComputeTaskV
+ for _ in 1:2
+ ops = get_operations(graph)
+ for fusion in ops.nodeFusions
+ if isa(fusion.input[1].task, ComputeTaskV)
+ push_operation!(graph, fusion)
+ break
+ end
+ end
+ end
+
+ # push fusions until the end
+ cont = true
+ while cont
+ cont = false
+ ops = get_operations(graph)
+ for fusion in ops.nodeFusions
+ if isa(fusion.input[1].task, FusedComputeTask)
+ push_operation!(graph, fusion)
+ cont = true
+ break
+ end
+ end
+ end
+
+ # try execute
+ @test is_valid(graph)
+ expected_result = 0.00013916495566048735
+ @test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001)
+ end
+
end
println("Execution Unit Tests Complete!")
diff --git a/test/unit_tests_graph.jl b/test/unit_tests_graph.jl
index 9c835e6..bab3155 100644
--- a/test/unit_tests_graph.jl
+++ b/test/unit_tests_graph.jl
@@ -11,104 +11,101 @@ import MetagraphOptimization.partners
@test length(graph.appliedOperations) == 0
@test length(graph.operationsToApply) == 0
@test length(graph.dirtyNodes) == 0
- @test length(graph.diff) ==
- (addedNodes = 0, removedNodes = 0, addedEdges = 0, removedEdges = 0)
- @test length(get_operations(graph)) ==
- (nodeFusions = 0, nodeReductions = 0, nodeSplits = 0)
+ @test length(graph.diff) == (addedNodes = 0, removedNodes = 0, addedEdges = 0, removedEdges = 0)
+ @test length(get_operations(graph)) == (nodeFusions = 0, nodeReductions = 0, nodeSplits = 0)
# s to output (exit node)
- d_exit = insert_node!(graph, make_node(DataTask(10)), false)
+ d_exit = insert_node!(graph, make_node(DataTask(10)), track = false)
@test length(graph.nodes) == 1
@test length(graph.dirtyNodes) == 1
# final s compute
- s0 = insert_node!(graph, make_node(ComputeTaskS2()), false)
+ s0 = insert_node!(graph, make_node(ComputeTaskS2()), track = false)
@test length(graph.nodes) == 2
@test length(graph.dirtyNodes) == 2
# data from v0 and v1 to s0
- d_v0_s0 = insert_node!(graph, make_node(DataTask(5)), false)
- d_v1_s0 = insert_node!(graph, make_node(DataTask(5)), false)
+ d_v0_s0 = insert_node!(graph, make_node(DataTask(5)), track = false)
+ d_v1_s0 = insert_node!(graph, make_node(DataTask(5)), track = false)
# v0 and v1 compute
- v0 = insert_node!(graph, make_node(ComputeTaskV()), false)
- v1 = insert_node!(graph, make_node(ComputeTaskV()), false)
+ v0 = insert_node!(graph, make_node(ComputeTaskV()), track = false)
+ v1 = insert_node!(graph, make_node(ComputeTaskV()), track = false)
# data from uB, uA, uBp and uAp to v0 and v1
- d_uB_v0 = insert_node!(graph, make_node(DataTask(3)), false)
- d_uA_v0 = insert_node!(graph, make_node(DataTask(3)), false)
- d_uBp_v1 = insert_node!(graph, make_node(DataTask(3)), false)
- d_uAp_v1 = insert_node!(graph, make_node(DataTask(3)), false)
+ d_uB_v0 = insert_node!(graph, make_node(DataTask(3)), track = false)
+ d_uA_v0 = insert_node!(graph, make_node(DataTask(3)), track = false)
+ d_uBp_v1 = insert_node!(graph, make_node(DataTask(3)), track = false)
+ d_uAp_v1 = insert_node!(graph, make_node(DataTask(3)), track = false)
# uB, uA, uBp and uAp computes
- uB = insert_node!(graph, make_node(ComputeTaskU()), false)
- uA = insert_node!(graph, make_node(ComputeTaskU()), false)
- uBp = insert_node!(graph, make_node(ComputeTaskU()), false)
- uAp = insert_node!(graph, make_node(ComputeTaskU()), false)
+ uB = insert_node!(graph, make_node(ComputeTaskU()), track = false)
+ uA = insert_node!(graph, make_node(ComputeTaskU()), track = false)
+ uBp = insert_node!(graph, make_node(ComputeTaskU()), track = false)
+ uAp = insert_node!(graph, make_node(ComputeTaskU()), track = false)
# data from PB, PA, PBp and PAp to uB, uA, uBp and uAp
- d_PB_uB = insert_node!(graph, make_node(DataTask(6)), false)
- d_PA_uA = insert_node!(graph, make_node(DataTask(6)), false)
- d_PBp_uBp = insert_node!(graph, make_node(DataTask(6)), false)
- d_PAp_uAp = insert_node!(graph, make_node(DataTask(6)), false)
+ d_PB_uB = insert_node!(graph, make_node(DataTask(6)), track = false)
+ d_PA_uA = insert_node!(graph, make_node(DataTask(6)), track = false)
+ d_PBp_uBp = insert_node!(graph, make_node(DataTask(6)), track = false)
+ d_PAp_uAp = insert_node!(graph, make_node(DataTask(6)), track = false)
# P computes PB, PA, PBp and PAp
- PB = insert_node!(graph, make_node(ComputeTaskP()), false)
- PA = insert_node!(graph, make_node(ComputeTaskP()), false)
- PBp = insert_node!(graph, make_node(ComputeTaskP()), false)
- PAp = insert_node!(graph, make_node(ComputeTaskP()), false)
+ PB = insert_node!(graph, make_node(ComputeTaskP()), track = false)
+ PA = insert_node!(graph, make_node(ComputeTaskP()), track = false)
+ PBp = insert_node!(graph, make_node(ComputeTaskP()), track = false)
+ PAp = insert_node!(graph, make_node(ComputeTaskP()), track = false)
# entry nodes getting data for P computes
- d_PB = insert_node!(graph, make_node(DataTask(4)), false)
- d_PA = insert_node!(graph, make_node(DataTask(4)), false)
- d_PBp = insert_node!(graph, make_node(DataTask(4)), false)
- d_PAp = insert_node!(graph, make_node(DataTask(4)), false)
+ d_PB = insert_node!(graph, make_node(DataTask(4)), track = false)
+ d_PA = insert_node!(graph, make_node(DataTask(4)), track = false)
+ d_PBp = insert_node!(graph, make_node(DataTask(4)), track = false)
+ d_PAp = insert_node!(graph, make_node(DataTask(4)), track = false)
@test length(graph.nodes) == 26
@test length(graph.dirtyNodes) == 26
# now for all the edges
- insert_edge!(graph, d_PB, PB, false)
- insert_edge!(graph, d_PA, PA, false)
- insert_edge!(graph, d_PBp, PBp, false)
- insert_edge!(graph, d_PAp, PAp, false)
+ insert_edge!(graph, d_PB, PB, track = false)
+ insert_edge!(graph, d_PA, PA, track = false)
+ insert_edge!(graph, d_PBp, PBp, track = false)
+ insert_edge!(graph, d_PAp, PAp, track = false)
- insert_edge!(graph, PB, d_PB_uB, false)
- insert_edge!(graph, PA, d_PA_uA, false)
- insert_edge!(graph, PBp, d_PBp_uBp, false)
- insert_edge!(graph, PAp, d_PAp_uAp, false)
+ insert_edge!(graph, PB, d_PB_uB, track = false)
+ insert_edge!(graph, PA, d_PA_uA, track = false)
+ insert_edge!(graph, PBp, d_PBp_uBp, track = false)
+ insert_edge!(graph, PAp, d_PAp_uAp, track = false)
- insert_edge!(graph, d_PB_uB, uB, false)
- insert_edge!(graph, d_PA_uA, uA, false)
- insert_edge!(graph, d_PBp_uBp, uBp, false)
- insert_edge!(graph, d_PAp_uAp, uAp, false)
+ insert_edge!(graph, d_PB_uB, uB, track = false)
+ insert_edge!(graph, d_PA_uA, uA, track = false)
+ insert_edge!(graph, d_PBp_uBp, uBp, track = false)
+ insert_edge!(graph, d_PAp_uAp, uAp, track = false)
- insert_edge!(graph, uB, d_uB_v0, false)
- insert_edge!(graph, uA, d_uA_v0, false)
- insert_edge!(graph, uBp, d_uBp_v1, false)
- insert_edge!(graph, uAp, d_uAp_v1, false)
+ insert_edge!(graph, uB, d_uB_v0, track = false)
+ insert_edge!(graph, uA, d_uA_v0, track = false)
+ insert_edge!(graph, uBp, d_uBp_v1, track = false)
+ insert_edge!(graph, uAp, d_uAp_v1, track = false)
- insert_edge!(graph, d_uB_v0, v0, false)
- insert_edge!(graph, d_uA_v0, v0, false)
- insert_edge!(graph, d_uBp_v1, v1, false)
- insert_edge!(graph, d_uAp_v1, v1, false)
+ insert_edge!(graph, d_uB_v0, v0, track = false)
+ insert_edge!(graph, d_uA_v0, v0, track = false)
+ insert_edge!(graph, d_uBp_v1, v1, track = false)
+ insert_edge!(graph, d_uAp_v1, v1, track = false)
- insert_edge!(graph, v0, d_v0_s0, false)
- insert_edge!(graph, v1, d_v1_s0, false)
+ insert_edge!(graph, v0, d_v0_s0, track = false)
+ insert_edge!(graph, v1, d_v1_s0, track = false)
- insert_edge!(graph, d_v0_s0, s0, false)
- insert_edge!(graph, d_v1_s0, s0, false)
+ insert_edge!(graph, d_v0_s0, s0, track = false)
+ insert_edge!(graph, d_v1_s0, s0, track = false)
- insert_edge!(graph, s0, d_exit, false)
+ insert_edge!(graph, s0, d_exit, track = false)
@test length(graph.nodes) == 26
@test length(graph.appliedOperations) == 0
@test length(graph.operationsToApply) == 0
@test length(graph.dirtyNodes) == 26
- @test length(graph.diff) ==
- (addedNodes = 0, removedNodes = 0, addedEdges = 0, removedEdges = 0)
+ @test length(graph.diff) == (addedNodes = 0, removedNodes = 0, addedEdges = 0, removedEdges = 0)
@test is_valid(graph)
@@ -135,8 +132,7 @@ import MetagraphOptimization.partners
@test length(siblings(s0)) == 1
operations = get_operations(graph)
- @test length(operations) ==
- (nodeFusions = 10, nodeReductions = 0, nodeSplits = 0)
+ @test length(operations) == (nodeFusions = 10, nodeReductions = 0, nodeSplits = 0)
@test length(graph.dirtyNodes) == 0
@test operations == get_operations(graph)
@@ -157,8 +153,7 @@ import MetagraphOptimization.partners
@test length(graph.operationsToApply) == 1
@test first(graph.operationsToApply) == nf
@test length(graph.dirtyNodes) == 0
- @test length(graph.diff) ==
- (addedNodes = 0, removedNodes = 0, addedEdges = 0, removedEdges = 0)
+ @test length(graph.diff) == (addedNodes = 0, removedNodes = 0, addedEdges = 0, removedEdges = 0)
# this applies pending operations
properties = get_properties(graph)
@@ -176,8 +171,7 @@ import MetagraphOptimization.partners
operations = get_operations(graph)
@test length(graph.dirtyNodes) == 0
- @test length(operations) ==
- (nodeFusions = 9, nodeReductions = 0, nodeSplits = 0)
+ @test length(operations) == (nodeFusions = 9, nodeReductions = 0, nodeSplits = 0)
@test !isempty(operations)
possibleNF = 9
@@ -185,14 +179,12 @@ import MetagraphOptimization.partners
push_operation!(graph, first(operations.nodeFusions))
operations = get_operations(graph)
possibleNF = possibleNF - 1
- @test length(operations) ==
- (nodeFusions = possibleNF, nodeReductions = 0, nodeSplits = 0)
+ @test length(operations) == (nodeFusions = possibleNF, nodeReductions = 0, nodeSplits = 0)
end
@test isempty(operations)
- @test length(operations) ==
- (nodeFusions = 0, nodeReductions = 0, nodeSplits = 0)
+ @test length(operations) == (nodeFusions = 0, nodeReductions = 0, nodeSplits = 0)
@test length(graph.dirtyNodes) == 0
@test length(graph.nodes) == 6
@test length(graph.appliedOperations) == 10
@@ -213,8 +205,7 @@ import MetagraphOptimization.partners
@test properties.computeIntensity ≈ 28 / 62
operations = get_operations(graph)
- @test length(operations) ==
- (nodeFusions = 10, nodeReductions = 0, nodeSplits = 0)
+ @test length(operations) == (nodeFusions = 10, nodeReductions = 0, nodeSplits = 0)
@test is_valid(graph)
end
diff --git a/test/unit_tests_nodes.jl b/test/unit_tests_nodes.jl
index 74be0e8..7a274d0 100644
--- a/test/unit_tests_nodes.jl
+++ b/test/unit_tests_nodes.jl
@@ -3,8 +3,7 @@
nC1 = MetagraphOptimization.make_node(MetagraphOptimization.ComputeTaskU())
nC2 = MetagraphOptimization.make_node(MetagraphOptimization.ComputeTaskV())
nC3 = MetagraphOptimization.make_node(MetagraphOptimization.ComputeTaskP())
- nC4 =
- MetagraphOptimization.make_node(MetagraphOptimization.ComputeTaskSum())
+ nC4 = MetagraphOptimization.make_node(MetagraphOptimization.ComputeTaskSum())
nD1 = MetagraphOptimization.make_node(MetagraphOptimization.DataTask(10))
nD2 = MetagraphOptimization.make_node(MetagraphOptimization.DataTask(20))
diff --git a/test/unit_tests_utility.jl b/test/unit_tests_utility.jl
index 169023e..db04d80 100644
--- a/test/unit_tests_utility.jl
+++ b/test/unit_tests_utility.jl
@@ -5,9 +5,7 @@
@test MetagraphOptimization.bytes_to_human_readable(1025) == "1.001 KiB"
@test MetagraphOptimization.bytes_to_human_readable(684235) == "668.2 KiB"
@test MetagraphOptimization.bytes_to_human_readable(86214576) == "82.22 MiB"
- @test MetagraphOptimization.bytes_to_human_readable(9241457698) ==
- "8.607 GiB"
- @test MetagraphOptimization.bytes_to_human_readable(3218598654367) ==
- "2.927 TiB"
+ @test MetagraphOptimization.bytes_to_human_readable(9241457698) == "8.607 GiB"
+ @test MetagraphOptimization.bytes_to_human_readable(3218598654367) == "2.927 TiB"
end
println("Utility Unit Tests Complete!")