Compare commits
43 Commits
performanc
...
scheduling
Author | SHA1 | Date | |
---|---|---|---|
6a09ecf33d | |||
4dcb616606 | |||
9b28601f18 | |||
3267daadfd | |||
140a954d01 | |||
a86901e425 | |||
0f50b59933 | |||
cbfed20b82 | |||
f9e60a7b5e | |||
314330f00f | |||
dd01a5e691 | |||
37d645cb4e | |||
afb6af44ca | |||
bef017130b | |||
7dd9fedf2e | |||
a69dd6018e | |||
4b44eb5286 | |||
24ade323f0 | |||
95f92f080c | |||
cc05cae1cd | |||
c88898a502 | |||
0d8d824540 | |||
c428613c80 | |||
f8a591991c | |||
bd6c54c1ae | |||
62791ab422 | |||
4c452dce98 | |||
27c4b8ba34 | |||
e59d24ebe5 | |||
d1666de432 | |||
0f78053ccf | |||
7a1a97dac8 | |||
f1edce258a | |||
32fcd069d7 | |||
e09ab7c77b | |||
7387fa86b1 | |||
065236be22 | |||
8014bbffcd | |||
ae1345d547 | |||
dbcd569967 | |||
0f5f475cb4 | |||
1b4030d633 | |||
383c92ec47 |
13
.JuliaFormatter.toml
Normal file
13
.JuliaFormatter.toml
Normal file
@ -0,0 +1,13 @@
|
||||
indent = 4
|
||||
margin = 120
|
||||
always_for_in = true
|
||||
for_in_replacement = "in"
|
||||
whitespace_typedefs = true
|
||||
whitespace_ops_in_indices = true
|
||||
long_to_short_function_def = false
|
||||
always_use_return = true
|
||||
whitespace_in_kwargs = true
|
||||
conditional_to_if = true
|
||||
normalize_line_endings = "unix"
|
||||
|
||||
overwrite = true
|
4
.gitattributes
vendored
4
.gitattributes
vendored
@ -1,2 +1,2 @@
|
||||
examples/AB->ABBBBBBB.txt filter=lfs diff=lfs merge=lfs -text
|
||||
examples/AB->ABBBBBBBBB.txt filter=lfs diff=lfs merge=lfs -text
|
||||
input/AB->ABBBBBBBBB.txt filter=lfs diff=lfs merge=lfs -text
|
||||
input/AB->ABBBBBBB.txt filter=lfs diff=lfs merge=lfs -text
|
||||
|
@ -1,33 +1,185 @@
|
||||
name: Test
|
||||
name: MetagraphOptimization_CI
|
||||
|
||||
on: [push]
|
||||
|
||||
env:
|
||||
# keep the depot directly in the repository for the cache
|
||||
JULIA_DEPOT_PATH: './.julia'
|
||||
|
||||
jobs:
|
||||
test:
|
||||
prepare:
|
||||
runs-on: arch-latest
|
||||
|
||||
steps:
|
||||
#- name: Get git-lfs
|
||||
# run: apt-get update && apt-get install git-lfs
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
#- name: Checkout LFS objects
|
||||
# run: git lfs checkout
|
||||
|
||||
- name: Setup Julia environment
|
||||
uses: https://github.com/julia-actions/setup-julia@v1.9.1
|
||||
uses: https://github.com/julia-actions/setup-julia@v1.9.2
|
||||
with:
|
||||
version: '1.9.1'
|
||||
version: '1.9.2'
|
||||
|
||||
# needed for the file hashing, should be removed when ${{ hashFiles('**/Project.toml') }} is supported in gitea
|
||||
- name: Setup go environment
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: '1.20'
|
||||
|
||||
- name: Hash files
|
||||
uses: https://gitea.com/actions/go-hashfiles@v0.0.1
|
||||
id: get-hash
|
||||
with:
|
||||
patterns: |-
|
||||
**/Project.toml
|
||||
|
||||
- name: Restore Cache
|
||||
uses: actions/cache/restore@v3
|
||||
id: cache-restore
|
||||
with:
|
||||
path: |
|
||||
.julia/artifacts
|
||||
.julia/packages
|
||||
.julia/registries
|
||||
key: julia-${{ steps.get-hash.outputs.hash }}
|
||||
|
||||
- name: Check cache hit
|
||||
if: steps.cache-restore.outputs.cache-hit == 'true'
|
||||
run: exit 0
|
||||
|
||||
- name: Install dependencies
|
||||
run: julia --project=./ -e 'import Pkg; Pkg.instantiate()'
|
||||
run: |
|
||||
julia --project=./ -e 'import Pkg; Pkg.instantiate(); Pkg.precompile()'
|
||||
julia --project=examples/ -e 'import Pkg; Pkg.develop(Pkg.PackageSpec(path=pwd())); Pkg.instantiate(); Pkg.precompile()'
|
||||
julia --project=docs/ -e 'import Pkg; Pkg.develop(Pkg.PackageSpec(path=pwd())); Pkg.instantiate(); Pkg.precompile()'
|
||||
|
||||
- name: Cache Julia packages
|
||||
uses: actions/cache/save@v3
|
||||
with:
|
||||
path: |
|
||||
.julia/artifacts
|
||||
.julia/packages
|
||||
.julia/registries
|
||||
key: julia-${{ steps.get-hash.outputs.hash }}
|
||||
|
||||
test:
|
||||
needs: prepare
|
||||
runs-on: arch-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Julia environment
|
||||
uses: https://github.com/julia-actions/setup-julia@v1.9.2
|
||||
with:
|
||||
version: '1.9.2'
|
||||
|
||||
# needed for the file hashing, should be removed when ${{ hashFiles('**/Project.toml') }} is supported in gitea
|
||||
- name: Setup go environment
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: '1.20'
|
||||
|
||||
- name: Hash files
|
||||
uses: https://gitea.com/actions/go-hashfiles@v0.0.1
|
||||
id: get-hash
|
||||
with:
|
||||
patterns: |-
|
||||
**/Project.toml
|
||||
|
||||
- name: Restore cached Julia packages
|
||||
uses: actions/cache/restore@v3
|
||||
with:
|
||||
path: |
|
||||
.julia/artifacts
|
||||
.julia/packages
|
||||
.julia/registries
|
||||
key: julia-${{ steps.get-hash.outputs.hash }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
julia --project=./ -e 'import Pkg; Pkg.instantiate(); Pkg.precompile()'
|
||||
julia --project=examples/ -e 'import Pkg; Pkg.develop(Pkg.PackageSpec(path=pwd())); Pkg.instantiate(); Pkg.precompile()'
|
||||
julia --project=docs/ -e 'import Pkg; Pkg.develop(Pkg.PackageSpec(path=pwd())); Pkg.instantiate(); Pkg.precompile()'
|
||||
|
||||
- name: Format check
|
||||
run: |
|
||||
julia --project=./ -e 'using JuliaFormatter; format(".", verbose=true, ignore=[".julia/*"])'
|
||||
julia --project=./ -e '
|
||||
out = Cmd(`git diff --name-only`) |> read |> String
|
||||
if out == ""
|
||||
exit(0)
|
||||
else
|
||||
@error "Some files have not been formatted !!!"
|
||||
write(stdout, out)
|
||||
exit(1)
|
||||
end'
|
||||
|
||||
- name: Run tests
|
||||
run: julia --project=./ -t 4 -e 'import Pkg; Pkg.test()' -O0
|
||||
|
||||
- name: Run examples
|
||||
run: julia --project=examples/ -t 4 -e 'import Pkg; Pkg.develop(Pkg.PackageSpec(path=pwd())); Pkg.instantiate(); include("examples/import_bench.jl")' -O3
|
||||
run: julia --project=examples/ -t 4 -e 'include("examples/import_bench.jl")' -O3
|
||||
|
||||
docs:
|
||||
needs: prepare
|
||||
runs-on: arch-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Julia environment
|
||||
uses: https://github.com/julia-actions/setup-julia@v1.9.2
|
||||
with:
|
||||
version: '1.9.2'
|
||||
|
||||
# needed for the file hashing, should be removed when ${{ hashFiles('**/Project.toml') }} is supported in gitea
|
||||
- name: Setup go environment
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: '1.20'
|
||||
|
||||
- name: Hash files
|
||||
uses: https://gitea.com/actions/go-hashfiles@v0.0.1
|
||||
id: get-hash
|
||||
with:
|
||||
patterns: |-
|
||||
**/Project.toml
|
||||
|
||||
- name: Restore cached Julia packages
|
||||
uses: actions/cache/restore@v3
|
||||
with:
|
||||
path: |
|
||||
.julia/artifacts
|
||||
.julia/packages
|
||||
.julia/registries
|
||||
key: julia-${{ steps.get-hash.outputs.hash }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
julia --project=./ -e 'import Pkg; Pkg.instantiate(); Pkg.precompile()'
|
||||
julia --project=examples/ -e 'import Pkg; Pkg.develop(Pkg.PackageSpec(path=pwd())); Pkg.instantiate(); Pkg.precompile()'
|
||||
julia --project=docs/ -e 'import Pkg; Pkg.develop(Pkg.PackageSpec(path=pwd())); Pkg.instantiate(); Pkg.precompile()'
|
||||
|
||||
- name: Build docs
|
||||
run: julia --project=docs/ docs/make.jl
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: web-doc
|
||||
path: docs/build/
|
||||
|
||||
#- name: Webhook Trigger
|
||||
# uses: https://github.com/zzzze/webhook-trigger@master
|
||||
# continue-on-error: true
|
||||
# with:
|
||||
# data: "{\"event\":\"action_completed\", \"download_url\":\"deckardcain.local:8099/something\"}"
|
||||
# webhook_url: ${{ secrets.WEBHOOK_URL }}
|
||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -26,3 +26,5 @@ Manifest.toml
|
||||
|
||||
# vscode workspace directory
|
||||
.vscode
|
||||
.julia
|
||||
**/.ipynb_checkpoints/
|
||||
|
@ -4,9 +4,16 @@ authors = ["Anton Reinhard <anton.reinhard@proton.me>"]
|
||||
version = "0.1.0"
|
||||
|
||||
[deps]
|
||||
AccurateArithmetic = "22286c92-06ac-501d-9306-4abd417d9753"
|
||||
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
|
||||
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
|
||||
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
|
||||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
|
||||
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
|
||||
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
|
||||
NumaAllocators = "21436f30-1b4a-4f08-87af-e26101bb5379"
|
||||
QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
|
||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
|
||||
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
||||
|
||||
[extras]
|
||||
|
@ -42,7 +42,7 @@ Problems:
|
||||
- Lots of testing required because mistakes will propagate and multiply.
|
||||
|
||||
## Other TODOs
|
||||
- Reduce memory footprint of the graph, are the UUIDs too large?
|
||||
- Reduce memory footprint of the graph
|
||||
- Memory layout of Nodes? They should lie linearly in memory, right now probably on heap?
|
||||
- Add scaling functions
|
||||
|
||||
@ -53,7 +53,7 @@ For graphs AB->AB^n:
|
||||
- Number of ComputeTaskS2 should always be (n+1)!
|
||||
- Number of ComputeTaskU should always be (n+3)
|
||||
|
||||
Times are from my home machine: AMD Ryzen 7900X3D, 64GB DDR5 RAM @ 6000MHz
|
||||
Times are from my home machine: AMD Ryzen 7900X3D, 64GB DDR5 RAM @ 6000MHz (not necessarily up to date, check Jupyter Notebooks in `notebooks/` instead)
|
||||
|
||||
```
|
||||
$ julia --project examples/import_bench.jl
|
||||
|
4
docs/Project.toml
Normal file
4
docs/Project.toml
Normal file
@ -0,0 +1,4 @@
|
||||
[deps]
|
||||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
||||
DocumenterTools = "35a29f4d-8980-5a13-9543-d66fff28ecb8"
|
||||
MetagraphOptimization = "3e869610-d48d-4942-ba70-c1b702a33ca4"
|
34
docs/make.jl
Normal file
34
docs/make.jl
Normal file
@ -0,0 +1,34 @@
|
||||
using Documenter
|
||||
using MetagraphOptimization
|
||||
|
||||
makedocs(
|
||||
#format = Documenter.LaTeX(platform=""),
|
||||
|
||||
root = "docs",
|
||||
source = "src",
|
||||
build = "build",
|
||||
warnonly = true,
|
||||
clean = true,
|
||||
doctest = true,
|
||||
modules = Module[MetagraphOptimization],
|
||||
#repo = "https://code.woubery.com/Rubydragon/MetagraphOptimization.jl/src/branch/{commit}{path}#L{line}",
|
||||
remotes = nothing,
|
||||
sitename = "MetagraphOptimization.jl",
|
||||
pages = [
|
||||
"index.md",
|
||||
"Manual" => "manual.md",
|
||||
"Library" => [
|
||||
"Public" => "lib/public.md",
|
||||
"Graph" => "lib/internals/graph.md",
|
||||
"Node" => "lib/internals/node.md",
|
||||
"Task" => "lib/internals/task.md",
|
||||
"Operation" => "lib/internals/operation.md",
|
||||
"Models" => "lib/internals/models.md",
|
||||
"Diff" => "lib/internals/diff.md",
|
||||
"Utility" => "lib/internals/utility.md",
|
||||
"Code Generation" => "lib/internals/code_gen.md",
|
||||
"Devices" => "lib/internals/devices.md",
|
||||
],
|
||||
"Contribution" => "contribution.md",
|
||||
],
|
||||
)
|
3
docs/src/contribution.md
Normal file
3
docs/src/contribution.md
Normal file
@ -0,0 +1,3 @@
|
||||
# Contribution
|
||||
|
||||
This is currently in development for a diploma thesis and is therefore private and impossible to contribute to.
|
75
docs/src/flowchart.drawio
Normal file
75
docs/src/flowchart.drawio
Normal file
@ -0,0 +1,75 @@
|
||||
<mxfile host="Electron" modified="2023-09-17T13:34:45.840Z" agent="Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/21.6.1 Chrome/114.0.5735.134 Electron/25.6.0 Safari/537.36" etag="e0c8qLevhaP_q_R2fyC9" version="21.6.1" type="device">
|
||||
<diagram name="Page-1" id="Vy0cA1nkMPfy-3cC5ahA">
|
||||
<mxGraphModel dx="1185" dy="707" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="500" pageHeight="900" math="0" shadow="0">
|
||||
<root>
|
||||
<mxCell id="0" />
|
||||
<mxCell id="1" parent="0" />
|
||||
<mxCell id="yG8qeggDCLqQ8GwY7ugi-5" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;shadow=1;" edge="1" parent="1" source="yG8qeggDCLqQ8GwY7ugi-1" target="yG8qeggDCLqQ8GwY7ugi-2">
|
||||
<mxGeometry relative="1" as="geometry" />
|
||||
</mxCell>
|
||||
<mxCell id="yG8qeggDCLqQ8GwY7ugi-1" value="Process Generator Script" style="rounded=1;whiteSpace=wrap;html=1;shadow=1;" vertex="1" parent="1">
|
||||
<mxGeometry x="180" y="120" width="120" height="60" as="geometry" />
|
||||
</mxCell>
|
||||
<mxCell id="yG8qeggDCLqQ8GwY7ugi-11" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;shadow=1;" edge="1" parent="1" source="yG8qeggDCLqQ8GwY7ugi-2" target="yG8qeggDCLqQ8GwY7ugi-3">
|
||||
<mxGeometry relative="1" as="geometry" />
|
||||
</mxCell>
|
||||
<mxCell id="yG8qeggDCLqQ8GwY7ugi-2" value="Process Parser" style="rounded=1;whiteSpace=wrap;html=1;shadow=1;" vertex="1" parent="1">
|
||||
<mxGeometry x="180" y="220" width="120" height="60" as="geometry" />
|
||||
</mxCell>
|
||||
<mxCell id="yG8qeggDCLqQ8GwY7ugi-8" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=1;exitY=0.25;exitDx=0;exitDy=0;entryX=0;entryY=0.25;entryDx=0;entryDy=0;shadow=1;" edge="1" parent="1" source="yG8qeggDCLqQ8GwY7ugi-3" target="yG8qeggDCLqQ8GwY7ugi-6">
|
||||
<mxGeometry relative="1" as="geometry" />
|
||||
</mxCell>
|
||||
<mxCell id="yG8qeggDCLqQ8GwY7ugi-14" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;shadow=1;" edge="1" parent="1" source="yG8qeggDCLqQ8GwY7ugi-3" target="yG8qeggDCLqQ8GwY7ugi-12">
|
||||
<mxGeometry relative="1" as="geometry" />
|
||||
</mxCell>
|
||||
<mxCell id="yG8qeggDCLqQ8GwY7ugi-3" value="Optimizer" style="rounded=1;whiteSpace=wrap;html=1;shadow=1;" vertex="1" parent="1">
|
||||
<mxGeometry x="180" y="320" width="120" height="60" as="geometry" />
|
||||
</mxCell>
|
||||
<mxCell id="yG8qeggDCLqQ8GwY7ugi-9" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=1;entryY=0.75;entryDx=0;entryDy=0;exitX=0;exitY=0.75;exitDx=0;exitDy=0;shadow=1;" edge="1" parent="1" source="yG8qeggDCLqQ8GwY7ugi-6" target="yG8qeggDCLqQ8GwY7ugi-3">
|
||||
<mxGeometry relative="1" as="geometry" />
|
||||
</mxCell>
|
||||
<mxCell id="yG8qeggDCLqQ8GwY7ugi-6" value="Fast Cost Estimator<br>(Global Metrics)" style="rounded=1;whiteSpace=wrap;html=1;shadow=1;" vertex="1" parent="1">
|
||||
<mxGeometry x="340" y="320" width="120" height="60" as="geometry" />
|
||||
</mxCell>
|
||||
<mxCell id="yG8qeggDCLqQ8GwY7ugi-15" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;shadow=1;" edge="1" parent="1" source="yG8qeggDCLqQ8GwY7ugi-12" target="yG8qeggDCLqQ8GwY7ugi-13">
|
||||
<mxGeometry relative="1" as="geometry">
|
||||
<Array as="points">
|
||||
<mxPoint x="80" y="450" />
|
||||
</Array>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="yG8qeggDCLqQ8GwY7ugi-12" value="Scheduler" style="rounded=1;whiteSpace=wrap;html=1;shadow=1;" vertex="1" parent="1">
|
||||
<mxGeometry x="180" y="420" width="120" height="60" as="geometry" />
|
||||
</mxCell>
|
||||
<mxCell id="yG8qeggDCLqQ8GwY7ugi-16" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;shadow=1;" edge="1" parent="1" source="yG8qeggDCLqQ8GwY7ugi-13" target="yG8qeggDCLqQ8GwY7ugi-3">
|
||||
<mxGeometry relative="1" as="geometry">
|
||||
<mxPoint x="120" y="380" as="targetPoint" />
|
||||
<Array as="points">
|
||||
<mxPoint x="80" y="350" />
|
||||
</Array>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="yG8qeggDCLqQ8GwY7ugi-19" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=1;exitDx=0;exitDy=0;shadow=1;" edge="1" parent="1" source="yG8qeggDCLqQ8GwY7ugi-12" target="yG8qeggDCLqQ8GwY7ugi-18">
|
||||
<mxGeometry relative="1" as="geometry">
|
||||
<Array as="points">
|
||||
<mxPoint x="240" y="500" />
|
||||
<mxPoint x="240" y="500" />
|
||||
</Array>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="yG8qeggDCLqQ8GwY7ugi-13" value="Accurate Cost Estimator<br>(Machine Specific)" style="rounded=1;whiteSpace=wrap;html=1;shadow=1;" vertex="1" parent="1">
|
||||
<mxGeometry x="20" y="370" width="120" height="60" as="geometry" />
|
||||
</mxCell>
|
||||
<mxCell id="yG8qeggDCLqQ8GwY7ugi-21" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;shadow=1;" edge="1" parent="1" source="yG8qeggDCLqQ8GwY7ugi-18" target="yG8qeggDCLqQ8GwY7ugi-20">
|
||||
<mxGeometry relative="1" as="geometry" />
|
||||
</mxCell>
|
||||
<mxCell id="yG8qeggDCLqQ8GwY7ugi-18" value="Code Generator" style="rounded=1;whiteSpace=wrap;html=1;shadow=1;" vertex="1" parent="1">
|
||||
<mxGeometry x="180" y="520" width="120" height="60" as="geometry" />
|
||||
</mxCell>
|
||||
<mxCell id="yG8qeggDCLqQ8GwY7ugi-20" value="Executor" style="rounded=1;whiteSpace=wrap;html=1;shadow=1;" vertex="1" parent="1">
|
||||
<mxGeometry x="180" y="620" width="120" height="60" as="geometry" />
|
||||
</mxCell>
|
||||
</root>
|
||||
</mxGraphModel>
|
||||
</diagram>
|
||||
</mxfile>
|
26
docs/src/index.md
Normal file
26
docs/src/index.md
Normal file
@ -0,0 +1,26 @@
|
||||
# MetagraphOptimization.jl
|
||||
|
||||
*A domain-specific DAG-optimizer*
|
||||
|
||||
## Package Features
|
||||
- Read a DAG from a file
|
||||
- Analyze its properties
|
||||
- Mute the graph using the operations NodeFusion, NodeReduction and NodeSplit
|
||||
|
||||
## Coming Soon:
|
||||
- Add Code Generation from finished DAG
|
||||
- Add optimization algorithms and strategies
|
||||
|
||||
## Library Outline
|
||||
|
||||
```@contents
|
||||
Pages = [
|
||||
"lib/public.md",
|
||||
"lib/internals.md"
|
||||
]
|
||||
```
|
||||
|
||||
### [Index](@id main-index)
|
||||
```@index
|
||||
Pages = ["lib/public.md"]
|
||||
```
|
8
docs/src/lib/internals/code_gen.md
Normal file
8
docs/src/lib/internals/code_gen.md
Normal file
@ -0,0 +1,8 @@
|
||||
# Code Generation
|
||||
|
||||
## Main
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["code_gen/main.jl"]
|
||||
Order = [:function]
|
||||
```
|
59
docs/src/lib/internals/devices.md
Normal file
59
docs/src/lib/internals/devices.md
Normal file
@ -0,0 +1,59 @@
|
||||
# Devices
|
||||
|
||||
## Interface
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["devices/interface.jl"]
|
||||
Order = [:type, :constant, :function]
|
||||
```
|
||||
|
||||
## Detect
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["devices/detect.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Measure
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["devices/measure.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Implementations
|
||||
|
||||
### General
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["devices/impl.jl"]
|
||||
Order = [:type, :function]
|
||||
```
|
||||
|
||||
### NUMA
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["devices/numa/impl.jl"]
|
||||
Order = [:type, :function]
|
||||
```
|
||||
|
||||
### CUDA
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["devices/cuda/impl.jl"]
|
||||
Order = [:type, :function]
|
||||
```
|
||||
|
||||
### ROCm
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["devices/rocm/impl.jl"]
|
||||
Order = [:type, :function]
|
||||
```
|
||||
|
||||
### oneAPI
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["devices/oneapi/impl.jl"]
|
||||
Order = [:type, :function]
|
||||
```
|
22
docs/src/lib/internals/diff.md
Normal file
22
docs/src/lib/internals/diff.md
Normal file
@ -0,0 +1,22 @@
|
||||
# Diff
|
||||
|
||||
## Type
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["diff/type.jl"]
|
||||
Order = [:type]
|
||||
```
|
||||
|
||||
## Properties
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["diff/properties.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Printing
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["diff/print.jl"]
|
||||
Order = [:function]
|
||||
```
|
50
docs/src/lib/internals/graph.md
Normal file
50
docs/src/lib/internals/graph.md
Normal file
@ -0,0 +1,50 @@
|
||||
# Graph
|
||||
|
||||
## Type
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["graph/type.jl"]
|
||||
Order = [:type]
|
||||
```
|
||||
|
||||
## Interface
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["graph/interface.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Compare
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["graph/compare.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Mute
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["graph/mute.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Print
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["graph/print.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Properties
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["graph/properties.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Validate
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["graph/validate.jl"]
|
||||
Order = [:function]
|
||||
```
|
72
docs/src/lib/internals/models.md
Normal file
72
docs/src/lib/internals/models.md
Normal file
@ -0,0 +1,72 @@
|
||||
# Models
|
||||
|
||||
## Interface
|
||||
|
||||
The interface that has to be implemented for a model to be usable is defined in `src/models/interface.jl`.
|
||||
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["models/interface.jl"]
|
||||
Order = [:type, :constant, :function]
|
||||
```
|
||||
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["models/print.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## ABC-Model
|
||||
|
||||
### Types
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["models/abc/types.jl"]
|
||||
Order = [:type, :constant]
|
||||
```
|
||||
|
||||
### Particle
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["models/abc/particle.jl"]
|
||||
Order = [:type, :constant, :function]
|
||||
```
|
||||
|
||||
### Parse
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["models/abc/parse.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
### Properties
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["models/abc/properties.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
### Create
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["models/abc/create.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
### Compute
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["models/abc/compute.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
### Print
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["models/abc/print.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## QED-Model
|
||||
|
||||
*To be added*
|
43
docs/src/lib/internals/node.md
Normal file
43
docs/src/lib/internals/node.md
Normal file
@ -0,0 +1,43 @@
|
||||
# Node
|
||||
|
||||
## Type
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["node/type.jl"]
|
||||
Order = [:type]
|
||||
```
|
||||
|
||||
## Create
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["node/create.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Compare
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["node/compare.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Properties
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["node/properties.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Print
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["node/print.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Validate
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["node/validate.jl"]
|
||||
Order = [:function]
|
||||
```
|
57
docs/src/lib/internals/operation.md
Normal file
57
docs/src/lib/internals/operation.md
Normal file
@ -0,0 +1,57 @@
|
||||
# Operation
|
||||
|
||||
## Types
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["operation/type.jl"]
|
||||
Order = [:type]
|
||||
```
|
||||
|
||||
## Find
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["operation/find.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Apply
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["operation/apply.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Get
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["operation/get.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Clean
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["operation/clean.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Utility
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["operation/utility.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Print
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["operation/print.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Validate
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["operation/validate.jl"]
|
||||
Order = [:function]
|
||||
```
|
22
docs/src/lib/internals/properties.md
Normal file
22
docs/src/lib/internals/properties.md
Normal file
@ -0,0 +1,22 @@
|
||||
# Properties
|
||||
|
||||
## Type
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["properties/type.jl"]
|
||||
Order = [:type]
|
||||
```
|
||||
|
||||
## Create
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["properties/create.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Utility
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["properties/utility.jl"]
|
||||
Order = [:function]
|
||||
```
|
15
docs/src/lib/internals/scheduler.md
Normal file
15
docs/src/lib/internals/scheduler.md
Normal file
@ -0,0 +1,15 @@
|
||||
# Scheduler
|
||||
|
||||
## Interface
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["scheduler/interface.jl"]
|
||||
Order = [:type, :function]
|
||||
```
|
||||
|
||||
## Greedy
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["scheduler/greedy.jl"]
|
||||
Order = [:type, :function]
|
||||
```
|
43
docs/src/lib/internals/task.md
Normal file
43
docs/src/lib/internals/task.md
Normal file
@ -0,0 +1,43 @@
|
||||
# Task
|
||||
|
||||
## Type
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["task/type.jl"]
|
||||
Order = [:type]
|
||||
```
|
||||
|
||||
## Create
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["task/create.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Compare
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["task/compare.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Compute
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["task/compute.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Properties
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["task/properties.jl"]
|
||||
Order = [:function]
|
||||
```
|
||||
|
||||
## Print
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["task/print.jl"]
|
||||
Order = [:function]
|
||||
```
|
17
docs/src/lib/internals/utility.md
Normal file
17
docs/src/lib/internals/utility.md
Normal file
@ -0,0 +1,17 @@
|
||||
# Utility
|
||||
|
||||
## Helper Functions
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["./utility.jl"]
|
||||
Order = [:type, :function]
|
||||
```
|
||||
|
||||
## Trie Helper
|
||||
This is a simple implementation of a [Trie Data Structure](https://en.wikipedia.org/wiki/Trie) to greatly improve the performance of the Node Reduction search.
|
||||
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["trie.jl"]
|
||||
Order = [:type, :function]
|
||||
```
|
24
docs/src/lib/public.md
Normal file
24
docs/src/lib/public.md
Normal file
@ -0,0 +1,24 @@
|
||||
# Public Documentation
|
||||
|
||||
Documentation for `MetagraphOptimization.jl`'s public interface.
|
||||
|
||||
See the Internals section of the manual for documentation of everything else.
|
||||
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["MetagraphOptimization.jl"]
|
||||
Order = [:module]
|
||||
```
|
||||
|
||||
## Contents
|
||||
|
||||
```@contents
|
||||
Pages = ["public.md"]
|
||||
Depth = 2
|
||||
```
|
||||
|
||||
## Index
|
||||
|
||||
```@index
|
||||
Pages = ["public.md"]
|
||||
```
|
7
docs/src/manual.md
Normal file
7
docs/src/manual.md
Normal file
@ -0,0 +1,7 @@
|
||||
# Manual
|
||||
|
||||
## Jupyter Notebooks
|
||||
|
||||
In the `notebooks` directory are notebooks containing some examples of the usage of this repository.
|
||||
|
||||
- `abc_model_showcase`: A simple showcase of the intended usage of the ABC Model implementation.
|
@ -1,7 +1,3 @@
|
||||
[deps]
|
||||
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
|
||||
MetagraphOptimization = "3e869610-d48d-4942-ba70-c1b702a33ca4"
|
||||
PProf = "e4faabce-9ead-11e9-39d9-4379958e3056"
|
||||
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
|
||||
ProfileView = "c46f51b8-102a-5cf2-8d2c-8597cb0e0da7"
|
||||
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
|
||||
|
@ -13,18 +13,20 @@ function bench_txt(filepath::String, bench::Bool = true)
|
||||
return
|
||||
end
|
||||
|
||||
model = ABCModel()
|
||||
|
||||
println(name, ":")
|
||||
g = parse_abc(filepath)
|
||||
g = parse_dag(filepath, model)
|
||||
print(g)
|
||||
#println(" Graph size in memory: ", bytes_to_human_readable(Base.summarysize(g)))
|
||||
println(" Graph size in memory: ", bytes_to_human_readable(MetagraphOptimization.mem(g)))
|
||||
|
||||
if (bench)
|
||||
@btime parse_abc($filepath)
|
||||
@btime parse_dag($filepath, $model)
|
||||
end
|
||||
|
||||
println(" Get Operations: ")
|
||||
@time get_operations(g)
|
||||
println()
|
||||
return println()
|
||||
end
|
||||
|
||||
function import_bench()
|
||||
@ -34,7 +36,7 @@ function import_bench()
|
||||
bench_txt("AB->ABBBBBBB.txt")
|
||||
#bench_txt("AB->ABBBBBBBBB.txt")
|
||||
bench_txt("ABAB->ABAB.txt")
|
||||
bench_txt("ABAB->ABC.txt")
|
||||
return bench_txt("ABAB->ABC.txt")
|
||||
end
|
||||
|
||||
import_bench()
|
||||
|
@ -12,7 +12,7 @@ function gen_plot(filepath)
|
||||
return
|
||||
end
|
||||
|
||||
g = parse_abc(filepath)
|
||||
g = parse_dag(filepath, ABCModel())
|
||||
|
||||
Random.seed!(1)
|
||||
|
||||
@ -21,7 +21,7 @@ function gen_plot(filepath)
|
||||
x = Vector{Float64}()
|
||||
y = Vector{Float64}()
|
||||
|
||||
for i = 1:30
|
||||
for i in 1:30
|
||||
print("\r", i)
|
||||
# push
|
||||
opt = get_operations(g)
|
||||
@ -38,23 +38,23 @@ function gen_plot(filepath)
|
||||
push_operation!(g, rand(collect(opt.nodeSplits)))
|
||||
println("NS")
|
||||
else
|
||||
i = i-1
|
||||
i = i - 1
|
||||
end
|
||||
|
||||
props = graph_properties(g)
|
||||
props = get_properties(g)
|
||||
push!(x, props.data)
|
||||
push!(y, props.compute_effort)
|
||||
push!(y, props.computeEffort)
|
||||
end
|
||||
|
||||
println("\rDone.")
|
||||
|
||||
plot([x[1], x[2]], [y[1], y[2]], linestyle = :solid, linewidth = 1, color = :red, legend=false)
|
||||
plot([x[1], x[2]], [y[1], y[2]], linestyle = :solid, linewidth = 1, color = :red, legend = false)
|
||||
# Create lines connecting the reference point to each data point
|
||||
for i in 3:length(x)
|
||||
plot!([x[i-1], x[i]], [y[i-1], y[i]], linestyle = :solid, linewidth = 1, color = :red)
|
||||
plot!([x[i - 1], x[i]], [y[i - 1], y[i]], linestyle = :solid, linewidth = 1, color = :red)
|
||||
end
|
||||
|
||||
gui()
|
||||
return gui()
|
||||
end
|
||||
|
||||
gen_plot("AB->ABBB.txt")
|
||||
|
@ -12,13 +12,13 @@ function gen_plot(filepath)
|
||||
return
|
||||
end
|
||||
|
||||
g = parse_abc(filepath)
|
||||
g = parse_dag(filepath, ABCModel())
|
||||
|
||||
Random.seed!(1)
|
||||
|
||||
println("Random Walking... ")
|
||||
|
||||
for i = 1:30
|
||||
for i in 1:30
|
||||
print("\r", i)
|
||||
# push
|
||||
opt = get_operations(g)
|
||||
@ -35,7 +35,7 @@ function gen_plot(filepath)
|
||||
push_operation!(g, rand(collect(opt.nodeSplits)))
|
||||
println("NS")
|
||||
else
|
||||
i = i-1
|
||||
i = i - 1
|
||||
end
|
||||
end
|
||||
|
||||
@ -44,9 +44,9 @@ function gen_plot(filepath)
|
||||
|
||||
|
||||
|
||||
props = graph_properties(g)
|
||||
props = get_properties(g)
|
||||
x0 = props.data
|
||||
y0 = props.compute_effort
|
||||
y0 = props.computeEffort
|
||||
|
||||
x = Vector{Float64}()
|
||||
y = Vector{Float64}()
|
||||
@ -55,33 +55,33 @@ function gen_plot(filepath)
|
||||
opt = get_operations(g)
|
||||
for op in opt.nodeFusions
|
||||
push_operation!(g, op)
|
||||
props = graph_properties(g)
|
||||
props = get_properties(g)
|
||||
push!(x, props.data)
|
||||
push!(y, props.compute_effort)
|
||||
push!(y, props.computeEffort)
|
||||
pop_operation!(g)
|
||||
|
||||
push!(names, "NF: (" * string(props.data) * ", " * string(props.compute_effort) * ")")
|
||||
push!(names, "NF: (" * string(props.data) * ", " * string(props.computeEffort) * ")")
|
||||
end
|
||||
for op in opt.nodeReductions
|
||||
push_operation!(g, op)
|
||||
props = graph_properties(g)
|
||||
props = get_properties(g)
|
||||
push!(x, props.data)
|
||||
push!(y, props.compute_effort)
|
||||
push!(y, props.computeEffort)
|
||||
pop_operation!(g)
|
||||
|
||||
push!(names, "NR: (" * string(props.data) * ", " * string(props.compute_effort) * ")")
|
||||
push!(names, "NR: (" * string(props.data) * ", " * string(props.computeEffort) * ")")
|
||||
end
|
||||
for op in opt.nodeSplits
|
||||
push_operation!(g, op)
|
||||
props = graph_properties(g)
|
||||
props = get_properties(g)
|
||||
push!(x, props.data)
|
||||
push!(y, props.compute_effort)
|
||||
push!(y, props.computeEffort)
|
||||
pop_operation!(g)
|
||||
|
||||
push!(names, "NS: (" * string(props.data) * ", " * string(props.compute_effort) * ")")
|
||||
push!(names, "NS: (" * string(props.data) * ", " * string(props.computeEffort) * ")")
|
||||
end
|
||||
|
||||
plot([x0, x[1]], [y0, y[1]], linestyle = :solid, linewidth = 1, color = :red, legend=false)
|
||||
|
||||
plot([x0, x[1]], [y0, y[1]], linestyle = :solid, linewidth = 1, color = :red, legend = false)
|
||||
# Create lines connecting the reference point to each data point
|
||||
for i in 2:length(x)
|
||||
plot!([x0, x[i]], [y0, y[i]], linestyle = :solid, linewidth = 1, color = :red)
|
||||
@ -90,7 +90,7 @@ function gen_plot(filepath)
|
||||
|
||||
print(names)
|
||||
|
||||
gui()
|
||||
return gui()
|
||||
end
|
||||
|
||||
gen_plot("AB->ABBB.txt")
|
||||
|
@ -1,11 +1,11 @@
|
||||
|
||||
function test_random_walk(g::DAG, n::Int64)
|
||||
# the purpose here is to do "random" operations and reverse them again and validate that the graph stays the same and doesn't diverge
|
||||
function random_walk!(g::DAG, n::Int64)
|
||||
# the purpose here is to do "random" operations on the graph to simulate an optimizer
|
||||
reset_graph!(g)
|
||||
|
||||
properties = graph_properties(g)
|
||||
properties = get_properties(g)
|
||||
|
||||
for i = 1:n
|
||||
for i in 1:n
|
||||
# choose push or pop
|
||||
if rand(Bool)
|
||||
# push
|
||||
@ -32,5 +32,28 @@ function test_random_walk(g::DAG, n::Int64)
|
||||
end
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function reduce_all!(g::DAG)
|
||||
reset_graph!(g)
|
||||
end
|
||||
|
||||
opt = get_operations(g)
|
||||
while (!isempty(opt.nodeReductions))
|
||||
push_operation!(g, pop!(opt.nodeReductions))
|
||||
|
||||
if (isempty(opt.nodeReductions))
|
||||
opt = get_operations(g)
|
||||
end
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function reduce_one!(g::DAG)
|
||||
opt = get_operations(g)
|
||||
if !isempty(opt.nodeReductions)
|
||||
push_operation!(g, pop!(opt.nodeReductions))
|
||||
end
|
||||
opt = get_operations(g)
|
||||
return nothing
|
||||
end
|
||||
|
Binary file not shown.
678
notebooks/abc_model_large.ipynb
Normal file
678
notebooks/abc_model_large.ipynb
Normal file
@ -0,0 +1,678 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"using MetagraphOptimization"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Graph:\n",
|
||||
" Nodes: Total: 438436, ComputeTaskP: 10, ComputeTaskU: 10, \n",
|
||||
" ComputeTaskV: 109600, ComputeTaskSum: 1, ComputeTaskS2: 40320, \n",
|
||||
" ComputeTaskS1: 69272, DataTask: 219223\n",
|
||||
" Edges: 628665\n",
|
||||
" Total Compute Effort: 1.903443e6\n",
|
||||
" Total Data Transfer: 1.8040896e7\n",
|
||||
" Total Compute Intensity: 0.10550712115407128\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = ABCModel()\n",
|
||||
"process_str = \"AB->ABBBBBBB\"\n",
|
||||
"process = parse_process(process_str, model)\n",
|
||||
"graph = parse_dag(\"../input/$process_str.txt\", model)\n",
|
||||
"print(graph)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"351.606942 seconds (1.13 G allocations: 25.949 GiB, 1.33% gc time, 0.72% compilation time)\n",
|
||||
"Graph:\n",
|
||||
" Nodes: Total: 277188, ComputeTaskP: 10, ComputeTaskU: 10, \n",
|
||||
" ComputeTaskV: 69288, ComputeTaskSum: 1, ComputeTaskS2: 40320, \n",
|
||||
" ComputeTaskS1: 28960, DataTask: 138599\n",
|
||||
" Edges: 427105\n",
|
||||
" Total Compute Effort: 1.218139e6\n",
|
||||
" Total Data Transfer: 1.2235968e7\n",
|
||||
" Total Compute Intensity: 0.0995539543745129\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"include(\"../examples/profiling_utilities.jl\")\n",
|
||||
"@time reduce_all!(graph)\n",
|
||||
"print(graph)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Found 1 NUMA nodes\n",
|
||||
"CUDA is non-functional\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Get machine and set dictionary caching strategy\n",
|
||||
"machine = get_machine_info()\n",
|
||||
"MetagraphOptimization.set_cache_strategy(machine.devices[1], MetagraphOptimization.Dictionary())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2315.896312 seconds (87.18 M allocations: 132.726 GiB, 0.11% gc time, 0.04% compilation time)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"compute__8fd7c454_6214_11ee_3616_0f2435e477fe (generic function with 1 method)"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@time compute_AB_AB7 = get_compute_function(graph, process, machine)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 1.910169 seconds (4.34 M allocations: 278.284 MiB, 6.25% gc time, 99.23% compilation time)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"1000-element Vector{ABCProcessInput}:\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [8.411745173347825, 0.0, 0.0, 8.352092962924948]\n",
|
||||
" B: [8.411745173347825, 0.0, 0.0, -8.352092962924948]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-2.003428483168789, 1.2386385417950023, -0.8321671195319228, 0.8871291535745444]\n",
|
||||
" B: [-2.444326994820653, 1.1775023368116424, -0.9536682034633904, 1.6366855721594777]\n",
|
||||
" B: [-4.289211829680359, -3.7216649121036443, 1.128125248220305, 1.50793959634144]\n",
|
||||
" B: [-1.2727607454602508, 0.07512513775641204, 0.6370236198332677, -0.45659285653208986]\n",
|
||||
" B: [-1.8777156401619268, -1.042329795325101, -0.5508846238377632, -1.0657817573524957]\n",
|
||||
" B: [-1.1322368113474306, 0.0498922458527246, -0.2963537951915457, -0.4377732162313449]\n",
|
||||
" B: [-1.4340705015357569, 0.7798902829682378, 0.144450581630926, -0.6538068364381232]\n",
|
||||
" B: [-2.369739340520482, 1.4429461622447262, 0.7234742923401235, -1.4177996555214083]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [8.262146117199348, 0.0, 0.0, 8.201405883258813]\n",
|
||||
" B: [8.262146117199348, 0.0, 0.0, -8.201405883258813]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-2.022253637967156, 0.040616190652067494, 1.5789161216660899, -0.7712872241073523]\n",
|
||||
" B: [-1.085155894223277, -0.4013306445746292, 0.044561160964560184, -0.12046298778597243]\n",
|
||||
" B: [-2.3099664718736963, -0.6028883246226666, 0.7721426580907682, 1.8374619682515352]\n",
|
||||
" B: [-3.8528592267292674, -1.1057919702708323, -3.154341441424319, -1.6345881470237529]\n",
|
||||
" B: [-1.445065980497648, -0.3803292238069696, -0.9038074225417192, 0.3559459403736899]\n",
|
||||
" B: [-1.637993216461692, 0.18276067729419151, -0.6165325663294264, 1.1267244146927589]\n",
|
||||
" B: [-3.0791604558286254, 1.8666082398498536, 2.1149851082876507, -0.7237684597886623]\n",
|
||||
" B: [-1.091837350817336, 0.4003550554789843, 0.16407638128639515, -0.0700255046122441]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [9.522164300929319, 0.0, 0.0, 9.4695096480173]\n",
|
||||
" B: [9.522164300929319, 0.0, 0.0, -9.4695096480173]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-2.2614545815907876, 0.09596466269330481, -1.680314037563078, -1.1320390202111377]\n",
|
||||
" B: [-2.5164555101345942, 2.0544568173259474, 0.7608284478099104, 0.7299969816600982]\n",
|
||||
" B: [-3.527555187469315, 3.1461533872404055, -0.4998113855480195, 1.1382236350884531]\n",
|
||||
" B: [-1.5843416170605953, -0.649775322646379, 0.6368565466386346, -0.8260412390634552]\n",
|
||||
" B: [-1.0715042390215452, 0.33101538188959895, -0.19275377509309963, -0.037364868271978664]\n",
|
||||
" B: [-1.8269658913133924, -1.2104472444295427, -0.7036857693244948, 0.6143681099517287]\n",
|
||||
" B: [-1.7510547915269752, 0.35168054121444203, 0.408535633181173, -1.3325210378384098]\n",
|
||||
" B: [-4.504996783741433, -4.119048223287777, 1.270344339898973, 0.8453774386847008]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [7.225275339000687, 0.0, 0.0, 7.1557392157883655]\n",
|
||||
" B: [7.225275339000687, 0.0, 0.0, -7.1557392157883655]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-1.5721586195862234, -0.6346644373772993, 0.7957285133297657, -0.6600756851617959]\n",
|
||||
" B: [-1.0093393293662618, -0.11321130994303012, 0.07324286826550051, -0.024177745030521003]\n",
|
||||
" B: [-2.7355755394886443, 0.2329840388558535, -2.4939308642531, -0.4576033371958622]\n",
|
||||
" B: [-1.618399027736879, -0.47727357006920945, 1.0132042772011558, -0.6040218911217943]\n",
|
||||
" B: [-1.7201610947708947, 0.01110230391313025, 0.8839000043421623, -1.0851505486038107]\n",
|
||||
" B: [-1.792300907703241, 0.8101193095744785, -0.625916307414256, 1.0790171565463333]\n",
|
||||
" B: [-1.5563810656498285, -1.1865287585293671, 0.12019738267353275, -0.004910793671790455]\n",
|
||||
" B: [-2.4462350936994026, 1.3574724235754438, 0.2335741258552372, 1.7569228442392408]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [7.94532861335446, 0.0, 0.0, 7.882147345374172]\n",
|
||||
" B: [7.94532861335446, 0.0, 0.0, -7.882147345374172]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-2.118671714766621, -0.6322452591326608, -1.2236882164873555, -1.2615953852509143]\n",
|
||||
" B: [-2.560753710001491, -1.7412395645571277, -1.5891033163317627, 0.01717533495153369]\n",
|
||||
" B: [-1.5550581087132076, -0.639122838128628, -0.9624327134008909, 0.2888788525193626]\n",
|
||||
" B: [-2.181477133464949, 0.4918918998013713, 1.8559068969600523, -0.2692479016749415]\n",
|
||||
" B: [-1.2628370388798702, -0.4013500667990802, 0.24813196852393224, 0.6100049482124643]\n",
|
||||
" B: [-1.901139724448186, 1.3625293914322611, -0.8176066997802711, 0.2989401174693193]\n",
|
||||
" B: [-2.2302691928842697, -0.1867565668705846, 1.9609184768063308, 0.3066290670808993]\n",
|
||||
" B: [-2.0804506035503256, 1.7462930042544484, 0.5278736037099664, 0.009214966692276028]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [5.597768901835826, 0.0, 0.0, 5.507723366179557]\n",
|
||||
" B: [5.597768901835826, 0.0, 0.0, -5.507723366179557]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-1.0009073340208385, 0.03522831505376105, -0.010844681575969111, -0.021374049609080487]\n",
|
||||
" B: [-1.3943823799403026, -0.886019044587247, 0.21582726795187737, -0.3356948979730148]\n",
|
||||
" B: [-1.0593061926863385, 0.3261714964515558, -0.10930051701751846, -0.06160488410736567]\n",
|
||||
" B: [-1.0190344437384602, 0.02512063114228613, 0.04379726771854621, -0.18942531709556668]\n",
|
||||
" B: [-1.0919277601624486, -0.39612686480944176, 0.07078221355247243, -0.17429750036714983]\n",
|
||||
" B: [-1.8292258091360047, 1.1565638126055895, 0.329244535677723, 0.9486966026643375]\n",
|
||||
" B: [-1.7379569022732355, 0.6562121276078657, 0.7749535141539342, -0.9946491284065995]\n",
|
||||
" B: [-2.0627969817140217, -0.9171504734643696, -1.3144596004610647, 0.8283491748944392]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [6.860362769879496, 0.0, 0.0, 6.787089017712134]\n",
|
||||
" B: [6.860362769879496, 0.0, 0.0, -6.787089017712134]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-2.1483538194490985, 1.8204047500578164, 0.1342978924269131, -0.532461036694855]\n",
|
||||
" B: [-1.2136825716769264, 0.12932805245115084, -0.43609629710270903, -0.5158678699965871]\n",
|
||||
" B: [-3.3642987422516573, -1.7653207470663739, 0.533955101409256, 2.630026736893018]\n",
|
||||
" B: [-1.053677321951765, 0.11000921943972916, 0.04739423847128557, -0.30965732123337875]\n",
|
||||
" B: [-1.2932387925896982, -0.6843810329952256, 0.045636429012288295, -0.4494513240410521]\n",
|
||||
" B: [-1.1237194151971648, -0.45140047643622017, 0.19994785657222267, -0.13785422959193222]\n",
|
||||
" B: [-1.7619597212239484, 1.3299261857304887, 0.561749934748497, 0.1422512233127988]\n",
|
||||
" B: [-1.7617951554187332, -0.488565951181366, -1.0868851555377534, -0.8269861786480115]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [9.57507915889135, 0.0, 0.0, 9.522717096450755]\n",
|
||||
" B: [9.57507915889135, 0.0, 0.0, -9.522717096450755]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-3.4305207411483516, 2.6682294806816835, -1.883054168339437, -0.3211401453721668]\n",
|
||||
" B: [-2.185574270107571, 1.4558232366821502, 1.2235951792097912, 0.40016050668089054]\n",
|
||||
" B: [-3.0259648593433583, -0.9184166853584697, -0.10930222461665634, -2.7020412923806107]\n",
|
||||
" B: [-3.246659025038245, -2.493839704051011, -1.0189869044243565, 1.5110340975546257]\n",
|
||||
" B: [-1.4247322676315595, 0.05954103854817788, 0.9940897925990366, -0.19519831815252583]\n",
|
||||
" B: [-1.4889906300188005, 0.5912092032645169, -0.19371449043911573, -0.9110650198822441]\n",
|
||||
" B: [-1.1268952499657272, 0.36236812621338876, -0.3636229828302436, 0.07975319340034331]\n",
|
||||
" B: [-3.220821274529085, -1.7249146959804351, 1.350995798840981, 2.1384969781516885]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [8.472852690841874, 0.0, 0.0, 8.413633740584764]\n",
|
||||
" B: [8.472852690841874, 0.0, 0.0, -8.413633740584764]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-1.1530011327357317, 0.34211475449117323, -0.45923141786607913, -0.03841369149190832]\n",
|
||||
" B: [-2.62915067223017, 1.042431210232047, 0.6288618003426715, -2.1048285595963105]\n",
|
||||
" B: [-1.1265473249385953, -0.4344882737979479, -0.1553035746380426, 0.2370856700921221]\n",
|
||||
" B: [-1.4826889242092416, -0.5889894099544346, -0.45026884678673923, -0.8054290077639529]\n",
|
||||
" B: [-4.118520088756618, -2.101194203160593, -3.0008966741533745, 1.5943054265577095]\n",
|
||||
" B: [-3.9992129109551517, 1.0607252636964415, 3.6847882851419875, 0.539352496783755]\n",
|
||||
" B: [-1.3172538577755006, 0.4084669000294691, -0.6351790575407871, 0.4060296568803221]\n",
|
||||
" B: [-1.1193304700827373, 0.2709337584638445, 0.3872294855003629, 0.17189800853826395]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [5.913538688235051, 0.0, 0.0, 5.828373685450576]\n",
|
||||
" B: [5.913538688235051, 0.0, 0.0, -5.828373685450576]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-1.6813734506828508, -1.1942921586618185, -0.384476919421686, 0.5028522833318558]\n",
|
||||
" B: [-1.412586238014363, 0.010275442474480664, 0.8780055986304257, -0.4737092609218783]\n",
|
||||
" B: [-1.5338446207986793, 1.1234162145644635, 0.1670274754582306, -0.25043392751132176]\n",
|
||||
" B: [-1.4260274101869397, 0.9023875675844153, -0.4646063309051003, -0.058239245843783906]\n",
|
||||
" B: [-1.1055189977833793, -0.3699146930280028, 0.2809292901965394, -0.08008812803177658]\n",
|
||||
" B: [-1.1926016738662872, 0.4242726765633766, 0.34415633034138016, -0.3519202590308968]\n",
|
||||
" B: [-1.4188061371181722, 0.47356120240959365, 0.33662773751584696, 0.8218469496393668]\n",
|
||||
" B: [-2.0563188480194308, -1.3697062519065082, -1.1576631818156364, -0.1103084116315648]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [6.062750568659298, 0.0, 0.0, 5.979711068085032]\n",
|
||||
" B: [6.062750568659298, 0.0, 0.0, -5.979711068085032]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-1.1157392140073992, -0.0424317149721654, 0.4662958482482185, -0.16013033799016252]\n",
|
||||
" B: [-2.395340693850968, -1.171776361305547, -1.746409249879336, 0.5609384374776449]\n",
|
||||
" B: [-1.0289722654275464, 0.23139962589771268, 0.07055331234631396, 0.01613586906426155]\n",
|
||||
" B: [-1.212565238145815, -0.6377842504248107, 0.04163119753237706, 0.24862129848767983]\n",
|
||||
" B: [-1.8156755638105053, -0.3987185167288875, 1.2510245302740972, 0.7567290942527487]\n",
|
||||
" B: [-2.003891077687212, 1.2159250459117166, 0.38048599808923245, -1.1799729400359336]\n",
|
||||
" B: [-1.4663599649673638, 0.593985649692284, -0.7733488095969958, -0.44645740391848543]\n",
|
||||
" B: [-1.086957119421786, 0.20940052192969777, 0.3097671729860923, 0.20413598266224653]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [7.088363151833832, 0.0, 0.0, 7.017470496715726]\n",
|
||||
" B: [7.088363151833832, 0.0, 0.0, -7.017470496715726]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-3.1474601133746627, 0.14412280671945385, 2.7364508363525357, 1.1821889028802701]\n",
|
||||
" B: [-1.256451004773104, 0.1153142495225348, -0.7455659837621855, -0.09748392231091944]\n",
|
||||
" B: [-1.4964417911663928, -0.0996845872039782, -0.8492275192498467, 0.7128910421459969]\n",
|
||||
" B: [-3.2499484244824526, -0.8927423628721523, -1.0242747556675866, -2.777775559729678]\n",
|
||||
" B: [-1.0489067674373789, -0.31603136975662793, 0.016268502528308637, -0.008057042333727152]\n",
|
||||
" B: [-1.6957667777105587, 1.0857339287179024, 0.6252297389508089, 0.5530773670555896]\n",
|
||||
" B: [-1.243679438145053, 0.06348629097723194, -0.7145975145476898, 0.17904867473682565]\n",
|
||||
" B: [-1.0380719865780628, -0.10019895610436466, -0.044283304604344965, 0.2561105375556422]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [9.842517855137334, 0.0, 0.0, 9.791586068084028]\n",
|
||||
" B: [9.842517855137334, 0.0, 0.0, -9.791586068084028]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-1.0081083393933719, 0.09315850477843095, -0.05390640772287413, 0.06854207575149836]\n",
|
||||
" B: [-1.2533776879399583, -0.09567218890986252, -0.022562148977002077, -0.749195175056841]\n",
|
||||
" B: [-4.199102452438099, 3.1204551726062775, 2.23725963921713, 1.3747327844190023]\n",
|
||||
" B: [-5.1018332572388285, -4.999892707918183, 0.09407944148737099, -0.14465321518774693]\n",
|
||||
" B: [-3.7582268429742243, 2.1814891293707577, -1.5410280493623207, -2.4475715991095703]\n",
|
||||
" B: [-1.1792132348986593, 0.6125282131702711, -0.12369433042852651, -0.007263198361168502]\n",
|
||||
" B: [-1.3600169327450258, -0.07835376476887727, -0.6694537001487819, 0.6287594836317273]\n",
|
||||
" B: [-1.8251569626465018, -0.8337123583288142, 0.07930555593500455, 1.2766488439130985]\n",
|
||||
"\n",
|
||||
" ⋮\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [9.861596443743153, 0.0, 0.0, 9.810763702141012]\n",
|
||||
" B: [9.861596443743153, 0.0, 0.0, -9.810763702141012]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-1.8179384769334697, 0.9572508915748105, -0.9794338269553214, 0.6551949443563104]\n",
|
||||
" B: [-2.1028582035167607, -0.7676665378472812, 0.6218562087985972, -1.5639678917247444]\n",
|
||||
" B: [-3.1263866679666865, 2.3808322573838474, -1.6099851834448586, 0.7168535896041835]\n",
|
||||
" B: [-5.177179415841987, -1.3605325795287053, 4.805481256903438, -0.9270855911989424]\n",
|
||||
" B: [-1.2605754590213083, -0.023284320526100116, -0.14250915308265208, 0.7537900699744495]\n",
|
||||
" B: [-2.712925004518324, -1.4343063146086636, -1.452340398698398, 1.4810249296764189]\n",
|
||||
" B: [-2.3798188172675734, 0.6412170781802653, -1.487389994435021, -1.4283029321979925]\n",
|
||||
" B: [-1.1455108424201939, -0.39351047462817185, 0.24432109091421514, 0.3124928815103169]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [5.611571819338176, 0.0, 0.0, 5.521751378284825]\n",
|
||||
" B: [5.611571819338176, 0.0, 0.0, -5.521751378284825]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-1.0759150984150232, -0.3903007964405737, 0.045679777762273936, -0.05632002484775736]\n",
|
||||
" B: [-1.021003529021616, -0.07269336486556076, 0.11388411952175649, 0.15554513267817288]\n",
|
||||
" B: [-1.6939705353811365, -0.1440535362616654, -0.25084793375093056, -1.3363607550219565]\n",
|
||||
" B: [-1.185801144621379, -0.31618880274591826, 0.5459120200606805, -0.09016131075324207]\n",
|
||||
" B: [-1.197431131926246, 0.16472462054297168, -0.17198607315407527, -0.6141074056988615]\n",
|
||||
" B: [-1.0089442324730478, -0.12314856400749492, -0.027052115631495212, -0.04550910308256443]\n",
|
||||
" B: [-2.703474424566498, 0.16902217864171518, -0.14049660772763695, 2.502092358533033]\n",
|
||||
" B: [-1.3366035422714058, 0.7126382651365266, -0.11509318708057305, -0.5151788918068239]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [8.775111706253933, 0.0, 0.0, 8.717946171962454]\n",
|
||||
" B: [8.775111706253933, 0.0, 0.0, -8.717946171962454]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-2.2750151423103953, 1.8467170131598, 0.8729070809034145, 0.05799482008261441]\n",
|
||||
" B: [-1.5756212156561644, 1.0377655822554295, 0.3001332912880399, 0.5617337616455574]\n",
|
||||
" B: [-1.6945981163898138, -0.5153714693329569, 0.050834292767083435, 1.2662823142365867]\n",
|
||||
" B: [-2.630307241578496, -0.5126707368632603, 1.3344949978186418, -1.9684532002212756]\n",
|
||||
" B: [-3.0848917600353407, -2.827901193400985, -0.46541663267058264, -0.5503811129833626]\n",
|
||||
" B: [-2.812675339815945, 2.346626876124383, -1.1757879806725677, 0.14834923648401968]\n",
|
||||
" B: [-1.695817659938434, -0.3817827622891304, -0.19598317768122073, 1.3006267920675472]\n",
|
||||
" B: [-1.7812969367832734, -0.9933833096532803, -0.7211818717528079, -0.8161526113116866]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [6.832501783927461, 0.0, 0.0, 6.758925996589395]\n",
|
||||
" B: [6.832501783927461, 0.0, 0.0, -6.758925996589395]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-1.0114752465345387, -0.11558780230223581, -0.03776248532804595, -0.09108034372406744]\n",
|
||||
" B: [-1.031154612454516, -0.04425244057817861, -0.0789748074180023, -0.23470095032271823]\n",
|
||||
" B: [-2.2555952063288855, 1.7491237654517413, -0.4233804231771479, -0.9214254203222908]\n",
|
||||
" B: [-2.089561973736715, 0.9235335217807571, 1.3477207222453012, -0.8348676128969853]\n",
|
||||
" B: [-1.3199981586264844, -0.6902187266500668, -0.06216816149242132, -0.5119847340063199]\n",
|
||||
" B: [-1.0105028642371863, -0.09317036739551621, -0.041275823376393385, -0.1035935696630954]\n",
|
||||
" B: [-1.2426376312622325, -0.48126859609618416, 0.05225488689293943, -0.5565952280036419]\n",
|
||||
" B: [-3.704077874674367, -1.2481593542103167, -0.7564139083462295, 3.254247858939119]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [8.775903429741401, 0.0, 0.0, 8.718743086485969]\n",
|
||||
" B: [8.775903429741401, 0.0, 0.0, -8.718743086485969]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-1.7137666526922533, 1.1358800766324049, 0.08268488211087159, 0.7999598750311686]\n",
|
||||
" B: [-1.1669696745288112, -0.04351472671445914, 0.5992401461010018, 0.028912577361687116]\n",
|
||||
" B: [-3.5481649603318184, 0.4490928742123019, 1.0371640968528058, -3.21124287656006]\n",
|
||||
" B: [-1.276578701414564, -0.08287623449031867, -0.6317118623642547, -0.47299559576203803]\n",
|
||||
" B: [-4.955351547203613, -2.6459981607514886, 0.5026315754882429, 4.037519558961317]\n",
|
||||
" B: [-2.3130557250521284, 1.4242375193555785, -1.5228161303749386, 0.05296516521446809]\n",
|
||||
" B: [-1.4353464814836179, 0.25997106791735547, -0.029309860840599063, -0.9958792586507745]\n",
|
||||
" B: [-1.1425731167759967, -0.4967924161613736, -0.03788284697312998, -0.23923944559576807]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [8.907102929629284, 0.0, 0.0, 8.850789942090511]\n",
|
||||
" B: [8.907102929629284, 0.0, 0.0, -8.850789942090511]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-2.946046511363992, -0.9439001466724447, 2.1873638734369836, 1.4155146927582347]\n",
|
||||
" B: [-3.7848309582649415, -2.22832689875391, -0.18756115269295068, -2.885190709282662]\n",
|
||||
" B: [-1.0159875652570234, 0.04172671107403079, -0.15271016054388648, 0.08467125371989566]\n",
|
||||
" B: [-2.0867601165869685, -1.8155383548303043, -0.021995043965926685, -0.24063350631004576]\n",
|
||||
" B: [-4.34790862339958, 3.6266859724946396, -1.8990793068549607, 1.0700261868843775]\n",
|
||||
" B: [-1.1578951917200673, 0.35622580432348594, 0.23734793715600985, 0.3968506117802061]\n",
|
||||
" B: [-1.4421363377447174, 1.0156020669389267, -0.20020339434090184, -0.0907097523285523]\n",
|
||||
" B: [-1.0326405549212787, -0.052475154574424254, 0.03683724780563263, 0.24947122277854633]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [6.294285658794556, 0.0, 0.0, 6.214340830249562]\n",
|
||||
" B: [6.294285658794556, 0.0, 0.0, -6.214340830249562]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-1.06844272609547, -0.2848922847204133, 0.15179083391454987, -0.19330232226393051]\n",
|
||||
" B: [-2.114647837734541, -1.6956804594706658, -0.38950327120442063, 0.6668511518515798]\n",
|
||||
" B: [-1.494217345848325, 0.7529614584695401, -0.5432224448027106, -0.6088053006963738]\n",
|
||||
" B: [-1.3783311635115514, 0.9215501628423943, 0.0395584401371469, -0.2213079833313275]\n",
|
||||
" B: [-1.7816982863175768, 0.5393674002906785, 0.38766524831377364, 1.316528482874748]\n",
|
||||
" B: [-1.659172767477475, 0.17135237894801714, -1.2297516401309854, -0.45956886117628726]\n",
|
||||
" B: [-1.55277617510909, -0.23319042207457166, 1.041131562383322, 0.522284545863997]\n",
|
||||
" B: [-1.5392850154950812, -0.17146823428497893, 0.5423312713893238, -1.022679713122405]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [6.965556009635571, 0.0, 0.0, 6.8934005050751415]\n",
|
||||
" B: [6.965556009635571, 0.0, 0.0, -6.8934005050751415]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-1.0775179795487104, -0.05690318568456522, -0.2919638065794134, 0.269377354945329]\n",
|
||||
" B: [-3.216279237662679, -2.600571207682032, 0.23217633942174215, 1.5898351096286563]\n",
|
||||
" B: [-1.9852997763312183, 1.2696870590322706, -0.6412445999499571, -0.9581833525279955]\n",
|
||||
" B: [-1.9885313318262752, 0.8019078287339996, 1.2060162608136897, 0.9255946577864792]\n",
|
||||
" B: [-1.4288503016026572, 0.2805632486843285, 0.07929023042776773, -0.9780646743628009]\n",
|
||||
" B: [-1.3652585458391595, -0.12810083240879516, 0.7809145290728301, -0.4875382774777694]\n",
|
||||
" B: [-1.8158888731893035, 0.7439741257624499, -1.2924797037897653, -0.2710186621991885]\n",
|
||||
" B: [-1.0534859732711408, -0.3105570364376559, -0.07270924941689365, -0.0900021557927108]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [6.43062328219917, 0.0, 0.0, 6.352394493225528]\n",
|
||||
" B: [6.43062328219917, 0.0, 0.0, -6.352394493225528]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-2.125364788369443, -1.214725294501684, 0.4075454777366224, 1.369497946736289]\n",
|
||||
" B: [-1.1032249572940587, -0.2977536437640783, 0.35819035202044425, 0.012155070594697458]\n",
|
||||
" B: [-2.225917349319406, 1.3039585629995813, -0.8668848261688078, 1.2259326287114942]\n",
|
||||
" B: [-2.717025897056506, -0.9721840017189309, 0.6274004665152297, -2.2457641565164295]\n",
|
||||
" B: [-1.000557419196324, 0.013685057618434337, 0.015873673340379625, 0.025997976872664537]\n",
|
||||
" B: [-1.1652637249339481, 0.20750251779397902, -0.05219673300317853, -0.5586212982154317]\n",
|
||||
" B: [-1.4667402310584912, 0.9160649085291783, -0.533306342231441, -0.16654228923208916]\n",
|
||||
" B: [-1.057152197170161, 0.043451893043520345, 0.0433779317907512, 0.3373441210488047]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [8.156196486154876, 0.0, 0.0, 8.094661272762755]\n",
|
||||
" B: [8.156196486154876, 0.0, 0.0, -8.094661272762755]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-1.4617318374080812, 0.10404421660552193, -0.19476289320497314, -1.0430254938944576]\n",
|
||||
" B: [-2.745518719911882, 2.0283487429720055, -0.01415841484271091, -1.556751090431481]\n",
|
||||
" B: [-1.193795120882441, -0.223211890483827, 0.20666745479885903, 0.5767250694363129]\n",
|
||||
" B: [-1.0771186742980503, 0.29121400254582763, -0.18584437613704033, 0.20209134345899718]\n",
|
||||
" B: [-2.9756813564276348, 0.7747616688600099, 0.31071107817153876, 2.6754219325851647]\n",
|
||||
" B: [-1.8605025819101852, -0.3441559100391822, 0.5570133470539003, 1.4257498722017754]\n",
|
||||
" B: [-3.3546424693401353, -1.4228183303706836, -0.7768040014609222, -2.7614832317390525]\n",
|
||||
" B: [-1.6434022121313414, -1.2081825000896715, 0.0971778056213486, 0.48127159838273986]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [9.631814348202784, 0.0, 0.0, 9.579762399884718]\n",
|
||||
" B: [9.631814348202784, 0.0, 0.0, -9.579762399884718]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-2.4271747113709625, -0.9216752449526319, 0.35006248470601437, 1.9796838331313595]\n",
|
||||
" B: [-1.926574191117535, -0.6155920425308834, -0.36855158619622796, 1.4821957628346814]\n",
|
||||
" B: [-2.809711053334662, 0.053095841327541846, -2.415282611989454, -1.0286238083410733]\n",
|
||||
" B: [-2.069340346061984, 0.0706218659128716, 1.6880494984307581, 0.6539655271821153]\n",
|
||||
" B: [-1.600891859223819, 0.522182956459051, 1.0136801062226364, -0.5124766796267364]\n",
|
||||
" B: [-2.3653602811566903, 0.7359929823506941, 2.003935313635875, 0.19361520696286152]\n",
|
||||
" B: [-4.134587420071929, 0.11270979705086029, -1.0448676862999513, -3.871738776569513]\n",
|
||||
" B: [-1.9299888340679847, 0.04266384438249662, -1.2270255185096508, 1.1033789344263045]\n",
|
||||
"\n",
|
||||
" Input for ABC Process: 'AB->ABBBBBBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [7.383091586636561, 0.0, 0.0, 7.31505580133628]\n",
|
||||
" B: [7.383091586636561, 0.0, 0.0, -7.31505580133628]\n",
|
||||
" 8 Outgoing Particles:\n",
|
||||
" A: [-1.0026822379766207, 0.02425303574920085, -0.0683120173174935, 0.010813366763733786]\n",
|
||||
" B: [-3.2851307251831745, -2.830568076855887, -0.9156122597784988, 0.9703723169846757]\n",
|
||||
" B: [-2.028220232462834, 1.6810294384373135, 0.4923274291375999, -0.21314558638988076]\n",
|
||||
" B: [-1.5191535227395792, -0.17123543395193966, -1.1293131485074372, -0.05619309939470401]\n",
|
||||
" B: [-1.1059696544762567, 0.2375361941082015, -0.40208228112542477, -0.07124094550113935]\n",
|
||||
" B: [-1.371740281577803, -0.2278482692103191, -0.6986437390927988, -0.5845113276468179]\n",
|
||||
" B: [-1.2867512190171768, 0.6015837296464805, -0.16735271525316733, -0.5155761675681034]\n",
|
||||
" B: [-3.166535299839676, 0.6852493820769491, 2.888988731937221, 0.4594814427522358]\n"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@time inputs = [gen_process_input(process) for _ in 1:1000]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Internal error: stack overflow in type inference of materialize(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(MetagraphOptimization.compute__8fd7c454_6214_11ee_3616_0f2435e477fe), Tuple{Array{MetagraphOptimization.ABCProcessInput, 1}}}).\n",
|
||||
"This might be caused by recursion over very long tuples or argument lists.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "LoadError",
|
||||
"evalue": "StackOverflowError:",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"StackOverflowError:",
|
||||
"",
|
||||
"Stacktrace:",
|
||||
" [1] get",
|
||||
" @ ./iddict.jl:102 [inlined]",
|
||||
" [2] in",
|
||||
" @ ./iddict.jl:189 [inlined]",
|
||||
" [3] haskey",
|
||||
" @ ./abstractdict.jl:17 [inlined]",
|
||||
" [4] findall(sig::Type, table::Core.Compiler.CachedMethodTable{Core.Compiler.InternalMethodTable}; limit::Int64)",
|
||||
" @ Core.Compiler ./compiler/methodtable.jl:120",
|
||||
" [5] findall",
|
||||
" @ ./compiler/methodtable.jl:114 [inlined]",
|
||||
" [6] find_matching_methods(argtypes::Vector{Any}, atype::Any, method_table::Core.Compiler.CachedMethodTable{Core.Compiler.InternalMethodTable}, union_split::Int64, max_methods::Int64)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:336",
|
||||
" [7] abstract_call_gf_by_type(interp::Core.Compiler.NativeInterpreter, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, atype::Any, sv::Core.Compiler.InferenceState, max_methods::Int64)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:80",
|
||||
" [8] abstract_call_known(interp::Core.Compiler.NativeInterpreter, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:1949",
|
||||
" [9] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:2020",
|
||||
" [10] abstract_apply(interp::Core.Compiler.NativeInterpreter, argtypes::Vector{Any}, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:1566",
|
||||
" [11] abstract_call_known(interp::Core.Compiler.NativeInterpreter, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:1855",
|
||||
" [12] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Nothing)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:2020",
|
||||
" [13] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:1999",
|
||||
" [14] abstract_eval_statement_expr(interp::Core.Compiler.NativeInterpreter, e::Expr, vtypes::Vector{Core.Compiler.VarState}, sv::Core.Compiler.InferenceState, mi::Nothing)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:2183",
|
||||
" [15] abstract_eval_statement(interp::Core.Compiler.NativeInterpreter, e::Any, vtypes::Vector{Core.Compiler.VarState}, sv::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:2396",
|
||||
" [16] abstract_eval_basic_statement(interp::Core.Compiler.NativeInterpreter, stmt::Any, pc_vartable::Vector{Core.Compiler.VarState}, frame::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:2682",
|
||||
" [17] typeinf_local(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:2867",
|
||||
" [18] typeinf_nocycle(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:2955",
|
||||
" [19] _typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/typeinfer.jl:246",
|
||||
" [20] typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/typeinfer.jl:216",
|
||||
" [21] typeinf_edge(interp::Core.Compiler.NativeInterpreter, method::Method, atype::Any, sparams::Core.SimpleVector, caller::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/typeinfer.jl:932",
|
||||
" [22] abstract_call_method(interp::Core.Compiler.NativeInterpreter, method::Method, sig::Any, sparams::Core.SimpleVector, hardlimit::Bool, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:611",
|
||||
" [23] abstract_call_gf_by_type(interp::Core.Compiler.NativeInterpreter, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, atype::Any, sv::Core.Compiler.InferenceState, max_methods::Int64)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:152",
|
||||
"--- the last 16 lines are repeated 413 more times ---",
|
||||
" [6632] abstract_call_known(interp::Core.Compiler.NativeInterpreter, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:1949",
|
||||
" [6633] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:2020",
|
||||
" [6634] abstract_apply(interp::Core.Compiler.NativeInterpreter, argtypes::Vector{Any}, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:1566",
|
||||
" [6635] abstract_call_known(interp::Core.Compiler.NativeInterpreter, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:1855",
|
||||
" [6636] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Nothing)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:2020",
|
||||
" [6637] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:1999",
|
||||
" [6638] abstract_eval_statement_expr(interp::Core.Compiler.NativeInterpreter, e::Expr, vtypes::Vector{Core.Compiler.VarState}, sv::Core.Compiler.InferenceState, mi::Nothing)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:2183",
|
||||
" [6639] abstract_eval_statement(interp::Core.Compiler.NativeInterpreter, e::Any, vtypes::Vector{Core.Compiler.VarState}, sv::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:2396",
|
||||
" [6640] abstract_eval_basic_statement(interp::Core.Compiler.NativeInterpreter, stmt::Any, pc_vartable::Vector{Core.Compiler.VarState}, frame::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:2658",
|
||||
" [6641] typeinf_local(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:2867",
|
||||
" [6642] typeinf_nocycle(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:2955",
|
||||
" [6643] _typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/typeinfer.jl:246",
|
||||
" [6644] typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/typeinfer.jl:216",
|
||||
" [6645] typeinf_edge(interp::Core.Compiler.NativeInterpreter, method::Method, atype::Any, sparams::Core.SimpleVector, caller::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/typeinfer.jl:932",
|
||||
" [6646] abstract_call_method(interp::Core.Compiler.NativeInterpreter, method::Method, sig::Any, sparams::Core.SimpleVector, hardlimit::Bool, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:611",
|
||||
" [6647] abstract_call_gf_by_type(interp::Core.Compiler.NativeInterpreter, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, atype::Any, sv::Core.Compiler.InferenceState, max_methods::Int64)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:152",
|
||||
" [6648] abstract_call_known(interp::Core.Compiler.NativeInterpreter, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:1949",
|
||||
" [6649] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Nothing)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:2020",
|
||||
" [6650] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:1999",
|
||||
" [6651] abstract_eval_statement_expr(interp::Core.Compiler.NativeInterpreter, e::Expr, vtypes::Vector{Core.Compiler.VarState}, sv::Core.Compiler.InferenceState, mi::Nothing)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:2183",
|
||||
" [6652] abstract_eval_statement(interp::Core.Compiler.NativeInterpreter, e::Any, vtypes::Vector{Core.Compiler.VarState}, sv::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:2396",
|
||||
" [6653] abstract_eval_basic_statement(interp::Core.Compiler.NativeInterpreter, stmt::Any, pc_vartable::Vector{Core.Compiler.VarState}, frame::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:2682",
|
||||
" [6654] typeinf_local(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:2867",
|
||||
" [6655] typeinf_nocycle(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/abstractinterpretation.jl:2955",
|
||||
" [6656] _typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/typeinfer.jl:246",
|
||||
" [6657] typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)",
|
||||
" @ Core.Compiler ./compiler/typeinfer.jl:216",
|
||||
" [6658] typeinf",
|
||||
" @ ./compiler/typeinfer.jl:12 [inlined]",
|
||||
" [6659] typeinf_type(interp::Core.Compiler.NativeInterpreter, method::Method, atype::Any, sparams::Core.SimpleVector)",
|
||||
" @ Core.Compiler ./compiler/typeinfer.jl:1079",
|
||||
" [6660] return_type(interp::Core.Compiler.NativeInterpreter, t::DataType)",
|
||||
" @ Core.Compiler ./compiler/typeinfer.jl:1140",
|
||||
" [6661] return_type(f::Any, t::DataType)",
|
||||
" @ Core.Compiler ./compiler/typeinfer.jl:1112",
|
||||
" [6662] combine_eltypes(f::Function, args::Tuple{Vector{ABCProcessInput}})",
|
||||
" @ Base.Broadcast ./broadcast.jl:730",
|
||||
" [6663] copy(bc::Base.Broadcast.Broadcasted{Style}) where Style",
|
||||
" @ Base.Broadcast ./broadcast.jl:895",
|
||||
" [6664] materialize(bc::Base.Broadcast.Broadcasted)",
|
||||
" @ Base.Broadcast ./broadcast.jl:873",
|
||||
" [6665] var\"##core#293\"()",
|
||||
" @ Main ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:489",
|
||||
" [6666] var\"##sample#294\"(::Tuple{}, __params::BenchmarkTools.Parameters)",
|
||||
" @ Main ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:495",
|
||||
" [6667] _run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; verbose::Bool, pad::String, kwargs::Base.Pairs{Symbol, Integer, NTuple{4, Symbol}, NamedTuple{(:samples, :evals, :gctrial, :gcsample), Tuple{Int64, Int64, Bool, Bool}}})",
|
||||
" @ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:99",
|
||||
" [6668] #invokelatest#2",
|
||||
" @ ./essentials.jl:821 [inlined]",
|
||||
" [6669] invokelatest",
|
||||
" @ ./essentials.jl:816 [inlined]",
|
||||
" [6670] #run_result#45",
|
||||
" @ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:34 [inlined]",
|
||||
" [6671] run_result",
|
||||
" @ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:34 [inlined]",
|
||||
" [6672] run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; progressid::Nothing, nleaves::Float64, ndone::Float64, kwargs::Base.Pairs{Symbol, Integer, NTuple{5, Symbol}, NamedTuple{(:verbose, :samples, :evals, :gctrial, :gcsample), Tuple{Bool, Int64, Int64, Bool, Bool}}})",
|
||||
" @ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:117",
|
||||
" [6673] run (repeats 2 times)",
|
||||
" @ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:117 [inlined]",
|
||||
" [6674] #warmup#54",
|
||||
" @ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:169 [inlined]",
|
||||
" [6675] warmup(item::BenchmarkTools.Benchmark)",
|
||||
" @ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:168"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"using BenchmarkTools\n",
|
||||
"@benchmark compute_AB_AB7.(inputs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Julia 1.9.3",
|
||||
"language": "julia",
|
||||
"name": "julia-1.9"
|
||||
},
|
||||
"language_info": {
|
||||
"file_extension": ".jl",
|
||||
"mimetype": "application/julia",
|
||||
"name": "julia",
|
||||
"version": "1.9.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
409
notebooks/abc_model_showcase.ipynb
Normal file
409
notebooks/abc_model_showcase.ipynb
Normal file
@ -0,0 +1,409 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "20768e45-df62-4638-ba33-b0ccf239f1aa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"using Revise\n",
|
||||
"using MetagraphOptimization\n",
|
||||
"using BenchmarkTools"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "ff5f4a49",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Found 1 NUMA nodes\n",
|
||||
"CUDA is non-functional\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Machine(MetagraphOptimization.AbstractDevice[MetagraphOptimization.NumaNode(0x0000, 0x0001, MetagraphOptimization.LocalVariables(), -1.0, UUID(\"a89974f6-6212-11ee-0866-0f591a3b69ea\"))], [-1.0;;])"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Get our machine's info\n",
|
||||
"machine = get_machine_info()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "9df482a4-ca44-44c5-9ea7-7a2977d529be",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ABCModel()"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Create a model identifier\n",
|
||||
"model = ABCModel()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "30b16872-07f7-4d47-8ff8-8c3a849c9d4e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ABC Process: 'AB->ABBB'"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Create a process in our model\n",
|
||||
"process_str = \"AB->ABBB\"\n",
|
||||
"process = parse_process(process_str, model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "043bd9e2-f89a-4362-885a-8c89d4cdd76f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Total: 280, ComputeTaskP"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Graph:\n",
|
||||
" Nodes: \n",
|
||||
" Edges: 385\n",
|
||||
" Total Compute Effort: 1075.0\n",
|
||||
" Total Data Transfer: 10944.0\n",
|
||||
" Total Compute Intensity: 0.09822733918128655\n"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
": 6, ComputeTaskU: 6, \n",
|
||||
" ComputeTaskV: 64, ComputeTaskSum: 1, ComputeTaskS2: 24, \n",
|
||||
" ComputeTaskS1: 36, DataTask: 143"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Read the graph (of the same process) from a file\n",
|
||||
"graph = parse_dag(\"../input/$process_str.txt\", model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "02f01ad3-fd10-48d5-a0e0-c03dc83c80a4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Input for ABC Process: 'AB->ABBB':\n",
|
||||
" 2 Incoming particles:\n",
|
||||
" A: [5.77986599979293, 0.0, 0.0, 5.692701553354288]\n",
|
||||
" B: [5.77986599979293, 0.0, 0.0, -5.692701553354288]\n",
|
||||
" 4 Outgoing Particles:\n",
|
||||
" A: [-3.8835293143673746, -1.4292027910861678, 2.8576090179942106, 1.968057422378813]\n",
|
||||
" B: [-1.1554024905063585, -0.1464656500147254, -0.2082400426692148, 0.5197487980391896]\n",
|
||||
" B: [-2.849749730594798, -1.0177034035100576, -2.464951858896686, -0.09677625137882176]\n",
|
||||
" B: [-3.6710504641173287, 2.5933718446109513, -0.1844171164283155, -2.391029969039186]\n"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Generate some random input data for our process\n",
|
||||
"input_data = gen_process_input(process)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "083fb1be-ce2a-47f9-afb9-60a6fdfaed0b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"compute__af4450a2_6212_11ee_2601_cde7cf2aedc1 (generic function with 1 method)"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Get the function computing the result of the process from a ProcessInput\n",
|
||||
"AB_AB3_compute = get_compute_function(graph, process, machine)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "a40c9500-8f79-4f04-b3c5-59b72a6b7ba9",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"-1.8924431710735022e-13"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Actually compute a result using the generated function and the input data\n",
|
||||
"result = AB_AB3_compute(input_data)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "80c70010",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"include(\"../examples/profiling_utilities.jl\")\n",
|
||||
"\n",
|
||||
"# We can also mute the graph by applying some operations to it\n",
|
||||
"reduce_all!(graph)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "5b192b44",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# The result should be the same as before (we can use execute to save having to generate the function ourselves)\n",
|
||||
"@assert result ≈ execute(graph, process, machine, input_data)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "9b2f4a3f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"1000-element Vector{Float64}:\n",
|
||||
" -2.1491995259940396e-11\n",
|
||||
" -1.04995646459455e-11\n",
|
||||
" 5.821760691187782e-15\n",
|
||||
" -6.556969485683705e-14\n",
|
||||
" -1.3588086164373753e-14\n",
|
||||
" -1.8789662441593694e-13\n",
|
||||
" -2.131973301835892e-13\n",
|
||||
" -5.3359759072004825e-12\n",
|
||||
" -9.053914191490223e-13\n",
|
||||
" -5.61107901706923e-13\n",
|
||||
" -5.063492275603428e-11\n",
|
||||
" 2.9168508985811397e-15\n",
|
||||
" -1.6420151378194157e-13\n",
|
||||
" ⋮\n",
|
||||
" 1.0931677247833436e-13\n",
|
||||
" -7.704755306462797e-16\n",
|
||||
" -1.8385907037491397e-12\n",
|
||||
" -6.036215596560059e-14\n",
|
||||
" -9.98872401400362e-12\n",
|
||||
" 3.4861755637292935e-13\n",
|
||||
" -1.1051119822969222e-10\n",
|
||||
" -2.496572513216201e-12\n",
|
||||
" -3.8682427847201926e-11\n",
|
||||
" 7.904149696653438e-15\n",
|
||||
" -7.606811743178716e-11\n",
|
||||
" -5.100594937480292e-13"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Now we can generate a function and use it on lots of inputs\n",
|
||||
"inputs = [gen_process_input(process) for _ in 1:1000]\n",
|
||||
"AB_AB3_reduced_compute = get_compute_function(graph, process, machine)\n",
|
||||
"\n",
|
||||
"results = AB_AB3_reduced_compute.(inputs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "d43e4ff0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"BenchmarkTools.Trial: 879 samples with 1 evaluation.\n",
|
||||
" Range \u001b[90m(\u001b[39m\u001b[36m\u001b[1mmin\u001b[22m\u001b[39m … \u001b[35mmax\u001b[39m\u001b[90m): \u001b[39m\u001b[36m\u001b[1m4.567 ms\u001b[22m\u001b[39m … \u001b[35m14.334 ms\u001b[39m \u001b[90m┊\u001b[39m GC \u001b[90m(\u001b[39mmin … max\u001b[90m): \u001b[39m0.00% … 54.51%\n",
|
||||
" Time \u001b[90m(\u001b[39m\u001b[34m\u001b[1mmedian\u001b[22m\u001b[39m\u001b[90m): \u001b[39m\u001b[34m\u001b[1m4.998 ms \u001b[22m\u001b[39m\u001b[90m┊\u001b[39m GC \u001b[90m(\u001b[39mmedian\u001b[90m): \u001b[39m0.00%\n",
|
||||
" Time \u001b[90m(\u001b[39m\u001b[32m\u001b[1mmean\u001b[22m\u001b[39m ± \u001b[32mσ\u001b[39m\u001b[90m): \u001b[39m\u001b[32m\u001b[1m5.686 ms\u001b[22m\u001b[39m ± \u001b[32m 1.414 ms\u001b[39m \u001b[90m┊\u001b[39m GC \u001b[90m(\u001b[39mmean ± σ\u001b[90m): \u001b[39m9.09% ± 14.49%\n",
|
||||
"\n",
|
||||
" \u001b[39m \u001b[39m \u001b[39m▃\u001b[39m▇\u001b[39m█\u001b[34m▅\u001b[39m\u001b[39m▄\u001b[39m▁\u001b[39m \u001b[39m▁\u001b[39m \u001b[39m \u001b[32m \u001b[39m\u001b[39m \u001b[39m▁\u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m▁\u001b[39m▁\u001b[39m \u001b[39m▁\u001b[39m▁\u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \n",
|
||||
" \u001b[39m▆\u001b[39m█\u001b[39m█\u001b[39m█\u001b[39m█\u001b[34m█\u001b[39m\u001b[39m█\u001b[39m█\u001b[39m▇\u001b[39m█\u001b[39m▇\u001b[39m▇\u001b[32m█\u001b[39m\u001b[39m▆\u001b[39m█\u001b[39m█\u001b[39m▆\u001b[39m▆\u001b[39m▇\u001b[39m▅\u001b[39m▅\u001b[39m▄\u001b[39m▁\u001b[39m▄\u001b[39m▅\u001b[39m▅\u001b[39m▆\u001b[39m▅\u001b[39m▅\u001b[39m▄\u001b[39m▁\u001b[39m▄\u001b[39m▄\u001b[39m▁\u001b[39m▅\u001b[39m▄\u001b[39m▄\u001b[39m▆\u001b[39m▇\u001b[39m█\u001b[39m█\u001b[39m█\u001b[39m█\u001b[39m█\u001b[39m█\u001b[39m▄\u001b[39m▅\u001b[39m▆\u001b[39m▅\u001b[39m▅\u001b[39m▅\u001b[39m▁\u001b[39m▅\u001b[39m▄\u001b[39m▄\u001b[39m▅\u001b[39m▁\u001b[39m▄\u001b[39m \u001b[39m▇\n",
|
||||
" 4.57 ms\u001b[90m \u001b[39m\u001b[90mHistogram: \u001b[39m\u001b[90m\u001b[1mlog(\u001b[22m\u001b[39m\u001b[90mfrequency\u001b[39m\u001b[90m\u001b[1m)\u001b[22m\u001b[39m\u001b[90m by time\u001b[39m 10 ms \u001b[0m\u001b[1m<\u001b[22m\n",
|
||||
"\n",
|
||||
" Memory estimate\u001b[90m: \u001b[39m\u001b[33m6.17 MiB\u001b[39m, allocs estimate\u001b[90m: \u001b[39m\u001b[33m143006\u001b[39m."
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@benchmark results = AB_AB3_compute.($inputs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "e18d9546",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"BenchmarkTools.Trial: 1089 samples with 1 evaluation.\n",
|
||||
" Range \u001b[90m(\u001b[39m\u001b[36m\u001b[1mmin\u001b[22m\u001b[39m … \u001b[35mmax\u001b[39m\u001b[90m): \u001b[39m\u001b[36m\u001b[1m3.637 ms\u001b[22m\u001b[39m … \u001b[35m10.921 ms\u001b[39m \u001b[90m┊\u001b[39m GC \u001b[90m(\u001b[39mmin … max\u001b[90m): \u001b[39m 0.00% … 59.52%\n",
|
||||
" Time \u001b[90m(\u001b[39m\u001b[34m\u001b[1mmedian\u001b[22m\u001b[39m\u001b[90m): \u001b[39m\u001b[34m\u001b[1m4.098 ms \u001b[22m\u001b[39m\u001b[90m┊\u001b[39m GC \u001b[90m(\u001b[39mmedian\u001b[90m): \u001b[39m 0.00%\n",
|
||||
" Time \u001b[90m(\u001b[39m\u001b[32m\u001b[1mmean\u001b[22m\u001b[39m ± \u001b[32mσ\u001b[39m\u001b[90m): \u001b[39m\u001b[32m\u001b[1m4.587 ms\u001b[22m\u001b[39m ± \u001b[32m 1.334 ms\u001b[39m \u001b[90m┊\u001b[39m GC \u001b[90m(\u001b[39mmean ± σ\u001b[90m): \u001b[39m10.21% ± 15.77%\n",
|
||||
"\n",
|
||||
" \u001b[39m \u001b[39m▂\u001b[39m▆\u001b[39m▆\u001b[39m▇\u001b[34m█\u001b[39m\u001b[39m▆\u001b[39m▂\u001b[39m \u001b[39m \u001b[39m \u001b[32m \u001b[39m\u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m▁\u001b[39m▁\u001b[39m \u001b[39m▁\u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \n",
|
||||
" \u001b[39m▆\u001b[39m█\u001b[39m█\u001b[39m█\u001b[39m█\u001b[34m█\u001b[39m\u001b[39m█\u001b[39m█\u001b[39m▇\u001b[39m█\u001b[39m▇\u001b[32m▆\u001b[39m\u001b[39m▅\u001b[39m▇\u001b[39m▅\u001b[39m▅\u001b[39m▅\u001b[39m▄\u001b[39m▆\u001b[39m▄\u001b[39m▅\u001b[39m▅\u001b[39m▅\u001b[39m▅\u001b[39m▆\u001b[39m▄\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▁\u001b[39m▄\u001b[39m▆\u001b[39m▆\u001b[39m▆\u001b[39m█\u001b[39m█\u001b[39m▇\u001b[39m█\u001b[39m█\u001b[39m▆\u001b[39m▆\u001b[39m▆\u001b[39m█\u001b[39m█\u001b[39m▇\u001b[39m▆\u001b[39m▄\u001b[39m▄\u001b[39m \u001b[39m█\n",
|
||||
" 3.64 ms\u001b[90m \u001b[39m\u001b[90mHistogram: \u001b[39m\u001b[90m\u001b[1mlog(\u001b[22m\u001b[39m\u001b[90mfrequency\u001b[39m\u001b[90m\u001b[1m)\u001b[22m\u001b[39m\u001b[90m by time\u001b[39m 8.78 ms \u001b[0m\u001b[1m<\u001b[22m\n",
|
||||
"\n",
|
||||
" Memory estimate\u001b[90m: \u001b[39m\u001b[33m5.26 MiB\u001b[39m, allocs estimate\u001b[90m: \u001b[39m\u001b[33m123006\u001b[39m."
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@benchmark results = AB_AB3_reduced_compute.($inputs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "13efed12-3547-400b-a7a2-5dfae9a973a2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Set a different caching strategy\n",
|
||||
"MetagraphOptimization.set_cache_strategy(machine.devices[1], MetagraphOptimization.Dictionary())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "ef62716b-a219-4f6e-9150-f984d3734839",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"BenchmarkTools.Trial: 331 samples with 1 evaluation.\n",
|
||||
" Range \u001b[90m(\u001b[39m\u001b[36m\u001b[1mmin\u001b[22m\u001b[39m … \u001b[35mmax\u001b[39m\u001b[90m): \u001b[39m\u001b[36m\u001b[1m12.148 ms\u001b[22m\u001b[39m … \u001b[35m24.164 ms\u001b[39m \u001b[90m┊\u001b[39m GC \u001b[90m(\u001b[39mmin … max\u001b[90m): \u001b[39m 0.00% … 13.35%\n",
|
||||
" Time \u001b[90m(\u001b[39m\u001b[34m\u001b[1mmedian\u001b[22m\u001b[39m\u001b[90m): \u001b[39m\u001b[34m\u001b[1m15.412 ms \u001b[22m\u001b[39m\u001b[90m┊\u001b[39m GC \u001b[90m(\u001b[39mmedian\u001b[90m): \u001b[39m17.47%\n",
|
||||
" Time \u001b[90m(\u001b[39m\u001b[32m\u001b[1mmean\u001b[22m\u001b[39m ± \u001b[32mσ\u001b[39m\u001b[90m): \u001b[39m\u001b[32m\u001b[1m15.117 ms\u001b[22m\u001b[39m ± \u001b[32m 2.194 ms\u001b[39m \u001b[90m┊\u001b[39m GC \u001b[90m(\u001b[39mmean ± σ\u001b[90m): \u001b[39m12.31% ± 8.95%\n",
|
||||
"\n",
|
||||
" \u001b[39m \u001b[39m▄\u001b[39m█\u001b[39m▄\u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[32m▄\u001b[39m\u001b[39m▄\u001b[34m▂\u001b[39m\u001b[39m \u001b[39m▂\u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \u001b[39m \n",
|
||||
" \u001b[39m▅\u001b[39m█\u001b[39m█\u001b[39m█\u001b[39m▅\u001b[39m▃\u001b[39m▃\u001b[39m▂\u001b[39m▃\u001b[39m▂\u001b[39m▅\u001b[39m▂\u001b[39m▃\u001b[39m▁\u001b[39m▂\u001b[39m▂\u001b[39m▂\u001b[39m▃\u001b[39m▂\u001b[39m▃\u001b[32m█\u001b[39m\u001b[39m█\u001b[34m█\u001b[39m\u001b[39m▇\u001b[39m█\u001b[39m▄\u001b[39m▆\u001b[39m▄\u001b[39m▆\u001b[39m▄\u001b[39m▄\u001b[39m▆\u001b[39m▅\u001b[39m▄\u001b[39m▃\u001b[39m▄\u001b[39m▂\u001b[39m▂\u001b[39m▃\u001b[39m▃\u001b[39m▄\u001b[39m▃\u001b[39m▂\u001b[39m▂\u001b[39m▁\u001b[39m▂\u001b[39m▂\u001b[39m▃\u001b[39m▂\u001b[39m▂\u001b[39m▁\u001b[39m▂\u001b[39m▁\u001b[39m▃\u001b[39m▃\u001b[39m▂\u001b[39m▂\u001b[39m▁\u001b[39m▂\u001b[39m \u001b[39m▃\n",
|
||||
" 12.1 ms\u001b[90m Histogram: frequency by time\u001b[39m 21 ms \u001b[0m\u001b[1m<\u001b[22m\n",
|
||||
"\n",
|
||||
" Memory estimate\u001b[90m: \u001b[39m\u001b[33m27.46 MiB\u001b[39m, allocs estimate\u001b[90m: \u001b[39m\u001b[33m118013\u001b[39m."
|
||||
]
|
||||
},
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# ... and bench again\n",
|
||||
"AB_AB3_reduced_dict_compute = get_compute_function(graph, process, machine)\n",
|
||||
"@benchmark results = AB_AB3_reduced_dict_compute.($inputs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5461ffd4-6a0e-4f1f-b1f1-3a2854a8ae88",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Julia 1.9.3",
|
||||
"language": "julia",
|
||||
"name": "julia-1.9"
|
||||
},
|
||||
"language_info": {
|
||||
"file_extension": ".jl",
|
||||
"mimetype": "application/julia",
|
||||
"name": "julia",
|
||||
"version": "1.9.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
70
notebooks/profiling.ipynb
Normal file
70
notebooks/profiling.ipynb
Normal file
@ -0,0 +1,70 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"using Revise; using MetagraphOptimization; using BenchmarkTools; using ProfileView\n",
|
||||
"using Base.Threads\n",
|
||||
"nthreads()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = ABCModel()\n",
|
||||
"process_str = \"AB->ABBBBB\"\n",
|
||||
"process = parse_process(process_str, model)\n",
|
||||
"graph = parse_dag(\"../input/$process_str.txt\", model)\n",
|
||||
"print(graph)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"include(\"../examples/profiling_utilities.jl\")\n",
|
||||
"@ProfileView.profview reduce_all!(graph)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@ProfileView.profview comp_func = get_compute_function(graph, process)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Julia 1.9.3",
|
||||
"language": "julia",
|
||||
"name": "julia-1.9"
|
||||
},
|
||||
"language_info": {
|
||||
"file_extension": ".jl",
|
||||
"mimetype": "application/julia",
|
||||
"name": "julia",
|
||||
"version": "1.9.3"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -6,20 +6,20 @@ julia --project=./examples -t 4 -e 'import Pkg; Pkg.instantiate()'
|
||||
|
||||
#for i in $(seq $minthreads $maxthreads)
|
||||
# printf "(AB->AB, $i) "
|
||||
# julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("input/AB->AB.txt"))'
|
||||
# julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_dag("input/AB->AB.txt"), ABCModel())'
|
||||
#end
|
||||
|
||||
#for i in $(seq $minthreads $maxthreads)
|
||||
# printf "(AB->ABBB, $i) "
|
||||
# julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("input/AB->ABBB.txt"))'
|
||||
# julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_dag("input/AB->ABBB.txt"), ABCModel())'
|
||||
#end
|
||||
|
||||
#for i in $(seq $minthreads $maxthreads)
|
||||
# printf "(AB->ABBBBB, $i) "
|
||||
# julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("input/AB->ABBBBB.txt"))'
|
||||
# julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_dag("input/AB->ABBBBB.txt"), ABCModel())'
|
||||
#end
|
||||
|
||||
for i in $(seq $minthreads $maxthreads)
|
||||
printf "(AB->ABBBBBBB, $i) "
|
||||
julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_abc("input/AB->ABBBBBBB.txt"))'
|
||||
julia --project=./examples -t $i -O3 -e 'using MetagraphOptimization; using BenchmarkTools; @btime get_operations(graph) setup=(graph = parse_dag("input/AB->ABBBBBBB.txt"), ABCModel())'
|
||||
end
|
||||
|
@ -1,10 +1,64 @@
|
||||
"""
|
||||
MetagraphOptimization
|
||||
|
||||
A module containing tools to work on DAGs.
|
||||
"""
|
||||
module MetagraphOptimization
|
||||
|
||||
export Node, Edge, ComputeTaskNode, DataTaskNode, DAG
|
||||
export AbstractTask, AbstractComputeTask, AbstractDataTask, DataTask, FusedComputeTask
|
||||
export make_node, make_edge, insert_node, insert_edge, is_entry_node, is_exit_node, parents, children, compute, graph_properties, get_exit_node, is_valid
|
||||
export NodeFusion, NodeReduction, NodeSplit, push_operation!, pop_operation!, can_pop, reset_graph!, get_operations
|
||||
export parse_abc, ComputeTaskP, ComputeTaskS1, ComputeTaskS2, ComputeTaskV, ComputeTaskU, ComputeTaskSum
|
||||
export DAG
|
||||
export Node
|
||||
export Edge
|
||||
export ComputeTaskNode
|
||||
export DataTaskNode
|
||||
export AbstractTask
|
||||
export AbstractComputeTask
|
||||
export AbstractDataTask
|
||||
export DataTask
|
||||
export FusedComputeTask
|
||||
export PossibleOperations
|
||||
export GraphProperties
|
||||
|
||||
export make_node
|
||||
export make_edge
|
||||
export insert_node
|
||||
export insert_edge
|
||||
export is_entry_node
|
||||
export is_exit_node
|
||||
export parents
|
||||
export children
|
||||
export compute
|
||||
export get_properties
|
||||
export get_exit_node
|
||||
export is_valid, is_scheduled
|
||||
|
||||
export Operation
|
||||
export AppliedOperation
|
||||
export NodeFusion
|
||||
export NodeReduction
|
||||
export NodeSplit
|
||||
export push_operation!
|
||||
export pop_operation!
|
||||
export can_pop
|
||||
export reset_graph!
|
||||
export get_operations
|
||||
|
||||
export ComputeTaskP
|
||||
export ComputeTaskS1
|
||||
export ComputeTaskS2
|
||||
export ComputeTaskV
|
||||
export ComputeTaskU
|
||||
export ComputeTaskSum
|
||||
|
||||
export execute
|
||||
export parse_dag, parse_process
|
||||
export gen_process_input
|
||||
export get_compute_function
|
||||
export ParticleValue
|
||||
export ParticleA, ParticleB, ParticleC
|
||||
export ABCProcessDescription, ABCProcessInput, ABCModel
|
||||
|
||||
export Machine
|
||||
export get_machine_info
|
||||
|
||||
export ==, in, show, isempty, delete!, length
|
||||
|
||||
@ -13,6 +67,8 @@ export bytes_to_human_readable
|
||||
import Base.length
|
||||
import Base.show
|
||||
import Base.==
|
||||
import Base.+
|
||||
import Base.-
|
||||
import Base.in
|
||||
import Base.copy
|
||||
import Base.isempty
|
||||
@ -21,29 +77,75 @@ import Base.insert!
|
||||
import Base.collect
|
||||
|
||||
|
||||
include("tasks.jl")
|
||||
include("nodes.jl")
|
||||
include("graph.jl")
|
||||
include("devices/interface.jl")
|
||||
include("task/type.jl")
|
||||
include("node/type.jl")
|
||||
include("diff/type.jl")
|
||||
include("properties/type.jl")
|
||||
include("operation/type.jl")
|
||||
include("graph/type.jl")
|
||||
|
||||
include("trie.jl")
|
||||
include("utility.jl")
|
||||
|
||||
include("task_functions.jl")
|
||||
include("node_functions.jl")
|
||||
include("graph_functions.jl")
|
||||
include("diff/print.jl")
|
||||
include("diff/properties.jl")
|
||||
|
||||
include("operations/utility.jl")
|
||||
include("operations/apply.jl")
|
||||
include("operations/clean.jl")
|
||||
include("operations/find.jl")
|
||||
include("operations/get.jl")
|
||||
include("operations/print.jl")
|
||||
include("operations/validate.jl")
|
||||
include("graph/compare.jl")
|
||||
include("graph/interface.jl")
|
||||
include("graph/mute.jl")
|
||||
include("graph/print.jl")
|
||||
include("graph/properties.jl")
|
||||
include("graph/validate.jl")
|
||||
|
||||
include("graph_interface.jl")
|
||||
include("node/compare.jl")
|
||||
include("node/create.jl")
|
||||
include("node/print.jl")
|
||||
include("node/properties.jl")
|
||||
include("node/validate.jl")
|
||||
|
||||
include("abc_model/tasks.jl")
|
||||
include("abc_model/task_functions.jl")
|
||||
include("abc_model/parse.jl")
|
||||
include("operation/utility.jl")
|
||||
include("operation/apply.jl")
|
||||
include("operation/clean.jl")
|
||||
include("operation/find.jl")
|
||||
include("operation/get.jl")
|
||||
include("operation/print.jl")
|
||||
include("operation/validate.jl")
|
||||
|
||||
include("properties/create.jl")
|
||||
include("properties/utility.jl")
|
||||
|
||||
include("task/create.jl")
|
||||
include("task/compare.jl")
|
||||
include("task/compute.jl")
|
||||
include("task/print.jl")
|
||||
include("task/properties.jl")
|
||||
|
||||
include("models/interface.jl")
|
||||
include("models/print.jl")
|
||||
|
||||
include("models/abc/types.jl")
|
||||
include("models/abc/particle.jl")
|
||||
include("models/abc/compute.jl")
|
||||
include("models/abc/create.jl")
|
||||
include("models/abc/properties.jl")
|
||||
include("models/abc/parse.jl")
|
||||
include("models/abc/print.jl")
|
||||
|
||||
include("devices/measure.jl")
|
||||
include("devices/detect.jl")
|
||||
include("devices/impl.jl")
|
||||
|
||||
include("devices/numa/impl.jl")
|
||||
include("devices/cuda/impl.jl")
|
||||
# can currently not use AMDGPU because of incompatability with the newest rocm drivers
|
||||
# include("devices/rocm/impl.jl")
|
||||
# oneapi seems also broken for now
|
||||
# include("devices/oneapi/impl.jl")
|
||||
|
||||
include("scheduler/interface.jl")
|
||||
include("scheduler/greedy.jl")
|
||||
|
||||
include("code_gen/main.jl")
|
||||
|
||||
end # module MetagraphOptimization
|
||||
|
@ -1,152 +0,0 @@
|
||||
using Printf
|
||||
|
||||
# functions for importing DAGs from a file
|
||||
regex_a = r"^[A-C]\d+$" # Regex for the initial particles
|
||||
regex_c = r"^[A-C]\(([^']*),([^']*)\)$" # Regex for the combinations of 2 particles
|
||||
regex_m = r"^M\(([^']*),([^']*),([^']*)\)$" # Regex for the combinations of 3 particles
|
||||
regex_plus = r"^\+$" # Regex for the sum
|
||||
|
||||
function parse_nodes(input::AbstractString)
|
||||
regex = r"'([^']*)'"
|
||||
matches = eachmatch(regex, input)
|
||||
output = [match.captures[1] for match in matches]
|
||||
return output
|
||||
end
|
||||
|
||||
function parse_edges(input::AbstractString)
|
||||
regex = r"\('([^']*)', '([^']*)'\)"
|
||||
matches = eachmatch(regex, input)
|
||||
output = [(match.captures[1], match.captures[2]) for match in matches]
|
||||
return output
|
||||
end
|
||||
|
||||
# reads an abc-model process from the given file
|
||||
function parse_abc(filename::String, verbose::Bool = false)
|
||||
file = open(filename, "r")
|
||||
|
||||
if (verbose) println("Opened file") end
|
||||
nodes_string = readline(file)
|
||||
nodes = parse_nodes(nodes_string)
|
||||
|
||||
close(file)
|
||||
if (verbose) println("Read file") end
|
||||
|
||||
graph = DAG()
|
||||
|
||||
# estimate total number of nodes
|
||||
# try to slightly overestimate so no resizing is necessary
|
||||
# data nodes are not included in length(nodes) and there are a few more than compute nodes
|
||||
estimate_no_nodes = round(Int, length(nodes) * 4)
|
||||
if (verbose) println("Estimating ", estimate_no_nodes, " Nodes") end
|
||||
sizehint!(graph.nodes, estimate_no_nodes)
|
||||
|
||||
sum_node = insert_node!(graph, make_node(ComputeTaskSum()), false, false)
|
||||
global_data_out = insert_node!(graph, make_node(DataTask(10)), false, false)
|
||||
insert_edge!(graph, sum_node, global_data_out, false, false)
|
||||
|
||||
# remember the data out nodes for connection
|
||||
dataOutNodes = Dict()
|
||||
|
||||
if (verbose) println("Building graph") end
|
||||
noNodes = 0
|
||||
nodesToRead = length(nodes)
|
||||
while !isempty(nodes)
|
||||
node = popfirst!(nodes)
|
||||
noNodes += 1
|
||||
if (noNodes % 100 == 0)
|
||||
if (verbose) @printf "\rReading Nodes... %.2f%%" (100. * noNodes / nodesToRead) end
|
||||
end
|
||||
if occursin(regex_a, node)
|
||||
# add nodes and edges for the state reading to u(P(Particle))
|
||||
data_in = insert_node!(graph, make_node(DataTask(4)), false, false) # read particle data node
|
||||
compute_P = insert_node!(graph, make_node(ComputeTaskP()), false, false) # compute P node
|
||||
data_Pu = insert_node!(graph, make_node(DataTask(6)), false, false) # transfer data from P to u
|
||||
compute_u = insert_node!(graph, make_node(ComputeTaskU()), false, false) # compute U node
|
||||
data_out = insert_node!(graph, make_node(DataTask(3)), false, false) # transfer data out from u
|
||||
|
||||
insert_edge!(graph, data_in, compute_P, false, false)
|
||||
insert_edge!(graph, compute_P, data_Pu, false, false)
|
||||
insert_edge!(graph, data_Pu, compute_u, false, false)
|
||||
insert_edge!(graph, compute_u, data_out, false, false)
|
||||
|
||||
# remember the data_out node for future edges
|
||||
dataOutNodes[node] = data_out
|
||||
elseif occursin(regex_c, node)
|
||||
capt = match(regex_c, node)
|
||||
|
||||
in1 = capt.captures[1]
|
||||
in2 = capt.captures[2]
|
||||
|
||||
compute_v = insert_node!(graph, make_node(ComputeTaskV()), false, false)
|
||||
data_out = insert_node!(graph, make_node(DataTask(5)), false, false)
|
||||
|
||||
if (occursin(regex_c, in1))
|
||||
# put an S node after this input
|
||||
compute_S = insert_node!(graph, make_node(ComputeTaskS1()), false, false)
|
||||
data_S_v = insert_node!(graph, make_node(DataTask(5)), false, false)
|
||||
|
||||
insert_edge!(graph, dataOutNodes[in1], compute_S, false, false)
|
||||
insert_edge!(graph, compute_S, data_S_v, false, false)
|
||||
|
||||
insert_edge!(graph, data_S_v, compute_v, false, false)
|
||||
else
|
||||
insert_edge!(graph, dataOutNodes[in1], compute_v, false, false)
|
||||
end
|
||||
|
||||
if (occursin(regex_c, in2))
|
||||
# i think the current generator only puts the combined particles in the first space, so this case might never be entered
|
||||
# put an S node after this input
|
||||
compute_S = insert_node!(graph, make_node(ComputeTaskS1()), false, false)
|
||||
data_S_v = insert_node!(graph, make_node(DataTask(5)), false, false)
|
||||
|
||||
insert_edge!(graph, dataOutNodes[in2], compute_S, false, false)
|
||||
insert_edge!(graph, compute_S, data_S_v, false, false)
|
||||
|
||||
insert_edge!(graph, data_S_v, compute_v, false, false)
|
||||
else
|
||||
insert_edge!(graph, dataOutNodes[in2], compute_v, false, false)
|
||||
end
|
||||
|
||||
insert_edge!(graph, compute_v, data_out, false, false)
|
||||
dataOutNodes[node] = data_out
|
||||
|
||||
elseif occursin(regex_m, node)
|
||||
# assume for now that only the first particle of the three is combined and the other two are "original" ones
|
||||
capt = match(regex_m, node)
|
||||
in1 = capt.captures[1]
|
||||
in2 = capt.captures[2]
|
||||
in3 = capt.captures[3]
|
||||
|
||||
# in2 + in3 with a v
|
||||
compute_v = insert_node!(graph, make_node(ComputeTaskV()), false, false)
|
||||
data_v = insert_node!(graph, make_node(DataTask(5)), false, false)
|
||||
|
||||
insert_edge!(graph, dataOutNodes[in2], compute_v, false, false)
|
||||
insert_edge!(graph, dataOutNodes[in3], compute_v, false, false)
|
||||
insert_edge!(graph, compute_v, data_v, false, false)
|
||||
|
||||
# combine with the v of the combined other input
|
||||
compute_S2 = insert_node!(graph, make_node(ComputeTaskS2()), false, false)
|
||||
data_out = insert_node!(graph, make_node(DataTask(10)), false, false)
|
||||
|
||||
insert_edge!(graph, data_v, compute_S2, false, false)
|
||||
insert_edge!(graph, dataOutNodes[in1], compute_S2, false, false)
|
||||
insert_edge!(graph, compute_S2, data_out, false, false)
|
||||
|
||||
insert_edge!(graph, data_out, sum_node, false, false)
|
||||
elseif occursin(regex_plus, node)
|
||||
if (verbose)
|
||||
println("\rReading Nodes Complete ")
|
||||
println("Added ", length(graph.nodes), " nodes")
|
||||
end
|
||||
else
|
||||
@assert false ("Unknown node '$node' while reading from file $filename")
|
||||
end
|
||||
end
|
||||
|
||||
#put all nodes into dirty nodes set
|
||||
graph.dirtyNodes = copy(graph.nodes)
|
||||
|
||||
# don't actually need to read the edges
|
||||
return graph
|
||||
end
|
@ -1,21 +0,0 @@
|
||||
# define compute_efforts tasks computation
|
||||
# put some "random" numbers here for now
|
||||
compute_effort(t::ComputeTaskS1) = 10
|
||||
compute_effort(t::ComputeTaskS2) = 10
|
||||
compute_effort(t::ComputeTaskU) = 6
|
||||
compute_effort(t::ComputeTaskV) = 20
|
||||
compute_effort(t::ComputeTaskP) = 15
|
||||
compute_effort(t::ComputeTaskSum) = 1
|
||||
|
||||
function show(io::IO, t::DataTask)
|
||||
print(io, "Data", t.data)
|
||||
end
|
||||
|
||||
show(io::IO, t::ComputeTaskS1) = print("ComputeS1")
|
||||
show(io::IO, t::ComputeTaskS2) = print("ComputeS2")
|
||||
show(io::IO, t::ComputeTaskP) = print("ComputeP")
|
||||
show(io::IO, t::ComputeTaskU) = print("ComputeU")
|
||||
show(io::IO, t::ComputeTaskV) = print("ComputeV")
|
||||
show(io::IO, t::ComputeTaskSum) = print("ComputeSum")
|
||||
|
||||
copy(t::DataTask) = DataTask(t.data)
|
@ -1,29 +0,0 @@
|
||||
struct DataTask <: AbstractDataTask
|
||||
data::UInt64
|
||||
end
|
||||
|
||||
# S task with 1 child
|
||||
struct ComputeTaskS1 <: AbstractComputeTask
|
||||
end
|
||||
|
||||
# S task with 2 children
|
||||
struct ComputeTaskS2 <: AbstractComputeTask
|
||||
end
|
||||
|
||||
# P task with 0 children
|
||||
struct ComputeTaskP <: AbstractComputeTask
|
||||
end
|
||||
|
||||
# v task with 2 children
|
||||
struct ComputeTaskV <: AbstractComputeTask
|
||||
end
|
||||
|
||||
# u task with 1 child
|
||||
struct ComputeTaskU <: AbstractComputeTask
|
||||
end
|
||||
|
||||
# task that sums all its inputs, n children
|
||||
struct ComputeTaskSum <: AbstractComputeTask
|
||||
end
|
||||
|
||||
ABC_TASKS = [DataTask, ComputeTaskS1, ComputeTaskS2, ComputeTaskP, ComputeTaskV, ComputeTaskU, ComputeTaskSum]
|
157
src/code_gen/main.jl
Normal file
157
src/code_gen/main.jl
Normal file
@ -0,0 +1,157 @@
|
||||
"""
|
||||
gen_code(graph::DAG)
|
||||
|
||||
Generate the code for a given graph. The return value is a named tuple of:
|
||||
|
||||
- `code::Expr`: The julia expression containing the code for the whole graph.
|
||||
- `inputSymbols::Dict{String, Vector{Symbol}}`: A dictionary of symbols mapping the names of the input nodes of the graph to the symbols their inputs should be provided on.
|
||||
- `outputSymbol::Symbol`: The symbol of the final calculated value
|
||||
|
||||
See also: [`execute`](@ref)
|
||||
"""
|
||||
function gen_code(graph::DAG, machine::Machine)
|
||||
sched = schedule_dag(GreedyScheduler(), graph, machine)
|
||||
|
||||
codeAcc = Vector{Expr}()
|
||||
sizehint!(codeAcc, length(graph.nodes))
|
||||
|
||||
for node in sched
|
||||
# TODO: this is kind of ugly, should init nodes be scheduled differently from the rest?
|
||||
if (node isa DataTaskNode && length(node.children) == 0)
|
||||
push!(codeAcc, get_init_expression(node, entry_device(machine)))
|
||||
continue
|
||||
end
|
||||
push!(codeAcc, get_expression(node))
|
||||
end
|
||||
|
||||
# get inSymbols
|
||||
inputSyms = Dict{String, Vector{Symbol}}()
|
||||
for node in get_entry_nodes(graph)
|
||||
if !haskey(inputSyms, node.name)
|
||||
inputSyms[node.name] = Vector{Symbol}()
|
||||
end
|
||||
|
||||
push!(inputSyms[node.name], Symbol("$(to_var_name(node.id))_in"))
|
||||
end
|
||||
|
||||
# get outSymbol
|
||||
outSym = Symbol(to_var_name(get_exit_node(graph).id))
|
||||
|
||||
return (code = Expr(:block, codeAcc...), inputSymbols = inputSyms, outputSymbol = outSym)
|
||||
end
|
||||
|
||||
function gen_cache_init_code(machine::Machine)
|
||||
initializeCaches = Vector{Expr}()
|
||||
|
||||
for device in machine.devices
|
||||
push!(initializeCaches, gen_cache_init_code(device))
|
||||
end
|
||||
|
||||
return Expr(:block, initializeCaches...)
|
||||
end
|
||||
|
||||
function gen_input_assignment_code(
|
||||
inputSymbols::Dict{String, Vector{Symbol}},
|
||||
processDescription::AbstractProcessDescription,
|
||||
machine::Machine,
|
||||
processInputSymbol::Symbol = :input,
|
||||
)
|
||||
@assert length(inputSymbols) >=
|
||||
sum(values(in_particles(processDescription))) + sum(values(out_particles(processDescription))) "Number of input Symbols is smaller than the number of particles in the process description"
|
||||
|
||||
assignInputs = Vector{Expr}()
|
||||
for (name, symbols) in inputSymbols
|
||||
type = type_from_name(name)
|
||||
index = parse(Int, name[2:end])
|
||||
|
||||
p = nothing
|
||||
|
||||
if (index > in_particles(processDescription)[type])
|
||||
index -= in_particles(processDescription)[type]
|
||||
@assert index <= out_particles(processDescription)[type] "Too few particles of type $type in input particles for this process"
|
||||
|
||||
p = "filter(x -> typeof(x) <: $type, out_particles($(processInputSymbol)))[$(index)]"
|
||||
else
|
||||
p = "filter(x -> typeof(x) <: $type, in_particles($(processInputSymbol)))[$(index)]"
|
||||
end
|
||||
|
||||
for symbol in symbols
|
||||
# TODO: how to get the "default" cpu device?
|
||||
device = entry_device(machine)
|
||||
evalExpr = eval(gen_access_expr(device, symbol))
|
||||
push!(assignInputs, Meta.parse("$(evalExpr) = ParticleValue($p, 1.0)"))
|
||||
end
|
||||
end
|
||||
|
||||
return Expr(:block, assignInputs...)
|
||||
end
|
||||
|
||||
"""
|
||||
get_compute_function(graph::DAG, process::AbstractProcessDescription, machine::Machine)
|
||||
|
||||
Return a function of signature `compute_<id>(input::AbstractProcessInput)`, which will return the result of the DAG computation on the given input.
|
||||
"""
|
||||
function get_compute_function(graph::DAG, process::AbstractProcessDescription, machine::Machine)
|
||||
(code, inputSymbols, outputSymbol) = gen_code(graph, machine)
|
||||
|
||||
initCaches = gen_cache_init_code(machine)
|
||||
assignInputs = gen_input_assignment_code(inputSymbols, process, machine, :input)
|
||||
|
||||
functionId = to_var_name(UUIDs.uuid1(rng[1]))
|
||||
resSym = eval(gen_access_expr(entry_device(machine), outputSymbol))
|
||||
expr = Meta.parse(
|
||||
"function compute_$(functionId)(input::AbstractProcessInput) $initCaches; $assignInputs; $code; return $resSym; end",
|
||||
)
|
||||
func = eval(expr)
|
||||
|
||||
return func
|
||||
end
|
||||
|
||||
"""
|
||||
execute(graph::DAG, process::AbstractProcessDescription, machine::Machine, input::AbstractProcessInput)
|
||||
|
||||
Execute the code of the given `graph` on the given input particles.
|
||||
|
||||
This is essentially shorthand for
|
||||
```julia
|
||||
compute_graph = get_compute_function(graph, process)
|
||||
result = compute_graph(particles)
|
||||
```
|
||||
|
||||
See also: [`parse_dag`](@ref), [`parse_process`](@ref), [`gen_process_input`](@ref)
|
||||
"""
|
||||
function execute(graph::DAG, process::AbstractProcessDescription, machine::Machine, input::AbstractProcessInput)
|
||||
(code, inputSymbols, outputSymbol) = gen_code(graph, machine)
|
||||
|
||||
initCaches = gen_cache_init_code(machine)
|
||||
assignInputs = gen_input_assignment_code(inputSymbols, process, machine, :input)
|
||||
|
||||
|
||||
functionId = to_var_name(UUIDs.uuid1(rng[1]))
|
||||
resSym = eval(gen_access_expr(entry_device(machine), outputSymbol))
|
||||
expr = Meta.parse(
|
||||
"function compute_$(functionId)(input::AbstractProcessInput) $initCaches; $assignInputs; $code; return $resSym; end",
|
||||
)
|
||||
func = eval(expr)
|
||||
|
||||
result = 0
|
||||
try
|
||||
result = @eval $func($input)
|
||||
catch e
|
||||
println("Error while evaluating: $e")
|
||||
|
||||
# if we find a uuid in the exception we can color it in so it's easier to spot
|
||||
uuidRegex = r"[0-9a-f]{8}_[0-9a-f]{4}_[0-9a-f]{4}_[0-9a-f]{4}_[0-9a-f]{12}"
|
||||
m = match(uuidRegex, string(e))
|
||||
|
||||
functionStr = string(expr)
|
||||
if (isa(m, RegexMatch))
|
||||
functionStr = replace(functionStr, m.match => "\033[31m$(m.match)\033[0m")
|
||||
end
|
||||
|
||||
println("Function:\n$functionStr")
|
||||
@assert false
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
53
src/devices/cuda/impl.jl
Normal file
53
src/devices/cuda/impl.jl
Normal file
@ -0,0 +1,53 @@
|
||||
using CUDA
|
||||
|
||||
"""
|
||||
CUDAGPU <: AbstractGPU
|
||||
|
||||
Representation of a specific CUDA GPU that code can run on. Implements the [`AbstractDevice`](@ref) interface.
|
||||
"""
|
||||
mutable struct CUDAGPU <: AbstractGPU
|
||||
device::Any # TODO: what's the cuda device type?
|
||||
cacheStrategy::CacheStrategy
|
||||
FLOPS::Float64
|
||||
end
|
||||
|
||||
push!(DEVICE_TYPES, CUDAGPU)
|
||||
|
||||
CACHE_STRATEGIES[CUDAGPU] = [LocalVariables()]
|
||||
|
||||
default_strategy(::Type{T}) where {T <: CUDAGPU} = LocalVariables()
|
||||
|
||||
function measure_device!(device::CUDAGPU; verbose::Bool)
|
||||
if verbose
|
||||
println("Measuring CUDA GPU $(device.device)")
|
||||
end
|
||||
|
||||
# TODO implement
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
get_devices(deviceType::Type{T}; verbose::Bool) where {T <: CUDAGPU}
|
||||
|
||||
Return a Vector of [`CUDAGPU`](@ref)s available on the current machine. If `verbose` is true, print some additional information.
|
||||
"""
|
||||
function get_devices(deviceType::Type{T}; verbose::Bool = false) where {T <: CUDAGPU}
|
||||
devices = Vector{AbstractDevice}()
|
||||
|
||||
if !CUDA.functional()
|
||||
if verbose
|
||||
println("CUDA is non-functional")
|
||||
end
|
||||
return devices
|
||||
end
|
||||
|
||||
CUDADevices = CUDA.devices()
|
||||
if verbose
|
||||
println("Found $(length(CUDADevices)) CUDA devices")
|
||||
end
|
||||
for device in CUDADevices
|
||||
push!(devices, CUDAGPU(device, default_strategy(CUDAGPU), -1))
|
||||
end
|
||||
|
||||
return devices
|
||||
end
|
23
src/devices/detect.jl
Normal file
23
src/devices/detect.jl
Normal file
@ -0,0 +1,23 @@
|
||||
|
||||
"""
|
||||
get_machine_info(verbose::Bool)
|
||||
|
||||
Return the [`Machine`](@ref) currently running on. The parameter `verbose` defaults to true when interactive.
|
||||
"""
|
||||
function get_machine_info(; verbose::Bool = Base.is_interactive)
|
||||
devices = Vector{AbstractDevice}()
|
||||
|
||||
for device in device_types()
|
||||
devs = get_devices(device, verbose = verbose)
|
||||
for dev in devs
|
||||
push!(devices, dev)
|
||||
end
|
||||
end
|
||||
|
||||
noDevices = length(devices)
|
||||
@assert noDevices > 0 "No devices were found, but at least one NUMA node should always be available!"
|
||||
|
||||
transferRates = Matrix{Float64}(undef, noDevices, noDevices)
|
||||
fill!(transferRates, -1)
|
||||
return Machine(devices, transferRates)
|
||||
end
|
52
src/devices/impl.jl
Normal file
52
src/devices/impl.jl
Normal file
@ -0,0 +1,52 @@
|
||||
"""
|
||||
device_types()
|
||||
|
||||
Return a vector of available and implemented device types.
|
||||
|
||||
See also: [`DEVICE_TYPES`](@ref)
|
||||
"""
|
||||
function device_types()
|
||||
return DEVICE_TYPES
|
||||
end
|
||||
|
||||
"""
|
||||
entry_device(machine::Machine)
|
||||
|
||||
Return the "entry" device, i.e., the device that starts CPU threads and GPU kernels, and takes input values and returns the output value.
|
||||
"""
|
||||
function entry_device(machine::Machine)
|
||||
return machine.devices[1]
|
||||
end
|
||||
|
||||
"""
|
||||
strategies(t::Type{T}) where {T <: AbstractDevice}
|
||||
|
||||
Return a vector of available [`CacheStrategy`](@ref)s for the given [`AbstractDevice`](@ref).
|
||||
The caching strategies are used in code generation.
|
||||
"""
|
||||
function strategies(t::Type{T}) where {T <: AbstractDevice}
|
||||
if !haskey(CACHE_STRATEGIES, t)
|
||||
error("Trying to get strategies for $T, but it has no strategies defined!")
|
||||
end
|
||||
|
||||
return CACHE_STRATEGIES[t]
|
||||
end
|
||||
|
||||
"""
|
||||
cache_strategy(device::AbstractDevice)
|
||||
|
||||
Returns the cache strategy set for this device.
|
||||
"""
|
||||
function cache_strategy(device::AbstractDevice)
|
||||
return device.cacheStrategy
|
||||
end
|
||||
|
||||
"""
|
||||
set_cache_strategy(device::AbstractDevice, cacheStrategy::CacheStrategy)
|
||||
|
||||
Sets the device's cache strategy. After this call, [`cache_strategy`](@ref) should return `cacheStrategy` on the given device.
|
||||
"""
|
||||
function set_cache_strategy(device::AbstractDevice, cacheStrategy::CacheStrategy)
|
||||
device.cacheStrategy = cacheStrategy
|
||||
return nothing
|
||||
end
|
108
src/devices/interface.jl
Normal file
108
src/devices/interface.jl
Normal file
@ -0,0 +1,108 @@
|
||||
"""
|
||||
AbstractDevice
|
||||
|
||||
Abstract base type for every device, like GPUs, CPUs or any other compute devices.
|
||||
Every implementation needs to implement various functions and needs a member `cacheStrategy`.
|
||||
"""
|
||||
abstract type AbstractDevice end
|
||||
|
||||
abstract type AbstractCPU <: AbstractDevice end
|
||||
|
||||
abstract type AbstractGPU <: AbstractDevice end
|
||||
|
||||
"""
|
||||
Machine
|
||||
|
||||
A representation of a machine to execute on. Contains information about its architecture (CPUs, GPUs, maybe more). This representation can be used to make a more accurate cost prediction of a [`DAG`](@ref) state.
|
||||
|
||||
See also: [`Scheduler`](@ref)
|
||||
"""
|
||||
struct Machine
|
||||
devices::Vector{AbstractDevice}
|
||||
|
||||
transferRates::Matrix{Float64}
|
||||
end
|
||||
|
||||
"""
|
||||
CacheStrategy
|
||||
|
||||
Abstract base type for caching strategies.
|
||||
|
||||
See also: [`strategies`](@ref)
|
||||
"""
|
||||
abstract type CacheStrategy end
|
||||
|
||||
"""
|
||||
LocalVariables <: CacheStrategy
|
||||
|
||||
A caching strategy relying solely on local variables for every input and output.
|
||||
|
||||
Implements the [`CacheStrategy`](@ref) interface.
|
||||
"""
|
||||
struct LocalVariables <: CacheStrategy end
|
||||
|
||||
"""
|
||||
Dictionary <: CacheStrategy
|
||||
|
||||
A caching strategy relying on a dictionary of Symbols to store every input and output.
|
||||
|
||||
Implements the [`CacheStrategy`](@ref) interface.
|
||||
"""
|
||||
struct Dictionary <: CacheStrategy end
|
||||
|
||||
"""
|
||||
DEVICE_TYPES::Vector{Type}
|
||||
|
||||
Global vector of available and implemented device types. Each implementation of a [`AbstractDevice`](@ref) should add its concrete type to this vector.
|
||||
|
||||
See also: [`device_types`](@ref), [`get_devices`](@ref)
|
||||
"""
|
||||
DEVICE_TYPES = Vector{Type}()
|
||||
|
||||
"""
|
||||
CACHE_STRATEGIES::Dict{Type{AbstractDevice}, Symbol}
|
||||
|
||||
Global dictionary of available caching strategies per device. Each implementation of [`AbstractDevice`](@ref) should add its available strategies to the dictionary.
|
||||
|
||||
See also: [`strategies`](@ref)
|
||||
"""
|
||||
CACHE_STRATEGIES = Dict{Type, Vector{CacheStrategy}}()
|
||||
|
||||
"""
|
||||
default_strategy(deviceType::Type{T}) where {T <: AbstractDevice}
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref). Returns the default [`CacheStrategy`](@ref) to use on the given device type.
|
||||
See also: [`cache_strategy`](@ref), [`set_cache_strategy`](@ref)
|
||||
"""
|
||||
function default_strategy end
|
||||
|
||||
"""
|
||||
get_devices(t::Type{T}; verbose::Bool) where {T <: AbstractDevice}
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref). Returns a `Vector{Type}` of the devices for the given [`AbstractDevice`](@ref) Type available on the current machine.
|
||||
"""
|
||||
function get_devices end
|
||||
|
||||
"""
|
||||
measure_device!(device::AbstractDevice; verbose::Bool)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref). Measures the compute speed of the given device and writes into it.
|
||||
"""
|
||||
function measure_device! end
|
||||
|
||||
"""
|
||||
gen_cache_init_code(device::AbstractDevice)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref) and at least one [`CacheStrategy`](@ref). Returns an `Expr` initializing this device's variable cache.
|
||||
|
||||
The strategy is a symbol
|
||||
"""
|
||||
function gen_cache_init_code end
|
||||
|
||||
"""
|
||||
gen_access_expr(device::AbstractDevice, symbol::Symbol)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref) and at least one [`CacheStrategy`](@ref).
|
||||
Return an `Expr` or `QuoteNode` accessing the variable identified by [`symbol`].
|
||||
"""
|
||||
function gen_access_expr end
|
22
src/devices/measure.jl
Normal file
22
src/devices/measure.jl
Normal file
@ -0,0 +1,22 @@
|
||||
"""
|
||||
measure_devices(machine::Machine; verbose::Bool)
|
||||
|
||||
Measure FLOPS, RAM, cache sizes and what other properties can be extracted for the devices in the given machine.
|
||||
"""
|
||||
function measure_devices!(machine::Machine; verbose::Bool = Base.is_interactive())
|
||||
for device in machine.devices
|
||||
measure_device!(device; verbose = verbose)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
measure_transfer_rates(machine::Machine; verbose::Bool)
|
||||
|
||||
Measure the transfer rates between devices in the machine.
|
||||
"""
|
||||
function measure_transfer_rates!(machine::Machine; verbose::Bool = Base.is_interactive())
|
||||
# TODO implement
|
||||
return nothing
|
||||
end
|
96
src/devices/numa/impl.jl
Normal file
96
src/devices/numa/impl.jl
Normal file
@ -0,0 +1,96 @@
|
||||
using NumaAllocators
|
||||
|
||||
"""
|
||||
NumaNode <: AbstractCPU
|
||||
|
||||
Representation of a specific CPU that code can run on. Implements the [`AbstractDevice`](@ref) interface.
|
||||
"""
|
||||
mutable struct NumaNode <: AbstractCPU
|
||||
numaId::UInt16
|
||||
threads::UInt16
|
||||
cacheStrategy::CacheStrategy
|
||||
FLOPS::Float64
|
||||
id::UUID
|
||||
end
|
||||
|
||||
push!(DEVICE_TYPES, NumaNode)
|
||||
|
||||
CACHE_STRATEGIES[NumaNode] = [LocalVariables()]
|
||||
|
||||
default_strategy(::Type{T}) where {T <: NumaNode} = LocalVariables()
|
||||
|
||||
function measure_device!(device::NumaNode; verbose::Bool)
|
||||
if verbose
|
||||
println("Measuring Numa Node $(device.numaId)")
|
||||
end
|
||||
|
||||
# TODO implement
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
get_devices(deviceType::Type{T}; verbose::Bool) where {T <: NumaNode}
|
||||
|
||||
Return a Vector of [`NumaNode`](@ref)s available on the current machine. If `verbose` is true, print some additional information.
|
||||
"""
|
||||
function get_devices(deviceType::Type{T}; verbose::Bool = false) where {T <: NumaNode}
|
||||
devices = Vector{AbstractDevice}()
|
||||
noNumaNodes = highest_numa_node()
|
||||
|
||||
if (verbose)
|
||||
println("Found $(noNumaNodes + 1) NUMA nodes")
|
||||
end
|
||||
for i in 0:noNumaNodes
|
||||
push!(devices, NumaNode(i, 1, default_strategy(NumaNode), -1, UUIDs.uuid1(rng[1])))
|
||||
end
|
||||
|
||||
return devices
|
||||
end
|
||||
|
||||
"""
|
||||
gen_cache_init_code(device::NumaNode)
|
||||
|
||||
Generate code for initializing the [`LocalVariables`](@ref) strategy on a [`NumaNode`](@ref).
|
||||
"""
|
||||
function gen_cache_init_code(device::NumaNode)
|
||||
if typeof(device.cacheStrategy) <: LocalVariables
|
||||
# don't need to initialize anything
|
||||
return Expr(:block)
|
||||
elseif typeof(device.cacheStrategy) <: Dictionary
|
||||
return Meta.parse("cache_$(to_var_name(device.id)) = Dict{Symbol, Any}()")
|
||||
# TODO: sizehint?
|
||||
end
|
||||
|
||||
return error("Unimplemented cache strategy \"$(device.cacheStrategy)\" for device \"$(device)\"")
|
||||
end
|
||||
|
||||
"""
|
||||
gen_access_expr(device::NumaNode, symbol::Symbol)
|
||||
|
||||
Generate code to access the variable designated by `symbol` on a [`NumaNode`](@ref), using the [`CacheStrategy`](@ref) set in the device.
|
||||
"""
|
||||
function gen_access_expr(device::NumaNode, symbol::Symbol)
|
||||
return _gen_access_expr(device, device.cacheStrategy, symbol)
|
||||
end
|
||||
|
||||
"""
|
||||
_gen_access_expr(device::NumaNode, ::LocalVariables, symbol::Symbol)
|
||||
|
||||
Internal function for dispatch, used in [`gen_access_expr`](@ref).
|
||||
"""
|
||||
function _gen_access_expr(device::NumaNode, ::LocalVariables, symbol::Symbol)
|
||||
s = Symbol("data_$symbol")
|
||||
quoteNode = Meta.parse(":($s)")
|
||||
return quoteNode
|
||||
end
|
||||
|
||||
"""
|
||||
_gen_access_expr(device::NumaNode, ::Dictionary, symbol::Symbol)
|
||||
|
||||
Internal function for dispatch, used in [`gen_access_expr`](@ref).
|
||||
"""
|
||||
function _gen_access_expr(device::NumaNode, ::Dictionary, symbol::Symbol)
|
||||
accessStr = ":(cache_$(to_var_name(device.id))[:$symbol])"
|
||||
quoteNode = Meta.parse(accessStr)
|
||||
return quoteNode
|
||||
end
|
53
src/devices/oneapi/impl.jl
Normal file
53
src/devices/oneapi/impl.jl
Normal file
@ -0,0 +1,53 @@
|
||||
using oneAPI
|
||||
|
||||
"""
|
||||
oneAPIGPU <: AbstractGPU
|
||||
|
||||
Representation of a specific Intel GPU that code can run on. Implements the [`AbstractDevice`](@ref) interface.
|
||||
"""
|
||||
mutable struct oneAPIGPU <: AbstractGPU
|
||||
device::Any
|
||||
cacheStrategy::CacheStrategy
|
||||
FLOPS::Float64
|
||||
end
|
||||
|
||||
push!(DEVICE_TYPES, oneAPIGPU)
|
||||
|
||||
CACHE_STRATEGIES[oneAPIGPU] = [LocalVariables()]
|
||||
|
||||
default_strategy(::Type{T}) where {T <: oneAPIGPU} = LocalVariables()
|
||||
|
||||
function measure_device!(device::oneAPIGPU; verbose::Bool)
|
||||
if verbose
|
||||
println("Measuring oneAPI GPU $(device.device)")
|
||||
end
|
||||
|
||||
# TODO implement
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
get_devices(deviceType::Type{T}; verbose::Bool = false) where {T <: oneAPIGPU}
|
||||
|
||||
Return a Vector of [`oneAPIGPU`](@ref)s available on the current machine. If `verbose` is true, print some additional information.
|
||||
"""
|
||||
function get_devices(deviceType::Type{T}; verbose::Bool = false) where {T <: oneAPIGPU}
|
||||
devices = Vector{AbstractDevice}()
|
||||
|
||||
if !oneAPI.functional()
|
||||
if verbose
|
||||
println("oneAPI is non-functional")
|
||||
end
|
||||
return devices
|
||||
end
|
||||
|
||||
oneAPIDevices = oneAPI.devices()
|
||||
if verbose
|
||||
println("Found $(length(oneAPIDevices)) oneAPI devices")
|
||||
end
|
||||
for device in oneAPIDevices
|
||||
push!(devices, oneAPIGPU(device, default_strategy(oneAPIGPU), -1))
|
||||
end
|
||||
|
||||
return devices
|
||||
end
|
53
src/devices/rocm/impl.jl
Normal file
53
src/devices/rocm/impl.jl
Normal file
@ -0,0 +1,53 @@
|
||||
using AMDGPU
|
||||
|
||||
"""
|
||||
ROCmGPU <: AbstractGPU
|
||||
|
||||
Representation of a specific AMD GPU that code can run on. Implements the [`AbstractDevice`](@ref) interface.
|
||||
"""
|
||||
mutable struct ROCmGPU <: AbstractGPU
|
||||
device::Any
|
||||
cacheStrategy::CacheStrategy
|
||||
FLOPS::Float64
|
||||
end
|
||||
|
||||
push!(DEVICE_TYPES, ROCmGPU)
|
||||
|
||||
CACHE_STRATEGIES[ROCmGPU] = [LocalVariables()]
|
||||
|
||||
default_strategy(::Type{T}) where {T <: ROCmGPU} = LocalVariables()
|
||||
|
||||
function measure_device!(device::ROCmGPU; verbose::Bool)
|
||||
if verbose
|
||||
println("Measuring ROCm GPU $(device.device)")
|
||||
end
|
||||
|
||||
# TODO implement
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
get_devices(deviceType::Type{T}; verbose::Bool = false) where {T <: ROCmGPU}
|
||||
|
||||
Return a Vector of [`ROCmGPU`](@ref)s available on the current machine. If `verbose` is true, print some additional information.
|
||||
"""
|
||||
function get_devices(deviceType::Type{T}; verbose::Bool = false) where {T <: ROCmGPU}
|
||||
devices = Vector{AbstractDevice}()
|
||||
|
||||
if !AMDGPU.functional()
|
||||
if verbose
|
||||
println("AMDGPU is non-functional")
|
||||
end
|
||||
return devices
|
||||
end
|
||||
|
||||
AMDDevices = AMDGPU.devices()
|
||||
if verbose
|
||||
println("Found $(length(AMDDevices)) AMD devices")
|
||||
end
|
||||
for device in AMDDevices
|
||||
push!(devices, ROCmGPU(device, default_strategy(ROCmGPU), -1))
|
||||
end
|
||||
|
||||
return devices
|
||||
end
|
11
src/diff/print.jl
Normal file
11
src/diff/print.jl
Normal file
@ -0,0 +1,11 @@
|
||||
"""
|
||||
show(io::IO, diff::Diff)
|
||||
|
||||
Pretty-print a [`Diff`](@ref). Called via print, println and co.
|
||||
"""
|
||||
function show(io::IO, diff::Diff)
|
||||
print(io, "Nodes: ")
|
||||
print(io, length(diff.addedNodes) + length(diff.removedNodes))
|
||||
print(io, ", Edges: ")
|
||||
return print(io, length(diff.addedEdges) + length(diff.removedEdges))
|
||||
end
|
14
src/diff/properties.jl
Normal file
14
src/diff/properties.jl
Normal file
@ -0,0 +1,14 @@
|
||||
"""
|
||||
length(diff::Diff)
|
||||
|
||||
Return a named tuple of the lengths of the added/removed nodes/edges.
|
||||
The fields are `.addedNodes`, `.addedEdges`, `.removedNodes` and `.removedEdges`.
|
||||
"""
|
||||
function length(diff::Diff)
|
||||
return (
|
||||
addedNodes = length(diff.addedNodes),
|
||||
removedNodes = length(diff.removedNodes),
|
||||
addedEdges = length(diff.addedEdges),
|
||||
removedEdges = length(diff.removedEdges),
|
||||
)
|
||||
end
|
21
src/diff/type.jl
Normal file
21
src/diff/type.jl
Normal file
@ -0,0 +1,21 @@
|
||||
"""
|
||||
Diff
|
||||
|
||||
A named tuple representing a difference of added and removed nodes and edges on a [`DAG`](@ref).
|
||||
"""
|
||||
const Diff = NamedTuple{
|
||||
(:addedNodes, :removedNodes, :addedEdges, :removedEdges, :updatedChildren),
|
||||
Tuple{Vector{Node}, Vector{Node}, Vector{Edge}, Vector{Edge}, Vector{Tuple{Node, AbstractTask}}},
|
||||
}
|
||||
|
||||
function Diff()
|
||||
return (
|
||||
addedNodes = Vector{Node}(),
|
||||
removedNodes = Vector{Node}(),
|
||||
addedEdges = Vector{Edge}(),
|
||||
removedEdges = Vector{Edge}(),
|
||||
|
||||
# children were updated in the task, updatedChildren[x][2] is the task before the update
|
||||
updatedChildren = Vector{Tuple{Node, AbstractTask}}(),
|
||||
)::Diff
|
||||
end
|
90
src/graph.jl
90
src/graph.jl
@ -1,90 +0,0 @@
|
||||
using DataStructures
|
||||
|
||||
const Diff = NamedTuple{
|
||||
(:addedNodes, :removedNodes, :addedEdges, :removedEdges),
|
||||
Tuple{Vector{Node}, Vector{Node}, Vector{Edge}, Vector{Edge}}
|
||||
}
|
||||
|
||||
function Diff()
|
||||
return (
|
||||
addedNodes = Vector{Node}(),
|
||||
removedNodes = Vector{Node}(),
|
||||
addedEdges = Vector{Edge}(),
|
||||
removedEdges = Vector{Edge}()
|
||||
)::Diff
|
||||
end
|
||||
|
||||
# An abstract base class for operations
|
||||
# an operation can be applied to a DAG
|
||||
abstract type Operation end
|
||||
|
||||
# An abstract base class for already applied operations
|
||||
# an applied operation can be reversed iff it is the last applied operation on the DAG
|
||||
abstract type AppliedOperation end
|
||||
|
||||
struct NodeFusion <: Operation
|
||||
input::Tuple{ComputeTaskNode, DataTaskNode, ComputeTaskNode}
|
||||
end
|
||||
|
||||
struct AppliedNodeFusion <: AppliedOperation
|
||||
operation::NodeFusion
|
||||
diff::Diff
|
||||
end
|
||||
|
||||
struct NodeReduction <: Operation
|
||||
input::Vector{Node}
|
||||
end
|
||||
|
||||
struct AppliedNodeReduction <: AppliedOperation
|
||||
operation::NodeReduction
|
||||
diff::Diff
|
||||
end
|
||||
|
||||
struct NodeSplit <: Operation
|
||||
input::Node
|
||||
end
|
||||
|
||||
struct AppliedNodeSplit <: AppliedOperation
|
||||
operation::NodeSplit
|
||||
diff::Diff
|
||||
end
|
||||
|
||||
mutable struct PossibleOperations
|
||||
nodeFusions::Set{NodeFusion}
|
||||
nodeReductions::Set{NodeReduction}
|
||||
nodeSplits::Set{NodeSplit}
|
||||
end
|
||||
|
||||
function PossibleOperations()
|
||||
return PossibleOperations(
|
||||
Set{NodeFusion}(),
|
||||
Set{NodeReduction}(),
|
||||
Set{NodeSplit}()
|
||||
)
|
||||
end
|
||||
|
||||
# The actual state of the DAG is the initial state given by the set of nodes
|
||||
# but with all the operations in appliedChain applied in order
|
||||
mutable struct DAG
|
||||
nodes::Set{Node}
|
||||
|
||||
# The operations currently applied to the set of nodes
|
||||
appliedOperations::Stack{AppliedOperation}
|
||||
|
||||
# The operations not currently applied but part of the current state of the DAG
|
||||
operationsToApply::Deque{Operation}
|
||||
|
||||
# The possible operations at the current state of the DAG
|
||||
possibleOperations::PossibleOperations
|
||||
|
||||
# The set of nodes whose possible operations need to be reevaluated
|
||||
dirtyNodes::Set{Node}
|
||||
|
||||
# "snapshot" system: keep track of added/removed nodes/edges since last snapshot
|
||||
# these are muted in insert_node! etc.
|
||||
diff::Diff
|
||||
end
|
||||
|
||||
function DAG()
|
||||
return DAG(Set{Node}(), Stack{AppliedOperation}(), Deque{Operation}(), PossibleOperations(), Set{Node}(), Diff())
|
||||
end
|
37
src/graph/compare.jl
Normal file
37
src/graph/compare.jl
Normal file
@ -0,0 +1,37 @@
|
||||
"""
|
||||
in(node::Node, graph::DAG)
|
||||
|
||||
Check whether the node is part of the graph.
|
||||
"""
|
||||
in(node::Node, graph::DAG) = node in graph.nodes
|
||||
|
||||
"""
|
||||
in(edge::Edge, graph::DAG)
|
||||
|
||||
Check whether the edge is part of the graph.
|
||||
"""
|
||||
function in(edge::Edge, graph::DAG)
|
||||
n1 = edge.edge[1]
|
||||
n2 = edge.edge[2]
|
||||
if !(n1 in graph) || !(n2 in graph)
|
||||
return false
|
||||
end
|
||||
|
||||
return n1 in n2.children
|
||||
end
|
||||
|
||||
"""
|
||||
==(n1::Node, n2::Node, g::DAG)
|
||||
|
||||
Check equality of two nodes in a graph.
|
||||
"""
|
||||
function ==(n1::Node, n2::Node, g::DAG)
|
||||
if typeof(n1) != typeof(n2)
|
||||
return false
|
||||
end
|
||||
if !(n1 in g) || !(n2 in g)
|
||||
return false
|
||||
end
|
||||
|
||||
return n1.task == n2.task && children(n1) == children(n2)
|
||||
end
|
54
src/graph/interface.jl
Normal file
54
src/graph/interface.jl
Normal file
@ -0,0 +1,54 @@
|
||||
"""
|
||||
push_operation!(graph::DAG, operation::Operation)
|
||||
|
||||
Apply a new operation to the graph.
|
||||
|
||||
See also: [`DAG`](@ref), [`pop_operation!`](@ref)
|
||||
"""
|
||||
function push_operation!(graph::DAG, operation::Operation)
|
||||
# 1.: Add the operation to the DAG
|
||||
push!(graph.operationsToApply, operation)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
pop_operation!(graph::DAG)
|
||||
|
||||
Revert the latest applied operation on the graph.
|
||||
|
||||
See also: [`DAG`](@ref), [`push_operation!`](@ref)
|
||||
"""
|
||||
function pop_operation!(graph::DAG)
|
||||
# 1.: Remove the operation from the appliedChain of the DAG
|
||||
if !isempty(graph.operationsToApply)
|
||||
pop!(graph.operationsToApply)
|
||||
elseif !isempty(graph.appliedOperations)
|
||||
appliedOp = pop!(graph.appliedOperations)
|
||||
revert_operation!(graph, appliedOp)
|
||||
else
|
||||
error("No more operations to pop!")
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
can_pop(graph::DAG)
|
||||
|
||||
Return `true` if [`pop_operation!`](@ref) is possible, `false` otherwise.
|
||||
"""
|
||||
can_pop(graph::DAG) = !isempty(graph.operationsToApply) || !isempty(graph.appliedOperations)
|
||||
|
||||
"""
|
||||
reset_graph!(graph::DAG)
|
||||
|
||||
Reset the graph to its initial state with no operations applied.
|
||||
"""
|
||||
function reset_graph!(graph::DAG)
|
||||
while (can_pop(graph))
|
||||
pop_operation!(graph)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
321
src/graph/mute.jl
Normal file
321
src/graph/mute.jl
Normal file
@ -0,0 +1,321 @@
|
||||
# for graph mutating functions we need to do a few things
|
||||
# 1: mute the graph (duh)
|
||||
# 2: keep track of what was changed for the diff (if track == true)
|
||||
# 3: invalidate operation caches
|
||||
|
||||
"""
|
||||
insert_node!(graph::DAG, node::Node; track = true, invalidate_cache = true)
|
||||
|
||||
Insert the node into the graph.
|
||||
|
||||
## Keyword Arguments
|
||||
`track::Bool`: Whether to add the changes to the [`DAG`](@ref)'s [`Diff`](@ref). Should be set `false` in parsing or graph creation functions for performance.
|
||||
|
||||
`invalidate_cache::Bool`: Whether to invalidate caches associated with the changes. Should also be turned off for graph creation or parsing.
|
||||
|
||||
See also: [`remove_node!`](@ref), [`insert_edge!`](@ref), [`remove_edge!`](@ref)
|
||||
"""
|
||||
function insert_node!(graph::DAG, node::Node; track = true, invalidate_cache = true)
|
||||
# 1: mute
|
||||
push!(graph.nodes, node)
|
||||
|
||||
# 2: keep track
|
||||
if (track)
|
||||
push!(graph.diff.addedNodes, node)
|
||||
end
|
||||
|
||||
# 3: invalidate caches
|
||||
if (!invalidate_cache)
|
||||
return node
|
||||
end
|
||||
push!(graph.dirtyNodes, node)
|
||||
|
||||
return node
|
||||
end
|
||||
|
||||
"""
|
||||
insert_edge!(graph::DAG, node1::Node, node2::Node; track = true, invalidate_cache = true)
|
||||
|
||||
Insert the edge between node1 (child) and node2 (parent) into the graph.
|
||||
|
||||
## Keyword Arguments
|
||||
`track::Bool`: Whether to add the changes to the [`DAG`](@ref)'s [`Diff`](@ref). Should be set `false` in parsing or graph creation functions for performance.
|
||||
|
||||
`invalidate_cache::Bool`: Whether to invalidate caches associated with the changes. Should also be turned off for graph creation or parsing.
|
||||
|
||||
See also: [`insert_node!`](@ref), [`remove_node!`](@ref), [`remove_edge!`](@ref)
|
||||
"""
|
||||
function insert_edge!(graph::DAG, node1::Node, node2::Node; track = true, invalidate_cache = true)
|
||||
@assert (node2 ∉ node1.parents) && (node1 ∉ node2.children) "Edge to insert already exists"
|
||||
|
||||
# 1: mute
|
||||
# edge points from child to parent
|
||||
push!(node1.parents, node2)
|
||||
push!(node2.children, node1)
|
||||
|
||||
# 2: keep track
|
||||
if (track)
|
||||
push!(graph.diff.addedEdges, make_edge(node1, node2))
|
||||
end
|
||||
|
||||
# 3: invalidate caches
|
||||
if (!invalidate_cache)
|
||||
return nothing
|
||||
end
|
||||
|
||||
invalidate_operation_caches!(graph, node1)
|
||||
invalidate_operation_caches!(graph, node2)
|
||||
|
||||
push!(graph.dirtyNodes, node1)
|
||||
push!(graph.dirtyNodes, node2)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
remove_node!(graph::DAG, node::Node; track = true, invalidate_cache = true)
|
||||
|
||||
Remove the node from the graph.
|
||||
|
||||
## Keyword Arguments
|
||||
`track::Bool`: Whether to add the changes to the [`DAG`](@ref)'s [`Diff`](@ref). Should be set `false` in parsing or graph creation functions for performance.
|
||||
|
||||
`invalidate_cache::Bool`: Whether to invalidate caches associated with the changes. Should also be turned off for graph creation or parsing.
|
||||
|
||||
See also: [`insert_node!`](@ref), [`insert_edge!`](@ref), [`remove_edge!`](@ref)
|
||||
"""
|
||||
function remove_node!(graph::DAG, node::Node; track = true, invalidate_cache = true)
|
||||
@assert node in graph.nodes "Trying to remove a node that's not in the graph"
|
||||
|
||||
# 1: mute
|
||||
delete!(graph.nodes, node)
|
||||
|
||||
# 2: keep track
|
||||
if (track)
|
||||
push!(graph.diff.removedNodes, node)
|
||||
end
|
||||
|
||||
# 3: invalidate caches
|
||||
if (!invalidate_cache)
|
||||
return nothing
|
||||
end
|
||||
|
||||
invalidate_operation_caches!(graph, node)
|
||||
delete!(graph.dirtyNodes, node)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invalidate_cache = true)
|
||||
|
||||
Remove the edge between node1 (child) and node2 (parent) into the graph.
|
||||
|
||||
## Keyword Arguments
|
||||
`track::Bool`: Whether to add the changes to the [`DAG`](@ref)'s [`Diff`](@ref). Should be set `false` in parsing or graph creation functions for performance.
|
||||
|
||||
`invalidate_cache::Bool`: Whether to invalidate caches associated with the changes. Should also be turned off for graph creation or parsing.
|
||||
|
||||
See also: [`insert_node!`](@ref), [`remove_node!`](@ref), [`insert_edge!`](@ref)
|
||||
"""
|
||||
function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invalidate_cache = true)
|
||||
# 1: mute
|
||||
pre_length1 = length(node1.parents)
|
||||
pre_length2 = length(node2.children)
|
||||
|
||||
#TODO: filter is very slow
|
||||
filter!(x -> x != node2, node1.parents)
|
||||
filter!(x -> x != node1, node2.children)
|
||||
|
||||
@assert begin
|
||||
removed = pre_length1 - length(node1.parents)
|
||||
removed <= 1
|
||||
end "removed more than one node from node1's parents"
|
||||
|
||||
@assert begin
|
||||
removed = pre_length2 - length(node2.children)
|
||||
removed <= 1
|
||||
end "removed more than one node from node2's children"
|
||||
|
||||
# 2: keep track
|
||||
if (track)
|
||||
push!(graph.diff.removedEdges, make_edge(node1, node2))
|
||||
end
|
||||
|
||||
# 3: invalidate caches
|
||||
if (!invalidate_cache)
|
||||
return nothing
|
||||
end
|
||||
|
||||
invalidate_operation_caches!(graph, node1)
|
||||
invalidate_operation_caches!(graph, node2)
|
||||
if (node1 in graph)
|
||||
push!(graph.dirtyNodes, node1)
|
||||
end
|
||||
if (node2 in graph)
|
||||
push!(graph.dirtyNodes, node2)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function replace_children!(task::FusedComputeTask, before, after)
|
||||
replacedIn1 = length(findall(x -> x == before, task.t1_inputs))
|
||||
replacedIn2 = length(findall(x -> x == before, task.t2_inputs))
|
||||
|
||||
@assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
|
||||
|
||||
replace!(task.t1_inputs, before => after)
|
||||
replace!(task.t2_inputs, before => after)
|
||||
|
||||
# recursively descend down the tree, but only in the tasks where we're replacing things
|
||||
if replacedIn1 > 0
|
||||
replace_children!(task.first_task, before, after)
|
||||
end
|
||||
if replacedIn2 > 0
|
||||
replace_children!(task.second_task, before, after)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function replace_children!(task::AbstractTask, before, after)
|
||||
return nothing
|
||||
end
|
||||
|
||||
function update_child!(graph::DAG, n::Node, child_before::Symbol, child_after::Symbol; track = true)
|
||||
# only need to update fused compute tasks
|
||||
if !(typeof(n.task) <: FusedComputeTask)
|
||||
return nothing
|
||||
end
|
||||
|
||||
taskBefore = copy(n.task)
|
||||
|
||||
if !((child_before in n.task.t1_inputs) || (child_before in n.task.t2_inputs))
|
||||
println("------------------ Nothing to replace!! ------------------")
|
||||
child_ids = Vector{String}()
|
||||
for child in n.children
|
||||
push!(child_ids, "$(child.id)")
|
||||
end
|
||||
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
|
||||
@assert false
|
||||
end
|
||||
|
||||
replace_children!(n.task, child_before, child_after)
|
||||
|
||||
if !((child_after in n.task.t1_inputs) || (child_after in n.task.t2_inputs))
|
||||
println("------------------ Did not replace anything!! ------------------")
|
||||
child_ids = Vector{String}()
|
||||
for child in n.children
|
||||
push!(child_ids, "$(child.id)")
|
||||
end
|
||||
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
|
||||
@assert false
|
||||
end
|
||||
|
||||
# keep track
|
||||
if (track)
|
||||
push!(graph.diff.updatedChildren, (n, taskBefore))
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
get_snapshot_diff(graph::DAG)
|
||||
|
||||
Return the graph's [`Diff`](@ref) since last time this function was called.
|
||||
|
||||
See also: [`revert_diff!`](@ref), [`AppliedOperation`](@ref) and [`revert_operation!`](@ref)
|
||||
"""
|
||||
function get_snapshot_diff(graph::DAG)
|
||||
return swapfield!(graph, :diff, Diff())
|
||||
end
|
||||
|
||||
"""
|
||||
invalidate_caches!(graph::DAG, operation::NodeFusion)
|
||||
|
||||
Invalidate the operation caches for a given [`NodeFusion`](@ref).
|
||||
|
||||
This deletes the operation from the graph's possible operations and from the involved nodes' own operation caches.
|
||||
"""
|
||||
function invalidate_caches!(graph::DAG, operation::NodeFusion)
|
||||
delete!(graph.possibleOperations, operation)
|
||||
|
||||
# delete the operation from all caches of nodes involved in the operation
|
||||
# TODO: filter is very slow
|
||||
filter!(!=(operation), operation.input[1].nodeFusions)
|
||||
filter!(!=(operation), operation.input[3].nodeFusions)
|
||||
|
||||
operation.input[2].nodeFusion = missing
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
invalidate_caches!(graph::DAG, operation::NodeReduction)
|
||||
|
||||
Invalidate the operation caches for a given [`NodeReduction`](@ref).
|
||||
|
||||
This deletes the operation from the graph's possible operations and from the involved nodes' own operation caches.
|
||||
"""
|
||||
function invalidate_caches!(graph::DAG, operation::NodeReduction)
|
||||
delete!(graph.possibleOperations, operation)
|
||||
|
||||
for node in operation.input
|
||||
node.nodeReduction = missing
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
invalidate_caches!(graph::DAG, operation::NodeSplit)
|
||||
|
||||
Invalidate the operation caches for a given [`NodeSplit`](@ref).
|
||||
|
||||
This deletes the operation from the graph's possible operations and from the involved nodes' own operation caches.
|
||||
"""
|
||||
function invalidate_caches!(graph::DAG, operation::NodeSplit)
|
||||
delete!(graph.possibleOperations, operation)
|
||||
|
||||
# delete the operation from all caches of nodes involved in the operation
|
||||
# for node split there is only one node
|
||||
operation.input.nodeSplit = missing
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
invalidate_operation_caches!(graph::DAG, node::ComputeTaskNode)
|
||||
|
||||
Invalidate the operation caches of the given node through calls to the respective [`invalidate_caches!`](@ref) functions.
|
||||
"""
|
||||
function invalidate_operation_caches!(graph::DAG, node::ComputeTaskNode)
|
||||
if !ismissing(node.nodeReduction)
|
||||
invalidate_caches!(graph, node.nodeReduction)
|
||||
end
|
||||
if !ismissing(node.nodeSplit)
|
||||
invalidate_caches!(graph, node.nodeSplit)
|
||||
end
|
||||
while !isempty(node.nodeFusions)
|
||||
invalidate_caches!(graph, pop!(node.nodeFusions))
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
invalidate_operation_caches!(graph::DAG, node::DataTaskNode)
|
||||
|
||||
Invalidate the operation caches of the given node through calls to the respective [`invalidate_caches!`](@ref) functions.
|
||||
"""
|
||||
function invalidate_operation_caches!(graph::DAG, node::DataTaskNode)
|
||||
if !ismissing(node.nodeReduction)
|
||||
invalidate_caches!(graph, node.nodeReduction)
|
||||
end
|
||||
if !ismissing(node.nodeSplit)
|
||||
invalidate_caches!(graph, node.nodeSplit)
|
||||
end
|
||||
if !ismissing(node.nodeFusion)
|
||||
invalidate_caches!(graph, node.nodeFusion)
|
||||
end
|
||||
return nothing
|
||||
end
|
66
src/graph/print.jl
Normal file
66
src/graph/print.jl
Normal file
@ -0,0 +1,66 @@
|
||||
"""
|
||||
show_nodes(io::IO, graph::DAG)
|
||||
|
||||
Print a graph's nodes. Should only be used for small graphs as it prints every node in a list.
|
||||
"""
|
||||
function show_nodes(io::IO, graph::DAG)
|
||||
print(io, "[")
|
||||
first = true
|
||||
for n in graph.nodes
|
||||
if first
|
||||
first = false
|
||||
else
|
||||
print(io, ", ")
|
||||
end
|
||||
print(io, n)
|
||||
end
|
||||
return print(io, "]")
|
||||
end
|
||||
|
||||
"""
|
||||
show(io::IO, graph::DAG)
|
||||
|
||||
Print the given graph to io. If there are too many nodes it will print only a summary of them.
|
||||
"""
|
||||
function show(io::IO, graph::DAG)
|
||||
apply_all!(graph)
|
||||
println(io, "Graph:")
|
||||
print(io, " Nodes: ")
|
||||
|
||||
nodeDict = Dict{Type, Int64}()
|
||||
noEdges = 0
|
||||
for node in graph.nodes
|
||||
if haskey(nodeDict, typeof(node.task))
|
||||
nodeDict[typeof(node.task)] = nodeDict[typeof(node.task)] + 1
|
||||
else
|
||||
nodeDict[typeof(node.task)] = 1
|
||||
end
|
||||
noEdges += length(parents(node))
|
||||
end
|
||||
|
||||
if length(graph.nodes) <= 20
|
||||
show_nodes(io, graph)
|
||||
else
|
||||
print("Total: ", length(graph.nodes), ", ")
|
||||
first = true
|
||||
i = 0
|
||||
for (type, number) in zip(keys(nodeDict), values(nodeDict))
|
||||
i += 1
|
||||
if first
|
||||
first = false
|
||||
else
|
||||
print(", ")
|
||||
end
|
||||
if (i % 3 == 0)
|
||||
print("\n ")
|
||||
end
|
||||
print(type, ": ", number)
|
||||
end
|
||||
end
|
||||
println(io)
|
||||
println(io, " Edges: ", noEdges)
|
||||
properties = get_properties(graph)
|
||||
println(io, " Total Compute Effort: ", properties.computeEffort)
|
||||
println(io, " Total Data Transfer: ", properties.data)
|
||||
return println(io, " Total Compute Intensity: ", properties.computeIntensity)
|
||||
end
|
45
src/graph/properties.jl
Normal file
45
src/graph/properties.jl
Normal file
@ -0,0 +1,45 @@
|
||||
"""
|
||||
get_properties(graph::DAG)
|
||||
|
||||
Return the graph's [`GraphProperties`](@ref).
|
||||
"""
|
||||
function get_properties(graph::DAG)
|
||||
# make sure the graph is fully generated
|
||||
apply_all!(graph)
|
||||
|
||||
if (graph.properties.computeEffort == 0.0)
|
||||
graph.properties = GraphProperties(graph)
|
||||
end
|
||||
|
||||
return graph.properties
|
||||
end
|
||||
|
||||
"""
|
||||
get_exit_node(graph::DAG)
|
||||
|
||||
Return the graph's exit node. This assumes the graph only has a single exit node. If the graph has multiple exit nodes, the one encountered first will be returned.
|
||||
"""
|
||||
function get_exit_node(graph::DAG)
|
||||
for node in graph.nodes
|
||||
if (is_exit_node(node))
|
||||
return node
|
||||
end
|
||||
end
|
||||
@assert false "The given graph has no exit node! It is either empty or not acyclic!"
|
||||
end
|
||||
|
||||
"""
|
||||
get_entry_nodes(graph::DAG)
|
||||
|
||||
Return a vector of the graph's entry nodes.
|
||||
"""
|
||||
function get_entry_nodes(graph::DAG)
|
||||
apply_all!(graph)
|
||||
result = Vector{Node}()
|
||||
for node in graph.nodes
|
||||
if (is_entry_node(node))
|
||||
push!(result, node)
|
||||
end
|
||||
end
|
||||
return result
|
||||
end
|
73
src/graph/type.jl
Normal file
73
src/graph/type.jl
Normal file
@ -0,0 +1,73 @@
|
||||
using DataStructures
|
||||
|
||||
"""
|
||||
PossibleOperations
|
||||
|
||||
A struct storing all possible operations on a [`DAG`](@ref).
|
||||
To get the [`PossibleOperations`](@ref) on a [`DAG`](@ref), use [`get_operations`](@ref).
|
||||
"""
|
||||
mutable struct PossibleOperations
|
||||
nodeFusions::Set{NodeFusion}
|
||||
nodeReductions::Set{NodeReduction}
|
||||
nodeSplits::Set{NodeSplit}
|
||||
end
|
||||
|
||||
"""
|
||||
DAG
|
||||
|
||||
The representation of the graph as a set of [`Node`](@ref)s.
|
||||
|
||||
A DAG can be loaded using the appropriate parse_dag function, e.g. [`parse_dag`](@ref).
|
||||
|
||||
[`Operation`](@ref)s can be applied on it using [`push_operation!`](@ref) and reverted using [`pop_operation!`](@ref) like a stack.
|
||||
To get the set of possible operations, use [`get_operations`](@ref).
|
||||
The members of the object should not be manually accessed, instead always use the provided interface functions.
|
||||
"""
|
||||
mutable struct DAG
|
||||
nodes::Set{Node}
|
||||
|
||||
# The operations currently applied to the set of nodes
|
||||
appliedOperations::Stack{AppliedOperation}
|
||||
|
||||
# The operations not currently applied but part of the current state of the DAG
|
||||
operationsToApply::Deque{Operation}
|
||||
|
||||
# The possible operations at the current state of the DAG
|
||||
possibleOperations::PossibleOperations
|
||||
|
||||
# The set of nodes whose possible operations need to be reevaluated
|
||||
dirtyNodes::Set{Node}
|
||||
|
||||
# "snapshot" system: keep track of added/removed nodes/edges since last snapshot
|
||||
# these are muted in insert_node! etc.
|
||||
diff::Diff
|
||||
|
||||
# the cached properties of the DAG
|
||||
properties::GraphProperties
|
||||
end
|
||||
|
||||
"""
|
||||
PossibleOperations()
|
||||
|
||||
Construct and return an empty [`PossibleOperations`](@ref) object.
|
||||
"""
|
||||
function PossibleOperations()
|
||||
return PossibleOperations(Set{NodeFusion}(), Set{NodeReduction}(), Set{NodeSplit}())
|
||||
end
|
||||
|
||||
"""
|
||||
DAG()
|
||||
|
||||
Construct and return an empty [`DAG`](@ref).
|
||||
"""
|
||||
function DAG()
|
||||
return DAG(
|
||||
Set{Node}(),
|
||||
Stack{AppliedOperation}(),
|
||||
Deque{Operation}(),
|
||||
PossibleOperations(),
|
||||
Set{Node}(),
|
||||
Diff(),
|
||||
GraphProperties(),
|
||||
)
|
||||
end
|
77
src/graph/validate.jl
Normal file
77
src/graph/validate.jl
Normal file
@ -0,0 +1,77 @@
|
||||
"""
|
||||
is_connected(graph::DAG)
|
||||
|
||||
Return whether the given graph is connected.
|
||||
"""
|
||||
function is_connected(graph::DAG)
|
||||
nodeQueue = Deque{Node}()
|
||||
push!(nodeQueue, get_exit_node(graph))
|
||||
seenNodes = Set{Node}()
|
||||
|
||||
while !isempty(nodeQueue)
|
||||
current = pop!(nodeQueue)
|
||||
push!(seenNodes, current)
|
||||
|
||||
for child in current.children
|
||||
push!(nodeQueue, child)
|
||||
end
|
||||
end
|
||||
|
||||
return length(seenNodes) == length(graph.nodes)
|
||||
end
|
||||
|
||||
"""
|
||||
is_valid(graph::DAG)
|
||||
|
||||
Validate the entire graph using asserts. Intended for testing with `@assert is_valid(graph)`.
|
||||
"""
|
||||
function is_valid(graph::DAG)
|
||||
for node in graph.nodes
|
||||
@assert is_valid(graph, node)
|
||||
end
|
||||
|
||||
for op in graph.operationsToApply
|
||||
@assert is_valid(graph, op)
|
||||
end
|
||||
|
||||
for nr in graph.possibleOperations.nodeReductions
|
||||
@assert is_valid(graph, nr)
|
||||
end
|
||||
for ns in graph.possibleOperations.nodeSplits
|
||||
@assert is_valid(graph, ns)
|
||||
end
|
||||
for nf in graph.possibleOperations.nodeFusions
|
||||
@assert is_valid(graph, nf)
|
||||
end
|
||||
|
||||
for node in graph.dirtyNodes
|
||||
@assert node in graph "Dirty Node is not part of the graph!"
|
||||
@assert ismissing(node.nodeReduction) "Dirty Node has a NodeReduction!"
|
||||
@assert ismissing(node.nodeSplit) "Dirty Node has a NodeSplit!"
|
||||
if (typeof(node) <: DataTaskNode)
|
||||
@assert ismissing(node.nodeFusion) "Dirty DataTaskNode has a Node Fusion!"
|
||||
elseif (typeof(node) <: ComputeTaskNode)
|
||||
@assert isempty(node.nodeFusions) "Dirty ComputeTaskNode has Node Fusions!"
|
||||
end
|
||||
end
|
||||
|
||||
@assert is_connected(graph) "Graph is not connected!"
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
is_scheduled(graph::DAG)
|
||||
|
||||
Validate that the entire graph has been scheduled, i.e., every [`ComputeTaskNode`](@ref) has its `.device` set.
|
||||
"""
|
||||
function is_scheduled(graph::DAG)
|
||||
for node in graph.nodes
|
||||
if (node isa DataTaskNode)
|
||||
continue
|
||||
end
|
||||
@assert !ismissing(node.device)
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
@ -1,354 +0,0 @@
|
||||
using DataStructures
|
||||
|
||||
in(node::Node, graph::DAG) = node in graph.nodes
|
||||
in(edge::Edge, graph::DAG) = edge in graph.edges
|
||||
|
||||
function is_parent(potential_parent, node)
|
||||
return potential_parent in node.parents
|
||||
end
|
||||
|
||||
function is_child(potential_child, node)
|
||||
return potential_child in node.children
|
||||
end
|
||||
|
||||
function ==(n1::Node, n2::Node, g::DAG)
|
||||
if typeof(n1) != typeof(n2)
|
||||
return false
|
||||
end
|
||||
if !(n1 in g) || !(n2 in g)
|
||||
return false
|
||||
end
|
||||
|
||||
return n1.task == n2.task && children(n1) == children(n2)
|
||||
end
|
||||
|
||||
# children = prerequisite nodes, nodes that need to execute before the task, edges point into this task
|
||||
function children(node::Node)
|
||||
return copy(node.children)
|
||||
end
|
||||
|
||||
# parents = subsequent nodes, nodes that need this node to execute, edges point from this task
|
||||
function parents(node::Node)
|
||||
return copy(node.parents)
|
||||
end
|
||||
|
||||
# siblings = all children of any parents, no duplicates, includes the node itself
|
||||
function siblings(node::Node)
|
||||
result = Set{Node}()
|
||||
push!(result, node)
|
||||
for parent in node.parents
|
||||
union!(result, parent.children)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
# partners = all parents of any children, no duplicates, includes the node itself
|
||||
function partners(node::Node)
|
||||
result = Set{Node}()
|
||||
push!(result, node)
|
||||
for child in node.children
|
||||
union!(result, child.parents)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
# alternative version to partners(Node), avoiding allocation of a new set
|
||||
# works on the given set and returns nothing
|
||||
function partners(node::Node, set::Set{Node})
|
||||
push!(set, node)
|
||||
for child in node.children
|
||||
union!(set, child.parents)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
is_entry_node(node::Node) = length(node.children) == 0
|
||||
is_exit_node(node::Node) = length(node.parents) == 0
|
||||
|
||||
# function to invalidate the operation caches for a given NodeFusion
|
||||
function invalidate_caches!(graph::DAG, operation::NodeFusion)
|
||||
delete!(graph.possibleOperations, operation)
|
||||
|
||||
# delete the operation from all caches of nodes involved in the operation
|
||||
filter!(!=(operation), operation.input[1].nodeFusions)
|
||||
filter!(!=(operation), operation.input[3].nodeFusions)
|
||||
|
||||
operation.input[2].nodeFusion = missing
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# function to invalidate the operation caches for a given NodeReduction
|
||||
function invalidate_caches!(graph::DAG, operation::NodeReduction)
|
||||
delete!(graph.possibleOperations, operation)
|
||||
|
||||
for node in operation.input
|
||||
node.nodeReduction = missing
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# function to invalidate the operation caches for a given NodeSplit
|
||||
function invalidate_caches!(graph::DAG, operation::NodeSplit)
|
||||
delete!(graph.possibleOperations, operation)
|
||||
|
||||
# delete the operation from all caches of nodes involved in the operation
|
||||
# for node split there is only one node
|
||||
operation.input.nodeSplit = missing
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# function to invalidate the operation caches of a ComputeTaskNode
|
||||
function invalidate_operation_caches!(graph::DAG, node::ComputeTaskNode)
|
||||
if !ismissing(node.nodeReduction)
|
||||
invalidate_caches!(graph, node.nodeReduction)
|
||||
end
|
||||
if !ismissing(node.nodeSplit)
|
||||
invalidate_caches!(graph, node.nodeSplit)
|
||||
end
|
||||
while !isempty(node.nodeFusions)
|
||||
invalidate_caches!(graph, pop!(node.nodeFusions))
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
# function to invalidate the operation caches of a DataTaskNode
|
||||
function invalidate_operation_caches!(graph::DAG, node::DataTaskNode)
|
||||
if !ismissing(node.nodeReduction)
|
||||
invalidate_caches!(graph, node.nodeReduction)
|
||||
end
|
||||
if !ismissing(node.nodeSplit)
|
||||
invalidate_caches!(graph, node.nodeSplit)
|
||||
end
|
||||
if !ismissing(node.nodeFusion)
|
||||
invalidate_caches!(graph, node.nodeFusion)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
# for graph mutating functions we need to do a few things
|
||||
# 1: mute the graph (duh)
|
||||
# 2: keep track of what was changed for the diff (if track == true)
|
||||
# 3: invalidate operation caches
|
||||
|
||||
function insert_node!(graph::DAG, node::Node, track=true, invalidate_cache=true)
|
||||
# 1: mute
|
||||
push!(graph.nodes, node)
|
||||
|
||||
# 2: keep track
|
||||
if (track) push!(graph.diff.addedNodes, node) end
|
||||
|
||||
# 3: invalidate caches
|
||||
if (!invalidate_cache) return node end
|
||||
push!(graph.dirtyNodes, node)
|
||||
|
||||
return node
|
||||
end
|
||||
|
||||
function insert_edge!(graph::DAG, node1::Node, node2::Node, track=true, invalidate_cache=true)
|
||||
# @assert (node2 ∉ node1.parents) && (node1 ∉ node2.children) "Edge to insert already exists"
|
||||
|
||||
# 1: mute
|
||||
# edge points from child to parent
|
||||
push!(node1.parents, node2)
|
||||
push!(node2.children, node1)
|
||||
|
||||
# 2: keep track
|
||||
if (track) push!(graph.diff.addedEdges, make_edge(node1, node2)) end
|
||||
|
||||
# 3: invalidate caches
|
||||
if (!invalidate_cache) return nothing end
|
||||
|
||||
invalidate_operation_caches!(graph, node1)
|
||||
invalidate_operation_caches!(graph, node2)
|
||||
|
||||
push!(graph.dirtyNodes, node1)
|
||||
push!(graph.dirtyNodes, node2)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function remove_node!(graph::DAG, node::Node, track=true, invalidate_cache=true)
|
||||
# @assert node in graph.nodes "Trying to remove a node that's not in the graph"
|
||||
|
||||
# 1: mute
|
||||
delete!(graph.nodes, node)
|
||||
|
||||
# 2: keep track
|
||||
if (track) push!(graph.diff.removedNodes, node) end
|
||||
|
||||
# 3: invalidate caches
|
||||
if (!invalidate_cache) return nothing end
|
||||
|
||||
invalidate_operation_caches!(graph, node)
|
||||
delete!(graph.dirtyNodes, node)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function remove_edge!(graph::DAG, node1::Node, node2::Node, track=true, invalidate_cache=true)
|
||||
# 1: mute
|
||||
pre_length1 = length(node1.parents)
|
||||
pre_length2 = length(node2.children)
|
||||
filter!(x -> x != node2, node1.parents)
|
||||
filter!(x -> x != node1, node2.children)
|
||||
|
||||
#=@assert begin
|
||||
removed = pre_length1 - length(node1.parents)
|
||||
removed <= 1
|
||||
end "removed more than one node from node1's parents"=#
|
||||
|
||||
#=@assert begin
|
||||
removed = pre_length2 - length(node2.children)
|
||||
removed <= 1
|
||||
end "removed more than one node from node2's children"=#
|
||||
|
||||
# 2: keep track
|
||||
if (track) push!(graph.diff.removedEdges, make_edge(node1, node2)) end
|
||||
|
||||
# 3: invalidate caches
|
||||
if (!invalidate_cache) return nothing end
|
||||
|
||||
invalidate_operation_caches!(graph, node1)
|
||||
invalidate_operation_caches!(graph, node2)
|
||||
if (node1 in graph)
|
||||
push!(graph.dirtyNodes, node1)
|
||||
end
|
||||
if (node2 in graph)
|
||||
push!(graph.dirtyNodes, node2)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# return the graph "difference" since last time this function was called
|
||||
function get_snapshot_diff(graph::DAG)
|
||||
return swapfield!(graph, :diff, Diff())
|
||||
end
|
||||
|
||||
function graph_properties(graph::DAG)
|
||||
# make sure the graph is fully generated
|
||||
apply_all!(graph)
|
||||
|
||||
d = 0
|
||||
ce = 0
|
||||
ed = 0
|
||||
for node in graph.nodes
|
||||
d += data(node.task) * length(node.parents)
|
||||
ce += compute_effort(node.task)
|
||||
ed += length(node.parents)
|
||||
end
|
||||
|
||||
ci = ce / d
|
||||
|
||||
result = (data = d,
|
||||
compute_effort = ce,
|
||||
compute_intensity = ci,
|
||||
nodes = length(graph.nodes),
|
||||
edges = ed)
|
||||
return result
|
||||
end
|
||||
|
||||
function get_exit_node(graph::DAG)
|
||||
for node in graph.nodes
|
||||
if (is_exit_node(node))
|
||||
return node
|
||||
end
|
||||
end
|
||||
@assert false "The given graph has no exit node! It is either empty or not acyclic!"
|
||||
end
|
||||
|
||||
# check whether the given graph is connected
|
||||
function is_valid(graph::DAG)
|
||||
nodeQueue = Deque{Node}()
|
||||
push!(nodeQueue, get_exit_node(graph))
|
||||
seenNodes = Set{Node}()
|
||||
|
||||
while !isempty(nodeQueue)
|
||||
current = pop!(nodeQueue)
|
||||
push!(seenNodes, current)
|
||||
|
||||
for child in current.chlidren
|
||||
push!(nodeQueue, child)
|
||||
end
|
||||
end
|
||||
|
||||
return length(seenNodes) == length(graph.nodes)
|
||||
end
|
||||
|
||||
function show_nodes(io, graph::DAG)
|
||||
print(io, "[")
|
||||
first = true
|
||||
for n in graph.nodes
|
||||
if first
|
||||
first = false
|
||||
else
|
||||
print(io, ", ")
|
||||
end
|
||||
print(io, n)
|
||||
end
|
||||
print(io, "]")
|
||||
end
|
||||
|
||||
function show(io::IO, graph::DAG)
|
||||
println(io, "Graph:")
|
||||
print(io, " Nodes: ")
|
||||
|
||||
nodeDict = Dict{Type, Int64}()
|
||||
noEdges = 0
|
||||
for node in graph.nodes
|
||||
if haskey(nodeDict, typeof(node.task))
|
||||
nodeDict[typeof(node.task)] = nodeDict[typeof(node.task)] + 1
|
||||
else
|
||||
nodeDict[typeof(node.task)] = 1
|
||||
end
|
||||
noEdges += length(parents(node))
|
||||
end
|
||||
|
||||
if length(graph.nodes) <= 20
|
||||
show_nodes(io, graph)
|
||||
else
|
||||
print("Total: ", length(graph.nodes), ", ")
|
||||
first = true
|
||||
i = 0
|
||||
for (type, number) in zip(keys(nodeDict), values(nodeDict))
|
||||
i += 1
|
||||
if first
|
||||
first = false
|
||||
else
|
||||
print(", ")
|
||||
end
|
||||
if (i % 3 == 0)
|
||||
print("\n ")
|
||||
end
|
||||
print(type, ": ", number)
|
||||
end
|
||||
end
|
||||
println(io)
|
||||
println(io, " Edges: ", noEdges)
|
||||
properties = graph_properties(graph)
|
||||
println(io, " Total Compute Effort: ", properties.compute_effort)
|
||||
println(io, " Total Data Transfer: ", properties.data)
|
||||
println(io, " Total Compute Intensity: ", properties.compute_intensity)
|
||||
end
|
||||
|
||||
function show(io::IO, diff::Diff)
|
||||
print(io, "Nodes: ")
|
||||
print(io, length(diff.addedNodes) + length(diff.removedNodes))
|
||||
print(io, " Edges: ")
|
||||
print(io, length(diff.addedEdges) + length(diff.removedEdges))
|
||||
end
|
||||
|
||||
# return a namedtuple of the lengths of the added/removed nodes/edges
|
||||
function length(diff::Diff)
|
||||
return (
|
||||
addedNodes = length(diff.addedNodes),
|
||||
removedNodes = length(diff.removedNodes),
|
||||
addedEdges = length(diff.addedEdges),
|
||||
removedEdges = length(diff.removedEdges)
|
||||
)
|
||||
end
|
@ -1,34 +0,0 @@
|
||||
# user interface on the DAG
|
||||
|
||||
# applies a new operation to the end of the graph
|
||||
function push_operation!(graph::DAG, operation::Operation)
|
||||
# 1.: Add the operation to the DAG
|
||||
push!(graph.operationsToApply, operation)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# reverts the latest applied operation, essentially like a ctrl+z for
|
||||
function pop_operation!(graph::DAG)
|
||||
# 1.: Remove the operation from the appliedChain of the DAG
|
||||
if !isempty(graph.operationsToApply)
|
||||
pop!(graph.operationsToApply)
|
||||
elseif !isempty(graph.appliedOperations)
|
||||
appliedOp = pop!(graph.appliedOperations)
|
||||
revert_operation!(graph, appliedOp)
|
||||
else
|
||||
error("No more operations to pop!")
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
can_pop(graph::DAG) = !isempty(graph.operationsToApply) || !isempty(graph.appliedOperations)
|
||||
|
||||
# reset the graph to its initial state with no operations applied
|
||||
function reset_graph!(graph::DAG)
|
||||
while (can_pop(graph))
|
||||
pop_operation!(graph)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
154
src/models/abc/compute.jl
Normal file
154
src/models/abc/compute.jl
Normal file
@ -0,0 +1,154 @@
|
||||
using AccurateArithmetic
|
||||
|
||||
"""
|
||||
compute(::ComputeTaskP, data::ParticleValue)
|
||||
|
||||
Return the particle and value as is.
|
||||
|
||||
0 FLOP.
|
||||
"""
|
||||
function compute(::ComputeTaskP, data::ParticleValue)
|
||||
return data
|
||||
end
|
||||
|
||||
"""
|
||||
compute(::ComputeTaskU, data::ParticleValue)
|
||||
|
||||
Compute an outer edge. Return the particle value with the same particle and the value multiplied by an outer_edge factor.
|
||||
|
||||
1 FLOP.
|
||||
"""
|
||||
function compute(::ComputeTaskU, data::ParticleValue)
|
||||
return ParticleValue(data.p, data.v * outer_edge(data.p))
|
||||
end
|
||||
|
||||
"""
|
||||
compute(::ComputeTaskV, data1::ParticleValue, data2::ParticleValue)
|
||||
|
||||
Compute a vertex. Preserve momentum and particle types (AB->C etc.) to create resulting particle, multiply values together and times a vertex factor.
|
||||
|
||||
6 FLOP.
|
||||
"""
|
||||
function compute(::ComputeTaskV, data1::ParticleValue, data2::ParticleValue)
|
||||
p3 = preserve_momentum(data1.p, data2.p)
|
||||
dataOut = ParticleValue(p3, data1.v * vertex() * data2.v)
|
||||
return dataOut
|
||||
end
|
||||
|
||||
"""
|
||||
compute(::ComputeTaskS2, data1::ParticleValue, data2::ParticleValue)
|
||||
|
||||
Compute a final inner edge (2 input particles, no output particle).
|
||||
|
||||
For valid inputs, both input particles should have the same momenta at this point.
|
||||
|
||||
12 FLOP.
|
||||
"""
|
||||
function compute(::ComputeTaskS2, data1::ParticleValue, data2::ParticleValue)
|
||||
#=
|
||||
@assert isapprox(abs(data1.p.momentum.E), abs(data2.p.momentum.E), rtol = 0.001, atol = sqrt(eps())) "E: $(data1.p.momentum.E) vs. $(data2.p.momentum.E)"
|
||||
@assert isapprox(data1.p.momentum.px, -data2.p.momentum.px, rtol = 0.001, atol = sqrt(eps())) "px: $(data1.p.momentum.px) vs. $(data2.p.momentum.px)"
|
||||
@assert isapprox(data1.p.momentum.py, -data2.p.momentum.py, rtol = 0.001, atol = sqrt(eps())) "py: $(data1.p.momentum.py) vs. $(data2.p.momentum.py)"
|
||||
@assert isapprox(data1.p.momentum.pz, -data2.p.momentum.pz, rtol = 0.001, atol = sqrt(eps())) "pz: $(data1.p.momentum.pz) vs. $(data2.p.momentum.pz)"
|
||||
=#
|
||||
return data1.v * inner_edge(data1.p) * data2.v
|
||||
end
|
||||
|
||||
"""
|
||||
compute(::ComputeTaskS1, data::ParticleValue)
|
||||
|
||||
Compute inner edge (1 input particle, 1 output particle).
|
||||
|
||||
11 FLOP.
|
||||
"""
|
||||
function compute(::ComputeTaskS1, data::ParticleValue)
|
||||
return ParticleValue(data.p, data.v * inner_edge(data.p))
|
||||
end
|
||||
|
||||
"""
|
||||
compute(::ComputeTaskSum, data::Vector{Float64})
|
||||
|
||||
Compute a sum over the vector. Use an algorithm that accounts for accumulated errors in long sums with potentially large differences in magnitude of the summands.
|
||||
|
||||
Linearly many FLOP with growing data.
|
||||
"""
|
||||
function compute(::ComputeTaskSum, data::Vector{Float64})
|
||||
return sum_kbn(data)
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(::ComputeTaskP, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
|
||||
|
||||
Generate and return code evaluating [`ComputeTaskP`](@ref) on `inSyms`, providing the output on `outSym`.
|
||||
"""
|
||||
function get_expression(::ComputeTaskP, device::AbstractDevice, inExprs::Vector, outExpr)
|
||||
in = [eval(inExprs[1])]
|
||||
out = eval(outExpr)
|
||||
|
||||
return Meta.parse("$out = compute(ComputeTaskP(), $(in[1]))")
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(::ComputeTaskU, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
|
||||
|
||||
Generate code evaluating [`ComputeTaskU`](@ref) on `inSyms`, providing the output on `outSym`.
|
||||
`inSyms` should be of type [`ParticleValue`](@ref), `outSym` will be of type [`ParticleValue`](@ref).
|
||||
"""
|
||||
function get_expression(::ComputeTaskU, device::AbstractDevice, inExprs::Vector, outExpr)
|
||||
in = [eval(inExprs[1])]
|
||||
out = eval(outExpr)
|
||||
|
||||
return Meta.parse("$out = compute(ComputeTaskU(), $(in[1]))")
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(::ComputeTaskV, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
|
||||
|
||||
Generate code evaluating [`ComputeTaskV`](@ref) on `inSyms`, providing the output on `outSym`.
|
||||
`inSym[1]` and `inSym[2]` should be of type [`ParticleValue`](@ref), `outSym` will be of type [`ParticleValue`](@ref).
|
||||
"""
|
||||
function get_expression(::ComputeTaskV, device::AbstractDevice, inExprs::Vector, outExpr)
|
||||
in = [eval(inExprs[1]), eval(inExprs[2])]
|
||||
out = eval(outExpr)
|
||||
|
||||
return Meta.parse("$out = compute(ComputeTaskV(), $(in[1]), $(in[2]))")
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(::ComputeTaskS2, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
|
||||
|
||||
Generate code evaluating [`ComputeTaskS2`](@ref) on `inSyms`, providing the output on `outSym`.
|
||||
`inSyms[1]` and `inSyms[2]` should be of type [`ParticleValue`](@ref), `outSym` will be of type `Float64`.
|
||||
"""
|
||||
function get_expression(::ComputeTaskS2, device::AbstractDevice, inExprs::Vector, outExpr)
|
||||
in = [eval(inExprs[1]), eval(inExprs[2])]
|
||||
out = eval(outExpr)
|
||||
|
||||
return Meta.parse("$out = compute(ComputeTaskS2(), $(in[1]), $(in[2]))")
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(::ComputeTaskS1, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
|
||||
|
||||
Generate code evaluating [`ComputeTaskS1`](@ref) on `inSyms`, providing the output on `outSym`.
|
||||
`inSyms` should be of type [`ParticleValue`](@ref), `outSym` will be of type [`ParticleValue`](@ref).
|
||||
"""
|
||||
function get_expression(::ComputeTaskS1, device::AbstractDevice, inExprs::Vector, outExpr)
|
||||
in = [eval(inExprs[1])]
|
||||
out = eval(outExpr)
|
||||
|
||||
return Meta.parse("$out = compute(ComputeTaskS1(), $(in[1]))")
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(::ComputeTaskSum, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
|
||||
|
||||
Generate code evaluating [`ComputeTaskSum`](@ref) on `inSyms`, providing the output on `outSym`.
|
||||
`inSyms` should be of type [`Float64`], `outSym` will be of type [`Float64`].
|
||||
"""
|
||||
function get_expression(::ComputeTaskSum, device::AbstractDevice, inExprs::Vector, outExpr)
|
||||
in = eval.(inExprs)
|
||||
out = eval(outExpr)
|
||||
|
||||
return Meta.parse("$out = compute(ComputeTaskSum(), [$(unroll_symbol_vector(in))])")
|
||||
end
|
198
src/models/abc/create.jl
Normal file
198
src/models/abc/create.jl
Normal file
@ -0,0 +1,198 @@
|
||||
using QEDbase
|
||||
using Random
|
||||
using Roots
|
||||
using ForwardDiff
|
||||
|
||||
ComputeTaskSum() = ComputeTaskSum(0)
|
||||
|
||||
"""
|
||||
gen_process_input(processDescription::ABCProcessDescription)
|
||||
|
||||
Return a ProcessInput of randomly generated [`ABCParticle`](@ref)s from a [`ABCProcessDescription`](@ref). The process description can be created manually or parsed from a string using [`parse_process`](@ref).
|
||||
|
||||
Note: This uses RAMBO to create a valid process with conservation of momentum and energy.
|
||||
"""
|
||||
function gen_process_input(processDescription::ABCProcessDescription)
|
||||
inParticleTypes = keys(processDescription.inParticles)
|
||||
outParticleTypes = keys(processDescription.outParticles)
|
||||
|
||||
massSum = 0
|
||||
inputMasses = Vector{Float64}()
|
||||
for (particle, n) in processDescription.inParticles
|
||||
for _ in 1:n
|
||||
massSum += mass(particle)
|
||||
push!(inputMasses, mass(particle))
|
||||
end
|
||||
end
|
||||
outputMasses = Vector{Float64}()
|
||||
for (particle, n) in processDescription.outParticles
|
||||
for _ in 1:n
|
||||
massSum += mass(particle)
|
||||
push!(outputMasses, mass(particle))
|
||||
end
|
||||
end
|
||||
|
||||
# add some extra random mass to allow for some momentum
|
||||
massSum += rand(rng[threadid()]) * (length(inputMasses) + length(outputMasses))
|
||||
|
||||
|
||||
inputParticles = Vector{ABCParticle}()
|
||||
initialMomenta = generate_initial_moms(massSum, inputMasses)
|
||||
index = 1
|
||||
for (particle, n) in processDescription.inParticles
|
||||
for _ in 1:n
|
||||
mom = initialMomenta[index]
|
||||
push!(inputParticles, particle(mom))
|
||||
index += 1
|
||||
end
|
||||
end
|
||||
|
||||
outputParticles = Vector{ABCParticle}()
|
||||
final_momenta = generate_physical_massive_moms(rng[threadid()], massSum, outputMasses)
|
||||
index = 1
|
||||
for (particle, n) in processDescription.outParticles
|
||||
for _ in 1:n
|
||||
mom = final_momenta[index]
|
||||
push!(outputParticles, particle(SFourMomentum(-mom.E, mom.px, mom.py, mom.pz)))
|
||||
index += 1
|
||||
end
|
||||
end
|
||||
|
||||
processInput = ABCProcessInput(processDescription, inputParticles, outputParticles)
|
||||
|
||||
return return processInput
|
||||
end
|
||||
|
||||
####################
|
||||
# CODE FROM HERE BORROWED FROM SOURCE: https://codebase.helmholtz.cloud/qedsandbox/QEDphasespaces.jl/
|
||||
# use qedphasespaces directly once released
|
||||
#
|
||||
# quick and dirty implementation of the RAMBO algorithm
|
||||
#
|
||||
# reference:
|
||||
# * https://cds.cern.ch/record/164736/files/198601282.pdf
|
||||
# * https://www.sciencedirect.com/science/article/pii/0010465586901190
|
||||
####################
|
||||
|
||||
function generate_initial_moms(ss, masses)
|
||||
E1 = (ss^2 + masses[1]^2 - masses[2]^2) / (2 * ss)
|
||||
E2 = (ss^2 + masses[2]^2 - masses[1]^2) / (2 * ss)
|
||||
|
||||
rho1 = sqrt(E1^2 - masses[1]^2)
|
||||
rho2 = sqrt(E2^2 - masses[2]^2)
|
||||
|
||||
return [SFourMomentum(E1, 0, 0, rho1), SFourMomentum(E2, 0, 0, -rho2)]
|
||||
end
|
||||
|
||||
|
||||
Random.rand(rng::AbstractRNG, ::Random.SamplerType{SFourMomentum}) = SFourMomentum(rand(rng, 4))
|
||||
Random.rand(rng::AbstractRNG, ::Random.SamplerType{NTuple{N, Float64}}) where {N} = Tuple(rand(rng, N))
|
||||
|
||||
|
||||
function _transform_uni_to_mom(u1, u2, u3, u4)
|
||||
cth = 2 * u1 - 1
|
||||
sth = sqrt(1 - cth^2)
|
||||
phi = 2 * pi * u2
|
||||
q0 = -log(u3 * u4)
|
||||
qx = q0 * sth * cos(phi)
|
||||
qy = q0 * sth * sin(phi)
|
||||
qz = q0 * cth
|
||||
|
||||
return SFourMomentum(q0, qx, qy, qz)
|
||||
end
|
||||
|
||||
function _transform_uni_to_mom!(uni_mom, dest)
|
||||
u1, u2, u3, u4 = Tuple(uni_mom)
|
||||
cth = 2 * u1 - 1
|
||||
sth = sqrt(1 - cth^2)
|
||||
phi = 2 * pi * u2
|
||||
q0 = -log(u3 * u4)
|
||||
qx = q0 * sth * cos(phi)
|
||||
qy = q0 * sth * sin(phi)
|
||||
qz = q0 * cth
|
||||
|
||||
return dest = SFourMomentum(q0, qx, qy, qz)
|
||||
end
|
||||
|
||||
_transform_uni_to_mom(u1234::Tuple) = _transform_uni_to_mom(u1234...)
|
||||
_transform_uni_to_mom(u1234::SFourMomentum) = _transform_uni_to_mom(Tuple(u1234))
|
||||
|
||||
function generate_massless_moms(rng, n::Int)
|
||||
a = Vector{SFourMomentum}(undef, n)
|
||||
rand!(rng, a)
|
||||
return map(_transform_uni_to_mom, a)
|
||||
end
|
||||
|
||||
function generate_physical_massless_moms(rng, ss, n)
|
||||
r_moms = generate_massless_moms(rng, n)
|
||||
Q = sum(r_moms)
|
||||
M = sqrt(Q * Q)
|
||||
fac = -1 / M
|
||||
Qx = getX(Q)
|
||||
Qy = getY(Q)
|
||||
Qz = getZ(Q)
|
||||
bx = fac * Qx
|
||||
by = fac * Qy
|
||||
bz = fac * Qz
|
||||
gamma = getT(Q) / M
|
||||
a = 1 / (1 + gamma)
|
||||
x = ss / M
|
||||
|
||||
i = 1
|
||||
while i <= n
|
||||
mom = r_moms[i]
|
||||
mom0 = getT(mom)
|
||||
mom1 = getX(mom)
|
||||
mom2 = getY(mom)
|
||||
mom3 = getZ(mom)
|
||||
|
||||
bq = bx * mom1 + by * mom2 + bz * mom3
|
||||
|
||||
p0 = x * (gamma * mom0 + bq)
|
||||
px = x * (mom1 + bx * mom0 + a * bq * bx)
|
||||
py = x * (mom2 + by * mom0 + a * bq * by)
|
||||
pz = x * (mom3 + bz * mom0 + a * bq * bz)
|
||||
|
||||
r_moms[i] = SFourMomentum(p0, px, py, pz)
|
||||
i += 1
|
||||
end
|
||||
return r_moms
|
||||
end
|
||||
|
||||
function _to_be_solved(xi, masses, p0s, ss)
|
||||
sum = 0.0
|
||||
for (i, E) in enumerate(p0s)
|
||||
sum += sqrt(masses[i]^2 + xi^2 * E^2)
|
||||
end
|
||||
return sum - ss
|
||||
end
|
||||
|
||||
function _build_massive_momenta(xi, masses, massless_moms)
|
||||
vec = SFourMomentum[]
|
||||
i = 1
|
||||
while i <= length(massless_moms)
|
||||
massless_mom = massless_moms[i]
|
||||
k0 = sqrt(getT(massless_mom)^2 * xi^2 + masses[i]^2)
|
||||
|
||||
kx = xi * getX(massless_mom)
|
||||
ky = xi * getY(massless_mom)
|
||||
kz = xi * getZ(massless_mom)
|
||||
|
||||
push!(vec, SFourMomentum(k0, kx, ky, kz))
|
||||
|
||||
i += 1
|
||||
end
|
||||
return vec
|
||||
end
|
||||
|
||||
first_derivative(func) = x -> ForwardDiff.derivative(func, float(x))
|
||||
|
||||
|
||||
function generate_physical_massive_moms(rng, ss, masses; x0 = 0.1)
|
||||
n = length(masses)
|
||||
massless_moms = generate_physical_massless_moms(rng, ss, n)
|
||||
energies = getT.(massless_moms)
|
||||
f = x -> _to_be_solved(x, masses, energies, ss)
|
||||
xi = find_zero((f, first_derivative(f)), x0, Roots.Newton())
|
||||
return _build_massive_momenta(xi, masses, massless_moms)
|
||||
end
|
248
src/models/abc/parse.jl
Normal file
248
src/models/abc/parse.jl
Normal file
@ -0,0 +1,248 @@
|
||||
# functions for importing DAGs from a file
|
||||
regex_a = r"^[A-C]\d+$" # Regex for the initial particles
|
||||
regex_c = r"^[A-C]\(([^']*),([^']*)\)$" # Regex for the combinations of 2 particles
|
||||
regex_m = r"^M\(([^']*),([^']*),([^']*)\)$" # Regex for the combinations of 3 particles
|
||||
regex_plus = r"^\+$" # Regex for the sum
|
||||
|
||||
const PARTICLE_VALUE_SIZE::Int = 48
|
||||
const FLOAT_SIZE::Int = 8
|
||||
|
||||
"""
|
||||
parse_nodes(input::AbstractString)
|
||||
|
||||
Parse the given string into a vector of strings containing each node.
|
||||
"""
|
||||
function parse_nodes(input::AbstractString)
|
||||
regex = r"'([^']*)'"
|
||||
matches = eachmatch(regex, input)
|
||||
output = [match.captures[1] for match in matches]
|
||||
return output
|
||||
end
|
||||
|
||||
"""
|
||||
parse_edges(input::AbstractString)
|
||||
|
||||
Parse the given string into a vector of strings containing each edge. Currently unused since the entire graph can be read from just the node names.
|
||||
"""
|
||||
function parse_edges(input::AbstractString)
|
||||
regex = r"\('([^']*)', '([^']*)'\)"
|
||||
matches = eachmatch(regex, input)
|
||||
output = [(match.captures[1], match.captures[2]) for match in matches]
|
||||
return output
|
||||
end
|
||||
|
||||
"""
|
||||
parse_dag(filename::String, model::ABCModel; verbose::Bool = false)
|
||||
|
||||
Read an abc-model process from the given file. If `verbose` is set to true, print some progress information to stdout.
|
||||
|
||||
Returns a valid [`DAG`](@ref).
|
||||
"""
|
||||
function parse_dag(filename::AbstractString, model::ABCModel, verbose::Bool = false)
|
||||
file = open(filename, "r")
|
||||
|
||||
if (verbose)
|
||||
println("Opened file")
|
||||
end
|
||||
nodes_string = readline(file)
|
||||
nodes = parse_nodes(nodes_string)
|
||||
|
||||
close(file)
|
||||
if (verbose)
|
||||
println("Read file")
|
||||
end
|
||||
|
||||
graph = DAG()
|
||||
|
||||
# estimate total number of nodes
|
||||
# try to slightly overestimate so no resizing is necessary
|
||||
# data nodes are not included in length(nodes) and there are a few more than compute nodes
|
||||
estimate_no_nodes = round(Int, length(nodes) * 4)
|
||||
if (verbose)
|
||||
println("Estimating ", estimate_no_nodes, " Nodes")
|
||||
end
|
||||
sizehint!(graph.nodes, estimate_no_nodes)
|
||||
|
||||
sum_node = insert_node!(graph, make_node(ComputeTaskSum(0)), track = false, invalidate_cache = false)
|
||||
global_data_out = insert_node!(graph, make_node(DataTask(FLOAT_SIZE)), track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, sum_node, global_data_out, track = false, invalidate_cache = false)
|
||||
|
||||
# remember the data out nodes for connection
|
||||
dataOutNodes = Dict()
|
||||
|
||||
if (verbose)
|
||||
println("Building graph")
|
||||
end
|
||||
noNodes = 0
|
||||
nodesToRead = length(nodes)
|
||||
while !isempty(nodes)
|
||||
node = popfirst!(nodes)
|
||||
noNodes += 1
|
||||
if (noNodes % 100 == 0)
|
||||
if (verbose)
|
||||
percent = string(round(100.0 * noNodes / nodesToRead, digits = 2), "%")
|
||||
print("\rReading Nodes... $percent")
|
||||
end
|
||||
end
|
||||
if occursin(regex_a, node)
|
||||
# add nodes and edges for the state reading to u(P(Particle))
|
||||
data_in = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE), string(node)),
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
) # read particle data node
|
||||
compute_P = insert_node!(graph, make_node(ComputeTaskP()), track = false, invalidate_cache = false) # compute P node
|
||||
data_Pu =
|
||||
insert_node!(graph, make_node(DataTask(PARTICLE_VALUE_SIZE)), track = false, invalidate_cache = false) # transfer data from P to u (one ParticleValue object)
|
||||
compute_u = insert_node!(graph, make_node(ComputeTaskU()), track = false, invalidate_cache = false) # compute U node
|
||||
data_out =
|
||||
insert_node!(graph, make_node(DataTask(PARTICLE_VALUE_SIZE)), track = false, invalidate_cache = false) # transfer data out from u (one ParticleValue object)
|
||||
|
||||
insert_edge!(graph, data_in, compute_P, track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, compute_P, data_Pu, track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, data_Pu, compute_u, track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, compute_u, data_out, track = false, invalidate_cache = false)
|
||||
|
||||
# remember the data_out node for future edges
|
||||
dataOutNodes[node] = data_out
|
||||
elseif occursin(regex_c, node)
|
||||
capt = match(regex_c, node)
|
||||
|
||||
in1 = capt.captures[1]
|
||||
in2 = capt.captures[2]
|
||||
|
||||
compute_v = insert_node!(graph, make_node(ComputeTaskV()), track = false, invalidate_cache = false)
|
||||
data_out =
|
||||
insert_node!(graph, make_node(DataTask(PARTICLE_VALUE_SIZE)), track = false, invalidate_cache = false)
|
||||
|
||||
if (occursin(regex_c, in1))
|
||||
# put an S node after this input
|
||||
compute_S = insert_node!(graph, make_node(ComputeTaskS1()), track = false, invalidate_cache = false)
|
||||
data_S_v = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE)),
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
|
||||
insert_edge!(graph, dataOutNodes[in1], compute_S, track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, compute_S, data_S_v, track = false, invalidate_cache = false)
|
||||
|
||||
insert_edge!(graph, data_S_v, compute_v, track = false, invalidate_cache = false)
|
||||
else
|
||||
insert_edge!(graph, dataOutNodes[in1], compute_v, track = false, invalidate_cache = false)
|
||||
end
|
||||
|
||||
if (occursin(regex_c, in2))
|
||||
# i think the current generator only puts the combined particles in the first space, so this case might never be entered
|
||||
# put an S node after this input
|
||||
compute_S = insert_node!(graph, make_node(ComputeTaskS1()), track = false, invalidate_cache = false)
|
||||
data_S_v = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE)),
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
|
||||
insert_edge!(graph, dataOutNodes[in2], compute_S, track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, compute_S, data_S_v, track = false, invalidate_cache = false)
|
||||
|
||||
insert_edge!(graph, data_S_v, compute_v, track = false, invalidate_cache = false)
|
||||
else
|
||||
insert_edge!(graph, dataOutNodes[in2], compute_v, track = false, invalidate_cache = false)
|
||||
end
|
||||
|
||||
insert_edge!(graph, compute_v, data_out, track = false, invalidate_cache = false)
|
||||
dataOutNodes[node] = data_out
|
||||
|
||||
elseif occursin(regex_m, node)
|
||||
# assume for now that only the first particle of the three is combined and the other two are "original" ones
|
||||
capt = match(regex_m, node)
|
||||
in1 = capt.captures[1]
|
||||
in2 = capt.captures[2]
|
||||
in3 = capt.captures[3]
|
||||
|
||||
# in2 + in3 with a v
|
||||
compute_v = insert_node!(graph, make_node(ComputeTaskV()), track = false, invalidate_cache = false)
|
||||
data_v =
|
||||
insert_node!(graph, make_node(DataTask(PARTICLE_VALUE_SIZE)), track = false, invalidate_cache = false)
|
||||
|
||||
insert_edge!(graph, dataOutNodes[in2], compute_v, track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, dataOutNodes[in3], compute_v, track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, compute_v, data_v, track = false, invalidate_cache = false)
|
||||
|
||||
# combine with the v of the combined other input
|
||||
compute_S2 = insert_node!(graph, make_node(ComputeTaskS2()), track = false, invalidate_cache = false)
|
||||
data_out = insert_node!(graph, make_node(DataTask(FLOAT_SIZE)), track = false, invalidate_cache = false) # output of a S2 task is only a float
|
||||
|
||||
insert_edge!(graph, data_v, compute_S2, track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, dataOutNodes[in1], compute_S2, track = false, invalidate_cache = false)
|
||||
insert_edge!(graph, compute_S2, data_out, track = false, invalidate_cache = false)
|
||||
|
||||
insert_edge!(graph, data_out, sum_node, track = false, invalidate_cache = false)
|
||||
add_child!(sum_node.task)
|
||||
elseif occursin(regex_plus, node)
|
||||
if (verbose)
|
||||
println("\rReading Nodes Complete ")
|
||||
println("Added ", length(graph.nodes), " nodes")
|
||||
end
|
||||
else
|
||||
@assert false ("Unknown node '$node' while reading from file $filename")
|
||||
end
|
||||
end
|
||||
|
||||
#put all nodes into dirty nodes set
|
||||
graph.dirtyNodes = copy(graph.nodes)
|
||||
|
||||
if (verbose)
|
||||
println("Generating the graph's properties")
|
||||
end
|
||||
graph.properties = GraphProperties(graph)
|
||||
|
||||
if (verbose)
|
||||
println("Done")
|
||||
end
|
||||
|
||||
# don't actually need to read the edges
|
||||
return graph
|
||||
end
|
||||
|
||||
"""
|
||||
parse_process(string::AbstractString, model::ABCModel)
|
||||
|
||||
Parse a string representation of a process, such as "AB->ABBB" into the corresponding [`ABCProcessDescription`](@ref).
|
||||
"""
|
||||
function parse_process(str::AbstractString, model::ABCModel)
|
||||
inParticles = Dict{Type, Int}()
|
||||
outParticles = Dict{Type, Int}()
|
||||
|
||||
if !(contains(str, "->"))
|
||||
throw("Did not find -> while parsing process \"$str\"")
|
||||
end
|
||||
|
||||
(inStr, outStr) = split(str, "->")
|
||||
|
||||
if (isempty(inStr) || isempty(outStr))
|
||||
throw("Process (\"$str\") input or output part is empty!")
|
||||
end
|
||||
|
||||
for t in types(model)
|
||||
inCount = count(x -> x == String(t)[1], inStr)
|
||||
outCount = count(x -> x == String(t)[1], outStr)
|
||||
if inCount != 0
|
||||
inParticles[t] = inCount
|
||||
end
|
||||
if outCount != 0
|
||||
outParticles[t] = outCount
|
||||
end
|
||||
end
|
||||
|
||||
if length(inStr) != sum(values(inParticles))
|
||||
throw("Encountered unknown characters in the input part of process \"$str\"")
|
||||
elseif length(outStr) != sum(values(outParticles))
|
||||
throw("Encountered unknown characters in the output part of process \"$str\"")
|
||||
end
|
||||
|
||||
return ABCProcessDescription(inParticles, outParticles)
|
||||
end
|
209
src/models/abc/particle.jl
Normal file
209
src/models/abc/particle.jl
Normal file
@ -0,0 +1,209 @@
|
||||
using QEDbase
|
||||
|
||||
"""
|
||||
ABCModel <: AbstractPhysicsModel
|
||||
|
||||
Singleton definition for identification of the ABC-Model.
|
||||
"""
|
||||
struct ABCModel <: AbstractPhysicsModel end
|
||||
|
||||
"""
|
||||
ABCParticle
|
||||
|
||||
Base type for all particles in the [`ABCModel`](@ref).
|
||||
"""
|
||||
abstract type ABCParticle <: AbstractParticle end
|
||||
|
||||
"""
|
||||
ParticleA <: ABCParticle
|
||||
|
||||
An 'A' particle in the ABC Model.
|
||||
"""
|
||||
struct ParticleA <: ABCParticle
|
||||
momentum::SFourMomentum
|
||||
end
|
||||
|
||||
"""
|
||||
ParticleB <: ABCParticle
|
||||
|
||||
A 'B' particle in the ABC Model.
|
||||
"""
|
||||
struct ParticleB <: ABCParticle
|
||||
momentum::SFourMomentum
|
||||
end
|
||||
|
||||
"""
|
||||
ParticleC <: ABCParticle
|
||||
|
||||
A 'C' particle in the ABC Model.
|
||||
"""
|
||||
struct ParticleC <: ABCParticle
|
||||
momentum::SFourMomentum
|
||||
end
|
||||
|
||||
"""
|
||||
ABCProcessDescription <: AbstractProcessDescription
|
||||
|
||||
A description of a process in the ABC-Model. Contains the input and output particles.
|
||||
|
||||
See also: [`in_particles`](@ref), [`out_particles`](@ref), [`parse_process`](@ref)
|
||||
"""
|
||||
struct ABCProcessDescription <: AbstractProcessDescription
|
||||
inParticles::Dict{Type, Int}
|
||||
outParticles::Dict{Type, Int}
|
||||
end
|
||||
|
||||
"""
|
||||
ABCProcessInput <: AbstractProcessInput
|
||||
|
||||
Input for a ABC Process. Contains the [`ABCProcessDescription`](@ref) of the process it is an input for, and the values of the in and out particles.
|
||||
|
||||
See also: [`gen_process_input`](@ref)
|
||||
"""
|
||||
struct ABCProcessInput <: AbstractProcessInput
|
||||
process::ABCProcessDescription
|
||||
inParticles::Vector{ABCParticle}
|
||||
outParticles::Vector{ABCParticle}
|
||||
end
|
||||
|
||||
"""
|
||||
PARTICLE_MASSES
|
||||
|
||||
A constant dictionary containing the masses of the different [`ABCParticle`](@ref)s.
|
||||
"""
|
||||
const PARTICLE_MASSES = Dict{Type, Float64}(ParticleA => 1.0, ParticleB => 1.0, ParticleC => 0.0)
|
||||
|
||||
"""
|
||||
mass(t::Type{T}) where {T <: ABCParticle}
|
||||
|
||||
Return the mass (at rest) of the given particle type.
|
||||
"""
|
||||
mass(t::Type{T}) where {T <: ABCParticle} = PARTICLE_MASSES[t]
|
||||
|
||||
"""
|
||||
interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: ABCParticle, T2 <: ABCParticle}
|
||||
|
||||
For 2 given (non-equal) particle types, return the third of ABC.
|
||||
"""
|
||||
function interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: ABCParticle, T2 <: ABCParticle}
|
||||
@assert t1 != t2
|
||||
if t1 != Type{ParticleA} && t2 != Type{ParticleA}
|
||||
return ParticleA
|
||||
elseif t1 != Type{ParticleB} && t2 != Type{ParticleB}
|
||||
return ParticleB
|
||||
else
|
||||
return ParticleC
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
types(::ABCModel)
|
||||
|
||||
Return a Vector of the possible types of particle in the [`ABCModel`](@ref).
|
||||
"""
|
||||
function types(::ABCModel)
|
||||
return [ParticleA, ParticleB, ParticleC]
|
||||
end
|
||||
|
||||
"""
|
||||
square(p::ABCParticle)
|
||||
|
||||
Return the square of the particle's momentum as a `Float` value.
|
||||
|
||||
Takes 7 effective FLOP.
|
||||
"""
|
||||
function square(p::ABCParticle)
|
||||
return getMass2(p.momentum)
|
||||
end
|
||||
|
||||
"""
|
||||
inner_edge(p::ABCParticle)
|
||||
|
||||
Return the factor of the inner edge with the given (virtual) particle.
|
||||
|
||||
Takes 10 effective FLOP. (3 here + 7 in square(p))
|
||||
"""
|
||||
function inner_edge(p::ABCParticle)
|
||||
return 1.0 / (square(p) - mass(typeof(p)) * mass(typeof(p)))
|
||||
end
|
||||
|
||||
"""
|
||||
outer_edge(p::ABCParticle)
|
||||
|
||||
Return the factor of the outer edge with the given (real) particle.
|
||||
|
||||
Takes 0 effective FLOP.
|
||||
"""
|
||||
function outer_edge(p::ABCParticle)
|
||||
return 1.0
|
||||
end
|
||||
|
||||
"""
|
||||
vertex()
|
||||
|
||||
Return the factor of a vertex.
|
||||
|
||||
Takes 0 effective FLOP since it's constant.
|
||||
"""
|
||||
function vertex()
|
||||
i = 1.0
|
||||
lambda = 1.0 / 137.0
|
||||
return i * lambda
|
||||
end
|
||||
|
||||
"""
|
||||
preserve_momentum(p1::ABCParticle, p2::ABCParticle)
|
||||
|
||||
Calculate and return a new particle from two given interacting ones at a vertex.
|
||||
|
||||
Takes 4 effective FLOP.
|
||||
"""
|
||||
function preserve_momentum(p1::ABCParticle, p2::ABCParticle)
|
||||
t3 = interaction_result(typeof(p1), typeof(p2))
|
||||
p3 = t3(p1.momentum + p2.momentum)
|
||||
|
||||
return p3
|
||||
end
|
||||
|
||||
"""
|
||||
type_from_name(name::String)
|
||||
|
||||
For a name of a particle, return the particle's [`Type`].
|
||||
"""
|
||||
function type_from_name(name::String)
|
||||
if startswith(name, "A")
|
||||
return ParticleA
|
||||
elseif startswith(name, "B")
|
||||
return ParticleB
|
||||
elseif startswith(name, "C")
|
||||
return ParticleC
|
||||
else
|
||||
throw("Invalid name for a particle in the ABC model")
|
||||
end
|
||||
end
|
||||
|
||||
function String(::Type{ParticleA})
|
||||
return "A"
|
||||
end
|
||||
function String(::Type{ParticleB})
|
||||
return "B"
|
||||
end
|
||||
function String(::Type{ParticleC})
|
||||
return "C"
|
||||
end
|
||||
|
||||
function in_particles(process::ABCProcessDescription)
|
||||
return process.inParticles
|
||||
end
|
||||
|
||||
function in_particles(input::ABCProcessInput)
|
||||
return input.inParticles
|
||||
end
|
||||
|
||||
function out_particles(process::ABCProcessDescription)
|
||||
return process.outParticles
|
||||
end
|
||||
|
||||
function out_particles(input::ABCProcessInput)
|
||||
return input.outParticles
|
||||
end
|
58
src/models/abc/print.jl
Normal file
58
src/models/abc/print.jl
Normal file
@ -0,0 +1,58 @@
|
||||
|
||||
"""
|
||||
show(io::IO, process::ABCProcessDescription)
|
||||
|
||||
Pretty print an [`ABCProcessDescription`](@ref) (no newlines).
|
||||
|
||||
```jldoctest
|
||||
julia> using MetagraphOptimization
|
||||
|
||||
julia> print(parse_process("AB->ABBB", ABCModel()))
|
||||
ABC Process: 'AB->ABBB'
|
||||
```
|
||||
"""
|
||||
function show(io::IO, process::ABCProcessDescription)
|
||||
# types() gives the types in order (ABC) instead of random like keys() would
|
||||
print(io, "ABC Process: \'")
|
||||
for type in types(ABCModel())
|
||||
for _ in 1:get(process.inParticles, type, 0)
|
||||
print(io, String(type))
|
||||
end
|
||||
end
|
||||
print(io, "->")
|
||||
for type in types(ABCModel())
|
||||
for _ in 1:get(process.outParticles, type, 0)
|
||||
print(io, String(type))
|
||||
end
|
||||
end
|
||||
print(io, "'")
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
show(io::IO, processInput::ABCProcessInput)
|
||||
|
||||
Pretty print an [`ABCProcessInput`](@ref) (with newlines).
|
||||
"""
|
||||
function show(io::IO, processInput::ABCProcessInput)
|
||||
println(io, "Input for $(processInput.process):")
|
||||
println(io, " $(length(processInput.inParticles)) Incoming particles:")
|
||||
for particle in processInput.inParticles
|
||||
println(io, " $particle")
|
||||
end
|
||||
println(io, " $(length(processInput.outParticles)) Outgoing Particles:")
|
||||
for particle in processInput.outParticles
|
||||
println(io, " $particle")
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
show(io::IO, particle::T) where {T <: ABCParticle}
|
||||
|
||||
Pretty print an [`ABCParticle`](@ref) (no newlines).
|
||||
"""
|
||||
function show(io::IO, particle::T) where {T <: ABCParticle}
|
||||
print(io, "$(String(typeof(particle))): $(particle.momentum)")
|
||||
return nothing
|
||||
end
|
166
src/models/abc/properties.jl
Normal file
166
src/models/abc/properties.jl
Normal file
@ -0,0 +1,166 @@
|
||||
"""
|
||||
compute_effort(t::ComputeTaskS1)
|
||||
|
||||
Return the compute effort of an S1 task.
|
||||
"""
|
||||
compute_effort(t::ComputeTaskS1) = 11
|
||||
|
||||
"""
|
||||
compute_effort(t::ComputeTaskS2)
|
||||
|
||||
Return the compute effort of an S2 task.
|
||||
"""
|
||||
compute_effort(t::ComputeTaskS2) = 12
|
||||
|
||||
"""
|
||||
compute_effort(t::ComputeTaskU)
|
||||
|
||||
Return the compute effort of a U task.
|
||||
"""
|
||||
compute_effort(t::ComputeTaskU) = 1
|
||||
|
||||
"""
|
||||
compute_effort(t::ComputeTaskV)
|
||||
|
||||
Return the compute effort of a V task.
|
||||
"""
|
||||
compute_effort(t::ComputeTaskV) = 6
|
||||
|
||||
"""
|
||||
compute_effort(t::ComputeTaskP)
|
||||
|
||||
Return the compute effort of a P task.
|
||||
"""
|
||||
compute_effort(t::ComputeTaskP) = 0
|
||||
|
||||
"""
|
||||
compute_effort(t::ComputeTaskSum)
|
||||
|
||||
Return the compute effort of a Sum task.
|
||||
|
||||
Note: This is a constant compute effort, even though sum scales with the number of its inputs. Since there is only ever a single sum node in a graph generated from the ABC-Model,
|
||||
this doesn't matter.
|
||||
"""
|
||||
compute_effort(t::ComputeTaskSum) = 1
|
||||
|
||||
"""
|
||||
show(io::IO, t::DataTask)
|
||||
|
||||
Print the data task to io.
|
||||
"""
|
||||
function show(io::IO, t::DataTask)
|
||||
return print(io, "Data", t.data)
|
||||
end
|
||||
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskS1)
|
||||
|
||||
Print the S1 task to io.
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskS1) = print(io, "ComputeS1")
|
||||
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskS2)
|
||||
|
||||
Print the S2 task to io.
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskS2) = print(io, "ComputeS2")
|
||||
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskP)
|
||||
|
||||
Print the P task to io.
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskP) = print(io, "ComputeP")
|
||||
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskU)
|
||||
|
||||
Print the U task to io.
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskU) = print(io, "ComputeU")
|
||||
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskV)
|
||||
|
||||
Print the V task to io.
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskV) = print(io, "ComputeV")
|
||||
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskSum)
|
||||
|
||||
Print the sum task to io.
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskSum) = print(io, "ComputeSum")
|
||||
|
||||
"""
|
||||
copy(t::DataTask)
|
||||
|
||||
Copy the data task and return it.
|
||||
"""
|
||||
copy(t::DataTask) = DataTask(t.data)
|
||||
|
||||
"""
|
||||
children(::DataTask)
|
||||
|
||||
Return the number of children of a data task (always 1).
|
||||
"""
|
||||
children(::DataTask) = 1
|
||||
|
||||
"""
|
||||
children(::ComputeTaskS1)
|
||||
|
||||
Return the number of children of a ComputeTaskS1 (always 1).
|
||||
"""
|
||||
children(::ComputeTaskS1) = 1
|
||||
|
||||
"""
|
||||
children(::ComputeTaskS2)
|
||||
|
||||
Return the number of children of a ComputeTaskS2 (always 2).
|
||||
"""
|
||||
children(::ComputeTaskS2) = 2
|
||||
|
||||
"""
|
||||
children(::ComputeTaskP)
|
||||
|
||||
Return the number of children of a ComputeTaskP (always 1).
|
||||
"""
|
||||
children(::ComputeTaskP) = 1
|
||||
|
||||
"""
|
||||
children(::ComputeTaskU)
|
||||
|
||||
Return the number of children of a ComputeTaskU (always 1).
|
||||
"""
|
||||
children(::ComputeTaskU) = 1
|
||||
|
||||
"""
|
||||
children(::ComputeTaskV)
|
||||
|
||||
Return the number of children of a ComputeTaskV (always 2).
|
||||
"""
|
||||
children(::ComputeTaskV) = 2
|
||||
|
||||
|
||||
"""
|
||||
children(::ComputeTaskSum)
|
||||
|
||||
Return the number of children of a ComputeTaskSum.
|
||||
"""
|
||||
children(t::ComputeTaskSum) = t.children_number
|
||||
|
||||
"""
|
||||
children(t::FusedComputeTask)
|
||||
|
||||
Return the number of children of a FusedComputeTask.
|
||||
"""
|
||||
function children(t::FusedComputeTask)
|
||||
return length(union(Set(t.t1_inputs), Set(t.t2_inputs)))
|
||||
end
|
||||
|
||||
function add_child!(t::ComputeTaskSum)
|
||||
t.children_number += 1
|
||||
return nothing
|
||||
end
|
59
src/models/abc/types.jl
Normal file
59
src/models/abc/types.jl
Normal file
@ -0,0 +1,59 @@
|
||||
"""
|
||||
DataTask <: AbstractDataTask
|
||||
|
||||
Task representing a specific data transfer in the ABC Model.
|
||||
"""
|
||||
struct DataTask <: AbstractDataTask
|
||||
data::UInt64
|
||||
end
|
||||
|
||||
"""
|
||||
ComputeTaskS1 <: AbstractComputeTask
|
||||
|
||||
S task with a single child.
|
||||
"""
|
||||
struct ComputeTaskS1 <: AbstractComputeTask end
|
||||
|
||||
"""
|
||||
ComputeTaskS2 <: AbstractComputeTask
|
||||
|
||||
S task with two children.
|
||||
"""
|
||||
struct ComputeTaskS2 <: AbstractComputeTask end
|
||||
|
||||
"""
|
||||
ComputeTaskP <: AbstractComputeTask
|
||||
|
||||
P task with no children.
|
||||
"""
|
||||
struct ComputeTaskP <: AbstractComputeTask end
|
||||
|
||||
"""
|
||||
ComputeTaskV <: AbstractComputeTask
|
||||
|
||||
v task with two children.
|
||||
"""
|
||||
struct ComputeTaskV <: AbstractComputeTask end
|
||||
|
||||
"""
|
||||
ComputeTaskU <: AbstractComputeTask
|
||||
|
||||
u task with a single child.
|
||||
"""
|
||||
struct ComputeTaskU <: AbstractComputeTask end
|
||||
|
||||
"""
|
||||
ComputeTaskSum <: AbstractComputeTask
|
||||
|
||||
Task that sums all its inputs, n children.
|
||||
"""
|
||||
mutable struct ComputeTaskSum <: AbstractComputeTask
|
||||
children_number::Int
|
||||
end
|
||||
|
||||
"""
|
||||
ABC_TASKS
|
||||
|
||||
Constant vector of all tasks of the ABC-Model.
|
||||
"""
|
||||
ABC_TASKS = [DataTask, ComputeTaskS1, ComputeTaskS2, ComputeTaskP, ComputeTaskV, ComputeTaskU, ComputeTaskSum]
|
109
src/models/interface.jl
Normal file
109
src/models/interface.jl
Normal file
@ -0,0 +1,109 @@
|
||||
|
||||
"""
|
||||
AbstractPhysicsModel
|
||||
|
||||
Base type for a model, e.g. ABC-Model or QED. This is used to dispatch many functions.
|
||||
"""
|
||||
abstract type AbstractPhysicsModel end
|
||||
|
||||
"""
|
||||
AbstractParticle
|
||||
|
||||
Base type for particles belonging to a certain [`AbstractPhysicsModel`](@ref).
|
||||
"""
|
||||
abstract type AbstractParticle end
|
||||
|
||||
"""
|
||||
ParticleValue{ParticleType <: AbstractParticle}
|
||||
|
||||
A struct describing a particle during a calculation of a Feynman Diagram, together with the value that's being calculated.
|
||||
|
||||
`sizeof(ParticleValue())` = 48 Byte
|
||||
"""
|
||||
struct ParticleValue{ParticleType <: AbstractParticle}
|
||||
p::ParticleType
|
||||
v::Float64
|
||||
end
|
||||
|
||||
"""
|
||||
AbstractProcessDescription
|
||||
|
||||
Base type for process descriptions. An object of this type of a corresponding [`AbstractPhysicsModel`](@ref) should uniquely identify a process in that model.
|
||||
|
||||
See also: [`parse_process`](@ref)
|
||||
"""
|
||||
abstract type AbstractProcessDescription end
|
||||
|
||||
"""
|
||||
AbstractProcessInput
|
||||
|
||||
Base type for process inputs. An object of this type contains the input values (e.g. momenta) of the particles in a process.
|
||||
|
||||
See also: [`gen_process_input`](@ref)
|
||||
"""
|
||||
abstract type AbstractProcessInput end
|
||||
|
||||
"""
|
||||
mass(t::Type{T}) where {T <: AbstractParticle}
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractParticle`](@ref), returning the particles mass at rest.
|
||||
"""
|
||||
function mass end
|
||||
|
||||
"""
|
||||
interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: AbstractParticle, T2 <: AbstractParticle}
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractParticle`](@ref), returning the result particle type when the two given particles interact.
|
||||
"""
|
||||
function interaction_result end
|
||||
|
||||
"""
|
||||
types(::AbstractPhysicsModel)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractPhysicsModel`](@ref), returning a `Vector` of the available particle types in the model.
|
||||
"""
|
||||
function types end
|
||||
|
||||
"""
|
||||
in_particles(::AbstractProcessDescription)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessDescription`](@ref).
|
||||
Returns a `<: Dict{Type{AbstractParticle}, Int}` object, representing the number of incoming particles for the process per particle type.
|
||||
|
||||
|
||||
in_particles(::AbstractProcessInput)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
|
||||
Returns a `<: Vector{AbstractParticle}` object with the values of all incoming particles for the corresponding `ProcessDescription`.
|
||||
"""
|
||||
function in_particles end
|
||||
|
||||
"""
|
||||
out_particles(::AbstractProcessDescription)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessDescription`](@ref).
|
||||
Returns a `<: Dict{Type{AbstractParticle}, Int}` object, representing the number of outgoing particles for the process per particle type.
|
||||
|
||||
|
||||
out_particles(::AbstractProcessInput)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractProcessInput`](@ref).
|
||||
Returns a `<: Vector{AbstractParticle}` object with the values of all outgoing particles for the corresponding `ProcessDescription`.
|
||||
"""
|
||||
function out_particles end
|
||||
|
||||
"""
|
||||
parse_process(::AbstractString, ::AbstractPhysicsModel)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractPhysicsModel`](@ref).
|
||||
Returns a `ProcessDescription` object.
|
||||
"""
|
||||
function parse_process end
|
||||
|
||||
"""
|
||||
gen_process_input(::AbstractProcessDescription)
|
||||
|
||||
Interface function that must be implemented for every specific [`AbstractProcessDescription`](@ref).
|
||||
Returns a randomly generated and valid corresponding `ProcessInput`.
|
||||
"""
|
||||
function gen_process_input end
|
10
src/models/print.jl
Normal file
10
src/models/print.jl
Normal file
@ -0,0 +1,10 @@
|
||||
|
||||
"""
|
||||
show(io::IO, particleValue::ParticleValue)
|
||||
|
||||
Pretty print a [`ParticleValue`](@ref), no newlines.
|
||||
"""
|
||||
function show(io::IO, particleValue::ParticleValue)
|
||||
print(io, "($(particleValue.p), value: $(particleValue.v))")
|
||||
return nothing
|
||||
end
|
35
src/node/compare.jl
Normal file
35
src/node/compare.jl
Normal file
@ -0,0 +1,35 @@
|
||||
"""
|
||||
==(e1::Edge, e2::Edge)
|
||||
|
||||
Equality comparison between two edges.
|
||||
"""
|
||||
function ==(e1::Edge, e2::Edge)
|
||||
return e1.edge[1] == e2.edge[1] && e1.edge[2] == e2.edge[2]
|
||||
end
|
||||
|
||||
"""
|
||||
==(n1::Node, n2::Node)
|
||||
|
||||
Fallback equality comparison between two nodes. For equal node types, the more specific versions of this function will be called.
|
||||
"""
|
||||
function ==(n1::Node, n2::Node)
|
||||
return false
|
||||
end
|
||||
|
||||
"""
|
||||
==(n1::ComputeTaskNode, n2::ComputeTaskNode)
|
||||
|
||||
Equality comparison between two [`ComputeTaskNode`](@ref)s.
|
||||
"""
|
||||
function ==(n1::ComputeTaskNode, n2::ComputeTaskNode)
|
||||
return n1.id == n2.id
|
||||
end
|
||||
|
||||
"""
|
||||
==(n1::DataTaskNode, n2::DataTaskNode)
|
||||
|
||||
Equality comparison between two [`DataTaskNode`](@ref)s.
|
||||
"""
|
||||
function ==(n1::DataTaskNode, n2::DataTaskNode)
|
||||
return n1.id == n2.id
|
||||
end
|
71
src/node/create.jl
Normal file
71
src/node/create.jl
Normal file
@ -0,0 +1,71 @@
|
||||
|
||||
DataTaskNode(t::AbstractDataTask, name = "") =
|
||||
DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, missing, name)
|
||||
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
|
||||
t, # task
|
||||
Vector{Node}(), # parents
|
||||
Vector{Node}(), # children
|
||||
UUIDs.uuid1(rng[threadid()]), # id
|
||||
missing, # node reduction
|
||||
missing, # node split
|
||||
Vector{NodeFusion}(), # node fusions
|
||||
missing, # device
|
||||
)
|
||||
|
||||
copy(m::Missing) = missing
|
||||
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task))
|
||||
copy(n::DataTaskNode) = DataTaskNode(copy(n.task), n.name)
|
||||
|
||||
"""
|
||||
make_node(t::AbstractTask)
|
||||
|
||||
Fallback implementation of `make_node` for an [`AbstractTask`](@ref), throwing an error.
|
||||
"""
|
||||
function make_node(t::AbstractTask)
|
||||
return error("Cannot make a node from this task type")
|
||||
end
|
||||
|
||||
"""
|
||||
make_node(t::AbstractDataTask)
|
||||
|
||||
Construct and return a new [`DataTaskNode`](@ref) with the given task.
|
||||
"""
|
||||
function make_node(t::AbstractDataTask, name::String = "")
|
||||
return DataTaskNode(t, name)
|
||||
end
|
||||
|
||||
"""
|
||||
make_node(t::AbstractComputeTask)
|
||||
|
||||
Construct and return a new [`ComputeTaskNode`](@ref) with the given task.
|
||||
"""
|
||||
function make_node(t::AbstractComputeTask)
|
||||
return ComputeTaskNode(t)
|
||||
end
|
||||
|
||||
"""
|
||||
make_edge(n1::Node, n2::Node)
|
||||
|
||||
Fallback implementation of `make_edge` throwing an error. If you got this error it likely means you tried to construct an edge between two nodes of the same type.
|
||||
"""
|
||||
function make_edge(n1::Node, n2::Node)
|
||||
return error("Can only create edges from compute to data node or reverse")
|
||||
end
|
||||
|
||||
"""
|
||||
make_edge(n1::ComputeTaskNode, n2::DataTaskNode)
|
||||
|
||||
Construct and return a new [`Edge`](@ref) pointing from `n1` (child) to `n2` (parent).
|
||||
"""
|
||||
function make_edge(n1::ComputeTaskNode, n2::DataTaskNode)
|
||||
return Edge((n1, n2))
|
||||
end
|
||||
|
||||
"""
|
||||
make_edge(n1::DataTaskNode, n2::ComputeTaskNode)
|
||||
|
||||
Construct and return a new [`Edge`](@ref) pointing from `n1` (child) to `n2` (parent).
|
||||
"""
|
||||
function make_edge(n1::DataTaskNode, n2::ComputeTaskNode)
|
||||
return Edge((n1, n2))
|
||||
end
|
27
src/node/print.jl
Normal file
27
src/node/print.jl
Normal file
@ -0,0 +1,27 @@
|
||||
"""
|
||||
show(io::IO, n::Node)
|
||||
|
||||
Print a short string representation of the node to io.
|
||||
"""
|
||||
function show(io::IO, n::Node)
|
||||
return print(io, "Node(", n.task, ")")
|
||||
end
|
||||
|
||||
"""
|
||||
show(io::IO, e::Edge)
|
||||
|
||||
Print a short string representation of the edge to io.
|
||||
"""
|
||||
function show(io::IO, e::Edge)
|
||||
return print(io, "Edge(", e.edge[1], ", ", e.edge[2], ")")
|
||||
end
|
||||
|
||||
"""
|
||||
to_var_name(id::UUID)
|
||||
|
||||
Return the uuid as a string usable as a variable name in code generation.
|
||||
"""
|
||||
function to_var_name(id::UUID)
|
||||
str = "_" * replace(string(id), "-" => "_")
|
||||
return str
|
||||
end
|
115
src/node/properties.jl
Normal file
115
src/node/properties.jl
Normal file
@ -0,0 +1,115 @@
|
||||
"""
|
||||
is_entry_node(node::Node)
|
||||
|
||||
Return whether this node is an entry node in its graph, i.e., it has no children.
|
||||
"""
|
||||
is_entry_node(node::Node) = length(node.children) == 0
|
||||
|
||||
"""
|
||||
is_exit_node(node::Node)
|
||||
|
||||
Return whether this node is an exit node of its graph, i.e., it has no parents.
|
||||
"""
|
||||
is_exit_node(node::Node) = length(node.parents) == 0
|
||||
|
||||
"""
|
||||
data(edge::Edge)
|
||||
|
||||
Return the data transfered by this edge, i.e., 0 if the child is a [`ComputeTaskNode`](@ref), otherwise the child's `data()`.
|
||||
"""
|
||||
function data(edge::Edge)
|
||||
if typeof(edge.edge[1]) <: DataTaskNode
|
||||
return data(edge.edge[1].task)
|
||||
end
|
||||
return 0.0
|
||||
end
|
||||
|
||||
"""
|
||||
children(node::Node)
|
||||
|
||||
Return a copy of the node's children so it can safely be muted without changing the node in the graph.
|
||||
|
||||
A node's children are its prerequisite nodes, nodes that need to execute before the task of this node.
|
||||
"""
|
||||
function children(node::Node)
|
||||
return copy(node.children)
|
||||
end
|
||||
|
||||
"""
|
||||
parents(node::Node)
|
||||
|
||||
Return a copy of the node's parents so it can safely be muted without changing the node in the graph.
|
||||
|
||||
A node's parents are its subsequent nodes, nodes that need this node to execute.
|
||||
"""
|
||||
function parents(node::Node)
|
||||
return copy(node.parents)
|
||||
end
|
||||
|
||||
"""
|
||||
siblings(node::Node)
|
||||
|
||||
Return a vector of all siblings of this node.
|
||||
|
||||
A node's siblings are all children of any of its parents. The result contains no duplicates and includes the node itself.
|
||||
"""
|
||||
function siblings(node::Node)
|
||||
result = Set{Node}()
|
||||
push!(result, node)
|
||||
for parent in node.parents
|
||||
union!(result, parent.children)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
"""
|
||||
partners(node::Node)
|
||||
|
||||
Return a vector of all partners of this node.
|
||||
|
||||
A node's partners are all parents of any of its children. The result contains no duplicates and includes the node itself.
|
||||
|
||||
Note: This is very slow when there are multiple children with many parents.
|
||||
This is less of a problem in [`siblings(node::Node)`](@ref) because (depending on the model) there are no nodes with a large number of children, or only a single one.
|
||||
"""
|
||||
function partners(node::Node)
|
||||
result = Set{Node}()
|
||||
push!(result, node)
|
||||
for child in node.children
|
||||
union!(result, child.parents)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
"""
|
||||
partners(node::Node, set::Set{Node})
|
||||
|
||||
Alternative version to [`partners(node::Node)`](@ref), avoiding allocation of a new set. Works on the given set and returns `nothing`.
|
||||
"""
|
||||
function partners(node::Node, set::Set{Node})
|
||||
push!(set, node)
|
||||
for child in node.children
|
||||
union!(set, child.parents)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
is_parent(potential_parent::Node, node::Node)
|
||||
|
||||
Return whether the `potential_parent` is a parent of `node`.
|
||||
"""
|
||||
function is_parent(potential_parent::Node, node::Node)
|
||||
return potential_parent in node.parents
|
||||
end
|
||||
|
||||
"""
|
||||
is_child(potential_child::Node, node::Node)
|
||||
|
||||
Return whether the `potential_child` is a child of `node`.
|
||||
"""
|
||||
function is_child(potential_child::Node, node::Node)
|
||||
return potential_child in node.children
|
||||
end
|
104
src/node/type.jl
Normal file
104
src/node/type.jl
Normal file
@ -0,0 +1,104 @@
|
||||
using Random
|
||||
using UUIDs
|
||||
using Base.Threads
|
||||
|
||||
# TODO: reliably find out how many threads we're running with (nthreads() returns 1 when precompiling :/)
|
||||
rng = [Random.MersenneTwister(0) for _ in 1:32]
|
||||
|
||||
"""
|
||||
Node
|
||||
|
||||
The abstract base type of every node.
|
||||
|
||||
See [`DataTaskNode`](@ref), [`ComputeTaskNode`](@ref) and [`make_node`](@ref).
|
||||
"""
|
||||
abstract type Node end
|
||||
|
||||
# declare this type here because it's needed
|
||||
# the specific operations are declared in graph.jl
|
||||
abstract type Operation end
|
||||
|
||||
"""
|
||||
DataTaskNode <: Node
|
||||
|
||||
Any node that transfers data and does no computation.
|
||||
|
||||
# Fields
|
||||
`.task`: The node's data task type. Usually [`DataTask`](@ref).\\
|
||||
`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
|
||||
`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
|
||||
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
|
||||
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.\\
|
||||
`.name`: The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\
|
||||
"""
|
||||
mutable struct DataTaskNode <: Node
|
||||
task::AbstractDataTask
|
||||
|
||||
# use vectors as sets have way too much memory overhead
|
||||
parents::Vector{Node}
|
||||
children::Vector{Node}
|
||||
|
||||
# need a unique identifier unique to every *constructed* node
|
||||
# however, it can be copied when splitting a node
|
||||
id::Base.UUID
|
||||
|
||||
# the NodeReduction involving this node, if it exists
|
||||
# Can't use the NodeReduction type here because it's not yet defined
|
||||
nodeReduction::Union{Operation, Missing}
|
||||
|
||||
# the NodeSplit involving this node, if it exists
|
||||
nodeSplit::Union{Operation, Missing}
|
||||
|
||||
# the node fusion involving this node, if it exists
|
||||
nodeFusion::Union{Operation, Missing}
|
||||
|
||||
# for input nodes we need a name for the node to distinguish between them
|
||||
name::String
|
||||
end
|
||||
|
||||
"""
|
||||
ComputeTaskNode <: Node
|
||||
|
||||
Any node that computes a result from inputs using an [`AbstractComputeTask`](@ref).
|
||||
|
||||
# Fields
|
||||
`.task`: The node's compute task type. A concrete subtype of [`AbstractComputeTask`](@ref).\\
|
||||
`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
|
||||
`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
|
||||
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
|
||||
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeFusions`: A vector of this node's [`NodeFusion`](@ref)s. For a `ComputeTaskNode` there can be any number of these, unlike the [`DataTaskNode`](@ref)s.\\
|
||||
`.device`: The Device this node has been scheduled on by a [`Scheduler`](@ref).
|
||||
"""
|
||||
mutable struct ComputeTaskNode <: Node
|
||||
task::AbstractComputeTask
|
||||
parents::Vector{Node}
|
||||
children::Vector{Node}
|
||||
id::Base.UUID
|
||||
|
||||
nodeReduction::Union{Operation, Missing}
|
||||
nodeSplit::Union{Operation, Missing}
|
||||
|
||||
# for ComputeTasks there can be multiple fusions, unlike the DataTasks
|
||||
nodeFusions::Vector{Operation}
|
||||
|
||||
# the device this node is assigned to execute on
|
||||
device::Union{AbstractDevice, Missing}
|
||||
end
|
||||
|
||||
"""
|
||||
Edge
|
||||
|
||||
Type of an edge in the graph. Edges can only exist between a [`DataTaskNode`](@ref) and a [`ComputeTaskNode`](@ref) or vice versa, not between two of the same type of node.
|
||||
|
||||
An edge always points from child to parent: `child = e.edge[1]` and `parent = e.edge[2]`.
|
||||
|
||||
The child is the prerequisite node of the parent.
|
||||
"""
|
||||
struct Edge
|
||||
# edge points from child to parent
|
||||
edge::Union{Tuple{DataTaskNode, ComputeTaskNode}, Tuple{ComputeTaskNode, DataTaskNode}}
|
||||
end
|
76
src/node/validate.jl
Normal file
76
src/node/validate.jl
Normal file
@ -0,0 +1,76 @@
|
||||
"""
|
||||
is_valid_node(graph::DAG, node::Node)
|
||||
|
||||
Verify that a given node is valid in the graph. Call like `@test is_valid_node(g, n)`. Uses `@assert` to fail if something is invalid but also provide an error message.
|
||||
|
||||
This function is very performance intensive and should only be used when testing or debugging.
|
||||
|
||||
See also this function's specific versions for the concrete Node types [`is_valid(graph::DAG, node::ComputeTaskNode)`](@ref) and [`is_valid(graph::DAG, node::DataTaskNode)`](@ref).
|
||||
"""
|
||||
function is_valid_node(graph::DAG, node::Node)
|
||||
@assert node in graph "Node is not part of the given graph!"
|
||||
|
||||
for parent in node.parents
|
||||
@assert typeof(parent) != typeof(node) "Node's type is the same as its parent's!"
|
||||
@assert parent in graph "Node's parent is not in the same graph!"
|
||||
@assert node in parent.children "Node is not a child of its parent!"
|
||||
end
|
||||
|
||||
for child in node.children
|
||||
@assert typeof(child) != typeof(node) "Node's type is the same as its child's!"
|
||||
@assert child in graph "Node's child is not in the same graph!"
|
||||
@assert node in child.parents "Node is not a parent of its child!"
|
||||
end
|
||||
|
||||
#=if !ismissing(node.nodeReduction)
|
||||
@assert is_valid(graph, node.nodeReduction)
|
||||
end
|
||||
if !ismissing(node.nodeSplit)
|
||||
@assert is_valid(graph, node.nodeSplit)
|
||||
end=#
|
||||
|
||||
if !(typeof(node.task) <: FusedComputeTask)
|
||||
# the remaining checks are only necessary for fused compute tasks
|
||||
return true
|
||||
end
|
||||
|
||||
# every child must be in some input of the task
|
||||
for child in node.children
|
||||
str = Symbol(to_var_name(child.id))
|
||||
@assert (str in node.task.t1_inputs) || (str in node.task.t2_inputs) "$str was not in any of the tasks' inputs\nt1_inputs: $(node.task.t1_inputs)\nt2_inputs: $(node.task.t2_inputs)"
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
is_valid(graph::DAG, node::ComputeTaskNode)
|
||||
|
||||
Verify that the given compute node is valid in the graph. Call with `@assert` or `@test` when testing or debugging.
|
||||
|
||||
This also calls [`is_valid_node(graph::DAG, node::Node)`](@ref).
|
||||
"""
|
||||
function is_valid(graph::DAG, node::ComputeTaskNode)
|
||||
@assert is_valid_node(graph, node)
|
||||
|
||||
#=for nf in node.nodeFusions
|
||||
@assert is_valid(graph, nf)
|
||||
end=#
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
is_valid(graph::DAG, node::DataTaskNode)
|
||||
|
||||
Verify that the given compute node is valid in the graph. Call with `@assert` or `@test` when testing or debugging.
|
||||
|
||||
This also calls [`is_valid_node(graph::DAG, node::Node)`](@ref).
|
||||
"""
|
||||
function is_valid(graph::DAG, node::DataTaskNode)
|
||||
@assert is_valid_node(graph, node)
|
||||
|
||||
#=if !ismissing(node.nodeFusion)
|
||||
@assert is_valid(graph, node.nodeFusion)
|
||||
end=#
|
||||
return true
|
||||
end
|
@ -1,51 +0,0 @@
|
||||
function make_node(t::AbstractTask)
|
||||
error("Cannot make a node from this task type")
|
||||
end
|
||||
|
||||
function make_node(t::AbstractDataTask)
|
||||
return DataTaskNode(t)
|
||||
end
|
||||
|
||||
function make_node(t::AbstractComputeTask)
|
||||
return ComputeTaskNode(t)
|
||||
end
|
||||
|
||||
function make_edge(n1::Node, n2::Node)
|
||||
error("Can only create edges from compute to data node or reverse")
|
||||
end
|
||||
|
||||
function make_edge(n1::ComputeTaskNode, n2::DataTaskNode)
|
||||
return Edge((n1, n2))
|
||||
end
|
||||
|
||||
function make_edge(n1::DataTaskNode, n2::ComputeTaskNode)
|
||||
return Edge((n1, n2))
|
||||
end
|
||||
|
||||
function show(io::IO, n::Node)
|
||||
print(io, "Node(", n.task, ")")
|
||||
end
|
||||
|
||||
function show(io::IO, e::Edge)
|
||||
print(io, "Edge(", e.edge[1], ", ", e.edge[2], ")")
|
||||
end
|
||||
|
||||
function ==(e1::Edge, e2::Edge)
|
||||
return e1.edge[1] == e2.edge[1] && e1.edge[2] == e2.edge[2]
|
||||
end
|
||||
|
||||
function ==(n1::Node, n2::Node)
|
||||
return false
|
||||
end
|
||||
|
||||
function ==(n1::ComputeTaskNode, n2::ComputeTaskNode)
|
||||
return n1.id == n2.id
|
||||
end
|
||||
|
||||
function ==(n1::DataTaskNode, n2::DataTaskNode)
|
||||
return n1.id == n2.id
|
||||
end
|
||||
|
||||
copy(m::Missing) = missing
|
||||
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng[threadid()]), copy(n.nodeReduction), copy(n.nodeSplit), copy(n.nodeFusions))
|
||||
copy(n::DataTaskNode) = DataTaskNode(copy(n.task), copy(n.parents), copy(n.children), UUIDs.uuid1(rng[threadid()]), copy(n.nodeReduction), copy(n.nodeSplit), copy(n.nodeFusion))
|
56
src/nodes.jl
56
src/nodes.jl
@ -1,56 +0,0 @@
|
||||
using Random
|
||||
using UUIDs
|
||||
using Base.Threads
|
||||
|
||||
# TODO: reliably find out how many threads we're running with (nthreads() returns 1 when precompiling :/)
|
||||
rng = [Random.MersenneTwister(0) for _ in 1:32]
|
||||
|
||||
abstract type Node end
|
||||
|
||||
# declare this type here because it's needed
|
||||
# the specific operations are declared in graph.jl
|
||||
abstract type Operation end
|
||||
|
||||
mutable struct DataTaskNode <: Node
|
||||
task::AbstractDataTask
|
||||
|
||||
# use vectors as sets have way too much memory overhead
|
||||
parents::Vector{Node}
|
||||
children::Vector{Node}
|
||||
|
||||
# need a unique identifier unique to every *constructed* node
|
||||
# however, it can be copied when splitting a node
|
||||
id::Base.UUID
|
||||
|
||||
# the NodeReduction involving this node, if it exists
|
||||
# Can't use the NodeReduction type here because it's not yet defined
|
||||
nodeReduction::Union{Operation, Missing}
|
||||
|
||||
# the NodeSplit involving this node, if it exists
|
||||
nodeSplit::Union{Operation, Missing}
|
||||
|
||||
# the node fusion involving this node, if it exists
|
||||
nodeFusion::Union{Operation, Missing}
|
||||
end
|
||||
|
||||
# same as DataTaskNode
|
||||
mutable struct ComputeTaskNode <: Node
|
||||
task::AbstractComputeTask
|
||||
parents::Vector{Node}
|
||||
children::Vector{Node}
|
||||
id::Base.UUID
|
||||
|
||||
nodeReduction::Union{Operation, Missing}
|
||||
nodeSplit::Union{Operation, Missing}
|
||||
|
||||
# for ComputeTasks there can be multiple fusions, unlike the DataTasks
|
||||
nodeFusions::Vector{Operation}
|
||||
end
|
||||
|
||||
DataTaskNode(t::AbstractDataTask) = DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, missing)
|
||||
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, Vector{NodeFusion}())
|
||||
|
||||
struct Edge
|
||||
# edge points from child to parent
|
||||
edge::Union{Tuple{DataTaskNode, ComputeTaskNode}, Tuple{ComputeTaskNode, DataTaskNode}}
|
||||
end
|
312
src/operation/apply.jl
Normal file
312
src/operation/apply.jl
Normal file
@ -0,0 +1,312 @@
|
||||
"""
|
||||
apply_all!(graph::DAG)
|
||||
|
||||
Apply all unapplied operations in the DAG. Is automatically called in all functions that require the latest state of the [`DAG`](@ref).
|
||||
"""
|
||||
function apply_all!(graph::DAG)
|
||||
while !isempty(graph.operationsToApply)
|
||||
# get next operation to apply from front of the deque
|
||||
op = popfirst!(graph.operationsToApply)
|
||||
|
||||
# apply it
|
||||
appliedOp = apply_operation!(graph, op)
|
||||
|
||||
# push to the end of the appliedOperations deque
|
||||
push!(graph.appliedOperations, appliedOp)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
apply_operation!(graph::DAG, operation::Operation)
|
||||
|
||||
Fallback implementation of apply_operation! for unimplemented operation types, throwing an error.
|
||||
"""
|
||||
function apply_operation!(graph::DAG, operation::Operation)
|
||||
return error("Unknown operation type!")
|
||||
end
|
||||
|
||||
"""
|
||||
apply_operation!(graph::DAG, operation::NodeFusion)
|
||||
|
||||
Apply the given [`NodeFusion`](@ref) to the graph. Generic wrapper around [`node_fusion!`](@ref).
|
||||
|
||||
Return an [`AppliedNodeFusion`](@ref) object generated from the graph's [`Diff`](@ref).
|
||||
"""
|
||||
function apply_operation!(graph::DAG, operation::NodeFusion)
|
||||
diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3])
|
||||
|
||||
graph.properties += GraphProperties(diff)
|
||||
|
||||
return AppliedNodeFusion(operation, diff)
|
||||
end
|
||||
|
||||
"""
|
||||
apply_operation!(graph::DAG, operation::NodeReduction)
|
||||
|
||||
Apply the given [`NodeReduction`](@ref) to the graph. Generic wrapper around [`node_reduction!`](@ref).
|
||||
|
||||
Return an [`AppliedNodeReduction`](@ref) object generated from the graph's [`Diff`](@ref).
|
||||
"""
|
||||
function apply_operation!(graph::DAG, operation::NodeReduction)
|
||||
diff = node_reduction!(graph, operation.input)
|
||||
|
||||
graph.properties += GraphProperties(diff)
|
||||
|
||||
return AppliedNodeReduction(operation, diff)
|
||||
end
|
||||
|
||||
"""
|
||||
apply_operation!(graph::DAG, operation::NodeSplit)
|
||||
|
||||
Apply the given [`NodeSplit`](@ref) to the graph. Generic wrapper around [`node_split!`](@ref).
|
||||
|
||||
Return an [`AppliedNodeSplit`](@ref) object generated from the graph's [`Diff`](@ref).
|
||||
"""
|
||||
function apply_operation!(graph::DAG, operation::NodeSplit)
|
||||
diff = node_split!(graph, operation.input)
|
||||
|
||||
graph.properties += GraphProperties(diff)
|
||||
|
||||
return AppliedNodeSplit(operation, diff)
|
||||
end
|
||||
|
||||
"""
|
||||
revert_operation!(graph::DAG, operation::AppliedOperation)
|
||||
|
||||
Fallback implementation of operation reversion for unimplemented operation types, throwing an error.
|
||||
"""
|
||||
function revert_operation!(graph::DAG, operation::AppliedOperation)
|
||||
return error("Unknown operation type!")
|
||||
end
|
||||
|
||||
"""
|
||||
revert_operation!(graph::DAG, operation::AppliedNodeFusion)
|
||||
|
||||
Revert the applied node fusion on the graph. Return the original [`NodeFusion`](@ref) operation.
|
||||
"""
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeFusion)
|
||||
revert_diff!(graph, operation.diff)
|
||||
return operation.operation
|
||||
end
|
||||
|
||||
"""
|
||||
revert_operation!(graph::DAG, operation::AppliedNodeReduction)
|
||||
|
||||
Revert the applied node fusion on the graph. Return the original [`NodeReduction`](@ref) operation.
|
||||
"""
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeReduction)
|
||||
revert_diff!(graph, operation.diff)
|
||||
return operation.operation
|
||||
end
|
||||
|
||||
"""
|
||||
revert_operation!(graph::DAG, operation::AppliedNodeSplit)
|
||||
|
||||
Revert the applied node fusion on the graph. Return the original [`NodeSplit`](@ref) operation.
|
||||
"""
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeSplit)
|
||||
revert_diff!(graph, operation.diff)
|
||||
return operation.operation
|
||||
end
|
||||
|
||||
"""
|
||||
revert_diff!(graph::DAG, diff::Diff)
|
||||
|
||||
Revert the given diff on the graph. Used to revert the individual [`AppliedOperation`](@ref)s with [`revert_operation!`](@ref).
|
||||
"""
|
||||
function revert_diff!(graph::DAG, diff::Diff)
|
||||
# add removed nodes, remove added nodes, same for edges
|
||||
# note the order
|
||||
for edge in diff.addedEdges
|
||||
remove_edge!(graph, edge.edge[1], edge.edge[2], track = false)
|
||||
end
|
||||
for node in diff.addedNodes
|
||||
remove_node!(graph, node, track = false)
|
||||
end
|
||||
|
||||
for node in diff.removedNodes
|
||||
insert_node!(graph, node, track = false)
|
||||
end
|
||||
for edge in diff.removedEdges
|
||||
insert_edge!(graph, edge.edge[1], edge.edge[2], track = false)
|
||||
end
|
||||
|
||||
for (node, task) in diff.updatedChildren
|
||||
# node must be fused compute task at this point
|
||||
@assert typeof(node.task) <: FusedComputeTask
|
||||
|
||||
node.task = task
|
||||
end
|
||||
|
||||
graph.properties -= GraphProperties(diff)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
|
||||
Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph.
|
||||
|
||||
For details see [`NodeFusion`](@ref).
|
||||
"""
|
||||
function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
@assert is_valid_node_fusion_input(graph, n1, n2, n3)
|
||||
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
# save children and parents
|
||||
n1Children = children(n1)
|
||||
n3Parents = parents(n3)
|
||||
|
||||
n1Task = copy(n1.task)
|
||||
n3Task = copy(n3.task)
|
||||
|
||||
# assemble the input node vectors of n1 and n3 to save into the FusedComputeTask
|
||||
n1Inputs = Vector{Symbol}()
|
||||
for child in n1Children
|
||||
push!(n1Inputs, Symbol(to_var_name(child.id)))
|
||||
end
|
||||
|
||||
# remove the edges and nodes that will be replaced by the fused node
|
||||
remove_edge!(graph, n1, n2)
|
||||
remove_edge!(graph, n2, n3)
|
||||
remove_node!(graph, n1)
|
||||
remove_node!(graph, n2)
|
||||
|
||||
# get n3's children now so it automatically excludes n2
|
||||
n3Children = children(n3)
|
||||
|
||||
n3Inputs = Vector{Symbol}()
|
||||
for child in n3Children
|
||||
push!(n3Inputs, Symbol(to_var_name(child.id)))
|
||||
end
|
||||
|
||||
remove_node!(graph, n3)
|
||||
|
||||
# create new node with the fused compute task
|
||||
newNode = ComputeTaskNode(FusedComputeTask(n1Task, n3Task, n1Inputs, Symbol(to_var_name(n2.id)), n3Inputs))
|
||||
insert_node!(graph, newNode)
|
||||
|
||||
for child in n1Children
|
||||
remove_edge!(graph, child, n1)
|
||||
insert_edge!(graph, child, newNode)
|
||||
end
|
||||
|
||||
for child in n3Children
|
||||
remove_edge!(graph, child, n3)
|
||||
if !(child in n1Children)
|
||||
insert_edge!(graph, child, newNode)
|
||||
end
|
||||
end
|
||||
|
||||
for parent in n3Parents
|
||||
remove_edge!(graph, n3, parent)
|
||||
insert_edge!(graph, newNode, parent)
|
||||
|
||||
# important! update the parent node's child names in case they are fused compute tasks
|
||||
# needed for compute generation so the fused compute task can correctly match inputs to its component tasks
|
||||
update_child!(graph, parent, Symbol(to_var_name(n3.id)), Symbol(to_var_name(newNode.id)))
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
end
|
||||
|
||||
"""
|
||||
node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
|
||||
Reduce the given nodes together into one node, return the applied difference to the graph.
|
||||
|
||||
For details see [`NodeReduction`](@ref).
|
||||
"""
|
||||
function node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
@assert is_valid_node_reduction_input(graph, nodes)
|
||||
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
n1 = nodes[1]
|
||||
n1Children = children(n1)
|
||||
|
||||
n1Parents = Set(n1.parents)
|
||||
|
||||
# set of the new parents of n1
|
||||
newParents = Set{Node}()
|
||||
|
||||
# names of the previous children that n1 now replaces per parent
|
||||
newParentsChildNames = Dict{Node, Symbol}()
|
||||
|
||||
# remove all of the nodes' parents and children and the nodes themselves (except for first node)
|
||||
for i in 2:length(nodes)
|
||||
n = nodes[i]
|
||||
for child in n1Children
|
||||
remove_edge!(graph, child, n)
|
||||
end
|
||||
|
||||
for parent in parents(n)
|
||||
remove_edge!(graph, n, parent)
|
||||
|
||||
# collect all parents
|
||||
push!(newParents, parent)
|
||||
newParentsChildNames[parent] = Symbol(to_var_name(n.id))
|
||||
end
|
||||
|
||||
remove_node!(graph, n)
|
||||
end
|
||||
|
||||
for parent in newParents
|
||||
# now add parents of all input nodes to n1 without duplicates
|
||||
if !(parent in n1Parents)
|
||||
# don't double insert edges
|
||||
insert_edge!(graph, n1, parent)
|
||||
end
|
||||
|
||||
# this has to be done for all parents, even the ones of n1 because they can be duplicate
|
||||
prevChild = newParentsChildNames[parent]
|
||||
update_child!(graph, parent, prevChild, Symbol(to_var_name(n1.id)))
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
end
|
||||
|
||||
"""
|
||||
node_split!(graph::DAG, n1::Node)
|
||||
|
||||
Split the given node into one node per parent, return the applied difference to the graph.
|
||||
|
||||
For details see [`NodeSplit`](@ref).
|
||||
"""
|
||||
function node_split!(graph::DAG, n1::Node)
|
||||
@assert is_valid_node_split_input(graph, n1)
|
||||
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
n1Parents = parents(n1)
|
||||
n1Children = children(n1)
|
||||
|
||||
for parent in n1Parents
|
||||
remove_edge!(graph, n1, parent)
|
||||
end
|
||||
for child in n1Children
|
||||
remove_edge!(graph, child, n1)
|
||||
end
|
||||
remove_node!(graph, n1)
|
||||
|
||||
for parent in n1Parents
|
||||
nCopy = copy(n1)
|
||||
|
||||
insert_node!(graph, nCopy)
|
||||
insert_edge!(graph, nCopy, parent)
|
||||
|
||||
for child in n1Children
|
||||
insert_edge!(graph, child, nCopy)
|
||||
end
|
||||
|
||||
update_child!(graph, parent, Symbol(to_var_name(n1.id)), Symbol(to_var_name(nCopy.id)))
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
end
|
134
src/operation/clean.jl
Normal file
134
src/operation/clean.jl
Normal file
@ -0,0 +1,134 @@
|
||||
# These are functions for "cleaning" nodes, i.e. regenerating the possible operations for a node
|
||||
|
||||
"""
|
||||
find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
|
||||
Find node fusions involving the given data node. The function pushes the found [`NodeFusion`](@ref) (if any) everywhere it needs to be and returns nothing.
|
||||
|
||||
Does nothing if the node already has a node fusion set. Since it's a data node, only one node fusion can be possible with it.
|
||||
"""
|
||||
function find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
# if there is already a fusion here, skip to avoid duplicates
|
||||
if !ismissing(node.nodeFusion)
|
||||
return nothing
|
||||
end
|
||||
|
||||
if length(node.parents) != 1 || length(node.children) != 1
|
||||
return nothing
|
||||
end
|
||||
|
||||
child_node = first(node.children)
|
||||
parent_node = first(node.parents)
|
||||
|
||||
if !(child_node in graph) || !(parent_node in graph)
|
||||
error("Parents/Children that are not in the graph!!!")
|
||||
end
|
||||
|
||||
if length(child_node.parents) != 1
|
||||
return nothing
|
||||
end
|
||||
|
||||
nf = NodeFusion((child_node, node, parent_node))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(child_node.nodeFusions, nf)
|
||||
node.nodeFusion = nf
|
||||
push!(parent_node.nodeFusions, nf)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
find_fusions!(graph::DAG, node::ComputeTaskNode)
|
||||
|
||||
Find node fusions involving the given compute node. The function pushes the found [`NodeFusion`](@ref)s (if any) everywhere they need to be and returns nothing.
|
||||
"""
|
||||
function find_fusions!(graph::DAG, node::ComputeTaskNode)
|
||||
# just find fusions in neighbouring DataTaskNodes
|
||||
for child in node.children
|
||||
find_fusions!(graph, child)
|
||||
end
|
||||
|
||||
for parent in node.parents
|
||||
find_fusions!(graph, parent)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
find_reductions!(graph::DAG, node::Node)
|
||||
|
||||
Find node reductions involving the given node. The function pushes the found [`NodeReduction`](@ref) (if any) everywhere it needs to be and returns nothing.
|
||||
"""
|
||||
function find_reductions!(graph::DAG, node::Node)
|
||||
# there can only be one reduction per node, avoid adding duplicates
|
||||
if !ismissing(node.nodeReduction)
|
||||
return nothing
|
||||
end
|
||||
|
||||
reductionVector = nothing
|
||||
# possible reductions are with nodes that are partners, i.e. parents of children
|
||||
partners_ = partners(node)
|
||||
delete!(partners_, node)
|
||||
for partner in partners_
|
||||
@assert partner in graph.nodes
|
||||
if can_reduce(node, partner)
|
||||
if reductionVector === nothing
|
||||
# only when there's at least one reduction partner, insert the vector
|
||||
reductionVector = Vector{Node}()
|
||||
push!(reductionVector, node)
|
||||
end
|
||||
|
||||
push!(reductionVector, partner)
|
||||
end
|
||||
end
|
||||
|
||||
if reductionVector !== nothing
|
||||
nr = NodeReduction(reductionVector)
|
||||
push!(graph.possibleOperations.nodeReductions, nr)
|
||||
for node in reductionVector
|
||||
if !ismissing(node.nodeReduction)
|
||||
# it can happen that the dirty node becomes part of an existing NodeReduction and overrides those ones now
|
||||
# this is only a problem insofar the existing NodeReduction has to be deleted and replaced also in the possibleOperations
|
||||
invalidate_caches!(graph, node.nodeReduction)
|
||||
end
|
||||
node.nodeReduction = nr
|
||||
end
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
find_splits!(graph::DAG, node::Node)
|
||||
|
||||
Find the node split of the given node. The function pushes the found [`NodeSplit`](@ref) (if any) everywhere it needs to be and returns nothing.
|
||||
"""
|
||||
function find_splits!(graph::DAG, node::Node)
|
||||
if !ismissing(node.nodeSplit)
|
||||
return nothing
|
||||
end
|
||||
|
||||
if (can_split(node))
|
||||
ns = NodeSplit(node)
|
||||
push!(graph.possibleOperations.nodeSplits, ns)
|
||||
node.nodeSplit = ns
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
clean_node!(graph::DAG, node::Node)
|
||||
|
||||
Sort this node's parent and child sets, then find fusions, reductions and splits involving it. Needs to be called after the node was changed in some way.
|
||||
"""
|
||||
function clean_node!(graph::DAG, node::Node)
|
||||
sort_node!(node)
|
||||
|
||||
find_fusions!(graph, node)
|
||||
find_reductions!(graph, node)
|
||||
find_splits!(graph, node)
|
||||
|
||||
return nothing
|
||||
end
|
247
src/operation/find.jl
Normal file
247
src/operation/find.jl
Normal file
@ -0,0 +1,247 @@
|
||||
# functions that find operations on the inital graph
|
||||
|
||||
using Base.Threads
|
||||
|
||||
"""
|
||||
insert_operation!(nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock})
|
||||
|
||||
Insert the given node fusion into its input nodes' operation caches. For the compute nodes, locking via the given `locks` is employed to have safe multi-threading. For a large set of nodes, contention on the locks should be very small.
|
||||
"""
|
||||
function insert_operation!(nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock})
|
||||
n1 = nf.input[1]
|
||||
n2 = nf.input[2]
|
||||
n3 = nf.input[3]
|
||||
|
||||
lock(locks[n1]) do
|
||||
return push!(nf.input[1].nodeFusions, nf)
|
||||
end
|
||||
n2.nodeFusion = nf
|
||||
lock(locks[n3]) do
|
||||
return push!(nf.input[3].nodeFusions, nf)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
insert_operation!(nf::NodeReduction)
|
||||
|
||||
Insert the given node reduction into its input nodes' operation caches. This is thread-safe.
|
||||
"""
|
||||
function insert_operation!(nr::NodeReduction)
|
||||
for n in nr.input
|
||||
n.nodeReduction = nr
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
insert_operation!(nf::NodeSplit)
|
||||
|
||||
Insert the given node split into its input node's operation cache. This is thread-safe.
|
||||
"""
|
||||
function insert_operation!(ns::NodeSplit)
|
||||
ns.input.nodeSplit = ns
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}})
|
||||
|
||||
Insert the node reductions into the graph and the nodes' caches. Employs multithreading for speedup.
|
||||
"""
|
||||
function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}})
|
||||
total_len = 0
|
||||
for vec in nodeReductions
|
||||
total_len += length(vec)
|
||||
end
|
||||
sizehint!(operations.nodeReductions, total_len)
|
||||
|
||||
t = @task for vec in nodeReductions
|
||||
union!(operations.nodeReductions, Set(vec))
|
||||
end
|
||||
schedule(t)
|
||||
|
||||
@threads for vec in nodeReductions
|
||||
for op in vec
|
||||
insert_operation!(op)
|
||||
end
|
||||
end
|
||||
|
||||
wait(t)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
nf_insertion!(graph::DAG, operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}})
|
||||
|
||||
Insert the node fusions into the graph and the nodes' caches. Employs multithreading for speedup.
|
||||
"""
|
||||
function nf_insertion!(graph::DAG, operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}})
|
||||
total_len = 0
|
||||
for vec in nodeFusions
|
||||
total_len += length(vec)
|
||||
end
|
||||
sizehint!(operations.nodeFusions, total_len)
|
||||
|
||||
t = @task for vec in nodeFusions
|
||||
union!(operations.nodeFusions, Set(vec))
|
||||
end
|
||||
schedule(t)
|
||||
|
||||
locks = Dict{ComputeTaskNode, SpinLock}()
|
||||
for n in graph.nodes
|
||||
if (typeof(n) <: ComputeTaskNode)
|
||||
locks[n] = SpinLock()
|
||||
end
|
||||
end
|
||||
|
||||
@threads for vec in nodeFusions
|
||||
for op in vec
|
||||
insert_operation!(op, locks)
|
||||
end
|
||||
end
|
||||
|
||||
wait(t)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplits}})
|
||||
|
||||
Insert the node splits into the graph and the nodes' caches. Employs multithreading for speedup.
|
||||
"""
|
||||
function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}})
|
||||
total_len = 0
|
||||
for vec in nodeSplits
|
||||
total_len += length(vec)
|
||||
end
|
||||
sizehint!(operations.nodeSplits, total_len)
|
||||
|
||||
t = @task for vec in nodeSplits
|
||||
union!(operations.nodeSplits, Set(vec))
|
||||
end
|
||||
schedule(t)
|
||||
|
||||
@threads for vec in nodeSplits
|
||||
for op in vec
|
||||
insert_operation!(op)
|
||||
end
|
||||
end
|
||||
|
||||
wait(t)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
generate_operations(graph::DAG)
|
||||
|
||||
Generate all possible operations on the graph. Used initially when the graph is freshly assembled or parsed. Uses multithreading for speedup.
|
||||
|
||||
Safely inserts all the found operations into the graph and its nodes.
|
||||
"""
|
||||
function generate_operations(graph::DAG)
|
||||
generatedFusions = [Vector{NodeFusion}() for _ in 1:nthreads()]
|
||||
generatedReductions = [Vector{NodeReduction}() for _ in 1:nthreads()]
|
||||
generatedSplits = [Vector{NodeSplit}() for _ in 1:nthreads()]
|
||||
|
||||
# make sure the graph is fully generated through
|
||||
apply_all!(graph)
|
||||
|
||||
nodeArray = collect(graph.nodes)
|
||||
|
||||
# sort all nodes
|
||||
@threads for node in nodeArray
|
||||
sort_node!(node)
|
||||
end
|
||||
|
||||
checkedNodes = Set{Node}()
|
||||
checkedNodesLock = SpinLock()
|
||||
# --- find possible node reductions ---
|
||||
@threads for node in nodeArray
|
||||
# we're looking for nodes with multiple parents, those parents can then potentially reduce with one another
|
||||
if (length(node.parents) <= 1)
|
||||
continue
|
||||
end
|
||||
|
||||
candidates = node.parents
|
||||
|
||||
# sort into equivalence classes
|
||||
trie = NodeTrie()
|
||||
|
||||
for candidate in candidates
|
||||
# insert into trie
|
||||
insert!(trie, candidate)
|
||||
end
|
||||
|
||||
nodeReductions = collect(trie)
|
||||
|
||||
for nrVec in nodeReductions
|
||||
# parent sets are ordered and any node can only be part of one nodeReduction, so a NodeReduction is uniquely identifiable by its first element
|
||||
# this prevents duplicate nodeReductions being generated
|
||||
lock(checkedNodesLock)
|
||||
if (nrVec[1] in checkedNodes)
|
||||
unlock(checkedNodesLock)
|
||||
continue
|
||||
else
|
||||
push!(checkedNodes, nrVec[1])
|
||||
end
|
||||
unlock(checkedNodesLock)
|
||||
|
||||
push!(generatedReductions[threadid()], NodeReduction(nrVec))
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
# launch thread for node reduction insertion
|
||||
# remove duplicates
|
||||
nr_task = @task nr_insertion!(graph.possibleOperations, generatedReductions)
|
||||
schedule(nr_task)
|
||||
|
||||
# --- find possible node fusions ---
|
||||
@threads for node in nodeArray
|
||||
if (typeof(node) <: DataTaskNode)
|
||||
if length(node.parents) != 1
|
||||
# data node can only have a single parent
|
||||
continue
|
||||
end
|
||||
parent_node = first(node.parents)
|
||||
|
||||
if length(node.children) != 1
|
||||
# this node is an entry node or has multiple children which should not be possible
|
||||
continue
|
||||
end
|
||||
child_node = first(node.children)
|
||||
if (length(child_node.parents) != 1)
|
||||
continue
|
||||
end
|
||||
|
||||
push!(generatedFusions[threadid()], NodeFusion((child_node, node, parent_node)))
|
||||
end
|
||||
end
|
||||
|
||||
# launch thread for node fusion insertion
|
||||
nf_task = @task nf_insertion!(graph, graph.possibleOperations, generatedFusions)
|
||||
schedule(nf_task)
|
||||
|
||||
# find possible node splits
|
||||
@threads for node in nodeArray
|
||||
if (can_split(node))
|
||||
push!(generatedSplits[threadid()], NodeSplit(node))
|
||||
end
|
||||
end
|
||||
|
||||
# launch thread for node split insertion
|
||||
ns_task = @task ns_insertion!(graph.possibleOperations, generatedSplits)
|
||||
schedule(ns_task)
|
||||
|
||||
empty!(graph.dirtyNodes)
|
||||
|
||||
wait(nr_task)
|
||||
wait(nf_task)
|
||||
wait(ns_task)
|
||||
|
||||
return nothing
|
||||
end
|
23
src/operation/get.jl
Normal file
23
src/operation/get.jl
Normal file
@ -0,0 +1,23 @@
|
||||
# function to return the possible operations of a graph
|
||||
|
||||
using Base.Threads
|
||||
|
||||
"""
|
||||
get_operations(graph::DAG)
|
||||
|
||||
Return the [`PossibleOperations`](@ref) of the graph at the current state.
|
||||
"""
|
||||
function get_operations(graph::DAG)
|
||||
apply_all!(graph)
|
||||
|
||||
if isempty(graph.possibleOperations)
|
||||
generate_operations(graph)
|
||||
end
|
||||
|
||||
for node in graph.dirtyNodes
|
||||
clean_node!(graph, node)
|
||||
end
|
||||
empty!(graph.dirtyNodes)
|
||||
|
||||
return graph.possibleOperations
|
||||
end
|
@ -1,3 +1,8 @@
|
||||
"""
|
||||
show(io::IO, ops::PossibleOperations)
|
||||
|
||||
Print a string representation of the set of possible operations to io.
|
||||
"""
|
||||
function show(io::IO, ops::PossibleOperations)
|
||||
print(io, length(ops.nodeFusions))
|
||||
println(io, " Node Fusions: ")
|
||||
@ -16,23 +21,38 @@ function show(io::IO, ops::PossibleOperations)
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
show(io::IO, op::NodeReduction)
|
||||
|
||||
Print a string representation of the node reduction to io.
|
||||
"""
|
||||
function show(io::IO, op::NodeReduction)
|
||||
print(io, "NR: ")
|
||||
print(io, length(op.input))
|
||||
print(io, "x")
|
||||
print(io, op.input[1].task)
|
||||
return print(io, op.input[1].task)
|
||||
end
|
||||
|
||||
"""
|
||||
show(io::IO, op::NodeSplit)
|
||||
|
||||
Print a string representation of the node split to io.
|
||||
"""
|
||||
function show(io::IO, op::NodeSplit)
|
||||
print(io, "NS: ")
|
||||
print(io, op.input.task)
|
||||
return print(io, op.input.task)
|
||||
end
|
||||
|
||||
"""
|
||||
show(io::IO, op::NodeFusion)
|
||||
|
||||
Print a string representation of the node fusion to io.
|
||||
"""
|
||||
function show(io::IO, op::NodeFusion)
|
||||
print(io, "NF: ")
|
||||
print(io, op.input[1].task)
|
||||
print(io, "->")
|
||||
print(io, op.input[2].task)
|
||||
print(io, "->")
|
||||
print(io, op.input[3].task)
|
||||
return print(io, op.input[3].task)
|
||||
end
|
117
src/operation/type.jl
Normal file
117
src/operation/type.jl
Normal file
@ -0,0 +1,117 @@
|
||||
"""
|
||||
Operation
|
||||
|
||||
An abstract base class for operations. An operation can be applied to a [`DAG`](@ref), changing its nodes and edges.
|
||||
|
||||
Possible operations on a [`DAG`](@ref) can be retrieved using [`get_operations`](@ref).
|
||||
|
||||
See also: [`push_operation!`](@ref), [`pop_operation!`](@ref)
|
||||
"""
|
||||
abstract type Operation end
|
||||
|
||||
"""
|
||||
AppliedOperation
|
||||
|
||||
An abstract base class for already applied operations.
|
||||
An applied operation can be reversed iff it is the last applied operation on the DAG.
|
||||
Every applied operation stores a [`Diff`](@ref) from when it was initially applied to be able to revert the operation.
|
||||
|
||||
See also: [`revert_operation!`](@ref).
|
||||
"""
|
||||
abstract type AppliedOperation end
|
||||
|
||||
"""
|
||||
NodeFusion <: Operation
|
||||
|
||||
The NodeFusion operation. Represents the fusing of a chain of compute node -> data node -> compute node.
|
||||
|
||||
After the node fusion is applied, the graph has 2 fewer nodes and edges, and a new [`FusedComputeTask`](@ref) with the two input compute nodes as parts.
|
||||
|
||||
# Requirements for successful application
|
||||
|
||||
A chain of (n1, n2, n3) can be fused if:
|
||||
- All nodes are in the graph.
|
||||
- (n1, n2) is an edge in the graph.
|
||||
- (n2, n3) is an edge in the graph.
|
||||
- n2 has exactly one parent (n3) and exactly one child (n1).
|
||||
- n1 has exactly one parent (n2).
|
||||
|
||||
[`is_valid_node_fusion_input`](@ref) can be used to `@assert` these requirements.
|
||||
|
||||
See also: [`can_fuse`](@ref)
|
||||
"""
|
||||
struct NodeFusion <: Operation
|
||||
input::Tuple{ComputeTaskNode, DataTaskNode, ComputeTaskNode}
|
||||
end
|
||||
|
||||
"""
|
||||
AppliedNodeFusion <: AppliedOperation
|
||||
|
||||
The applied version of the [`NodeFusion`](@ref).
|
||||
"""
|
||||
struct AppliedNodeFusion <: AppliedOperation
|
||||
operation::NodeFusion
|
||||
diff::Diff
|
||||
end
|
||||
|
||||
"""
|
||||
NodeReduction <: Operation
|
||||
|
||||
The NodeReduction operation. Represents the reduction of two or more nodes with one another.
|
||||
Only one of the input nodes is kept, while all others are deleted and their parents are accumulated in the kept node's parents instead.
|
||||
|
||||
After the node reduction is applied, the graph has `length(nr.input) - 1` fewer nodes.
|
||||
|
||||
# Requirements for successful application
|
||||
|
||||
A vector of nodes can be reduced if:
|
||||
- All nodes are in the graph.
|
||||
- All nodes have the same task type.
|
||||
- All nodes have the same set of children.
|
||||
|
||||
[`is_valid_node_reduction_input`](@ref) can be used to `@assert` these requirements.
|
||||
|
||||
See also: [`can_reduce`](@ref)
|
||||
"""
|
||||
struct NodeReduction <: Operation
|
||||
input::Vector{Node}
|
||||
end
|
||||
|
||||
"""
|
||||
AppliedNodeReduction <: AppliedOperation
|
||||
|
||||
The applied version of the [`NodeReduction`](@ref).
|
||||
"""
|
||||
struct AppliedNodeReduction <: AppliedOperation
|
||||
operation::NodeReduction
|
||||
diff::Diff
|
||||
end
|
||||
|
||||
"""
|
||||
NodeSplit <: Operation
|
||||
|
||||
The NodeSplit operation. Represents the split of its input node into one node for each of its parents. It is the reverse operation to the [`NodeReduction`](@ref).
|
||||
|
||||
# Requirements for successful application
|
||||
|
||||
A node can be split if:
|
||||
- It is in the graph.
|
||||
- It has at least 2 parents.
|
||||
|
||||
[`is_valid_node_split_input`](@ref) can be used to `@assert` these requirements.
|
||||
|
||||
See also: [`can_split`](@ref)
|
||||
"""
|
||||
struct NodeSplit <: Operation
|
||||
input::Node
|
||||
end
|
||||
|
||||
"""
|
||||
AppliedNodeSplit <: AppliedOperation
|
||||
|
||||
The applied version of the [`NodeSplit`](@ref).
|
||||
"""
|
||||
struct AppliedNodeSplit <: AppliedOperation
|
||||
operation::NodeSplit
|
||||
diff::Diff
|
||||
end
|
163
src/operation/utility.jl
Normal file
163
src/operation/utility.jl
Normal file
@ -0,0 +1,163 @@
|
||||
"""
|
||||
isempty(operations::PossibleOperations)
|
||||
|
||||
Return whether `operations` is empty, i.e. all of its fields are empty.
|
||||
"""
|
||||
function isempty(operations::PossibleOperations)
|
||||
return isempty(operations.nodeFusions) && isempty(operations.nodeReductions) && isempty(operations.nodeSplits)
|
||||
end
|
||||
|
||||
"""
|
||||
length(operations::PossibleOperations)
|
||||
|
||||
Return a named tuple with the number of each of the operation types as a named tuple. The fields are named the same as the [`PossibleOperations`](@ref)'.
|
||||
"""
|
||||
function length(operations::PossibleOperations)
|
||||
return (
|
||||
nodeFusions = length(operations.nodeFusions),
|
||||
nodeReductions = length(operations.nodeReductions),
|
||||
nodeSplits = length(operations.nodeSplits),
|
||||
)
|
||||
end
|
||||
|
||||
"""
|
||||
delete!(operations::PossibleOperations, op::NodeFusion)
|
||||
|
||||
Delete the given node fusion from the possible operations.
|
||||
"""
|
||||
function delete!(operations::PossibleOperations, op::NodeFusion)
|
||||
delete!(operations.nodeFusions, op)
|
||||
return operations
|
||||
end
|
||||
|
||||
"""
|
||||
delete!(operations::PossibleOperations, op::NodeReduction)
|
||||
|
||||
Delete the given node reduction from the possible operations.
|
||||
"""
|
||||
function delete!(operations::PossibleOperations, op::NodeReduction)
|
||||
delete!(operations.nodeReductions, op)
|
||||
return operations
|
||||
end
|
||||
|
||||
"""
|
||||
delete!(operations::PossibleOperations, op::NodeSplit)
|
||||
|
||||
Delete the given node split from the possible operations.
|
||||
"""
|
||||
function delete!(operations::PossibleOperations, op::NodeSplit)
|
||||
delete!(operations.nodeSplits, op)
|
||||
return operations
|
||||
end
|
||||
|
||||
"""
|
||||
can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
|
||||
Return whether the given nodes can be fused. See [`NodeFusion`](@ref) for the requirements.
|
||||
"""
|
||||
function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
if !is_child(n1, n2) || !is_child(n2, n3)
|
||||
# the checks are redundant but maybe a good sanity check
|
||||
return false
|
||||
end
|
||||
|
||||
if length(n2.parents) != 1 || length(n2.children) != 1 || length(n1.parents) != 1
|
||||
return false
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
can_reduce(n1::Node, n2::Node)
|
||||
|
||||
Return whether the given two nodes can be reduced. See [`NodeReduction`](@ref) for the requirements.
|
||||
"""
|
||||
function can_reduce(n1::Node, n2::Node)
|
||||
if (n1.task != n2.task)
|
||||
return false
|
||||
end
|
||||
|
||||
n1_length = length(n1.children)
|
||||
n2_length = length(n2.children)
|
||||
|
||||
if (n1_length != n2_length)
|
||||
return false
|
||||
end
|
||||
|
||||
# this seems to be the most common case so do this first
|
||||
# doing it manually is a lot faster than using the sets for a general solution
|
||||
if (n1_length == 2)
|
||||
if (n1.children[1] != n2.children[1])
|
||||
if (n1.children[1] != n2.children[2])
|
||||
return false
|
||||
end
|
||||
# 1_1 == 2_2
|
||||
if (n1.children[2] != n2.children[1])
|
||||
return false
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
# 1_1 == 2_1
|
||||
if (n1.children[2] != n2.children[2])
|
||||
return false
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
# this is simple
|
||||
if (n1_length == 1)
|
||||
return n1.children[1] == n2.children[1]
|
||||
end
|
||||
|
||||
# this takes a long time
|
||||
return Set(n1.children) == Set(n2.children)
|
||||
end
|
||||
|
||||
"""
|
||||
can_split(n1::Node)
|
||||
|
||||
Return whether the given node can be split. See [`NodeSplit`](@ref) for the requirements.
|
||||
"""
|
||||
function can_split(n::Node)
|
||||
return length(parents(n)) > 1
|
||||
end
|
||||
|
||||
"""
|
||||
==(op1::Operation, op2::Operation)
|
||||
|
||||
Fallback implementation of operation equality. Return false. Actual comparisons are done by the overloads of same type operation comparisons.
|
||||
"""
|
||||
function ==(op1::Operation, op2::Operation)
|
||||
return false
|
||||
end
|
||||
|
||||
"""
|
||||
==(op1::NodeFusion, op2::NodeFusion)
|
||||
|
||||
Equality comparison between two node fusions. Two node fusions are considered equal if they have the same inputs.
|
||||
"""
|
||||
function ==(op1::NodeFusion, op2::NodeFusion)
|
||||
# there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same
|
||||
return op1.input[2] == op2.input[2]
|
||||
end
|
||||
|
||||
"""
|
||||
==(op1::NodeReduction, op2::NodeReduction)
|
||||
|
||||
Equality comparison between two node reductions. Two node reductions are considered equal when they have the same inputs.
|
||||
"""
|
||||
function ==(op1::NodeReduction, op2::NodeReduction)
|
||||
# node reductions are equal exactly if their first input is the same
|
||||
return op1.input[1].id == op2.input[1].id
|
||||
end
|
||||
|
||||
"""
|
||||
==(op1::NodeSplit, op2::NodeSplit)
|
||||
|
||||
Equality comparison between two node splits. Two node splits are considered equal if they have the same input node.
|
||||
"""
|
||||
function ==(op1::NodeSplit, op2::NodeSplit)
|
||||
return op1.input == op2.input
|
||||
end
|
146
src/operation/validate.jl
Normal file
146
src/operation/validate.jl
Normal file
@ -0,0 +1,146 @@
|
||||
# functions to throw assertion errors for inconsistent or wrong node operations
|
||||
# should be called with @assert
|
||||
# the functions throw their own errors though, to still have helpful error messages
|
||||
|
||||
"""
|
||||
is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
|
||||
Assert for a gven node fusion input whether the nodes can be fused. For the requirements of a node fusion see [`NodeFusion`](@ref).
|
||||
|
||||
Intended for use with `@assert` or `@test`.
|
||||
"""
|
||||
function is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
if !(n1 in graph) || !(n2 in graph) || !(n3 in graph)
|
||||
throw(AssertionError("[Node Fusion] The given nodes are not part of the given graph"))
|
||||
end
|
||||
|
||||
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
|
||||
throw(
|
||||
AssertionError(
|
||||
"[Node Fusion] The given nodes are not connected by edges which is required for node fusion",
|
||||
),
|
||||
)
|
||||
end
|
||||
|
||||
if length(n2.parents) > 1
|
||||
throw(AssertionError("[Node Fusion] The given data node has more than one parent"))
|
||||
end
|
||||
if length(n2.children) > 1
|
||||
throw(AssertionError("[Node Fusion] The given data node has more than one child"))
|
||||
end
|
||||
if length(n1.parents) > 1
|
||||
throw(AssertionError("[Node Fusion] The given n1 has more than one parent"))
|
||||
end
|
||||
|
||||
@assert is_valid(graph, n1)
|
||||
@assert is_valid(graph, n2)
|
||||
@assert is_valid(graph, n3)
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node})
|
||||
|
||||
Assert for a gven node reduction input whether the nodes can be reduced. For the requirements of a node reduction see [`NodeReduction`](@ref).
|
||||
|
||||
Intended for use with `@assert` or `@test`.
|
||||
"""
|
||||
function is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node})
|
||||
for n in nodes
|
||||
if n ∉ graph
|
||||
throw(AssertionError("[Node Reduction] The given nodes are not part of the given graph"))
|
||||
end
|
||||
@assert is_valid(graph, n)
|
||||
end
|
||||
|
||||
t = typeof(nodes[1].task)
|
||||
for n in nodes
|
||||
if typeof(n.task) != t
|
||||
throw(AssertionError("[Node Reduction] The given nodes are not of the same type"))
|
||||
end
|
||||
|
||||
if (typeof(n) <: DataTaskNode)
|
||||
if (n.name != nodes[1].name)
|
||||
throw(AssertionError("[Node Reduction] The given nodes do not have the same name"))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
n1_children = nodes[1].children
|
||||
for n in nodes
|
||||
if Set(n1_children) != Set(n.children)
|
||||
throw(
|
||||
AssertionError(
|
||||
"[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction",
|
||||
),
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
is_valid_node_split_input(graph::DAG, n1::Node)
|
||||
|
||||
Assert for a gven node split input whether the node can be split. For the requirements of a node split see [`NodeSplit`](@ref).
|
||||
|
||||
Intended for use with `@assert` or `@test`.
|
||||
"""
|
||||
function is_valid_node_split_input(graph::DAG, n1::Node)
|
||||
if n1 ∉ graph
|
||||
throw(AssertionError("[Node Split] The given node is not part of the given graph"))
|
||||
end
|
||||
|
||||
if length(n1.parents) <= 1
|
||||
throw(
|
||||
AssertionError(
|
||||
"[Node Split] The given node does not have multiple parents which is required for node split",
|
||||
),
|
||||
)
|
||||
end
|
||||
|
||||
@assert is_valid(graph, n1)
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
is_valid(graph::DAG, nr::NodeReduction)
|
||||
|
||||
Assert for a given [`NodeReduction`](@ref) whether it is a valid operation in the graph.
|
||||
|
||||
Intended for use with `@assert` or `@test`.
|
||||
"""
|
||||
function is_valid(graph::DAG, nr::NodeReduction)
|
||||
@assert is_valid_node_reduction_input(graph, nr.input)
|
||||
@assert nr in graph.possibleOperations.nodeReductions "NodeReduction is not part of the graph's possible operations!"
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
is_valid(graph::DAG, nr::NodeSplit)
|
||||
|
||||
Assert for a given [`NodeSplit`](@ref) whether it is a valid operation in the graph.
|
||||
|
||||
Intended for use with `@assert` or `@test`.
|
||||
"""
|
||||
function is_valid(graph::DAG, ns::NodeSplit)
|
||||
@assert is_valid_node_split_input(graph, ns.input)
|
||||
@assert ns in graph.possibleOperations.nodeSplits "NodeSplit is not part of the graph's possible operations!"
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
is_valid(graph::DAG, nr::NodeFusion)
|
||||
|
||||
Assert for a given [`NodeFusion`](@ref) whether it is a valid operation in the graph.
|
||||
|
||||
Intended for use with `@assert` or `@test`.
|
||||
"""
|
||||
function is_valid(graph::DAG, nf::NodeFusion)
|
||||
@assert is_valid_node_fusion_input(graph, nf.input[1], nf.input[2], nf.input[3])
|
||||
@assert nf in graph.possibleOperations.nodeFusions "NodeFusion is not part of the graph's possible operations!"
|
||||
return true
|
||||
end
|
@ -1,198 +0,0 @@
|
||||
# functions that apply graph operations
|
||||
|
||||
# applies all unapplied operations in the DAG
|
||||
function apply_all!(graph::DAG)
|
||||
while !isempty(graph.operationsToApply)
|
||||
# get next operation to apply from front of the deque
|
||||
op = popfirst!(graph.operationsToApply)
|
||||
|
||||
# apply it
|
||||
appliedOp = apply_operation!(graph, op)
|
||||
|
||||
# push to the end of the appliedOperations deque
|
||||
push!(graph.appliedOperations, appliedOp)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function apply_operation!(graph::DAG, operation::Operation)
|
||||
error("Unknown operation type!")
|
||||
end
|
||||
|
||||
function apply_operation!(graph::DAG, operation::NodeFusion)
|
||||
diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3])
|
||||
return AppliedNodeFusion(operation, diff)
|
||||
end
|
||||
|
||||
function apply_operation!(graph::DAG, operation::NodeReduction)
|
||||
diff = node_reduction!(graph, operation.input)
|
||||
return AppliedNodeReduction(operation, diff)
|
||||
end
|
||||
|
||||
function apply_operation!(graph::DAG, operation::NodeSplit)
|
||||
diff = node_split!(graph, operation.input)
|
||||
return AppliedNodeSplit(operation, diff)
|
||||
end
|
||||
|
||||
|
||||
function revert_operation!(graph::DAG, operation::AppliedOperation)
|
||||
error("Unknown operation type!")
|
||||
end
|
||||
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeFusion)
|
||||
revert_diff!(graph, operation.diff)
|
||||
return operation.operation
|
||||
end
|
||||
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeReduction)
|
||||
revert_diff!(graph, operation.diff)
|
||||
return operation.operation
|
||||
end
|
||||
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeSplit)
|
||||
revert_diff!(graph, operation.diff)
|
||||
return operation.operation
|
||||
end
|
||||
|
||||
|
||||
function revert_diff!(graph::DAG, diff::Diff)
|
||||
# add removed nodes, remove added nodes, same for edges
|
||||
# note the order
|
||||
for edge in diff.addedEdges
|
||||
remove_edge!(graph, edge.edge[1], edge.edge[2], false)
|
||||
end
|
||||
for node in diff.addedNodes
|
||||
remove_node!(graph, node, false)
|
||||
end
|
||||
|
||||
for node in diff.removedNodes
|
||||
insert_node!(graph, node, false)
|
||||
end
|
||||
for edge in diff.removedEdges
|
||||
insert_edge!(graph, edge.edge[1], edge.edge[2], false)
|
||||
end
|
||||
end
|
||||
|
||||
# Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph
|
||||
function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
# @assert is_valid_node_fusion_input(graph, n1, n2, n3)
|
||||
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
|
||||
# save children and parents
|
||||
n1_children = children(n1)
|
||||
n3_parents = parents(n3)
|
||||
n3_children = children(n3)
|
||||
|
||||
# remove the edges and nodes that will be replaced by the fused node
|
||||
remove_edge!(graph, n1, n2)
|
||||
remove_edge!(graph, n2, n3)
|
||||
remove_node!(graph, n1)
|
||||
remove_node!(graph, n2)
|
||||
|
||||
# get n3's children now so it automatically excludes n2
|
||||
n3_children = children(n3)
|
||||
remove_node!(graph, n3)
|
||||
|
||||
# create new node with the fused compute task
|
||||
new_node = ComputeTaskNode(FusedComputeTask{typeof(n1.task),typeof(n3.task)}())
|
||||
insert_node!(graph, new_node)
|
||||
|
||||
# use a set for combined children of n1 and n3 to not get duplicates
|
||||
n1and3_children = Set{Node}()
|
||||
|
||||
# remove edges from n1 children to n1
|
||||
for child in n1_children
|
||||
remove_edge!(graph, child, n1)
|
||||
push!(n1and3_children, child)
|
||||
end
|
||||
|
||||
# remove edges from n3 children to n3
|
||||
for child in n3_children
|
||||
remove_edge!(graph, child, n3)
|
||||
push!(n1and3_children, child)
|
||||
end
|
||||
|
||||
for child in n1and3_children
|
||||
insert_edge!(graph, child, new_node)
|
||||
end
|
||||
|
||||
# "repoint" parents of n3 from new node
|
||||
for parent in n3_parents
|
||||
remove_edge!(graph, n3, parent)
|
||||
insert_edge!(graph, new_node, parent)
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
end
|
||||
|
||||
function node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
# @assert is_valid_node_reduction_input(graph, nodes)
|
||||
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
n1 = nodes[1]
|
||||
n1_children = children(n1)
|
||||
|
||||
n1_parents = Set(n1.parents)
|
||||
new_parents = Set{Node}()
|
||||
|
||||
# remove all of the nodes' parents and children and the nodes themselves (except for first node)
|
||||
for i in 2:length(nodes)
|
||||
n = nodes[i]
|
||||
for child in n1_children
|
||||
remove_edge!(graph, child, n)
|
||||
end
|
||||
|
||||
for parent in parents(n)
|
||||
remove_edge!(graph, n, parent)
|
||||
|
||||
# collect all parents
|
||||
push!(new_parents, parent)
|
||||
end
|
||||
|
||||
remove_node!(graph, n)
|
||||
end
|
||||
|
||||
setdiff!(new_parents, n1_parents)
|
||||
|
||||
for parent in new_parents
|
||||
# now add parents of all input nodes to n1 without duplicates
|
||||
insert_edge!(graph, n1, parent)
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
end
|
||||
|
||||
function node_split!(graph::DAG, n1::Node)
|
||||
# @assert is_valid_node_split_input(graph, n1)
|
||||
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
n1_parents = parents(n1)
|
||||
n1_children = children(n1)
|
||||
|
||||
for parent in n1_parents
|
||||
remove_edge!(graph, n1, parent)
|
||||
end
|
||||
for child in n1_children
|
||||
remove_edge!(graph, child, n1)
|
||||
end
|
||||
remove_node!(graph, n1)
|
||||
|
||||
for parent in n1_parents
|
||||
n_copy = copy(n1)
|
||||
insert_node!(graph, n_copy)
|
||||
insert_edge!(graph, n_copy, parent)
|
||||
|
||||
for child in n1_children
|
||||
insert_edge!(graph, child, n_copy)
|
||||
end
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
end
|
@ -1,115 +0,0 @@
|
||||
# functions for "cleaning" nodes, i.e. regenerating the possible operations for a node
|
||||
|
||||
# function to find node fusions involving the given node if it's a data node
|
||||
# pushes the found fusion everywhere it needs to be and returns nothing
|
||||
function find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
# if there is already a fusion here, skip
|
||||
if !ismissing(node.nodeFusion)
|
||||
return nothing
|
||||
end
|
||||
|
||||
if length(node.parents) != 1 || length(node.children) != 1
|
||||
return nothing
|
||||
end
|
||||
|
||||
child_node = first(node.children)
|
||||
parent_node = first(node.parents)
|
||||
|
||||
if !(child_node in graph) || !(parent_node in graph)
|
||||
error("Parents/Children that are not in the graph!!!")
|
||||
end
|
||||
|
||||
if length(child_node.parents) != 1
|
||||
return nothing
|
||||
end
|
||||
|
||||
nf = NodeFusion((child_node, node, parent_node))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(child_node.nodeFusions, nf)
|
||||
node.nodeFusion = nf
|
||||
push!(parent_node.nodeFusions, nf)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
|
||||
function find_fusions!(graph::DAG, node::ComputeTaskNode)
|
||||
# just find fusions in neighbouring DataTaskNodes
|
||||
for child in node.children
|
||||
find_fusions!(graph, child)
|
||||
end
|
||||
|
||||
for parent in node.parents
|
||||
find_fusions!(graph, parent)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function find_reductions!(graph::DAG, node::Node)
|
||||
# there can only be one reduction per node, avoid adding duplicates
|
||||
if !ismissing(node.nodeReduction)
|
||||
return nothing
|
||||
end
|
||||
|
||||
reductionVector = nothing
|
||||
# possible reductions are with nodes that are partners, i.e. parents of children
|
||||
partners_ = partners(node)
|
||||
delete!(partners_, node)
|
||||
for partner in partners_
|
||||
if partner ∉ graph.nodes
|
||||
error("Partner is not part of the graph")
|
||||
end
|
||||
|
||||
if can_reduce(node, partner)
|
||||
if Set(node.children) != Set(partner.children)
|
||||
error("Not equal children")
|
||||
end
|
||||
if reductionVector === nothing
|
||||
# only when there's at least one reduction partner, insert the vector
|
||||
reductionVector = Vector{Node}()
|
||||
push!(reductionVector, node)
|
||||
end
|
||||
|
||||
push!(reductionVector, partner)
|
||||
end
|
||||
end
|
||||
|
||||
if reductionVector !== nothing
|
||||
nr = NodeReduction(reductionVector)
|
||||
push!(graph.possibleOperations.nodeReductions, nr)
|
||||
for node in reductionVector
|
||||
if !ismissing(node.nodeReduction)
|
||||
# it can happen that the dirty node becomes part of an existing NodeReduction and overrides those ones now
|
||||
# this is only a problem insofar the existing NodeReduction has to be deleted and replaced also in the possibleOperations
|
||||
invalidate_caches!(graph, node.nodeReduction)
|
||||
end
|
||||
node.nodeReduction = nr
|
||||
end
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function find_splits!(graph::DAG, node::Node)
|
||||
if !ismissing(node.nodeSplit)
|
||||
return nothing
|
||||
end
|
||||
|
||||
if (can_split(node))
|
||||
ns = NodeSplit(node)
|
||||
push!(graph.possibleOperations.nodeSplits, ns)
|
||||
node.nodeSplit = ns
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# "clean" the operations on a dirty node
|
||||
function clean_node!(graph::DAG, node::Node)
|
||||
sort_node!(node)
|
||||
|
||||
find_fusions!(graph, node)
|
||||
find_reductions!(graph, node)
|
||||
find_splits!(graph, node)
|
||||
end
|
@ -1,205 +0,0 @@
|
||||
# functions that find operations on the inital graph
|
||||
|
||||
using Base.Threads
|
||||
|
||||
function insert_operation!(nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock})
|
||||
n1 = nf.input[1]; n2 = nf.input[2]; n3 = nf.input[3]
|
||||
|
||||
lock(locks[n1]) do; push!(nf.input[1].nodeFusions, nf); end
|
||||
nf.input[2].nodeFusion = nf
|
||||
lock(locks[n3]) do; push!(nf.input[3].nodeFusions, nf); end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function insert_operation!(nr::NodeReduction)
|
||||
for n in nr.input
|
||||
n.nodeReduction = nr
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function insert_operation!(ns::NodeSplit)
|
||||
ns.input.nodeSplit = ns
|
||||
return nothing
|
||||
end
|
||||
|
||||
function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}})
|
||||
total_len = 0
|
||||
for vec in nodeReductions
|
||||
total_len += length(vec)
|
||||
end
|
||||
sizehint!(operations.nodeReductions, total_len)
|
||||
|
||||
t = @task for vec in nodeReductions
|
||||
union!(operations.nodeReductions, Set(vec))
|
||||
end
|
||||
schedule(t)
|
||||
|
||||
@threads for vec in nodeReductions
|
||||
for op in vec
|
||||
insert_operation!(op)
|
||||
end
|
||||
end
|
||||
|
||||
wait(t)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function nf_insertion!(graph::DAG, operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}})
|
||||
total_len = 0
|
||||
for vec in nodeFusions
|
||||
total_len += length(vec)
|
||||
end
|
||||
sizehint!(operations.nodeFusions, total_len)
|
||||
|
||||
t = @task for vec in nodeFusions
|
||||
union!(operations.nodeFusions, Set(vec))
|
||||
end
|
||||
schedule(t)
|
||||
|
||||
locks = Dict{ComputeTaskNode, SpinLock}()
|
||||
for n in graph.nodes
|
||||
if (typeof(n) <: ComputeTaskNode)
|
||||
locks[n] = SpinLock()
|
||||
end
|
||||
end
|
||||
|
||||
@threads for vec in nodeFusions
|
||||
for op in vec
|
||||
insert_operation!(op, locks)
|
||||
end
|
||||
end
|
||||
|
||||
wait(t)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}})
|
||||
total_len = 0
|
||||
for vec in nodeSplits
|
||||
total_len += length(vec)
|
||||
end
|
||||
sizehint!(operations.nodeSplits, total_len)
|
||||
|
||||
t = @task for vec in nodeSplits
|
||||
union!(operations.nodeSplits, Set(vec))
|
||||
end
|
||||
schedule(t)
|
||||
|
||||
@threads for vec in nodeSplits
|
||||
for op in vec
|
||||
insert_operation!(op)
|
||||
end
|
||||
end
|
||||
|
||||
wait(t)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# function to generate all possible operations on the graph
|
||||
function generate_options(graph::DAG)
|
||||
generatedFusions = [Vector{NodeFusion}() for _ in 1:nthreads()]
|
||||
generatedReductions = [Vector{NodeReduction}() for _ in 1:nthreads()]
|
||||
generatedSplits = [Vector{NodeSplit}() for _ in 1:nthreads()]
|
||||
|
||||
# make sure the graph is fully generated through
|
||||
apply_all!(graph)
|
||||
|
||||
nodeArray = collect(graph.nodes)
|
||||
|
||||
# sort all nodes
|
||||
@threads for node in nodeArray
|
||||
sort_node!(node)
|
||||
end
|
||||
|
||||
checkedNodes = Set{Node}()
|
||||
checkedNodesLock = SpinLock()
|
||||
# --- find possible node reductions ---
|
||||
@threads for node in nodeArray
|
||||
# we're looking for nodes with multiple parents, those parents can then potentially reduce with one another
|
||||
if (length(node.parents) <= 1)
|
||||
continue
|
||||
end
|
||||
|
||||
candidates = node.parents
|
||||
|
||||
# sort into equivalence classes
|
||||
trie = NodeTrie()
|
||||
|
||||
for candidate in candidates
|
||||
# insert into trie
|
||||
insert!(trie, candidate)
|
||||
end
|
||||
|
||||
nodeReductions = collect(trie)
|
||||
|
||||
for nrVec in nodeReductions
|
||||
# parent sets are ordered and any node can only be part of one nodeReduction, so a NodeReduction is uniquely identifiable by its first element
|
||||
# this prevents duplicate nodeReductions being generated
|
||||
lock(checkedNodesLock)
|
||||
if (nrVec[1] in checkedNodes)
|
||||
unlock(checkedNodesLock)
|
||||
continue
|
||||
else
|
||||
push!(checkedNodes, nrVec[1])
|
||||
end
|
||||
unlock(checkedNodesLock)
|
||||
|
||||
push!(generatedReductions[threadid()], NodeReduction(nrVec))
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
# launch thread for node reduction insertion
|
||||
# remove duplicates
|
||||
nr_task = @task nr_insertion!(graph.possibleOperations, generatedReductions)
|
||||
schedule(nr_task)
|
||||
|
||||
# --- find possible node fusions ---
|
||||
@threads for node in nodeArray
|
||||
if (typeof(node) <: DataTaskNode)
|
||||
if length(node.parents) != 1
|
||||
# data node can only have a single parent
|
||||
continue
|
||||
end
|
||||
parent_node = first(node.parents)
|
||||
|
||||
if length(node.children) != 1
|
||||
# this node is an entry node or has multiple children which should not be possible
|
||||
continue
|
||||
end
|
||||
child_node = first(node.children)
|
||||
if (length(child_node.parents) != 1)
|
||||
continue
|
||||
end
|
||||
|
||||
push!(generatedFusions[threadid()], NodeFusion((child_node, node, parent_node)))
|
||||
end
|
||||
end
|
||||
|
||||
# launch thread for node fusion insertion
|
||||
nf_task = @task nf_insertion!(graph, graph.possibleOperations, generatedFusions)
|
||||
schedule(nf_task)
|
||||
|
||||
# find possible node splits
|
||||
@threads for node in nodeArray
|
||||
if (can_split(node))
|
||||
push!(generatedSplits[threadid()], NodeSplit(node))
|
||||
end
|
||||
end
|
||||
|
||||
# launch thread for node split insertion
|
||||
ns_task = @task ns_insertion!(graph.possibleOperations, generatedSplits)
|
||||
schedule(ns_task)
|
||||
|
||||
empty!(graph.dirtyNodes)
|
||||
|
||||
wait(nr_task)
|
||||
wait(nf_task)
|
||||
wait(ns_task)
|
||||
|
||||
return nothing
|
||||
end
|
@ -1,18 +0,0 @@
|
||||
# function to return the possible operations of a graph
|
||||
|
||||
using Base.Threads
|
||||
|
||||
function get_operations(graph::DAG)
|
||||
apply_all!(graph)
|
||||
|
||||
if isempty(graph.possibleOperations)
|
||||
generate_options(graph)
|
||||
end
|
||||
|
||||
for node in graph.dirtyNodes
|
||||
clean_node!(graph, node)
|
||||
end
|
||||
empty!(graph.dirtyNodes)
|
||||
|
||||
return graph.possibleOperations
|
||||
end
|
@ -1,107 +0,0 @@
|
||||
|
||||
function isempty(operations::PossibleOperations)
|
||||
return isempty(operations.nodeFusions) &&
|
||||
isempty(operations.nodeReductions) &&
|
||||
isempty(operations.nodeSplits)
|
||||
end
|
||||
|
||||
function length(operations::PossibleOperations)
|
||||
return (nodeFusions = length(operations.nodeFusions),
|
||||
nodeReductions = length(operations.nodeReductions),
|
||||
nodeSplits = length(operations.nodeSplits))
|
||||
end
|
||||
|
||||
function delete!(operations::PossibleOperations, op::NodeFusion)
|
||||
delete!(operations.nodeFusions, op)
|
||||
return operations
|
||||
end
|
||||
|
||||
function delete!(operations::PossibleOperations, op::NodeReduction)
|
||||
delete!(operations.nodeReductions, op)
|
||||
return operations
|
||||
end
|
||||
|
||||
function delete!(operations::PossibleOperations, op::NodeSplit)
|
||||
delete!(operations.nodeSplits, op)
|
||||
return operations
|
||||
end
|
||||
|
||||
|
||||
function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
if !is_child(n1, n2) || !is_child(n2, n3)
|
||||
# the checks are redundant but maybe a good sanity check
|
||||
return false
|
||||
end
|
||||
|
||||
if length(n2.parents) != 1 || length(n2.children) != 1 || length(n1.parents) != 1
|
||||
return false
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function can_reduce(n1::Node, n2::Node)
|
||||
if (n1.task != n2.task)
|
||||
return false
|
||||
end
|
||||
|
||||
n1_length = length(n1.children)
|
||||
n2_length = length(n2.children)
|
||||
|
||||
if (n1_length != n2_length)
|
||||
return false
|
||||
end
|
||||
|
||||
# this seems to be the most common case so do this first
|
||||
# doing it manually is a lot faster than using the sets for a general solution
|
||||
if (n1_length == 2)
|
||||
if (n1.children[1] != n2.children[1])
|
||||
if (n1.children[1] != n2.children[2])
|
||||
return false
|
||||
end
|
||||
# 1_1 == 2_2
|
||||
if (n1.children[2] != n2.children[1])
|
||||
return false
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
# 1_1 == 2_1
|
||||
if (n1.children[2] != n2.children[2])
|
||||
return false
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
# this is simple
|
||||
if (n1_length == 1)
|
||||
return n1.children[1] == n2.children[1]
|
||||
end
|
||||
|
||||
# this takes a long time
|
||||
return Set(n1.children) == Set(n2.children)
|
||||
end
|
||||
|
||||
function can_split(n::Node)
|
||||
return length(parents(n)) > 1
|
||||
end
|
||||
|
||||
function ==(op1::Operation, op2::Operation)
|
||||
return false
|
||||
end
|
||||
|
||||
function ==(op1::NodeFusion, op2::NodeFusion)
|
||||
# there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same
|
||||
return op1.input[2] == op2.input[2]
|
||||
end
|
||||
|
||||
function ==(op1::NodeReduction, op2::NodeReduction)
|
||||
# node reductions are equal exactly if their first input is the same
|
||||
return op1.input[1].id == op2.input[1].id
|
||||
end
|
||||
|
||||
function ==(op1::NodeSplit, op2::NodeSplit)
|
||||
return op1.input == op2.input
|
||||
end
|
||||
|
||||
copy(id::UUID) = UUID(id.value)
|
@ -1,61 +0,0 @@
|
||||
# functions to throw assertion errors for inconsistent or wrong node operations
|
||||
# should be called with @assert
|
||||
# the functions throw their own errors though, to still have helpful error messages
|
||||
|
||||
function is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
if !(n1 in graph) || !(n2 in graph) || !(n3 in graph)
|
||||
throw(AssertionError("[Node Fusion] The given nodes are not part of the given graph"))
|
||||
end
|
||||
|
||||
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
|
||||
throw(AssertionError("[Node Fusion] The given nodes are not connected by edges which is required for node fusion"))
|
||||
end
|
||||
|
||||
if length(n2.parents) > 1
|
||||
throw(AssertionError("[Node Fusion] The given data node has more than one parent"))
|
||||
end
|
||||
if length(n2.children) > 1
|
||||
throw(AssertionError("[Node Fusion] The given data node has more than one child"))
|
||||
end
|
||||
if length(n1.parents) > 1
|
||||
throw(AssertionError("[Node Fusion] The given n1 has more than one parent"))
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node})
|
||||
for n in nodes
|
||||
if n ∉ graph
|
||||
throw(AssertionError("[Node Reduction] The given nodes are not part of the given graph"))
|
||||
end
|
||||
end
|
||||
|
||||
t = typeof(nodes[1].task)
|
||||
for n in nodes
|
||||
if typeof(n.task) != t
|
||||
throw(AssertionError("[Node Reduction] The given nodes are not of the same type"))
|
||||
end
|
||||
end
|
||||
|
||||
n1_children = nodes[1].children
|
||||
for n in nodes
|
||||
if Set(n1_children) != Set(n.children)
|
||||
throw(AssertionError("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction"))
|
||||
end
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function is_valid_node_split_input(graph::DAG, n1::Node)
|
||||
if n1 ∉ graph
|
||||
throw(AssertionError("[Node Split] The given node is not part of the given graph"))
|
||||
end
|
||||
|
||||
if length(n1.parents) <= 1
|
||||
throw(AssertionError("[Node Split] The given node does not have multiple parents which is required for node split"))
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
73
src/properties/create.jl
Normal file
73
src/properties/create.jl
Normal file
@ -0,0 +1,73 @@
|
||||
"""
|
||||
GraphProperties()
|
||||
|
||||
Create an empty [`GraphProperties`](@ref) object.
|
||||
"""
|
||||
function GraphProperties()
|
||||
return (
|
||||
data = 0.0,
|
||||
computeEffort = 0.0,
|
||||
computeIntensity = 0.0,
|
||||
cost = 0.0,
|
||||
noNodes = 0,
|
||||
noEdges = 0,
|
||||
)::GraphProperties
|
||||
end
|
||||
|
||||
"""
|
||||
GraphProperties(graph::DAG)
|
||||
|
||||
Calculate the graph's properties and return the constructed [`GraphProperties`](@ref) object.
|
||||
"""
|
||||
function GraphProperties(graph::DAG)
|
||||
# make sure the graph is fully generated
|
||||
apply_all!(graph)
|
||||
|
||||
d = 0.0
|
||||
ce = 0.0
|
||||
ed = 0
|
||||
for node in graph.nodes
|
||||
d += data(node.task) * length(node.parents)
|
||||
ce += compute_effort(node.task)
|
||||
ed += length(node.parents)
|
||||
end
|
||||
|
||||
return (
|
||||
data = d,
|
||||
computeEffort = ce,
|
||||
computeIntensity = (d == 0) ? 0.0 : ce / d,
|
||||
cost = 0.0, # TODO
|
||||
noNodes = length(graph.nodes),
|
||||
noEdges = ed,
|
||||
)::GraphProperties
|
||||
end
|
||||
|
||||
"""
|
||||
GraphProperties(diff::Diff)
|
||||
|
||||
Create the graph properties difference from a given [`Diff`](@ref).
|
||||
The graph's properties after applying the [`Diff`](@ref) will be `get_properties(graph) + GraphProperties(diff)`.
|
||||
For reverting a diff, it's `get_properties(graph) - GraphProperties(diff)`.
|
||||
"""
|
||||
function GraphProperties(diff::Diff)
|
||||
d = 0.0
|
||||
ce = 0.0
|
||||
c = 0.0 # TODO
|
||||
|
||||
ce =
|
||||
reduce(+, compute_effort(n.task) for n in diff.addedNodes; init = 0.0) -
|
||||
reduce(+, compute_effort(n.task) for n in diff.removedNodes; init = 0.0)
|
||||
|
||||
d =
|
||||
reduce(+, data(e) for e in diff.addedEdges; init = 0.0) -
|
||||
reduce(+, data(e) for e in diff.removedEdges; init = 0.0)
|
||||
|
||||
return (
|
||||
data = d,
|
||||
computeEffort = ce,
|
||||
computeIntensity = (d == 0) ? 0.0 : ce / d,
|
||||
cost = c,
|
||||
noNodes = length(diff.addedNodes) - length(diff.removedNodes),
|
||||
noEdges = length(diff.addedEdges) - length(diff.removedEdges),
|
||||
)::GraphProperties
|
||||
end
|
17
src/properties/type.jl
Normal file
17
src/properties/type.jl
Normal file
@ -0,0 +1,17 @@
|
||||
"""
|
||||
GraphProperties
|
||||
|
||||
Representation of a [`DAG`](@ref)'s properties.
|
||||
|
||||
# Fields:
|
||||
`.data`: The total data transfer.\\
|
||||
`.computeEffort`: The total compute effort.\\
|
||||
`.computeIntensity`: The compute intensity, will always equal `.computeEffort / .data`.\\
|
||||
`.cost`: The estimated cost.\\
|
||||
`.noNodes`: Number of [`Node`](@ref)s.\\
|
||||
`.noEdges`: Number of [`Edge`](@ref)s.
|
||||
"""
|
||||
const GraphProperties = NamedTuple{
|
||||
(:data, :computeEffort, :computeIntensity, :cost, :noNodes, :noEdges),
|
||||
Tuple{Float64, Float64, Float64, Float64, Int, Int},
|
||||
}
|
57
src/properties/utility.jl
Normal file
57
src/properties/utility.jl
Normal file
@ -0,0 +1,57 @@
|
||||
"""
|
||||
-(prop1::GraphProperties, prop2::GraphProperties)
|
||||
|
||||
Subtract `prop1` from `prop2` and return the result as a new [`GraphProperties`](@ref).
|
||||
Also take care to keep consistent compute intensity.
|
||||
"""
|
||||
function -(prop1::GraphProperties, prop2::GraphProperties)
|
||||
return (
|
||||
data = prop1.data - prop2.data,
|
||||
computeEffort = prop1.computeEffort - prop2.computeEffort,
|
||||
computeIntensity = if (prop1.data - prop2.data == 0)
|
||||
0.0
|
||||
else
|
||||
(prop1.computeEffort - prop2.computeEffort) / (prop1.data - prop2.data)
|
||||
end,
|
||||
cost = prop1.cost - prop2.cost,
|
||||
noNodes = prop1.noNodes - prop2.noNodes,
|
||||
noEdges = prop1.noEdges - prop2.noEdges,
|
||||
)::GraphProperties
|
||||
end
|
||||
|
||||
"""
|
||||
+(prop1::GraphProperties, prop2::GraphProperties)
|
||||
|
||||
Add `prop1` and `prop2` and return the result as a new [`GraphProperties`](@ref).
|
||||
Also take care to keep consistent compute intensity.
|
||||
"""
|
||||
function +(prop1::GraphProperties, prop2::GraphProperties)
|
||||
return (
|
||||
data = prop1.data + prop2.data,
|
||||
computeEffort = prop1.computeEffort + prop2.computeEffort,
|
||||
computeIntensity = if (prop1.data + prop2.data == 0)
|
||||
0.0
|
||||
else
|
||||
(prop1.computeEffort + prop2.computeEffort) / (prop1.data + prop2.data)
|
||||
end,
|
||||
cost = prop1.cost + prop2.cost,
|
||||
noNodes = prop1.noNodes + prop2.noNodes,
|
||||
noEdges = prop1.noEdges + prop2.noEdges,
|
||||
)::GraphProperties
|
||||
end
|
||||
|
||||
"""
|
||||
-(prop::GraphProperties)
|
||||
|
||||
Unary negation of the graph properties. `.computeIntensity` will not be negated because `.data` and `.computeEffort` both are.
|
||||
"""
|
||||
function -(prop::GraphProperties)
|
||||
return (
|
||||
data = -prop.data,
|
||||
computeEffort = -prop.computeEffort,
|
||||
computeIntensity = prop.computeIntensity, # no negation here!
|
||||
cost = -prop.cost,
|
||||
noNodes = -prop.noNodes,
|
||||
noEdges = -prop.noEdges,
|
||||
)::GraphProperties
|
||||
end
|
50
src/scheduler/greedy.jl
Normal file
50
src/scheduler/greedy.jl
Normal file
@ -0,0 +1,50 @@
|
||||
|
||||
"""
|
||||
GreedyScheduler
|
||||
|
||||
A greedy implementation of a scheduler, creating a topological ordering of nodes and naively balancing them onto the different devices.
|
||||
"""
|
||||
struct GreedyScheduler end
|
||||
|
||||
function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine)
|
||||
nodeQueue = PriorityQueue{Node, Int}()
|
||||
|
||||
# use a priority equal to the number of unseen children -> 0 are nodes that can be added
|
||||
for node in get_entry_nodes(graph)
|
||||
enqueue!(nodeQueue, node => 0)
|
||||
end
|
||||
|
||||
schedule = Vector{Node}()
|
||||
sizehint!(schedule, length(graph.nodes))
|
||||
|
||||
# keep an accumulated cost of things scheduled to this device so far
|
||||
deviceAccCost = PriorityQueue{AbstractDevice, Int}()
|
||||
for device in machine.devices
|
||||
enqueue!(deviceAccCost, device => 0)
|
||||
end
|
||||
|
||||
node = nothing
|
||||
while !isempty(nodeQueue)
|
||||
@assert peek(nodeQueue)[2] == 0
|
||||
node = dequeue!(nodeQueue)
|
||||
|
||||
# assign the device with lowest accumulated cost to the node (if it's a compute node)
|
||||
if (isa(node, ComputeTaskNode))
|
||||
lowestDevice = peek(deviceAccCost)[1]
|
||||
node.device = lowestDevice
|
||||
deviceAccCost[lowestDevice] = compute_effort(node.task)
|
||||
end
|
||||
|
||||
push!(schedule, node)
|
||||
for parent in node.parents
|
||||
# reduce the priority of all parents by one
|
||||
if (!haskey(nodeQueue, parent))
|
||||
enqueue!(nodeQueue, parent => length(parent.children) - 1)
|
||||
else
|
||||
nodeQueue[parent] = nodeQueue[parent] - 1
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return schedule
|
||||
end
|
18
src/scheduler/interface.jl
Normal file
18
src/scheduler/interface.jl
Normal file
@ -0,0 +1,18 @@
|
||||
|
||||
"""
|
||||
Scheduler
|
||||
|
||||
Abstract base type for scheduler implementations. The scheduler is used to assign each node to a device and create a topological ordering of tasks.
|
||||
"""
|
||||
abstract type Scheduler end
|
||||
|
||||
"""
|
||||
schedule_dag(::Scheduler, ::DAG, ::Machine)
|
||||
|
||||
Interface functions that must be implemented for implementations of [`Scheduler`](@ref).
|
||||
|
||||
The function assigns each [`ComputeTaskNode`](@ref) of the [`DAG`](@ref) to one of the devices in the given [`Machine`](@ref) and returns a `Vector{Node}` representing a topological ordering.
|
||||
|
||||
[`DataTaskNode`](@ref)s are not scheduled to devices since they do not compute. Instead, a data node transfers data from the [`AbstractDevice`](@ref) of their child to all [`AbstractDevice`](@ref)s of its parents.
|
||||
"""
|
||||
function schedule_dag end
|
26
src/task/compare.jl
Normal file
26
src/task/compare.jl
Normal file
@ -0,0 +1,26 @@
|
||||
"""
|
||||
==(t1::AbstractTask, t2::AbstractTask)
|
||||
|
||||
Fallback implementation of equality comparison between two abstract tasks. Always returns false. For equal specific types of t1 and t2, a more specific comparison is called instead, doing an actual comparison.
|
||||
"""
|
||||
function ==(t1::AbstractTask, t2::AbstractTask)
|
||||
return false
|
||||
end
|
||||
|
||||
"""
|
||||
==(t1::AbstractComputeTask, t2::AbstractComputeTask)
|
||||
|
||||
Equality comparison between two compute tasks.
|
||||
"""
|
||||
function ==(t1::AbstractComputeTask, t2::AbstractComputeTask)
|
||||
return typeof(t1) == typeof(t2)
|
||||
end
|
||||
|
||||
"""
|
||||
==(t1::AbstractDataTask, t2::AbstractDataTask)
|
||||
|
||||
Equality comparison between two data tasks.
|
||||
"""
|
||||
function ==(t1::AbstractDataTask, t2::AbstractDataTask)
|
||||
return data(t1) == data(t2)
|
||||
end
|
89
src/task/compute.jl
Normal file
89
src/task/compute.jl
Normal file
@ -0,0 +1,89 @@
|
||||
|
||||
"""
|
||||
compute(t::FusedComputeTask, data)
|
||||
|
||||
Compute a [`FusedComputeTask`](@ref). This simply asserts false and should not be called. Fused Compute Tasks generate their expressions directly through the other tasks instead.
|
||||
"""
|
||||
function compute(t::FusedComputeTask, data)
|
||||
@assert false "This is not implemented and should never be called"
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(t::FusedComputeTask, device::AbstractDevice, inExprs::Vector{String}, outExpr::String)
|
||||
|
||||
Generate code evaluating a [`FusedComputeTask`](@ref) on `inExprs`, providing the output on `outExpr`.
|
||||
`inExprs` should be of the correct types and may be heterogeneous. `outExpr` will be of the type of the output of `T2` of t.
|
||||
"""
|
||||
function get_expression(t::FusedComputeTask, device::AbstractDevice, inExprs::Vector, outExpr)
|
||||
inExprs1 = Vector()
|
||||
for sym in t.t1_inputs
|
||||
push!(inExprs1, gen_access_expr(device, sym))
|
||||
end
|
||||
|
||||
outExpr1 = gen_access_expr(device, t.t1_output)
|
||||
|
||||
inExprs2 = Vector()
|
||||
for sym in t.t2_inputs
|
||||
push!(inExprs2, gen_access_expr(device, sym))
|
||||
end
|
||||
|
||||
expr1 = get_expression(t.first_task, device, inExprs1, outExpr1)
|
||||
expr2 = get_expression(t.second_task, device, [inExprs2..., outExpr1], outExpr)
|
||||
|
||||
full_expr = Expr(:block, expr1, expr2)
|
||||
|
||||
return full_expr
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(node::ComputeTaskNode)
|
||||
|
||||
Generate and return code for a given [`ComputeTaskNode`](@ref).
|
||||
"""
|
||||
function get_expression(node::ComputeTaskNode)
|
||||
@assert length(node.children) <= children(node.task) "Node $(node) has too many children for its task: node has $(length(node.children)) versus task has $(children(node.task))\nNode's children: $(getfield.(node.children, :children))"
|
||||
@assert !ismissing(node.device) "Trying to get expression for an unscheduled ComputeTaskNode\nNode: $(node)"
|
||||
|
||||
inExprs = Vector()
|
||||
for id in getfield.(node.children, :id)
|
||||
push!(inExprs, gen_access_expr(node.device, Symbol(to_var_name(id))))
|
||||
end
|
||||
outExpr = gen_access_expr(node.device, Symbol(to_var_name(node.id)))
|
||||
|
||||
return get_expression(node.task, node.device, inExprs, outExpr)
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(node::DataTaskNode)
|
||||
|
||||
Generate and return code for a given [`DataTaskNode`](@ref).
|
||||
"""
|
||||
function get_expression(node::DataTaskNode)
|
||||
@assert length(node.children) == 1 "Trying to call get_expression on a data task node that has $(length(node.children)) children instead of 1"
|
||||
|
||||
# TODO: dispatch to device implementations generating the copy commands
|
||||
|
||||
child = node.children[1]
|
||||
inExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(child.id))))
|
||||
outExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(node.id))))
|
||||
dataTransportExp = Meta.parse("$outExpr = $inExpr")
|
||||
|
||||
return dataTransportExp
|
||||
end
|
||||
|
||||
"""
|
||||
get_init_expression(node::DataTaskNode, device::AbstractDevice)
|
||||
|
||||
Generate and return code for the initial input reading expression for [`DataTaskNode`](@ref)s with 0 children, i.e., entry nodes.
|
||||
|
||||
See also: [`get_entry_nodes`](@ref)
|
||||
"""
|
||||
function get_init_expression(node::DataTaskNode, device::AbstractDevice)
|
||||
@assert isempty(node.children) "Trying to call get_init_expression on a data task node that is not an entry node."
|
||||
|
||||
inExpr = eval(gen_access_expr(device, Symbol("$(to_var_name(node.id))_in")))
|
||||
outExpr = eval(gen_access_expr(device, Symbol(to_var_name(node.id))))
|
||||
dataTransportExp = Meta.parse("$outExpr = $inExpr")
|
||||
|
||||
return dataTransportExp
|
||||
end
|
31
src/task/create.jl
Normal file
31
src/task/create.jl
Normal file
@ -0,0 +1,31 @@
|
||||
"""
|
||||
copy(t::AbstractDataTask)
|
||||
|
||||
Fallback implementation of the copy of an abstract data task, throwing an error.
|
||||
"""
|
||||
copy(t::AbstractDataTask) = error("Need to implement copying for your data tasks!")
|
||||
|
||||
"""
|
||||
copy(t::AbstractComputeTask)
|
||||
|
||||
Return a copy of the given compute task.
|
||||
"""
|
||||
copy(t::AbstractComputeTask) = typeof(t)()
|
||||
|
||||
"""
|
||||
copy(t::FusedComputeTask)
|
||||
|
||||
Return a copy of th egiven [`FusedComputeTask`](@ref).
|
||||
"""
|
||||
function copy(t::FusedComputeTask{T1, T2}) where {T1, T2}
|
||||
return FusedComputeTask{T1, T2}(
|
||||
copy(t.first_task),
|
||||
copy(t.second_task),
|
||||
copy(t.t1_inputs),
|
||||
t.t1_output,
|
||||
copy(t.t2_inputs),
|
||||
)
|
||||
end
|
||||
|
||||
FusedComputeTask{T1, T2}(t1_inputs::Vector{String}, t1_output::String, t2_inputs::Vector{String}) where {T1, T2} =
|
||||
FusedComputeTask{T1, T2}(T1(), T2(), t1_inputs, t1_output, t2_inputs)
|
8
src/task/print.jl
Normal file
8
src/task/print.jl
Normal file
@ -0,0 +1,8 @@
|
||||
"""
|
||||
show(io::IO, t::FusedComputeTask)
|
||||
|
||||
Print a string representation of the fused compute task to io.
|
||||
"""
|
||||
function show(io::IO, t::FusedComputeTask)
|
||||
return print(io, "ComputeFuse($(t.first_task), $(t.second_task))")
|
||||
end
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user