diff --git a/.gitea/workflows/julia-package-ci.yml b/.gitea/workflows/julia-package-ci.yml index 7bf1568..6564ee1 100644 --- a/.gitea/workflows/julia-package-ci.yml +++ b/.gitea/workflows/julia-package-ci.yml @@ -30,4 +30,4 @@ jobs: run: julia --project -e 'import Pkg; Pkg.test()' - name: Run examples - run: julia --project -e 'import Pkg; include("examples/import_bench.jl")' + run: julia --project=examples -e 'import Pkg; Pkg.develop("."); Pkg.instantiate(); include("examples/import_bench.jl")' diff --git a/Project.toml b/Project.toml index b0736e5..e7ec467 100644 --- a/Project.toml +++ b/Project.toml @@ -4,9 +4,7 @@ authors = ["Anton Reinhard "] version = "0.1.0" [deps] -BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" diff --git a/README.md b/README.md index a66bec4..053b8fc 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,31 @@ Directed Acyclic Graph optimization for QED -## Generate Operations from chains +## Usage + +Instantiate the project first: + +`julia --project -e 'import Pkg; Pkg.instantiate()'` + +### Run Tests + +To run all tests, run + +`julia --project=. -e 'import Pkg; Pkg.test()'` + +### Run Examples + +Get the correct environment for the examples folder: + +`julia --project=examples -e 'import Pkg; Pkg.develop("."); Pkg.instantiate()'` + +Then execute a specific example: + +`julia --project=examples examples/.jl` + +## Concepts + +### Generate Operations from chains We assume we have a (valid) graph given. We can generate all initially possible graph operations from it, and we can calculate the graph properties like compute effort and total data transfer. @@ -121,3 +145,5 @@ Graph: Graph size in memory: 225.0625 KiB 286.583 μs (13996 allocations: 804.48 KiB) ``` + + diff --git a/examples/Project.toml b/examples/Project.toml new file mode 100644 index 0000000..a1e4baf --- /dev/null +++ b/examples/Project.toml @@ -0,0 +1,4 @@ +[deps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +MetagraphOptimization = "3e869610-d48d-4942-ba70-c1b702a33ca4" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" diff --git a/src/MetagraphOptimization.jl b/src/MetagraphOptimization.jl index 594675b..f9a4b5d 100644 --- a/src/MetagraphOptimization.jl +++ b/src/MetagraphOptimization.jl @@ -6,7 +6,7 @@ export make_node, make_edge, insert_node, insert_edge, is_entry_node, is_exit_no export NodeFusion, NodeReduction, NodeSplit, push_operation!, pop_operation!, can_pop, reset_graph!, get_operations export import_txt -export ==, in, show +export ==, in, show, isempty, delete! export bytes_to_human_readable @@ -15,6 +15,8 @@ import Base.show import Base.== import Base.in import Base.copy +import Base.isempty +import Base.delete! include("tasks.jl") diff --git a/src/graph_functions.jl b/src/graph_functions.jl index 2225508..d7647ee 100644 --- a/src/graph_functions.jl +++ b/src/graph_functions.jl @@ -3,6 +3,25 @@ using DataStructures in(node::Node, graph::DAG) = node in graph.nodes in(edge::Edge, graph::DAG) = edge in graph.edges +function isempty(operations::PossibleOperations) + return isempty(operations.nodeFusions) && + isempty(operations.nodeReductions) && + isempty(operations.nodeSplits) +end + +function delete!(operations::PossibleOperations, op::NodeFusion) + delete!(operations.nodeFusions, op) + return operations +end +function delete!(operations::PossibleOperations, op::NodeReduction) + delete!(operations.nodeReductions, op) + return operations +end +function delete!(operations::PossibleOperations, op::NodeSplit) + delete!(operations.nodeSplits, op) + return operations +end + function is_parent(potential_parent, node) return potential_parent in node.parents end @@ -68,14 +87,25 @@ function invalidate_caches!(graph::DAG, operation::Operation) delete!(graph.possibleOperations, operation) # delete the operation from all caches of nodes involved in the operation - # (we can iterate over single values, tuples and vectors just fine) + # (we can iterate over tuples and vectors just fine) for node in operation.input - delete!(node.operations, operation) + filter!(!=(operation), node.operations) end return nothing end +# function to invalidate the operation caches for a given Node Split specifically +function invalidate_caches!(graph::DAG, operation::NodeSplit) + delete!(graph.possibleOperations, operation) + + # delete the operation from all caches of nodes involved in the operation + # for node split there is only one node + filter!(!=(operation), operation.input.operations) + + return nothing +end + # for graph mutating functions we need to do a few things # 1: mute the graph (duh) # 2: keep track of what was changed for the diff (if track == true) @@ -127,7 +157,7 @@ function remove_node!(graph::DAG, node::Node, track=true) if (track) push!(graph.diff.removedNodes, node) end # 3: invalidate caches - while !isempty(node) + while !isempty(node.operations) invalidate_caches!(graph, first(node.operations)) end delete!(graph.dirtyNodes, node) @@ -197,7 +227,16 @@ function get_exit_node(graph::DAG) end function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode) - #Todo + if !is_child(n1, n2) || !is_child(n2, n3) + # the checks are redundant but maybe a good sanity check + return false + end + + if length(parents(n2)) != 1 || length(children(n2)) != 1 + return false + end + + return true end function can_reduce(n1::Node, n2::Node) diff --git a/src/graph_operations.jl b/src/graph_operations.jl index 18a3414..63d41f5 100644 --- a/src/graph_operations.jl +++ b/src/graph_operations.jl @@ -237,9 +237,113 @@ function node_split!(graph::DAG, n1::Node) return get_snapshot_diff(graph) end -# function to find node fusions involving the given node -function find_fusions(graph::DAG, node::Node) - +# function to find node fusions involving the given node if it's a data node +# pushes the found fusion everywhere it needs to be and returns nothing +function find_fusions!(graph::DAG, node::DataTaskNode) + if length(parents(node)) != 1 || length(children(node)) != 1 + return nothing + end + + child_node = first(children(node)) + parent_node = first(parents(node)) + + nf = NodeFusion((child_node, node, parent_node)) + push!(graph.possibleOperations.nodeFusions, nf) + push!(child_node.operations, nf) + push!(node.operations, nf) + push!(parent_node.operations, nf) + + return nothing +end + +# function to find node fusions involving the given node if it's a compute node +# pushes the found fusion(s) everywhere it needs to be and returns nothing +function find_fusions!(graph::DAG, node::ComputeTaskNode) + # for loop that always runs once for a scoped block we can break out of + for _ in 1:1 + # assume this node as child of the chain + if length(parents(node)) < 1 + break + end + node2 = first(parents(node)) + if length(parents(node2)) != 1 || length(children(node2)) != 1 + break + end + node3 = first(parents(node2)) + + nf = NodeFusion((node, node2, node3)) + push!(graph.possibleOperations.nodeFusions, nf) + push!(node.operations, nf) + push!(node2.operations, nf) + push!(node3.operations, nf) + end + + for _ in 1:1 + # assume this node as parent of the chain + if length(children(node)) < 1 + break + end + node2 = first(children(node)) + if length(parents(node2)) != 1 || length(children(node2)) != 1 + break + end + node1 = first(children(node2)) + + nf = NodeFusion((node1, node2, node)) + push!(graph.possibleOperations.nodeFusions, nf) + push!(node1.operations, nf) + push!(node2.operations, nf) + push!(node.operations, nf) + end + + return nothing +end + +function find_reductions!(graph::DAG, node::Node) + reductionVector = nothing + # possible reductions are with nodes that are partners, i.e. parents of children + for partner in partners(node) + if can_reduce(node, partner) + if reductionVector === nothing + # only when there's at least one reduction partner, insert the vector + reductionVector = Vector{Node}() + push!(reductionVector, node) + end + + push!(reductionVector, partner) + end + end + + if reductionVector !== nothing + nr = NodeReduction(reductionVector) + push!(graph.possibleOperations.nodeReductions, nr) + for node in reductionVector + push!(node.operations, nr) + end + end + + return nothing +end + +function find_splits!(graph::DAG, node::Node) + for node in graph.nodes + if (can_split(node)) + ns = NodeSplit(node) + push!(graph.possibleOperations.nodeSplits, ns) + push!(node.operations, ns) + end + end + + return nothing +end + +# "clean" the operations on a dirty node +function clean_node!(graph::DAG, node::Node) + find_fusions!(graph, node) + find_reductions!(graph, node) + find_splits!(graph, node) + + delete!(graph.dirtyNodes, node) end # function to generate all possible optmizations on the graph @@ -317,15 +421,20 @@ function generate_options(graph::DAG) end end - options.dirty = false - graph.possibleOperations = options + empty!(graph.dirtyNodes) end function get_operations(graph::DAG) - if (graph.possibleOperations.dirty) + apply_all!(graph) + + if isempty(graph.possibleOperations) generate_options(graph) end + while !isempty(graph.dirtyNodes) + clean_node!(graph, first(graph.dirtyNodes)) + end + return graph.possibleOperations -end \ No newline at end of file +end diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..7a21f89 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,3 @@ +[deps] +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"