Fix tests and operation cache
This commit is contained in:
parent
8a081ba93c
commit
ae07b4cf80
@ -30,4 +30,4 @@ jobs:
|
||||
run: julia --project -e 'import Pkg; Pkg.test()'
|
||||
|
||||
- name: Run examples
|
||||
run: julia --project -e 'import Pkg; include("examples/import_bench.jl")'
|
||||
run: julia --project=examples -e 'import Pkg; Pkg.develop("."); Pkg.instantiate(); include("examples/import_bench.jl")'
|
||||
|
@ -4,9 +4,7 @@ authors = ["Anton Reinhard <anton.reinhard@proton.me>"]
|
||||
version = "0.1.0"
|
||||
|
||||
[deps]
|
||||
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
|
||||
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
|
||||
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
|
||||
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
|
||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
||||
|
28
README.md
28
README.md
@ -2,7 +2,31 @@
|
||||
|
||||
Directed Acyclic Graph optimization for QED
|
||||
|
||||
## Generate Operations from chains
|
||||
## Usage
|
||||
|
||||
Instantiate the project first:
|
||||
|
||||
`julia --project -e 'import Pkg; Pkg.instantiate()'`
|
||||
|
||||
### Run Tests
|
||||
|
||||
To run all tests, run
|
||||
|
||||
`julia --project=. -e 'import Pkg; Pkg.test()'`
|
||||
|
||||
### Run Examples
|
||||
|
||||
Get the correct environment for the examples folder:
|
||||
|
||||
`julia --project=examples -e 'import Pkg; Pkg.develop("."); Pkg.instantiate()'`
|
||||
|
||||
Then execute a specific example:
|
||||
|
||||
`julia --project=examples examples/<file>.jl`
|
||||
|
||||
## Concepts
|
||||
|
||||
### Generate Operations from chains
|
||||
|
||||
We assume we have a (valid) graph given. We can generate all initially possible graph operations from it, and we can calculate the graph properties like compute effort and total data transfer.
|
||||
|
||||
@ -121,3 +145,5 @@ Graph:
|
||||
Graph size in memory: 225.0625 KiB
|
||||
286.583 μs (13996 allocations: 804.48 KiB)
|
||||
```
|
||||
|
||||
|
||||
|
4
examples/Project.toml
Normal file
4
examples/Project.toml
Normal file
@ -0,0 +1,4 @@
|
||||
[deps]
|
||||
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
|
||||
MetagraphOptimization = "3e869610-d48d-4942-ba70-c1b702a33ca4"
|
||||
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
|
@ -6,7 +6,7 @@ export make_node, make_edge, insert_node, insert_edge, is_entry_node, is_exit_no
|
||||
export NodeFusion, NodeReduction, NodeSplit, push_operation!, pop_operation!, can_pop, reset_graph!, get_operations
|
||||
export import_txt
|
||||
|
||||
export ==, in, show
|
||||
export ==, in, show, isempty, delete!
|
||||
|
||||
export bytes_to_human_readable
|
||||
|
||||
@ -15,6 +15,8 @@ import Base.show
|
||||
import Base.==
|
||||
import Base.in
|
||||
import Base.copy
|
||||
import Base.isempty
|
||||
import Base.delete!
|
||||
|
||||
|
||||
include("tasks.jl")
|
||||
|
@ -3,6 +3,25 @@ using DataStructures
|
||||
in(node::Node, graph::DAG) = node in graph.nodes
|
||||
in(edge::Edge, graph::DAG) = edge in graph.edges
|
||||
|
||||
function isempty(operations::PossibleOperations)
|
||||
return isempty(operations.nodeFusions) &&
|
||||
isempty(operations.nodeReductions) &&
|
||||
isempty(operations.nodeSplits)
|
||||
end
|
||||
|
||||
function delete!(operations::PossibleOperations, op::NodeFusion)
|
||||
delete!(operations.nodeFusions, op)
|
||||
return operations
|
||||
end
|
||||
function delete!(operations::PossibleOperations, op::NodeReduction)
|
||||
delete!(operations.nodeReductions, op)
|
||||
return operations
|
||||
end
|
||||
function delete!(operations::PossibleOperations, op::NodeSplit)
|
||||
delete!(operations.nodeSplits, op)
|
||||
return operations
|
||||
end
|
||||
|
||||
function is_parent(potential_parent, node)
|
||||
return potential_parent in node.parents
|
||||
end
|
||||
@ -68,14 +87,25 @@ function invalidate_caches!(graph::DAG, operation::Operation)
|
||||
delete!(graph.possibleOperations, operation)
|
||||
|
||||
# delete the operation from all caches of nodes involved in the operation
|
||||
# (we can iterate over single values, tuples and vectors just fine)
|
||||
# (we can iterate over tuples and vectors just fine)
|
||||
for node in operation.input
|
||||
delete!(node.operations, operation)
|
||||
filter!(!=(operation), node.operations)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# function to invalidate the operation caches for a given Node Split specifically
|
||||
function invalidate_caches!(graph::DAG, operation::NodeSplit)
|
||||
delete!(graph.possibleOperations, operation)
|
||||
|
||||
# delete the operation from all caches of nodes involved in the operation
|
||||
# for node split there is only one node
|
||||
filter!(!=(operation), operation.input.operations)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# for graph mutating functions we need to do a few things
|
||||
# 1: mute the graph (duh)
|
||||
# 2: keep track of what was changed for the diff (if track == true)
|
||||
@ -127,7 +157,7 @@ function remove_node!(graph::DAG, node::Node, track=true)
|
||||
if (track) push!(graph.diff.removedNodes, node) end
|
||||
|
||||
# 3: invalidate caches
|
||||
while !isempty(node)
|
||||
while !isempty(node.operations)
|
||||
invalidate_caches!(graph, first(node.operations))
|
||||
end
|
||||
delete!(graph.dirtyNodes, node)
|
||||
@ -197,7 +227,16 @@ function get_exit_node(graph::DAG)
|
||||
end
|
||||
|
||||
function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
#Todo
|
||||
if !is_child(n1, n2) || !is_child(n2, n3)
|
||||
# the checks are redundant but maybe a good sanity check
|
||||
return false
|
||||
end
|
||||
|
||||
if length(parents(n2)) != 1 || length(children(n2)) != 1
|
||||
return false
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function can_reduce(n1::Node, n2::Node)
|
||||
|
@ -237,9 +237,113 @@ function node_split!(graph::DAG, n1::Node)
|
||||
return get_snapshot_diff(graph)
|
||||
end
|
||||
|
||||
# function to find node fusions involving the given node
|
||||
function find_fusions(graph::DAG, node::Node)
|
||||
|
||||
# function to find node fusions involving the given node if it's a data node
|
||||
# pushes the found fusion everywhere it needs to be and returns nothing
|
||||
function find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
if length(parents(node)) != 1 || length(children(node)) != 1
|
||||
return nothing
|
||||
end
|
||||
|
||||
child_node = first(children(node))
|
||||
parent_node = first(parents(node))
|
||||
|
||||
nf = NodeFusion((child_node, node, parent_node))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(child_node.operations, nf)
|
||||
push!(node.operations, nf)
|
||||
push!(parent_node.operations, nf)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# function to find node fusions involving the given node if it's a compute node
|
||||
# pushes the found fusion(s) everywhere it needs to be and returns nothing
|
||||
function find_fusions!(graph::DAG, node::ComputeTaskNode)
|
||||
# for loop that always runs once for a scoped block we can break out of
|
||||
for _ in 1:1
|
||||
# assume this node as child of the chain
|
||||
if length(parents(node)) < 1
|
||||
break
|
||||
end
|
||||
node2 = first(parents(node))
|
||||
if length(parents(node2)) != 1 || length(children(node2)) != 1
|
||||
break
|
||||
end
|
||||
node3 = first(parents(node2))
|
||||
|
||||
nf = NodeFusion((node, node2, node3))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(node.operations, nf)
|
||||
push!(node2.operations, nf)
|
||||
push!(node3.operations, nf)
|
||||
end
|
||||
|
||||
for _ in 1:1
|
||||
# assume this node as parent of the chain
|
||||
if length(children(node)) < 1
|
||||
break
|
||||
end
|
||||
node2 = first(children(node))
|
||||
if length(parents(node2)) != 1 || length(children(node2)) != 1
|
||||
break
|
||||
end
|
||||
node1 = first(children(node2))
|
||||
|
||||
nf = NodeFusion((node1, node2, node))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(node1.operations, nf)
|
||||
push!(node2.operations, nf)
|
||||
push!(node.operations, nf)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function find_reductions!(graph::DAG, node::Node)
|
||||
reductionVector = nothing
|
||||
# possible reductions are with nodes that are partners, i.e. parents of children
|
||||
for partner in partners(node)
|
||||
if can_reduce(node, partner)
|
||||
if reductionVector === nothing
|
||||
# only when there's at least one reduction partner, insert the vector
|
||||
reductionVector = Vector{Node}()
|
||||
push!(reductionVector, node)
|
||||
end
|
||||
|
||||
push!(reductionVector, partner)
|
||||
end
|
||||
end
|
||||
|
||||
if reductionVector !== nothing
|
||||
nr = NodeReduction(reductionVector)
|
||||
push!(graph.possibleOperations.nodeReductions, nr)
|
||||
for node in reductionVector
|
||||
push!(node.operations, nr)
|
||||
end
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function find_splits!(graph::DAG, node::Node)
|
||||
for node in graph.nodes
|
||||
if (can_split(node))
|
||||
ns = NodeSplit(node)
|
||||
push!(graph.possibleOperations.nodeSplits, ns)
|
||||
push!(node.operations, ns)
|
||||
end
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
# "clean" the operations on a dirty node
|
||||
function clean_node!(graph::DAG, node::Node)
|
||||
find_fusions!(graph, node)
|
||||
find_reductions!(graph, node)
|
||||
find_splits!(graph, node)
|
||||
|
||||
delete!(graph.dirtyNodes, node)
|
||||
end
|
||||
|
||||
# function to generate all possible optmizations on the graph
|
||||
@ -317,15 +421,20 @@ function generate_options(graph::DAG)
|
||||
end
|
||||
end
|
||||
|
||||
options.dirty = false
|
||||
|
||||
graph.possibleOperations = options
|
||||
empty!(graph.dirtyNodes)
|
||||
end
|
||||
|
||||
function get_operations(graph::DAG)
|
||||
if (graph.possibleOperations.dirty)
|
||||
apply_all!(graph)
|
||||
|
||||
if isempty(graph.possibleOperations)
|
||||
generate_options(graph)
|
||||
end
|
||||
|
||||
while !isempty(graph.dirtyNodes)
|
||||
clean_node!(graph, first(graph.dirtyNodes))
|
||||
end
|
||||
|
||||
return graph.possibleOperations
|
||||
end
|
||||
end
|
||||
|
3
test/Project.toml
Normal file
3
test/Project.toml
Normal file
@ -0,0 +1,3 @@
|
||||
[deps]
|
||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
Loading…
x
Reference in New Issue
Block a user