From b7560685d49b8d58d8f325f4137ece8567854b66 Mon Sep 17 00:00:00 2001
From: Anton Reinhard <anton.reinhard@proton.me>
Date: Wed, 22 Nov 2023 13:51:54 +0100
Subject: [PATCH] Optimizer interface and sample implementation (#19)

Reviewed-on: https://code.woubery.com/Rubydragon/MetagraphOptimization.jl/pulls/19
Co-authored-by: Anton Reinhard <anton.reinhard@proton.me>
Co-committed-by: Anton Reinhard <anton.reinhard@proton.me>
---
 docs/src/lib/internals/estimator.md    |  2 +-
 docs/src/lib/internals/optimization.md | 41 +++++++++++++++
 examples/ab5.jl                        |  3 +-
 examples/ab7.jl                        |  3 +-
 examples/profiling_utilities.jl        | 59 ----------------------
 notebooks/abc_model_large.ipynb        |  3 +-
 notebooks/abc_model_showcase.ipynb     |  4 +-
 notebooks/profiling.ipynb              |  3 +-
 src/MetagraphOptimization.jl           | 13 +++++
 src/estimator/global_metric.jl         | 24 ++++++---
 src/estimator/interface.jl             |  2 +-
 src/graph/compare.jl                   | 18 +------
 src/graph/mute.jl                      | 59 ++++++++++++++--------
 src/graph/print.jl                     |  6 +--
 src/graph/properties.jl                |  9 ++++
 src/graph/type.jl                      |  4 +-
 src/models/abc/parse.jl                |  2 +-
 src/models/abc/properties.jl           | 12 ++---
 src/node/compare.jl                    |  4 +-
 src/node/create.jl                     |  4 +-
 src/node/print.jl                      |  2 +-
 src/node/properties.jl                 | 58 ++++++++++++----------
 src/node/type.jl                       | 10 ++--
 src/node/validate.jl                   |  4 +-
 src/operation/apply.jl                 | 29 ++++++-----
 src/operation/clean.jl                 | 17 ++++---
 src/operation/find.jl                  | 10 ++--
 src/operation/get.jl                   |  4 +-
 src/operation/iterate.jl               | 39 +++++++++++++++
 src/operation/print.jl                 | 10 ++--
 src/operation/type.jl                  | 29 ++++++-----
 src/operation/utility.jl               | 36 +++++++++-----
 src/operation/validate.jl              | 10 ++--
 src/optimization/greedy.jl             | 69 +++++++++++++++++++++++++-
 src/optimization/interface.jl          | 60 ++++++++++++++++++++++
 src/optimization/random_walk.jl        | 49 ++++++++++++++++++
 src/optimization/reduce.jl             | 30 +++++++++++
 src/properties/create.jl               | 41 ++++++++-------
 src/properties/type.jl                 |  5 +-
 src/properties/utility.jl              |  3 --
 src/scheduler/greedy.jl                |  6 +--
 src/task/compute.jl                    | 12 ++---
 src/task/create.jl                     | 21 ++++----
 src/task/properties.jl                 | 14 +++---
 src/task/type.jl                       |  6 +--
 src/trie.jl                            | 32 +++++++-----
 src/utility.jl                         |  4 +-
 test/runtests.jl                       |  1 +
 test/unit_tests_estimator.jl           |  7 ---
 test/unit_tests_execution.jl           |  8 +--
 test/unit_tests_graph.jl               |  6 +++
 test/unit_tests_optimization.jl        | 42 ++++++++++++++++
 test/unit_tests_properties.jl          | 21 +-------
 53 files changed, 639 insertions(+), 331 deletions(-)
 create mode 100644 docs/src/lib/internals/optimization.md
 delete mode 100644 examples/profiling_utilities.jl
 create mode 100644 src/operation/iterate.jl
 create mode 100644 src/optimization/interface.jl
 create mode 100644 src/optimization/random_walk.jl
 create mode 100644 src/optimization/reduce.jl
 create mode 100644 test/unit_tests_optimization.jl

diff --git a/docs/src/lib/internals/estimator.md b/docs/src/lib/internals/estimator.md
index 5484a07..3cea858 100644
--- a/docs/src/lib/internals/estimator.md
+++ b/docs/src/lib/internals/estimator.md
@@ -1,4 +1,4 @@
-# Models
+# Estimation
 
 ## Interface
 
diff --git a/docs/src/lib/internals/optimization.md b/docs/src/lib/internals/optimization.md
new file mode 100644
index 0000000..8f563e8
--- /dev/null
+++ b/docs/src/lib/internals/optimization.md
@@ -0,0 +1,41 @@
+# Optimization
+
+## Interface
+
+The interface that has to be implemented for an optimization algorithm.
+
+```@autodocs
+Modules = [MetagraphOptimization]
+Pages = ["optimization/interafce.jl"]
+Order = [:type, :constant, :function]
+```
+
+## Random Walk Optimizer
+
+Implementation of a random walk algorithm.
+
+```@autodocs
+Modules = [MetagraphOptimization]
+Pages = ["estimator/random_walk.jl"]
+Order = [:type, :function]
+```
+
+## Reduction Optimizer
+
+Implementation of a an optimizer that reduces as far as possible.
+
+```@autodocs
+Modules = [MetagraphOptimization]
+Pages = ["estimator/reduce.jl"]
+Order = [:type, :function]
+```
+
+## Greedy Optimizer
+
+Implementation of a greedy optimization algorithm.
+
+```@autodocs
+Modules = [MetagraphOptimization]
+Pages = ["estimator/greedy.jl"]
+Order = [:type, :function]
+```
diff --git a/examples/ab5.jl b/examples/ab5.jl
index f84da02..feee2aa 100644
--- a/examples/ab5.jl
+++ b/examples/ab5.jl
@@ -17,9 +17,8 @@ println("Parsing DAG")
 println("Generating input data")
 @time input_data = [gen_process_input(process) for _ in 1:1000]
 
-include("profiling_utilities.jl")
 println("Reducing graph")
-@time reduce_all!(graph)
+@time optimize_to_fixpoint!(ReductionOptimizer(), graph)
 
 println("Generating compute function")
 @time compute_func = get_compute_function(graph, process, machine)
diff --git a/examples/ab7.jl b/examples/ab7.jl
index 506d9fb..0b08f79 100644
--- a/examples/ab7.jl
+++ b/examples/ab7.jl
@@ -17,9 +17,8 @@ println("Parsing DAG")
 println("Generating input data")
 @time input_data = [gen_process_input(process) for _ in 1:1000]
 
-include("profiling_utilities.jl")
 println("Reducing graph")
-@time reduce_all!(graph)
+@time optimize_to_fixpoint!(ReductionOptimizer(), graph)
 
 println("Generating compute function")
 @time compute_func = get_compute_function(graph, process, machine)
diff --git a/examples/profiling_utilities.jl b/examples/profiling_utilities.jl
deleted file mode 100644
index 2400567..0000000
--- a/examples/profiling_utilities.jl
+++ /dev/null
@@ -1,59 +0,0 @@
-
-function random_walk!(g::DAG, n::Int64)
-    # the purpose here is to do "random" operations on the graph to simulate an optimizer
-    reset_graph!(g)
-
-    properties = get_properties(g)
-
-    for i in 1:n
-        # choose push or pop
-        if rand(Bool)
-            # push
-            opt = get_operations(g)
-
-            # choose one of fuse/split/reduce
-            option = rand(1:3)
-            if option == 1 && !isempty(opt.nodeFusions)
-                push_operation!(g, rand(collect(opt.nodeFusions)))
-            elseif option == 2 && !isempty(opt.nodeReductions)
-                push_operation!(g, rand(collect(opt.nodeReductions)))
-            elseif option == 3 && !isempty(opt.nodeSplits)
-                push_operation!(g, rand(collect(opt.nodeSplits)))
-            else
-                i = i - 1
-            end
-        else
-            # pop
-            if (can_pop(g))
-                pop_operation!(g)
-            else
-                i = i - 1
-            end
-        end
-    end
-
-    return nothing
-end
-
-function reduce_all!(g::DAG)
-    reset_graph!(g)
-
-    opt = get_operations(g)
-    while (!isempty(opt.nodeReductions))
-        push_operation!(g, pop!(opt.nodeReductions))
-
-        if (isempty(opt.nodeReductions))
-            opt = get_operations(g)
-        end
-    end
-    return nothing
-end
-
-function reduce_one!(g::DAG)
-    opt = get_operations(g)
-    if !isempty(opt.nodeReductions)
-        push_operation!(g, pop!(opt.nodeReductions))
-    end
-    opt = get_operations(g)
-    return nothing
-end
diff --git a/notebooks/abc_model_large.ipynb b/notebooks/abc_model_large.ipynb
index 00734f4..8189d04 100644
--- a/notebooks/abc_model_large.ipynb
+++ b/notebooks/abc_model_large.ipynb
@@ -99,8 +99,7 @@
     }
    ],
    "source": [
-    "include(\"../examples/profiling_utilities.jl\")\n",
-    "@time reduce_all!(graph)\n",
+    "@time optimize_to_fixpoint!(ReductionOptimizer(), graph)\n",
     "print(graph)"
    ]
   },
diff --git a/notebooks/abc_model_showcase.ipynb b/notebooks/abc_model_showcase.ipynb
index cd72498..9cef189 100644
--- a/notebooks/abc_model_showcase.ipynb
+++ b/notebooks/abc_model_showcase.ipynb
@@ -211,10 +211,8 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "include(\"../examples/profiling_utilities.jl\")\n",
-    "\n",
     "# We can also mute the graph by applying some operations to it\n",
-    "reduce_all!(graph)"
+    "optimize_to_fixpoint!(ReductionOptimizer(), graph)"
    ]
   },
   {
diff --git a/notebooks/profiling.ipynb b/notebooks/profiling.ipynb
index f782184..f4032ee 100644
--- a/notebooks/profiling.ipynb
+++ b/notebooks/profiling.ipynb
@@ -30,8 +30,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "include(\"../examples/profiling_utilities.jl\")\n",
-    "@ProfileView.profview reduce_all!(graph)"
+    "@ProfileView.profview optimize_to_fixpoint!(ReductionOptimizer(), graph)"
    ]
   },
   {
diff --git a/src/MetagraphOptimization.jl b/src/MetagraphOptimization.jl
index ba02314..385b8fb 100644
--- a/src/MetagraphOptimization.jl
+++ b/src/MetagraphOptimization.jl
@@ -31,8 +31,10 @@ export children
 export compute
 export data
 export compute_effort
+export task
 export get_properties
 export get_exit_node
+export operation_stack_length
 export is_valid, is_scheduled
 
 # graph operation related
@@ -68,6 +70,11 @@ export get_compute_function
 export cost_type, graph_cost, operation_effect
 export GlobalMetricEstimator, CDCost
 
+# optimization
+export AbstractOptimizer, GreedyOptimizer, ReductionOptimizer, RandomWalkOptimizer
+export optimize_step!, optimize!
+export fixpoint_reached, optimize_to_fixpoint!
+
 # machine info
 export Machine
 export get_machine_info
@@ -117,6 +124,7 @@ include("node/properties.jl")
 include("node/validate.jl")
 
 include("operation/utility.jl")
+include("operation/iterate.jl")
 include("operation/apply.jl")
 include("operation/clean.jl")
 include("operation/find.jl")
@@ -136,6 +144,11 @@ include("task/properties.jl")
 include("estimator/interface.jl")
 include("estimator/global_metric.jl")
 
+include("optimization/interface.jl")
+include("optimization/greedy.jl")
+include("optimization/random_walk.jl")
+include("optimization/reduce.jl")
+
 include("models/interface.jl")
 include("models/print.jl")
 
diff --git a/src/estimator/global_metric.jl b/src/estimator/global_metric.jl
index 521a67a..d83a450 100644
--- a/src/estimator/global_metric.jl
+++ b/src/estimator/global_metric.jl
@@ -29,9 +29,21 @@ function -(cost1::CDCost, cost2::CDCost)::CDCost
     return (data = d, computeEffort = ce, computeIntensity = ce / d)::CDCost
 end
 
+function isless(cost1::CDCost, cost2::CDCost)::Bool
+    return cost1.data + cost1.computeEffort < cost2.data + cost2.computeEffort
+end
+
+function zero(type::Type{CDCost})
+    return (data = 0.0, computeEffort = 00.0, computeIntensity = 0.0)::CDCost
+end
+
+function typemax(type::Type{CDCost})
+    return (data = Inf, computeEffort = Inf, computeIntensity = 0.0)::CDCost
+end
+
 struct GlobalMetricEstimator <: AbstractEstimator end
 
-function cost_type(estimator::GlobalMetricEstimator)
+function cost_type(estimator::GlobalMetricEstimator)::Type{CDCost}
     return CDCost
 end
 
@@ -51,15 +63,15 @@ end
 function operation_effect(estimator::GlobalMetricEstimator, graph::DAG, operation::NodeReduction)
     s = length(operation.input) - 1
     return (
-        data = s * -data(operation.input[1].task),
-        computeEffort = s * -compute_effort(operation.input[1].task),
+        data = s * -data(task(operation.input[1])),
+        computeEffort = s * -compute_effort(task(operation.input[1])),
         computeIntensity = typeof(operation.input) <: DataTaskNode ? 0.0 : Inf,
     )::CDCost
 end
 
 function operation_effect(estimator::GlobalMetricEstimator, graph::DAG, operation::NodeSplit)
-    s = length(operation.input.parents) - 1
-    d = s * data(operation.input.task)
-    ce = s * compute_effort(operation.input.task)
+    s::Float64 = length(parents(operation.input)) - 1
+    d::Float64 = s * data(task(operation.input))
+    ce::Float64 = s * compute_effort(task(operation.input))
     return (data = d, computeEffort = ce, computeIntensity = ce / d)::CDCost
 end
diff --git a/src/estimator/interface.jl b/src/estimator/interface.jl
index 52a3de1..e0cc5ff 100644
--- a/src/estimator/interface.jl
+++ b/src/estimator/interface.jl
@@ -20,7 +20,7 @@ function cost_type end
 """
     graph_cost(estimator::AbstractEstimator, graph::DAG)
 
-Get the total estimated cost of the graph. The cost's data type can be chosen by the implementation, but should have usable comparison operators (<, <=, >, >=, ==) and basic math operators (+, -, *, /).
+Get the total estimated cost of the graph. The cost's data type can be chosen by the implementation, but must have a usable lessthan comparison operator (<), basic math operators (+, -) and an implementation of `zero()` and `typemax()`.
 """
 function graph_cost end
 
diff --git a/src/graph/compare.jl b/src/graph/compare.jl
index 0845de3..7b4f206 100644
--- a/src/graph/compare.jl
+++ b/src/graph/compare.jl
@@ -17,21 +17,5 @@ function in(edge::Edge, graph::DAG)
         return false
     end
 
-    return n1 in n2.children
-end
-
-"""
-    ==(n1::Node, n2::Node, g::DAG)
-
-Check equality of two nodes in a graph.
-"""
-function ==(n1::Node, n2::Node, g::DAG)
-    if typeof(n1) != typeof(n2)
-        return false
-    end
-    if !(n1 in g) || !(n2 in g)
-        return false
-    end
-
-    return n1.task == n2.task && children(n1) == children(n2)
+    return n1 in children(n2)
 end
diff --git a/src/graph/mute.jl b/src/graph/mute.jl
index d23611f..4c6e0af 100644
--- a/src/graph/mute.jl
+++ b/src/graph/mute.jl
@@ -46,7 +46,7 @@ Insert the edge between node1 (child) and node2 (parent) into the graph.
 See also: [`insert_node!`](@ref), [`remove_node!`](@ref), [`remove_edge!`](@ref)
 """
 function insert_edge!(graph::DAG, node1::Node, node2::Node; track = true, invalidate_cache = true)
-    @assert (node2 ∉ node1.parents) && (node1 ∉ node2.children) "Edge to insert already exists"
+    #@assert (node2 ∉ parents(node1)) && (node1 ∉ children(node2)) "Edge to insert already exists"
 
     # 1: mute
     # edge points from child to parent
@@ -85,7 +85,7 @@ Remove the node from the graph.
 See also: [`insert_node!`](@ref), [`insert_edge!`](@ref), [`remove_edge!`](@ref)
 """
 function remove_node!(graph::DAG, node::Node; track = true, invalidate_cache = true)
-    @assert node in graph.nodes "Trying to remove a node that's not in the graph"
+    #@assert node in graph.nodes "Trying to remove a node that's not in the graph"
 
     # 1: mute
     delete!(graph.nodes, node)
@@ -124,18 +124,29 @@ function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invali
     pre_length2 = length(node2.children)
 
     #TODO: filter is very slow
-    filter!(x -> x != node2, node1.parents)
-    filter!(x -> x != node1, node2.children)
+    for i in eachindex(node1.parents)
+        if (node1.parents[i] == node2)
+            splice!(node1.parents, i)
+            break
+        end
+    end
 
-    @assert begin
+    for i in eachindex(node2.children)
+        if (node2.children[i] == node1)
+            splice!(node2.children, i)
+            break
+        end
+    end
+
+    #=@assert begin
         removed = pre_length1 - length(node1.parents)
         removed <= 1
-    end "removed more than one node from node1's parents"
+    end "removed more than one node from node1's parents"=#
 
-    @assert begin
-        removed = pre_length2 - length(node2.children)
+    #=@assert begin
+        removed = pre_length2 - length(children(node2))
         removed <= 1
-    end "removed more than one node from node2's children"
+    end "removed more than one node from node2's children"=#
 
     # 2: keep track
     if (track)
@@ -163,7 +174,7 @@ function replace_children!(task::FusedComputeTask, before, after)
     replacedIn1 = length(findall(x -> x == before, task.t1_inputs))
     replacedIn2 = length(findall(x -> x == before, task.t2_inputs))
 
-    @assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
+    #@assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
 
     replace!(task.t1_inputs, before => after)
     replace!(task.t2_inputs, before => after)
@@ -185,33 +196,33 @@ end
 
 function update_child!(graph::DAG, n::Node, child_before::Symbol, child_after::Symbol; track = true)
     # only need to update fused compute tasks
-    if !(typeof(n.task) <: FusedComputeTask)
+    if !(typeof(task(n)) <: FusedComputeTask)
         return nothing
     end
 
-    taskBefore = copy(n.task)
+    taskBefore = copy(task(n))
 
-    if !((child_before in n.task.t1_inputs) || (child_before in n.task.t2_inputs))
+    #=if !((child_before in task(n).t1_inputs) || (child_before in task(n).t2_inputs))
         println("------------------ Nothing to replace!! ------------------")
         child_ids = Vector{String}()
-        for child in n.children
+        for child in children(n)
             push!(child_ids, "$(child.id)")
         end
         println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
         @assert false
-    end
+    end=#
 
-    replace_children!(n.task, child_before, child_after)
+    replace_children!(task(n), child_before, child_after)
 
-    if !((child_after in n.task.t1_inputs) || (child_after in n.task.t2_inputs))
+    #=if !((child_after in task(n).t1_inputs) || (child_after in task(n).t2_inputs))
         println("------------------ Did not replace anything!! ------------------")
         child_ids = Vector{String}()
-        for child in n.children
+        for child in children(n)
             push!(child_ids, "$(child.id)")
         end
         println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
         @assert false
-    end
+    end=#
 
     # keep track
     if (track)
@@ -242,8 +253,14 @@ function invalidate_caches!(graph::DAG, operation::NodeFusion)
 
     # delete the operation from all caches of nodes involved in the operation
     # TODO: filter is very slow
-    filter!(!=(operation), operation.input[1].nodeFusions)
-    filter!(!=(operation), operation.input[3].nodeFusions)
+    for n in [1, 3]
+        for i in eachindex(operation.input[n].nodeFusions)
+            if operation == operation.input[n].nodeFusions[i]
+                splice!(operation.input[n].nodeFusions, i)
+                break
+            end
+        end
+    end
 
     operation.input[2].nodeFusion = missing
 
diff --git a/src/graph/print.jl b/src/graph/print.jl
index 5b130e7..e452749 100644
--- a/src/graph/print.jl
+++ b/src/graph/print.jl
@@ -30,10 +30,10 @@ function show(io::IO, graph::DAG)
     nodeDict = Dict{Type, Int64}()
     noEdges = 0
     for node in graph.nodes
-        if haskey(nodeDict, typeof(node.task))
-            nodeDict[typeof(node.task)] = nodeDict[typeof(node.task)] + 1
+        if haskey(nodeDict, typeof(task(node)))
+            nodeDict[typeof(task(node))] = nodeDict[typeof(task(node))] + 1
         else
-            nodeDict[typeof(node.task)] = 1
+            nodeDict[typeof(task(node))] = 1
         end
         noEdges += length(parents(node))
     end
diff --git a/src/graph/properties.jl b/src/graph/properties.jl
index 7458c13..394ddec 100644
--- a/src/graph/properties.jl
+++ b/src/graph/properties.jl
@@ -43,3 +43,12 @@ function get_entry_nodes(graph::DAG)
     end
     return result
 end
+
+"""
+    operation_stack_length(graph::DAG)
+
+Return the number of operations applied to the graph.
+"""
+function operation_stack_length(graph::DAG)
+    return length(graph.appliedOperations) + length(graph.operationsToApply)
+end
diff --git a/src/graph/type.jl b/src/graph/type.jl
index 6aa8585..e895b36 100644
--- a/src/graph/type.jl
+++ b/src/graph/type.jl
@@ -24,7 +24,7 @@ To get the set of possible operations, use [`get_operations`](@ref).
 The members of the object should not be manually accessed, instead always use the provided interface functions.
 """
 mutable struct DAG
-    nodes::Set{Node}
+    nodes::Set{Union{DataTaskNode, ComputeTaskNode}}
 
     # The operations currently applied to the set of nodes
     appliedOperations::Stack{AppliedOperation}
@@ -36,7 +36,7 @@ mutable struct DAG
     possibleOperations::PossibleOperations
 
     # The set of nodes whose possible operations need to be reevaluated
-    dirtyNodes::Set{Node}
+    dirtyNodes::Set{Union{DataTaskNode, ComputeTaskNode}}
 
     # "snapshot" system: keep track of added/removed nodes/edges since last snapshot
     # these are muted in insert_node! etc.
diff --git a/src/models/abc/parse.jl b/src/models/abc/parse.jl
index 5ef890b..05d2855 100644
--- a/src/models/abc/parse.jl
+++ b/src/models/abc/parse.jl
@@ -181,7 +181,7 @@ function parse_dag(filename::AbstractString, model::ABCModel, verbose::Bool = fa
             insert_edge!(graph, compute_S2, data_out, track = false, invalidate_cache = false)
 
             insert_edge!(graph, data_out, sum_node, track = false, invalidate_cache = false)
-            add_child!(sum_node.task)
+            add_child!(task(sum_node))
         elseif occursin(regex_plus, node)
             if (verbose)
                 println("\rReading Nodes Complete    ")
diff --git a/src/models/abc/properties.jl b/src/models/abc/properties.jl
index 7e321d2..ca9bdda 100644
--- a/src/models/abc/properties.jl
+++ b/src/models/abc/properties.jl
@@ -3,35 +3,35 @@
 
 Return the compute effort of an S1 task.
 """
-compute_effort(t::ComputeTaskS1) = 11.0
+compute_effort(t::ComputeTaskS1)::Float64 = 11.0
 
 """
     compute_effort(t::ComputeTaskS2)
 
 Return the compute effort of an S2 task.
 """
-compute_effort(t::ComputeTaskS2) = 12.0
+compute_effort(t::ComputeTaskS2)::Float64 = 12.0
 
 """
     compute_effort(t::ComputeTaskU)
 
 Return the compute effort of a U task.
 """
-compute_effort(t::ComputeTaskU) = 1.0
+compute_effort(t::ComputeTaskU)::Float64 = 1.0
 
 """
     compute_effort(t::ComputeTaskV)
 
 Return the compute effort of a V task.
 """
-compute_effort(t::ComputeTaskV) = 6.0
+compute_effort(t::ComputeTaskV)::Float64 = 6.0
 
 """
     compute_effort(t::ComputeTaskP)
 
 Return the compute effort of a P task.
 """
-compute_effort(t::ComputeTaskP) = 0.0
+compute_effort(t::ComputeTaskP)::Float64 = 0.0
 
 """
     compute_effort(t::ComputeTaskSum)
@@ -41,7 +41,7 @@ Return the compute effort of a Sum task.
 Note: This is a constant compute effort, even though sum scales with the number of its inputs. Since there is only ever a single sum node in a graph generated from the ABC-Model,
 this doesn't matter.
 """
-compute_effort(t::ComputeTaskSum) = 1.0
+compute_effort(t::ComputeTaskSum)::Float64 = 1.0
 
 """
     show(io::IO, t::DataTask)
diff --git a/src/node/compare.jl b/src/node/compare.jl
index d2e2c04..c84e0f4 100644
--- a/src/node/compare.jl
+++ b/src/node/compare.jl
@@ -21,7 +21,7 @@ end
 
 Equality comparison between two [`ComputeTaskNode`](@ref)s.
 """
-function ==(n1::ComputeTaskNode, n2::ComputeTaskNode)
+function ==(n1::ComputeTaskNode{TaskType}, n2::ComputeTaskNode{TaskType}) where {TaskType <: AbstractComputeTask}
     return n1.id == n2.id
 end
 
@@ -30,6 +30,6 @@ end
 
 Equality comparison between two [`DataTaskNode`](@ref)s.
 """
-function ==(n1::DataTaskNode, n2::DataTaskNode)
+function ==(n1::DataTaskNode{TaskType}, n2::DataTaskNode{TaskType}) where {TaskType <: AbstractDataTask}
     return n1.id == n2.id
 end
diff --git a/src/node/create.jl b/src/node/create.jl
index 84a8501..0b69885 100644
--- a/src/node/create.jl
+++ b/src/node/create.jl
@@ -13,8 +13,8 @@ ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
 )
 
 copy(m::Missing) = missing
-copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task))
-copy(n::DataTaskNode) = DataTaskNode(copy(n.task), n.name)
+copy(n::ComputeTaskNode) = ComputeTaskNode(copy(task(n)))
+copy(n::DataTaskNode) = DataTaskNode(copy(task(n)), n.name)
 
 """
     make_node(t::AbstractTask)
diff --git a/src/node/print.jl b/src/node/print.jl
index 61200a9..01b6d05 100644
--- a/src/node/print.jl
+++ b/src/node/print.jl
@@ -4,7 +4,7 @@
 Print a short string representation of the node to io.
 """
 function show(io::IO, n::Node)
-    return print(io, "Node(", n.task, ")")
+    return print(io, "Node(", task(n), ")")
 end
 
 """
diff --git a/src/node/properties.jl b/src/node/properties.jl
index b28a234..e2de923 100644
--- a/src/node/properties.jl
+++ b/src/node/properties.jl
@@ -3,25 +3,27 @@
 
 Return whether this node is an entry node in its graph, i.e., it has no children.
 """
-is_entry_node(node::Node) = length(node.children) == 0
+is_entry_node(node::Node) = length(children(node)) == 0
 
 """
     is_exit_node(node::Node)
 
 Return whether this node is an exit node of its graph, i.e., it has no parents.
 """
-is_exit_node(node::Node) = length(node.parents) == 0
+is_exit_node(node::Node)::Bool = length(parents(node)) == 0
 
 """
-    data(edge::Edge)
+    task(node::Node)
 
-Return the data transfered by this edge, i.e., 0 if the child is a [`ComputeTaskNode`](@ref), otherwise the child's `data()`.
+Return the node's task.
 """
-function data(edge::Edge)
-    if typeof(edge.edge[1]) <: DataTaskNode
-        return data(edge.edge[1].task)
-    end
-    return 0.0
+function task(node::DataTaskNode{TaskType})::TaskType where {TaskType <: Union{AbstractDataTask, AbstractComputeTask}}
+    return node.task
+end
+function task(
+    node::ComputeTaskNode{TaskType},
+)::TaskType where {TaskType <: Union{AbstractDataTask, AbstractComputeTask}}
+    return node.task
 end
 
 """
@@ -31,8 +33,11 @@ Return a copy of the node's children so it can safely be muted without changing
 
 A node's children are its prerequisite nodes, nodes that need to execute before the task of this node.
 """
-function children(node::Node)
-    return copy(node.children)
+function children(node::DataTaskNode)::Vector{ComputeTaskNode}
+    return node.children
+end
+function children(node::ComputeTaskNode)::Vector{DataTaskNode}
+    return node.children
 end
 
 """
@@ -42,8 +47,11 @@ Return a copy of the node's parents so it can safely be muted without changing t
 
 A node's parents are its subsequent nodes, nodes that need this node to execute.
 """
-function parents(node::Node)
-    return copy(node.parents)
+function parents(node::DataTaskNode)::Vector{ComputeTaskNode}
+    return node.parents
+end
+function parents(node::ComputeTaskNode)::Vector{DataTaskNode}
+    return node.parents
 end
 
 """
@@ -53,11 +61,11 @@ Return a vector of all siblings of this node.
 
 A node's siblings are all children of any of its parents. The result contains no duplicates and includes the node itself.
 """
-function siblings(node::Node)
+function siblings(node::Node)::Set{Node}
     result = Set{Node}()
     push!(result, node)
-    for parent in node.parents
-        union!(result, parent.children)
+    for parent in parents(node)
+        union!(result, children(parent))
     end
 
     return result
@@ -73,11 +81,11 @@ A node's partners are all parents of any of its children. The result contains no
 Note: This is very slow when there are multiple children with many parents. 
 This is less of a problem in [`siblings(node::Node)`](@ref) because (depending on the model) there are no nodes with a large number of children, or only a single one.
 """
-function partners(node::Node)
+function partners(node::Node)::Set{Node}
     result = Set{Node}()
     push!(result, node)
-    for child in node.children
-        union!(result, child.parents)
+    for child in children(node)
+        union!(result, parents(child))
     end
 
     return result
@@ -90,8 +98,8 @@ Alternative version to [`partners(node::Node)`](@ref), avoiding allocation of a
 """
 function partners(node::Node, set::Set{Node})
     push!(set, node)
-    for child in node.children
-        union!(set, child.parents)
+    for child in children(node)
+        union!(set, parents(child))
     end
     return nothing
 end
@@ -101,8 +109,8 @@ end
 
 Return whether the `potential_parent` is a parent of `node`.
 """
-function is_parent(potential_parent::Node, node::Node)
-    return potential_parent in node.parents
+function is_parent(potential_parent::Node, node::Node)::Bool
+    return potential_parent in parents(node)
 end
 
 """
@@ -110,6 +118,6 @@ end
 
 Return whether the `potential_child` is a child of `node`.
 """
-function is_child(potential_child::Node, node::Node)
-    return potential_child in node.children
+function is_child(potential_child::Node, node::Node)::Bool
+    return potential_child in children(node)
 end
diff --git a/src/node/type.jl b/src/node/type.jl
index 980962a..9105424 100644
--- a/src/node/type.jl
+++ b/src/node/type.jl
@@ -33,8 +33,8 @@ Any node that transfers data and does no computation.
 `.nodeFusion`:      Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.\\
 `.name`:            The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\
 """
-mutable struct DataTaskNode <: Node
-    task::AbstractDataTask
+mutable struct DataTaskNode{TaskType <: AbstractDataTask} <: Node
+    task::TaskType
 
     # use vectors as sets have way too much memory overhead
     parents::Vector{Node}
@@ -73,8 +73,8 @@ Any node that computes a result from inputs using an [`AbstractComputeTask`](@re
 `.nodeFusions`:     A vector of this node's [`NodeFusion`](@ref)s. For a `ComputeTaskNode` there can be any number of these, unlike the [`DataTaskNode`](@ref)s.\\
 `.device`:          The Device this node has been scheduled on by a [`Scheduler`](@ref).
 """
-mutable struct ComputeTaskNode <: Node
-    task::AbstractComputeTask
+mutable struct ComputeTaskNode{TaskType <: AbstractComputeTask} <: Node
+    task::TaskType
     parents::Vector{Node}
     children::Vector{Node}
     id::Base.UUID
@@ -83,7 +83,7 @@ mutable struct ComputeTaskNode <: Node
     nodeSplit::Union{Operation, Missing}
 
     # for ComputeTasks there can be multiple fusions, unlike the DataTasks
-    nodeFusions::Vector{Operation}
+    nodeFusions::Vector{<:Operation}
 
     # the device this node is assigned to execute on
     device::Union{AbstractDevice, Missing}
diff --git a/src/node/validate.jl b/src/node/validate.jl
index d7ad4dd..6b4fb87 100644
--- a/src/node/validate.jl
+++ b/src/node/validate.jl
@@ -29,7 +29,7 @@ function is_valid_node(graph::DAG, node::Node)
         @assert is_valid(graph, node.nodeSplit)
     end=#
 
-    if !(typeof(node.task) <: FusedComputeTask)
+    if !(typeof(task(node)) <: FusedComputeTask)
         # the remaining checks are only necessary for fused compute tasks
         return true
     end
@@ -37,7 +37,7 @@ function is_valid_node(graph::DAG, node::Node)
     # every child must be in some input of the task
     for child in node.children
         str = Symbol(to_var_name(child.id))
-        @assert (str in node.task.t1_inputs) || (str in node.task.t2_inputs) "$str was not in any of the tasks' inputs\nt1_inputs: $(node.task.t1_inputs)\nt2_inputs: $(node.task.t2_inputs)"
+        @assert (str in task(node).t1_inputs) || (str in task(node).t2_inputs) "$str was not in any of the tasks' inputs\nt1_inputs: $(task(node).t1_inputs)\nt2_inputs: $(task(node).t2_inputs)"
     end
 
     return true
diff --git a/src/operation/apply.jl b/src/operation/apply.jl
index dfe9a0b..164b67d 100644
--- a/src/operation/apply.jl
+++ b/src/operation/apply.jl
@@ -132,11 +132,11 @@ function revert_diff!(graph::DAG, diff::Diff)
         insert_edge!(graph, edge.edge[1], edge.edge[2], track = false)
     end
 
-    for (node, task) in diff.updatedChildren
+    for (node, t) in diff.updatedChildren
         # node must be fused compute task at this point
-        @assert typeof(node.task) <: FusedComputeTask
+        @assert typeof(task(node)) <: FusedComputeTask
 
-        node.task = task
+        node.task = t
     end
 
     graph.properties -= GraphProperties(diff)
@@ -158,11 +158,11 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com
     get_snapshot_diff(graph)
 
     # save children and parents
-    n1Children = children(n1)
-    n3Parents = parents(n3)
+    n1Children = copy(children(n1))
+    n3Parents = copy(parents(n3))
 
-    n1Task = copy(n1.task)
-    n3Task = copy(n3.task)
+    n1Task = copy(task(n1))
+    n3Task = copy(task(n3))
 
     # assemble the input node vectors of n1 and n3 to save into the FusedComputeTask
     n1Inputs = Vector{Symbol}()
@@ -177,7 +177,7 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com
     remove_node!(graph, n2)
 
     # get n3's children now so it automatically excludes n2
-    n3Children = children(n3)
+    n3Children = copy(children(n3))
 
     n3Inputs = Vector{Symbol}()
     for child in n3Children
@@ -228,7 +228,7 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
     get_snapshot_diff(graph)
 
     n1 = nodes[1]
-    n1Children = children(n1)
+    n1Children = copy(children(n1))
 
     n1Parents = Set(n1.parents)
 
@@ -245,7 +245,7 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
             remove_edge!(graph, child, n)
         end
 
-        for parent in parents(n)
+        for parent in copy(parents(n))
             remove_edge!(graph, n, parent)
 
             # collect all parents
@@ -278,14 +278,17 @@ Split the given node into one node per parent, return the applied difference to
 
 For details see [`NodeSplit`](@ref).
 """
-function node_split!(graph::DAG, n1::Node)
+function node_split!(
+    graph::DAG,
+    n1::Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}},
+) where {TaskType <: AbstractTask}
     @assert is_valid_node_split_input(graph, n1)
 
     # clear snapshot
     get_snapshot_diff(graph)
 
-    n1Parents = parents(n1)
-    n1Children = children(n1)
+    n1Parents = copy(parents(n1))
+    n1Children = copy(children(n1))
 
     for parent in n1Parents
         remove_edge!(graph, n1, parent)
diff --git a/src/operation/clean.jl b/src/operation/clean.jl
index 70ebbd3..2a7110d 100644
--- a/src/operation/clean.jl
+++ b/src/operation/clean.jl
@@ -13,18 +13,18 @@ function find_fusions!(graph::DAG, node::DataTaskNode)
         return nothing
     end
 
-    if length(node.parents) != 1 || length(node.children) != 1
+    if length(parents(node)) != 1 || length(children(node)) != 1
         return nothing
     end
 
-    child_node = first(node.children)
-    parent_node = first(node.parents)
+    child_node = first(children(node))
+    parent_node = first(parents(node))
 
     if !(child_node in graph) || !(parent_node in graph)
         error("Parents/Children that are not in the graph!!!")
     end
 
-    if length(child_node.parents) != 1
+    if length(parents(child_node)) != 1
         return nothing
     end
 
@@ -44,11 +44,11 @@ Find node fusions involving the given compute node. The function pushes the foun
 """
 function find_fusions!(graph::DAG, node::ComputeTaskNode)
     # just find fusions in neighbouring DataTaskNodes
-    for child in node.children
+    for child in children(node)
         find_fusions!(graph, child)
     end
 
-    for parent in node.parents
+    for parent in parents(node)
         find_fusions!(graph, parent)
     end
 
@@ -123,7 +123,10 @@ end
 
 Sort this node's parent and child sets, then find fusions, reductions and splits involving it. Needs to be called after the node was changed in some way.
 """
-function clean_node!(graph::DAG, node::Node)
+function clean_node!(
+    graph::DAG,
+    node::Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}},
+) where {TaskType <: AbstractTask}
     sort_node!(node)
 
     find_fusions!(graph, node)
diff --git a/src/operation/find.jl b/src/operation/find.jl
index f6d6218..141443b 100644
--- a/src/operation/find.jl
+++ b/src/operation/find.jl
@@ -203,18 +203,18 @@ function generate_operations(graph::DAG)
     # --- find possible node fusions ---
     @threads for node in nodeArray
         if (typeof(node) <: DataTaskNode)
-            if length(node.parents) != 1
+            if length(parents(node)) != 1
                 # data node can only have a single parent
                 continue
             end
-            parent_node = first(node.parents)
+            parent_node = first(parents(node))
 
-            if length(node.children) != 1
+            if length(children(node)) != 1
                 # this node is an entry node or has multiple children which should not be possible
                 continue
             end
-            child_node = first(node.children)
-            if (length(child_node.parents) != 1)
+            child_node = first(children(node))
+            if (length(parents(child_node)) != 1)
                 continue
             end
 
diff --git a/src/operation/get.jl b/src/operation/get.jl
index bb8653a..3294459 100644
--- a/src/operation/get.jl
+++ b/src/operation/get.jl
@@ -14,9 +14,7 @@ function get_operations(graph::DAG)
         generate_operations(graph)
     end
 
-    for node in graph.dirtyNodes
-        clean_node!(graph, node)
-    end
+    clean_node!.(Ref(graph), graph.dirtyNodes)
     empty!(graph.dirtyNodes)
 
     return graph.possibleOperations
diff --git a/src/operation/iterate.jl b/src/operation/iterate.jl
new file mode 100644
index 0000000..a52ea06
--- /dev/null
+++ b/src/operation/iterate.jl
@@ -0,0 +1,39 @@
+import Base.iterate
+
+const _POSSIBLE_OPERATIONS_FIELDS = fieldnames(PossibleOperations)
+
+_POIteratorStateType =
+    NamedTuple{(:result, :state), Tuple{Union{NodeFusion, NodeReduction, NodeSplit}, Tuple{Symbol, Int64}}}
+
+@inline function iterate(possibleOperations::PossibleOperations)::Union{Nothing, _POIteratorStateType}
+    for fieldname in _POSSIBLE_OPERATIONS_FIELDS
+        iterator = iterate(getfield(possibleOperations, fieldname))
+        if (!isnothing(iterator))
+            return (result = iterator[1], state = (fieldname, iterator[2]))
+        end
+    end
+
+    return nothing
+end
+
+@inline function iterate(possibleOperations::PossibleOperations, state)::Union{Nothing, _POIteratorStateType}
+    newStateSym = state[1]
+    newStateIt = iterate(getfield(possibleOperations, newStateSym), state[2])
+    if !isnothing(newStateIt)
+        return (result = newStateIt[1], state = (newStateSym, newStateIt[2]))
+    end
+
+    # cycle to next field
+    index = findfirst(x -> x == newStateSym, _POSSIBLE_OPERATIONS_FIELDS) + 1
+
+    while index <= length(_POSSIBLE_OPERATIONS_FIELDS)
+        newStateSym = _POSSIBLE_OPERATIONS_FIELDS[index]
+        newStateIt = iterate(getfield(possibleOperations, newStateSym))
+        if !isnothing(newStateIt)
+            return (result = newStateIt[1], state = (newStateSym, newStateIt[2]))
+        end
+        index += 1
+    end
+
+    return nothing
+end
diff --git a/src/operation/print.jl b/src/operation/print.jl
index 61239be..d4a1acd 100644
--- a/src/operation/print.jl
+++ b/src/operation/print.jl
@@ -30,7 +30,7 @@ function show(io::IO, op::NodeReduction)
     print(io, "NR: ")
     print(io, length(op.input))
     print(io, "x")
-    return print(io, op.input[1].task)
+    return print(io, task(op.input[1]))
 end
 
 """
@@ -40,7 +40,7 @@ Print a string representation of the node split to io.
 """
 function show(io::IO, op::NodeSplit)
     print(io, "NS: ")
-    return print(io, op.input.task)
+    return print(io, task(op.input))
 end
 
 """
@@ -50,9 +50,9 @@ Print a string representation of the node fusion to io.
 """
 function show(io::IO, op::NodeFusion)
     print(io, "NF: ")
-    print(io, op.input[1].task)
+    print(io, task(op.input[1]))
     print(io, "->")
-    print(io, op.input[2].task)
+    print(io, task(op.input[2]))
     print(io, "->")
-    return print(io, op.input[3].task)
+    return print(io, task(op.input[3]))
 end
diff --git a/src/operation/type.jl b/src/operation/type.jl
index edc67c3..606b101 100644
--- a/src/operation/type.jl
+++ b/src/operation/type.jl
@@ -40,8 +40,9 @@ A chain of (n1, n2, n3) can be fused if:
 
 See also: [`can_fuse`](@ref)
 """
-struct NodeFusion <: Operation
-    input::Tuple{ComputeTaskNode, DataTaskNode, ComputeTaskNode}
+struct NodeFusion{TaskType1 <: AbstractComputeTask, TaskType2 <: AbstractDataTask, TaskType3 <: AbstractComputeTask} <:
+       Operation
+    input::Tuple{ComputeTaskNode{TaskType1}, DataTaskNode{TaskType2}, ComputeTaskNode{TaskType3}}
 end
 
 """
@@ -49,8 +50,12 @@ end
 
 The applied version of the [`NodeFusion`](@ref).
 """
-struct AppliedNodeFusion <: AppliedOperation
-    operation::NodeFusion
+struct AppliedNodeFusion{
+    TaskType1 <: AbstractComputeTask,
+    TaskType2 <: AbstractDataTask,
+    TaskType3 <: AbstractComputeTask,
+} <: AppliedOperation
+    operation::NodeFusion{TaskType1, TaskType2, TaskType3}
     diff::Diff
 end
 
@@ -73,8 +78,8 @@ A vector of nodes can be reduced if:
 
 See also: [`can_reduce`](@ref)
 """
-struct NodeReduction <: Operation
-    input::Vector{Node}
+struct NodeReduction{NodeType <: Node} <: Operation
+    input::Vector{NodeType}
 end
 
 """
@@ -82,8 +87,8 @@ end
 
 The applied version of the [`NodeReduction`](@ref).
 """
-struct AppliedNodeReduction <: AppliedOperation
-    operation::NodeReduction
+struct AppliedNodeReduction{NodeType <: Node} <: AppliedOperation
+    operation::NodeReduction{NodeType}
     diff::Diff
 end
 
@@ -102,8 +107,8 @@ A node can be split if:
 
 See also: [`can_split`](@ref)
 """
-struct NodeSplit <: Operation
-    input::Node
+struct NodeSplit{NodeType <: Node} <: Operation
+    input::NodeType
 end
 
 """
@@ -111,7 +116,7 @@ end
 
 The applied version of the [`NodeSplit`](@ref).
 """
-struct AppliedNodeSplit <: AppliedOperation
-    operation::NodeSplit
+struct AppliedNodeSplit{NodeType <: Node} <: AppliedOperation
+    operation::NodeSplit{NodeType}
     diff::Diff
 end
diff --git a/src/operation/utility.jl b/src/operation/utility.jl
index b7f874a..0ccafb3 100644
--- a/src/operation/utility.jl
+++ b/src/operation/utility.jl
@@ -61,7 +61,7 @@ function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
         return false
     end
 
-    if length(n2.parents) != 1 || length(n2.children) != 1 || length(n1.parents) != 1
+    if length(parents(n2)) != 1 || length(children(n2)) != 1 || length(parents(n1)) != 1
         return false
     end
 
@@ -74,12 +74,15 @@ end
 Return whether the given two nodes can be reduced. See [`NodeReduction`](@ref) for the requirements.
 """
 function can_reduce(n1::Node, n2::Node)
-    if (n1.task != n2.task)
-        return false
-    end
+    return false
+end
 
-    n1_length = length(n1.children)
-    n2_length = length(n2.children)
+function can_reduce(
+    n1::NodeType,
+    n2::NodeType,
+) where {TaskType <: AbstractTask, NodeType <: Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}}}
+    n1_length = length(children(n1))
+    n2_length = length(children(n2))
 
     if (n1_length != n2_length)
         return false
@@ -88,19 +91,19 @@ function can_reduce(n1::Node, n2::Node)
     # this seems to be the most common case so do this first
     # doing it manually is a lot faster than using the sets for a general solution
     if (n1_length == 2)
-        if (n1.children[1] != n2.children[1])
-            if (n1.children[1] != n2.children[2])
+        if (children(n1)[1] != children(n2)[1])
+            if (children(n1)[1] != children(n2)[2])
                 return false
             end
             # 1_1 == 2_2
-            if (n1.children[2] != n2.children[1])
+            if (children(n1)[2] != children(n2)[1])
                 return false
             end
             return true
         end
 
         # 1_1 == 2_1
-        if (n1.children[2] != n2.children[2])
+        if (children(n1)[2] != children(n2)[2])
             return false
         end
         return true
@@ -108,11 +111,11 @@ function can_reduce(n1::Node, n2::Node)
 
     # this is simple
     if (n1_length == 1)
-        return n1.children[1] == n2.children[1]
+        return children(n1)[1] == children(n2)[1]
     end
 
     # this takes a long time
-    return Set(n1.children) == Set(n2.children)
+    return Set(children(n1)) == Set(children(n2))
 end
 
 """
@@ -138,7 +141,14 @@ end
 
 Equality comparison between two node fusions. Two node fusions are considered equal if they have the same inputs.
 """
-function ==(op1::NodeFusion, op2::NodeFusion)
+function ==(
+    op1::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2},
+    op2::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2},
+) where {
+    ComputeTaskType1 <: AbstractComputeTask,
+    DataTaskType <: AbstractDataTask,
+    ComputeTaskType2 <: AbstractComputeTask,
+}
     # there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same
     return op1.input[2] == op2.input[2]
 end
diff --git a/src/operation/validate.jl b/src/operation/validate.jl
index 0fe3218..ede35f3 100644
--- a/src/operation/validate.jl
+++ b/src/operation/validate.jl
@@ -54,9 +54,9 @@ function is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node})
         @assert is_valid(graph, n)
     end
 
-    t = typeof(nodes[1].task)
+    t = typeof(task(nodes[1]))
     for n in nodes
-        if typeof(n.task) != t
+        if typeof(task(n)) != t
             throw(AssertionError("[Node Reduction] The given nodes are not of the same type"))
         end
 
@@ -115,7 +115,7 @@ Intended for use with `@assert` or `@test`.
 """
 function is_valid(graph::DAG, nr::NodeReduction)
     @assert is_valid_node_reduction_input(graph, nr.input)
-    @assert nr in graph.possibleOperations.nodeReductions "NodeReduction is not part of the graph's possible operations!"
+    #@assert nr in graph.possibleOperations.nodeReductions "NodeReduction is not part of the graph's possible operations!"
     return true
 end
 
@@ -128,7 +128,7 @@ Intended for use with `@assert` or `@test`.
 """
 function is_valid(graph::DAG, ns::NodeSplit)
     @assert is_valid_node_split_input(graph, ns.input)
-    @assert ns in graph.possibleOperations.nodeSplits "NodeSplit is not part of the graph's possible operations!"
+    #@assert ns in graph.possibleOperations.nodeSplits "NodeSplit is not part of the graph's possible operations!"
     return true
 end
 
@@ -141,6 +141,6 @@ Intended for use with `@assert` or `@test`.
 """
 function is_valid(graph::DAG, nf::NodeFusion)
     @assert is_valid_node_fusion_input(graph, nf.input[1], nf.input[2], nf.input[3])
-    @assert nf in graph.possibleOperations.nodeFusions "NodeFusion is not part of the graph's possible operations!"
+    #@assert nf in graph.possibleOperations.nodeFusions "NodeFusion is not part of the graph's possible operations!"
     return true
 end
diff --git a/src/optimization/greedy.jl b/src/optimization/greedy.jl
index 07808fa..20ef336 100644
--- a/src/optimization/greedy.jl
+++ b/src/optimization/greedy.jl
@@ -2,7 +2,72 @@
     GreedyOptimizer
 
 An implementation of the greedy optimization algorithm, simply choosing the best next option evaluated with the given estimator.
+
+The fixpoint is reached when any leftover operation would increase the graph's total cost according to the given estimator.
 """
-struct GreedyOptimizer
-    estimator::AbstractEstimator
+struct GreedyOptimizer{EstimatorType <: AbstractEstimator} <: AbstractOptimizer
+    estimator::EstimatorType
+end
+
+function optimize_step!(optimizer::GreedyOptimizer, graph::DAG)
+    # generate all options
+    operations = get_operations(graph)
+    if isempty(operations)
+        return false
+    end
+
+    result = nothing
+
+    lowestCost = reduce(
+        (acc, op) -> begin
+            op_cost = operation_effect(optimizer.estimator, graph, op)
+            if op_cost < acc
+                result = op
+                return op_cost
+            end
+            return acc
+        end,
+        operations;
+        init = typemax(cost_type(optimizer.estimator)),
+    )
+
+    if lowestCost > zero(cost_type(optimizer.estimator))
+        return false
+    end
+
+    push_operation!(graph, result)
+
+    return true
+end
+
+function fixpoint_reached(optimizer::GreedyOptimizer, graph::DAG)
+    # generate all options
+    operations = get_operations(graph)
+    if isempty(operations)
+        return true
+    end
+
+    lowestCost = reduce(
+        (acc, op) -> begin
+            op_cost = operation_effect(optimizer.estimator, graph, op)
+            if op_cost < acc
+                return op_cost
+            end
+            return acc
+        end,
+        operations;
+        init = typemax(cost_type(optimizer.estimator)),
+    )
+
+    if lowestCost > zero(cost_type(optimizer.estimator))
+        return true
+    end
+
+    return false
+end
+
+function optimize_to_fixpoint!(optimizer::GreedyOptimizer, graph::DAG)
+    while optimize_step!(optimizer, graph)
+    end
+    return nothing
 end
diff --git a/src/optimization/interface.jl b/src/optimization/interface.jl
new file mode 100644
index 0000000..0d5f87c
--- /dev/null
+++ b/src/optimization/interface.jl
@@ -0,0 +1,60 @@
+
+"""
+    AbstractOptimizer
+
+Abstract base type for optimizer implementations.
+"""
+abstract type AbstractOptimizer end
+
+"""
+    optimize_step!(optimizer::AbstractOptimizer, graph::DAG)
+
+Interface function that must be implemented by implementations of [`AbstractOptimizer`](@ref). Returns `true` if an operations has been applied, `false` if not, usually when a fixpoint of the algorithm has been reached.
+
+It should do one smallest logical step on the given [`DAG`](@ref), muting the graph and, if necessary, the optimizer's state.
+"""
+function optimize_step! end
+
+"""
+    optimize!(optimizer::AbstractOptimizer, graph::DAG, n::Int)
+
+Function calling the given optimizer `n` times, muting the graph. Returns `true` if the requested number of operations has been applied, `false` if not, usually when a fixpoint of the algorithm has been reached.
+
+If a more efficient method exists, this can be overloaded for a specific optimizer.
+"""
+function optimize!(optimizer::AbstractOptimizer, graph::DAG, n::Int)
+    for i in 1:n
+        if !optimize_step!(optimizer, graph)
+            return false
+        end
+    end
+    return true
+end
+
+"""
+    fixpoint_reached(optimizer::AbstractOptimizer, graph::DAG)
+
+Interface function that can be implemented by optimization algorithms that can reach a fixpoint, returning as a `Bool` whether it has been reached. The default implementation returns `false`.
+
+See also: [`optimize_to_fixpoint!`](@ref)
+"""
+function fixpoint_reached(optimizer::AbstractOptimizer, graph::DAG)
+    return false
+end
+
+"""
+    optimize_to_fixpoint!(optimizer::AbstractOptimizer, graph::DAG)
+
+Interface function that can be implemented by optimization algorithms that can reach a fixpoint. The algorithm will be run until that fixpoint is reached, at which point [`fixpoint_reached`](@ref) should return true.
+
+A usual implementation might look like this:
+```julia
+    function optimize_to_fixpoint!(optimizer::MyOptimizer, graph::DAG)
+        while !fixpoint_reached(optimizer, graph)
+            optimize_step!(optimizer, graph)
+        end
+        return nothing
+    end
+```
+"""
+function optimize_to_fixpoint! end
diff --git a/src/optimization/random_walk.jl b/src/optimization/random_walk.jl
new file mode 100644
index 0000000..e43507e
--- /dev/null
+++ b/src/optimization/random_walk.jl
@@ -0,0 +1,49 @@
+using Random
+
+"""
+    RandomWalkOptimizer
+
+An optimizer that randomly pushes or pops operations. It doesn't optimize in any direction and is useful mainly for testing purposes.
+
+This algorithm never reaches a fixpoint, so it does not implement [`optimize_to_fixpoint`](@ref).
+"""
+struct RandomWalkOptimizer <: AbstractOptimizer
+    rng::AbstractRNG
+end
+
+function optimize_step!(optimizer::RandomWalkOptimizer, graph::DAG)
+    operations = get_operations(graph)
+
+    if sum(length(operations)) == 0 && length(graph.appliedOperations) + length(graph.operationsToApply) == 0
+        # in case there are zero operations possible at all on the graph
+        return false
+    end
+
+    r = optimizer.rng
+    # try until something was applied or popped
+    while true
+        # choose push or pop
+        if rand(r, Bool)
+            # push
+
+            # choose one of fuse/split/reduce
+            option = rand(r, 1:3)
+            if option == 1 && !isempty(operations.nodeFusions)
+                push_operation!(graph, rand(r, collect(operations.nodeFusions)))
+                return true
+            elseif option == 2 && !isempty(operations.nodeReductions)
+                push_operation!(graph, rand(r, collect(operations.nodeReductions)))
+                return true
+            elseif option == 3 && !isempty(operations.nodeSplits)
+                push_operation!(graph, rand(r, collect(operations.nodeSplits)))
+                return true
+            end
+        else
+            # pop
+            if (can_pop(graph))
+                pop_operation!(graph)
+                return true
+            end
+        end
+    end
+end
diff --git a/src/optimization/reduce.jl b/src/optimization/reduce.jl
new file mode 100644
index 0000000..625874a
--- /dev/null
+++ b/src/optimization/reduce.jl
@@ -0,0 +1,30 @@
+"""
+    ReductionOptimizer
+
+An optimizer that simply applies an available [`NodeReduction`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeReduction`](@ref)s in the graph.
+"""
+struct ReductionOptimizer <: AbstractOptimizer end
+
+function optimize_step!(optimizer::ReductionOptimizer, graph::DAG)
+    # generate all options
+    operations = get_operations(graph)
+    if fixpoint_reached(optimizer, graph)
+        return false
+    end
+
+    push_operation!(graph, first(operations.nodeReductions))
+
+    return true
+end
+
+function fixpoint_reached(optimizer::ReductionOptimizer, graph::DAG)
+    operations = get_operations(graph)
+    return isempty(operations.nodeReductions)
+end
+
+function optimize_to_fixpoint!(optimizer::ReductionOptimizer, graph::DAG)
+    while !fixpoint_reached(optimizer, graph)
+        optimize_step!(optimizer, graph)
+    end
+    return nothing
+end
diff --git a/src/properties/create.jl b/src/properties/create.jl
index 6db46f4..218723d 100644
--- a/src/properties/create.jl
+++ b/src/properties/create.jl
@@ -4,14 +4,18 @@
 Create an empty [`GraphProperties`](@ref) object.
 """
 function GraphProperties()
-    return (
-        data = 0.0,
-        computeEffort = 0.0,
-        computeIntensity = 0.0,
-        cost = 0.0,
-        noNodes = 0,
-        noEdges = 0,
-    )::GraphProperties
+    return (data = 0.0, computeEffort = 0.0, computeIntensity = 0.0, noNodes = 0, noEdges = 0)::GraphProperties
+end
+
+@inline function _props(
+    node::DataTaskNode{TaskType},
+)::Tuple{Float64, Float64, Int64} where {TaskType <: AbstractDataTask}
+    return (data(task(node)) * length(parents(node)), 0.0, length(parents(node)))
+end
+@inline function _props(
+    node::ComputeTaskNode{TaskType},
+)::Tuple{Float64, Float64, Int64} where {TaskType <: AbstractComputeTask}
+    return (0.0, compute_effort(task(node)), length(parents(node)))
 end
 
 """
@@ -27,16 +31,16 @@ function GraphProperties(graph::DAG)
     ce = 0.0
     ed = 0
     for node in graph.nodes
-        d += data(node.task) * length(node.parents)
-        ce += compute_effort(node.task)
-        ed += length(node.parents)
+        props = _props(node)
+        d += props[1]
+        ce += props[2]
+        ed += props[3]
     end
 
     return (
         data = d,
         computeEffort = ce,
         computeIntensity = (d == 0) ? 0.0 : ce / d,
-        cost = 0.0, # TODO
         noNodes = length(graph.nodes),
         noEdges = ed,
     )::GraphProperties
@@ -50,23 +54,18 @@ The graph's properties after applying the [`Diff`](@ref) will be `get_properties
 For reverting a diff, it's `get_properties(graph) - GraphProperties(diff)`.
 """
 function GraphProperties(diff::Diff)
-    d = 0.0
-    ce = 0.0
-    c = 0.0 # TODO
-
     ce =
-        reduce(+, compute_effort(n.task) for n in diff.addedNodes; init = 0.0) -
-        reduce(+, compute_effort(n.task) for n in diff.removedNodes; init = 0.0)
+        reduce(+, compute_effort(task(n)) for n in diff.addedNodes; init = 0.0) -
+        reduce(+, compute_effort(task(n)) for n in diff.removedNodes; init = 0.0)
 
     d =
-        reduce(+, data(e) for e in diff.addedEdges; init = 0.0) -
-        reduce(+, data(e) for e in diff.removedEdges; init = 0.0)
+        reduce(+, data(task(n)) for n in diff.addedNodes; init = 0.0) -
+        reduce(+, data(task(n)) for n in diff.removedNodes; init = 0.0)
 
     return (
         data = d,
         computeEffort = ce,
         computeIntensity = (d == 0) ? 0.0 : ce / d,
-        cost = c,
         noNodes = length(diff.addedNodes) - length(diff.removedNodes),
         noEdges = length(diff.addedEdges) - length(diff.removedEdges),
     )::GraphProperties
diff --git a/src/properties/type.jl b/src/properties/type.jl
index 084486c..3c8c6b8 100644
--- a/src/properties/type.jl
+++ b/src/properties/type.jl
@@ -7,11 +7,10 @@ Representation of a [`DAG`](@ref)'s properties.
 `.data`: The total data transfer.\\
 `.computeEffort`: The total compute effort.\\
 `.computeIntensity`: The compute intensity, will always equal `.computeEffort / .data`.\\
-`.cost`: The estimated cost.\\
 `.noNodes`: Number of [`Node`](@ref)s.\\
 `.noEdges`: Number of [`Edge`](@ref)s.
 """
 const GraphProperties = NamedTuple{
-    (:data, :computeEffort, :computeIntensity, :cost, :noNodes, :noEdges),
-    Tuple{Float64, Float64, Float64, Float64, Int, Int},
+    (:data, :computeEffort, :computeIntensity, :noNodes, :noEdges),
+    Tuple{Float64, Float64, Float64, Int, Int},
 }
diff --git a/src/properties/utility.jl b/src/properties/utility.jl
index 3aa9def..8a08ee9 100644
--- a/src/properties/utility.jl
+++ b/src/properties/utility.jl
@@ -13,7 +13,6 @@ function -(prop1::GraphProperties, prop2::GraphProperties)
         else
             (prop1.computeEffort - prop2.computeEffort) / (prop1.data - prop2.data)
         end,
-        cost = prop1.cost - prop2.cost,
         noNodes = prop1.noNodes - prop2.noNodes,
         noEdges = prop1.noEdges - prop2.noEdges,
     )::GraphProperties
@@ -34,7 +33,6 @@ function +(prop1::GraphProperties, prop2::GraphProperties)
         else
             (prop1.computeEffort + prop2.computeEffort) / (prop1.data + prop2.data)
         end,
-        cost = prop1.cost + prop2.cost,
         noNodes = prop1.noNodes + prop2.noNodes,
         noEdges = prop1.noEdges + prop2.noEdges,
     )::GraphProperties
@@ -50,7 +48,6 @@ function -(prop::GraphProperties)
         data = -prop.data,
         computeEffort = -prop.computeEffort,
         computeIntensity = prop.computeIntensity,   # no negation here!
-        cost = -prop.cost,
         noNodes = -prop.noNodes,
         noEdges = -prop.noEdges,
     )::GraphProperties
diff --git a/src/scheduler/greedy.jl b/src/scheduler/greedy.jl
index 7ab77e9..a43a933 100644
--- a/src/scheduler/greedy.jl
+++ b/src/scheduler/greedy.jl
@@ -32,14 +32,14 @@ function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine)
         if (isa(node, ComputeTaskNode))
             lowestDevice = peek(deviceAccCost)[1]
             node.device = lowestDevice
-            deviceAccCost[lowestDevice] = compute_effort(node.task)
+            deviceAccCost[lowestDevice] = compute_effort(task(node))
         end
 
         push!(schedule, node)
-        for parent in node.parents
+        for parent in parents(node)
             # reduce the priority of all parents by one
             if (!haskey(nodeQueue, parent))
-                enqueue!(nodeQueue, parent => length(parent.children) - 1)
+                enqueue!(nodeQueue, parent => length(children(parent)) - 1)
             else
                 nodeQueue[parent] = nodeQueue[parent] - 1
             end
diff --git a/src/task/compute.jl b/src/task/compute.jl
index beb4e52..274cc3e 100644
--- a/src/task/compute.jl
+++ b/src/task/compute.jl
@@ -41,16 +41,16 @@ end
 Generate and return code for a given [`ComputeTaskNode`](@ref).
 """
 function get_expression(node::ComputeTaskNode)
-    @assert length(node.children) <= children(node.task) "Node $(node) has too many children for its task: node has $(length(node.children)) versus task has $(children(node.task))\nNode's children: $(getfield.(node.children, :children))"
+    @assert length(children(node)) <= children(task(node)) "Node $(node) has too many children for its task: node has $(length(node.children)) versus task has $(children(task(node)))\nNode's children: $(getfield.(node.children, :children))"
     @assert !ismissing(node.device) "Trying to get expression for an unscheduled ComputeTaskNode\nNode: $(node)"
 
     inExprs = Vector()
-    for id in getfield.(node.children, :id)
+    for id in getfield.(children(node), :id)
         push!(inExprs, gen_access_expr(node.device, Symbol(to_var_name(id))))
     end
     outExpr = gen_access_expr(node.device, Symbol(to_var_name(node.id)))
 
-    return get_expression(node.task, node.device, inExprs, outExpr)
+    return get_expression(task(node), node.device, inExprs, outExpr)
 end
 
 """
@@ -59,11 +59,11 @@ end
 Generate and return code for a given [`DataTaskNode`](@ref).
 """
 function get_expression(node::DataTaskNode)
-    @assert length(node.children) == 1 "Trying to call get_expression on a data task node that has $(length(node.children)) children instead of 1"
+    @assert length(children(node)) == 1 "Trying to call get_expression on a data task node that has $(length(node.children)) children instead of 1"
 
     # TODO: dispatch to device implementations generating the copy commands
 
-    child = node.children[1]
+    child = children(node)[1]
     inExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(child.id))))
     outExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(node.id))))
     dataTransportExp = Meta.parse("$outExpr = $inExpr")
@@ -79,7 +79,7 @@ Generate and return code for the initial input reading expression for [`DataTask
 See also: [`get_entry_nodes`](@ref)
 """
 function get_init_expression(node::DataTaskNode, device::AbstractDevice)
-    @assert isempty(node.children) "Trying to call get_init_expression on a data task node that is not an entry node."
+    @assert isempty(children(node)) "Trying to call get_init_expression on a data task node that is not an entry node."
 
     inExpr = eval(gen_access_expr(device, Symbol("$(to_var_name(node.id))_in")))
     outExpr = eval(gen_access_expr(device, Symbol(to_var_name(node.id))))
diff --git a/src/task/create.jl b/src/task/create.jl
index 81dc564..147bfc1 100644
--- a/src/task/create.jl
+++ b/src/task/create.jl
@@ -17,15 +17,16 @@ copy(t::AbstractComputeTask) = typeof(t)()
 
 Return a copy of th egiven [`FusedComputeTask`](@ref).
 """
-function copy(t::FusedComputeTask{T1, T2}) where {T1, T2}
-    return FusedComputeTask{T1, T2}(
-        copy(t.first_task),
-        copy(t.second_task),
-        copy(t.t1_inputs),
-        t.t1_output,
-        copy(t.t2_inputs),
-    )
+function copy(t::FusedComputeTask)
+    return FusedComputeTask(copy(t.first_task), copy(t.second_task), copy(t.t1_inputs), t.t1_output, copy(t.t2_inputs))
 end
 
-FusedComputeTask{T1, T2}(t1_inputs::Vector{String}, t1_output::String, t2_inputs::Vector{String}) where {T1, T2} =
-    FusedComputeTask{T1, T2}(T1(), T2(), t1_inputs, t1_output, t2_inputs)
+function FusedComputeTask(
+    T1::Type{<:AbstractComputeTask},
+    T2::Type{<:AbstractComputeTask},
+    t1_inputs::Vector{String},
+    t1_output::String,
+    t2_inputs::Vector{String},
+)
+    return FusedComputeTask(T1(), T2(), t1_inputs, t1_output, t2_inputs)
+end
diff --git a/src/task/properties.jl b/src/task/properties.jl
index c39d4ca..68f7cd6 100644
--- a/src/task/properties.jl
+++ b/src/task/properties.jl
@@ -30,7 +30,7 @@ compute(t::AbstractDataTask; data...) = data
 
 Fallback implementation of the compute effort of a task, throwing an error.
 """
-function compute_effort(t::AbstractTask)
+function compute_effort(t::AbstractTask)::Float64
     # default implementation using compute
     return error("Need to implement compute_effort()")
 end
@@ -40,7 +40,7 @@ end
 
 Fallback implementation of the data of a task, throwing an error.
 """
-function data(t::AbstractTask)
+function data(t::AbstractTask)::Float64
     return error("Need to implement data()")
 end
 
@@ -49,28 +49,28 @@ end
 
 Return the compute effort of a data task, always zero, regardless of the specific task.
 """
-compute_effort(t::AbstractDataTask) = 0.0
+compute_effort(t::AbstractDataTask)::Float64 = 0.0
 
 """
     data(t::AbstractDataTask)
 
 Return the data of a data task. Given by the task's `.data` field.
 """
-data(t::AbstractDataTask) = getfield(t, :data)
+data(t::AbstractDataTask)::Float64 = getfield(t, :data)
 
 """
     data(t::AbstractComputeTask)
 
 Return the data of a compute task, always zero, regardless of the specific task.
 """
-data(t::AbstractComputeTask) = 0.0
+data(t::AbstractComputeTask)::Float64 = 0.0
 
 """
     compute_effort(t::FusedComputeTask)
 
 Return the compute effort of a fused compute task. 
 """
-function compute_effort(t::FusedComputeTask)
+function compute_effort(t::FusedComputeTask)::Float64
     return compute_effort(t.first_task) + compute_effort(t.second_task)
 end
 
@@ -79,4 +79,4 @@ end
 
 Return a tuple of a the fused compute task's components' types.
 """
-get_types(::FusedComputeTask{T1, T2}) where {T1, T2} = (T1, T2)
+get_types(t::FusedComputeTask) = (typeof(t.first_task), typeof(t.second_task))
diff --git a/src/task/type.jl b/src/task/type.jl
index 8f9dfe1..0f5bf22 100644
--- a/src/task/type.jl
+++ b/src/task/type.jl
@@ -26,9 +26,9 @@ A fused compute task made up of the computation of first `T1` and then `T2`.
 
 Also see: [`get_types`](@ref).
 """
-struct FusedComputeTask{T1 <: AbstractComputeTask, T2 <: AbstractComputeTask} <: AbstractComputeTask
-    first_task::T1
-    second_task::T2
+struct FusedComputeTask <: AbstractComputeTask
+    first_task::AbstractComputeTask
+    second_task::AbstractComputeTask
     # the names of the inputs for T1
     t1_inputs::Vector{Symbol}
     # output name of T1
diff --git a/src/trie.jl b/src/trie.jl
index 7637e19..b3babca 100644
--- a/src/trie.jl
+++ b/src/trie.jl
@@ -3,9 +3,9 @@
 
 Helper struct for [`NodeTrie`](@ref). After the Trie's first level, every Trie level contains the vector of nodes that had children up to that level, and the TrieNode's children by UUID of the node's children.
 """
-mutable struct NodeIdTrie
-    value::Vector{Node}
-    children::Dict{UUID, NodeIdTrie}
+mutable struct NodeIdTrie{NodeType <: Node}
+    value::Vector{NodeType}
+    children::Dict{UUID, NodeIdTrie{NodeType}}
 end
 
 """
@@ -35,8 +35,8 @@ end
 
 Constructor for an empty [`NodeIdTrie`](@ref).
 """
-function NodeIdTrie()
-    return NodeIdTrie(Vector{Node}(), Dict{UUID, NodeIdTrie}())
+function NodeIdTrie{NodeType}() where {NodeType <: Node}
+    return NodeIdTrie(Vector{NodeType}(), Dict{UUID, NodeIdTrie{NodeType}}())
 end
 
 """
@@ -44,8 +44,12 @@ end
 
 Insert the given node into the trie. The depth is used to iterate through the trie layers, while the function calls itself recursively until it ran through all children of the node.
 """
-function insert_helper!(trie::NodeIdTrie, node::Node, depth::Int)
-    if (length(node.children) == depth)
+function insert_helper!(
+    trie::NodeIdTrie{NodeType},
+    node::NodeType,
+    depth::Int,
+) where {TaskType <: AbstractTask, NodeType <: Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}}}
+    if (length(children(node)) == depth)
         push!(trie.value, node)
         return nothing
     end
@@ -54,7 +58,7 @@ function insert_helper!(trie::NodeIdTrie, node::Node, depth::Int)
     id = node.children[depth].id
 
     if (!haskey(trie.children, id))
-        trie.children[id] = NodeIdTrie()
+        trie.children[id] = NodeIdTrie{NodeType}()
     end
     return insert_helper!(trie.children[id], node, depth)
 end
@@ -64,12 +68,14 @@ end
 
 Insert the given node into the trie. It's sorted by its type in the first layer, then by its children in the following layers.
 """
-function insert!(trie::NodeTrie, node::Node)
-    t = typeof(node.task)
-    if (!haskey(trie.children, t))
-        trie.children[t] = NodeIdTrie()
+function insert!(
+    trie::NodeTrie,
+    node::NodeType,
+) where {TaskType <: AbstractTask, NodeType <: Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}}}
+    if (!haskey(trie.children, NodeType))
+        trie.children[NodeType] = NodeIdTrie{NodeType}()
     end
-    return insert_helper!(trie.children[typeof(node.task)], node, 0)
+    return insert_helper!(trie.children[NodeType], node, 0)
 end
 
 """
diff --git a/src/utility.jl b/src/utility.jl
index 3760690..38328d1 100644
--- a/src/utility.jl
+++ b/src/utility.jl
@@ -36,8 +36,8 @@ Sort the nodes' parents and children vectors. The vectors are mostly very short
 Sorted nodes are required to make the finding of [`NodeReduction`](@ref)s a lot faster using the [`NodeTrie`](@ref) data structure.
 """
 function sort_node!(node::Node)
-    sort!(node.children, lt = lt_nodes)
-    return sort!(node.parents, lt = lt_nodes)
+    sort!(children(node), lt = lt_nodes)
+    return sort!(parents(node), lt = lt_nodes)
 end
 
 """
diff --git a/test/runtests.jl b/test/runtests.jl
index 7244983..3b8d226 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -11,6 +11,7 @@ using Test
     include("node_reduction.jl")
     include("unit_tests_graph.jl")
     include("unit_tests_execution.jl")
+    include("unit_tests_optimization.jl")
 
     include("known_graphs.jl")
 end
diff --git a/test/unit_tests_estimator.jl b/test/unit_tests_estimator.jl
index d1b50a8..a911c7b 100644
--- a/test/unit_tests_estimator.jl
+++ b/test/unit_tests_estimator.jl
@@ -38,9 +38,6 @@ function test_op_specific(estimator, graph, ns::NodeSplit)
 end
 
 function test_op(estimator, graph, op)
-    #=
-    See issue #16
-
     estimate_before = graph_cost(estimator, graph)
 
     estimate = operation_effect(estimator, graph, op)
@@ -52,7 +49,6 @@ function test_op(estimator, graph, op)
     @test isapprox((estimate_before + estimate).data, estimate_after_apply.data)
     @test isapprox((estimate_before + estimate).computeEffort, estimate_after_apply.computeEffort)
     @test isapprox((estimate_before + estimate).computeIntensity, estimate_after_apply.computeIntensity)
-    =#
 
     test_op_specific(estimator, graph, op)
     return nothing
@@ -81,9 +77,6 @@ end
             nrs = copy(ops.nodeReductions)
             nss = copy(ops.nodeSplits)
 
-            println(
-                "Testing $(length(ops.nodeFusions))xNF, $(length(ops.nodeReductions))xNR, $(length(ops.nodeSplits))xNS",
-            )
             for nf in nfs
                 test_op(estimator, graph, nf)
             end
diff --git a/test/unit_tests_execution.jl b/test/unit_tests_execution.jl
index 00c3243..bce27db 100644
--- a/test/unit_tests_execution.jl
+++ b/test/unit_tests_execution.jl
@@ -3,10 +3,10 @@ import MetagraphOptimization.interaction_result
 
 using QEDbase
 using AccurateArithmetic
-
-include("../examples/profiling_utilities.jl")
+using Random
 
 const RTOL = sqrt(eps(Float64))
+RNG = Random.default_rng()
 
 function check_particle_reverse_moment(p1::SFourMomentum, p2::SFourMomentum)
     @test isapprox(abs(p1.E), abs(p2.E))
@@ -83,7 +83,7 @@ end
     @testset "AB->AB after random walk" begin
         for i in 1:200
             graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
-            random_walk!(graph, 50)
+            optimize!(RandomWalkOptimizer(RNG), graph, 50)
 
             @test is_valid(graph)
 
@@ -115,7 +115,7 @@ end
     @testset "AB->ABBB after random walk" begin
         for i in 1:50
             graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"), ABCModel())
-            random_walk!(graph, 100)
+            optimize!(RandomWalkOptimizer(RNG), graph, 100)
             @test is_valid(graph)
 
             @test isapprox(execute(graph, process_2_4, machine, particles_2_4), expected_result; rtol = RTOL)
diff --git a/test/unit_tests_graph.jl b/test/unit_tests_graph.jl
index bab3155..bad26b7 100644
--- a/test/unit_tests_graph.jl
+++ b/test/unit_tests_graph.jl
@@ -135,6 +135,12 @@ import MetagraphOptimization.partners
     @test length(operations) == (nodeFusions = 10, nodeReductions = 0, nodeSplits = 0)
     @test length(graph.dirtyNodes) == 0
 
+    i = 0
+    for op in operations
+        i += 1
+    end
+    @test i == 10
+
     @test operations == get_operations(graph)
     nf = first(operations.nodeFusions)
 
diff --git a/test/unit_tests_optimization.jl b/test/unit_tests_optimization.jl
new file mode 100644
index 0000000..fa571df
--- /dev/null
+++ b/test/unit_tests_optimization.jl
@@ -0,0 +1,42 @@
+using Random
+
+RNG = Random.default_rng()
+
+@testset "Unit Tests Optimization" begin
+    graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"), ABCModel())
+
+    # create the optimizers
+    FIXPOINT_OPTIMIZERS = [GreedyOptimizer(GlobalMetricEstimator()), ReductionOptimizer()]
+    NO_FIXPOINT_OPTIMIZERS = [RandomWalkOptimizer(RNG)]
+
+    @testset "Optimizer $optimizer" for optimizer in vcat(NO_FIXPOINT_OPTIMIZERS, FIXPOINT_OPTIMIZERS)
+        @test operation_stack_length(graph) == 0
+        @test optimize_step!(optimizer, graph)
+
+        @test !fixpoint_reached(optimizer, graph)
+        @test operation_stack_length(graph) == 1
+
+        @test optimize!(optimizer, graph, 10)
+
+        @test !fixpoint_reached(optimizer, graph)
+
+        reset_graph!(graph)
+    end
+
+    @testset "Fixpoint optimizer $optimizer" for optimizer in FIXPOINT_OPTIMIZERS
+        @test operation_stack_length(graph) == 0
+
+        optimize_to_fixpoint!(optimizer, graph)
+
+        @test fixpoint_reached(optimizer, graph)
+        @test !optimize_step!(optimizer, graph)
+        @test !optimize!(optimizer, graph, 10)
+
+        reset_graph!(graph)
+    end
+
+    @testset "No fixpoint optimizer $optimizer" for optimizer in NO_FIXPOINT_OPTIMIZERS
+        @test_throws MethodError optimize_to_fixpoint!(optimizer, graph)
+    end
+end
+println("Optimization Unit Tests Complete!")
diff --git a/test/unit_tests_properties.jl b/test/unit_tests_properties.jl
index 97a53ad..db60583 100644
--- a/test/unit_tests_properties.jl
+++ b/test/unit_tests_properties.jl
@@ -5,18 +5,10 @@
     @test prop.data == 0.0
     @test prop.computeEffort == 0.0
     @test prop.computeIntensity == 0.0
-    @test prop.cost == 0.0
     @test prop.noNodes == 0.0
     @test prop.noEdges == 0.0
 
-    prop2 = (
-        data = 5.0,
-        computeEffort = 6.0,
-        computeIntensity = 6.0 / 5.0,
-        cost = 0.0,
-        noNodes = 2,
-        noEdges = 3,
-    )::GraphProperties
+    prop2 = (data = 5.0, computeEffort = 6.0, computeIntensity = 6.0 / 5.0, noNodes = 2, noEdges = 3)::GraphProperties
 
     @test prop + prop2 == prop2
     @test prop2 - prop == prop2
@@ -25,27 +17,18 @@
     @test negProp.data == -5.0
     @test negProp.computeEffort == -6.0
     @test negProp.computeIntensity == 6.0 / 5.0
-    @test negProp.cost == 0.0
     @test negProp.noNodes == -2
     @test negProp.noEdges == -3
 
     @test negProp + prop2 == GraphProperties()
 
-    prop3 = (
-        data = 7.0,
-        computeEffort = 3.0,
-        computeIntensity = 7.0 / 3.0,
-        cost = 0.0,
-        noNodes = -3,
-        noEdges = 2,
-    )::GraphProperties
+    prop3 = (data = 7.0, computeEffort = 3.0, computeIntensity = 7.0 / 3.0, noNodes = -3, noEdges = 2)::GraphProperties
 
     propSum = prop2 + prop3
 
     @test propSum.data == 12.0
     @test propSum.computeEffort == 9.0
     @test propSum.computeIntensity == 9.0 / 12.0
-    @test propSum.cost == 0.0
     @test propSum.noNodes == -1
     @test propSum.noEdges == 5
 end