Compare commits
24 Commits
optimizer
...
scheduling
Author | SHA1 | Date | |
---|---|---|---|
6a09ecf33d | |||
4dcb616606 | |||
9b28601f18 | |||
3267daadfd | |||
140a954d01 | |||
a86901e425 | |||
0f50b59933 | |||
cbfed20b82 | |||
f9e60a7b5e | |||
314330f00f | |||
dd01a5e691 | |||
37d645cb4e | |||
afb6af44ca | |||
bef017130b | |||
7dd9fedf2e | |||
a69dd6018e | |||
4b44eb5286 | |||
24ade323f0 | |||
95f92f080c | |||
cc05cae1cd | |||
c88898a502 | |||
0d8d824540 | |||
c428613c80 | |||
f8a591991c |
@ -8,7 +8,7 @@ env:
|
||||
|
||||
jobs:
|
||||
prepare:
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: arch-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@ -65,7 +65,7 @@ jobs:
|
||||
|
||||
test:
|
||||
needs: prepare
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: arch-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@ -127,7 +127,7 @@ jobs:
|
||||
|
||||
docs:
|
||||
needs: prepare
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: arch-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
|
@ -1,21 +0,0 @@
|
||||
# Estimation
|
||||
|
||||
## Interface
|
||||
|
||||
The interface that has to be implemented for an estimator.
|
||||
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["estimator/interafce.jl"]
|
||||
Order = [:type, :constant, :function]
|
||||
```
|
||||
|
||||
## Global Metric Estimator
|
||||
|
||||
Implementation of a global metric estimator. It uses the graph properties compute effort, data transfer, and compute intensity.
|
||||
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["estimator/global_metric.jl"]
|
||||
Order = [:type, :function]
|
||||
```
|
@ -1,41 +0,0 @@
|
||||
# Optimization
|
||||
|
||||
## Interface
|
||||
|
||||
The interface that has to be implemented for an optimization algorithm.
|
||||
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["optimization/interafce.jl"]
|
||||
Order = [:type, :constant, :function]
|
||||
```
|
||||
|
||||
## Random Walk Optimizer
|
||||
|
||||
Implementation of a random walk algorithm.
|
||||
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["estimator/random_walk.jl"]
|
||||
Order = [:type, :function]
|
||||
```
|
||||
|
||||
## Reduction Optimizer
|
||||
|
||||
Implementation of a an optimizer that reduces as far as possible.
|
||||
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["estimator/reduce.jl"]
|
||||
Order = [:type, :function]
|
||||
```
|
||||
|
||||
## Greedy Optimizer
|
||||
|
||||
Implementation of a greedy optimization algorithm.
|
||||
|
||||
```@autodocs
|
||||
Modules = [MetagraphOptimization]
|
||||
Pages = ["estimator/greedy.jl"]
|
||||
Order = [:type, :function]
|
||||
```
|
@ -1,33 +0,0 @@
|
||||
using MetagraphOptimization
|
||||
using BenchmarkTools
|
||||
|
||||
println("Getting machine info")
|
||||
@time machine = get_machine_info()
|
||||
|
||||
println("Making model")
|
||||
@time model = ABCModel()
|
||||
|
||||
println("Making process")
|
||||
process_str = "AB->ABBBBB"
|
||||
@time process = parse_process(process_str, model)
|
||||
|
||||
println("Parsing DAG")
|
||||
@time graph = parse_dag("input/$process_str.txt", model)
|
||||
|
||||
println("Generating input data")
|
||||
@time input_data = [gen_process_input(process) for _ in 1:1000]
|
||||
|
||||
println("Reducing graph")
|
||||
@time optimize_to_fixpoint!(ReductionOptimizer(), graph)
|
||||
|
||||
println("Generating compute function")
|
||||
@time compute_func = get_compute_function(graph, process, machine)
|
||||
|
||||
println("First run, single argument")
|
||||
@time compute_func(input_data[1])
|
||||
|
||||
println("\nBenchmarking function, 1 input")
|
||||
display(@benchmark compute_func($(input_data[1])))
|
||||
|
||||
println("\nBenchmarking function, 1000 inputs")
|
||||
display(@benchmark compute_func.($input_data))
|
@ -1,33 +0,0 @@
|
||||
using MetagraphOptimization
|
||||
using BenchmarkTools
|
||||
|
||||
println("Getting machine info")
|
||||
@time machine = get_machine_info()
|
||||
|
||||
println("Making model")
|
||||
@time model = ABCModel()
|
||||
|
||||
println("Making process")
|
||||
process_str = "AB->ABBBBBBB"
|
||||
@time process = parse_process(process_str, model)
|
||||
|
||||
println("Parsing DAG")
|
||||
@time graph = parse_dag("input/$process_str.txt", model)
|
||||
|
||||
println("Generating input data")
|
||||
@time input_data = [gen_process_input(process) for _ in 1:1000]
|
||||
|
||||
println("Reducing graph")
|
||||
@time optimize_to_fixpoint!(ReductionOptimizer(), graph)
|
||||
|
||||
println("Generating compute function")
|
||||
@time compute_func = get_compute_function(graph, process, machine)
|
||||
|
||||
println("First run, single argument")
|
||||
@time compute_func(input_data[1])
|
||||
|
||||
println("\nBenchmarking function, 1 input")
|
||||
display(@benchmark compute_func($(input_data[1])))
|
||||
|
||||
println("\nBenchmarking function, 1000 inputs")
|
||||
display(@benchmark compute_func.($input_data))
|
59
examples/profiling_utilities.jl
Normal file
59
examples/profiling_utilities.jl
Normal file
@ -0,0 +1,59 @@
|
||||
|
||||
function random_walk!(g::DAG, n::Int64)
|
||||
# the purpose here is to do "random" operations on the graph to simulate an optimizer
|
||||
reset_graph!(g)
|
||||
|
||||
properties = get_properties(g)
|
||||
|
||||
for i in 1:n
|
||||
# choose push or pop
|
||||
if rand(Bool)
|
||||
# push
|
||||
opt = get_operations(g)
|
||||
|
||||
# choose one of fuse/split/reduce
|
||||
option = rand(1:3)
|
||||
if option == 1 && !isempty(opt.nodeFusions)
|
||||
push_operation!(g, rand(collect(opt.nodeFusions)))
|
||||
elseif option == 2 && !isempty(opt.nodeReductions)
|
||||
push_operation!(g, rand(collect(opt.nodeReductions)))
|
||||
elseif option == 3 && !isempty(opt.nodeSplits)
|
||||
push_operation!(g, rand(collect(opt.nodeSplits)))
|
||||
else
|
||||
i = i - 1
|
||||
end
|
||||
else
|
||||
# pop
|
||||
if (can_pop(g))
|
||||
pop_operation!(g)
|
||||
else
|
||||
i = i - 1
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function reduce_all!(g::DAG)
|
||||
reset_graph!(g)
|
||||
|
||||
opt = get_operations(g)
|
||||
while (!isempty(opt.nodeReductions))
|
||||
push_operation!(g, pop!(opt.nodeReductions))
|
||||
|
||||
if (isempty(opt.nodeReductions))
|
||||
opt = get_operations(g)
|
||||
end
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function reduce_one!(g::DAG)
|
||||
opt = get_operations(g)
|
||||
if !isempty(opt.nodeReductions)
|
||||
push_operation!(g, pop!(opt.nodeReductions))
|
||||
end
|
||||
opt = get_operations(g)
|
||||
return nothing
|
||||
end
|
File diff suppressed because it is too large
Load Diff
@ -211,8 +211,10 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"include(\"../examples/profiling_utilities.jl\")\n",
|
||||
"\n",
|
||||
"# We can also mute the graph by applying some operations to it\n",
|
||||
"optimize_to_fixpoint!(ReductionOptimizer(), graph)"
|
||||
"reduce_all!(graph)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -30,7 +30,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@ProfileView.profview optimize_to_fixpoint!(ReductionOptimizer(), graph)"
|
||||
"include(\"../examples/profiling_utilities.jl\")\n",
|
||||
"@ProfileView.profview reduce_all!(graph)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -5,7 +5,6 @@ A module containing tools to work on DAGs.
|
||||
"""
|
||||
module MetagraphOptimization
|
||||
|
||||
# graph types
|
||||
export DAG
|
||||
export Node
|
||||
export Edge
|
||||
@ -19,7 +18,6 @@ export FusedComputeTask
|
||||
export PossibleOperations
|
||||
export GraphProperties
|
||||
|
||||
# graph functions
|
||||
export make_node
|
||||
export make_edge
|
||||
export insert_node
|
||||
@ -29,15 +27,10 @@ export is_exit_node
|
||||
export parents
|
||||
export children
|
||||
export compute
|
||||
export data
|
||||
export compute_effort
|
||||
export task
|
||||
export get_properties
|
||||
export get_exit_node
|
||||
export operation_stack_length
|
||||
export is_valid, is_scheduled
|
||||
|
||||
# graph operation related
|
||||
export Operation
|
||||
export AppliedOperation
|
||||
export NodeFusion
|
||||
@ -49,10 +42,6 @@ export can_pop
|
||||
export reset_graph!
|
||||
export get_operations
|
||||
|
||||
# ABC model
|
||||
export ParticleValue
|
||||
export ParticleA, ParticleB, ParticleC
|
||||
export ABCProcessDescription, ABCProcessInput, ABCModel
|
||||
export ComputeTaskP
|
||||
export ComputeTaskS1
|
||||
export ComputeTaskS2
|
||||
@ -60,22 +49,14 @@ export ComputeTaskV
|
||||
export ComputeTaskU
|
||||
export ComputeTaskSum
|
||||
|
||||
# code generation related
|
||||
export execute
|
||||
export parse_dag, parse_process
|
||||
export gen_process_input
|
||||
export get_compute_function
|
||||
export ParticleValue
|
||||
export ParticleA, ParticleB, ParticleC
|
||||
export ABCProcessDescription, ABCProcessInput, ABCModel
|
||||
|
||||
# estimator
|
||||
export cost_type, graph_cost, operation_effect
|
||||
export GlobalMetricEstimator, CDCost
|
||||
|
||||
# optimization
|
||||
export AbstractOptimizer, GreedyOptimizer, ReductionOptimizer, RandomWalkOptimizer
|
||||
export optimize_step!, optimize!
|
||||
export fixpoint_reached, optimize_to_fixpoint!
|
||||
|
||||
# machine info
|
||||
export Machine
|
||||
export get_machine_info
|
||||
|
||||
@ -124,7 +105,6 @@ include("node/properties.jl")
|
||||
include("node/validate.jl")
|
||||
|
||||
include("operation/utility.jl")
|
||||
include("operation/iterate.jl")
|
||||
include("operation/apply.jl")
|
||||
include("operation/clean.jl")
|
||||
include("operation/find.jl")
|
||||
@ -141,14 +121,6 @@ include("task/compute.jl")
|
||||
include("task/print.jl")
|
||||
include("task/properties.jl")
|
||||
|
||||
include("estimator/interface.jl")
|
||||
include("estimator/global_metric.jl")
|
||||
|
||||
include("optimization/interface.jl")
|
||||
include("optimization/greedy.jl")
|
||||
include("optimization/random_walk.jl")
|
||||
include("optimization/reduce.jl")
|
||||
|
||||
include("models/interface.jl")
|
||||
include("models/print.jl")
|
||||
|
||||
|
@ -79,7 +79,7 @@ function gen_input_assignment_code(
|
||||
# TODO: how to get the "default" cpu device?
|
||||
device = entry_device(machine)
|
||||
evalExpr = eval(gen_access_expr(device, symbol))
|
||||
push!(assignInputs, Meta.parse("$(evalExpr)::ParticleValue{$type} = ParticleValue($p, 1.0)"))
|
||||
push!(assignInputs, Meta.parse("$(evalExpr) = ParticleValue($p, 1.0)"))
|
||||
end
|
||||
end
|
||||
|
||||
@ -102,7 +102,6 @@ function get_compute_function(graph::DAG, process::AbstractProcessDescription, m
|
||||
expr = Meta.parse(
|
||||
"function compute_$(functionId)(input::AbstractProcessInput) $initCaches; $assignInputs; $code; return $resSym; end",
|
||||
)
|
||||
|
||||
func = eval(expr)
|
||||
|
||||
return func
|
||||
|
@ -1,77 +0,0 @@
|
||||
|
||||
"""
|
||||
CDCost
|
||||
|
||||
Representation of a [`DAG`](@ref)'s cost as estimated by the [`GlobalMetricEstimator`](@ref).
|
||||
|
||||
# Fields:
|
||||
`.data`: The total data transfer.\\
|
||||
`.computeEffort`: The total compute effort.\\
|
||||
`.computeIntensity`: The compute intensity, will always equal `.computeEffort / .data`.
|
||||
|
||||
|
||||
!!! note
|
||||
Note that the `computeIntensity` doesn't necessarily make sense in the context of only operation costs.
|
||||
For example, for node fusions this will always be 0, since the computeEffort is zero.
|
||||
It will still work as intended when adding/subtracting to/from a `graph_cost` estimate.
|
||||
"""
|
||||
const CDCost = NamedTuple{(:data, :computeEffort, :computeIntensity), Tuple{Float64, Float64, Float64}}
|
||||
|
||||
function +(cost1::CDCost, cost2::CDCost)::CDCost
|
||||
d = cost1.data + cost2.data
|
||||
ce = computeEffort = cost1.computeEffort + cost2.computeEffort
|
||||
return (data = d, computeEffort = ce, computeIntensity = ce / d)::CDCost
|
||||
end
|
||||
|
||||
function -(cost1::CDCost, cost2::CDCost)::CDCost
|
||||
d = cost1.data - cost2.data
|
||||
ce = computeEffort = cost1.computeEffort - cost2.computeEffort
|
||||
return (data = d, computeEffort = ce, computeIntensity = ce / d)::CDCost
|
||||
end
|
||||
|
||||
function isless(cost1::CDCost, cost2::CDCost)::Bool
|
||||
return cost1.data + cost1.computeEffort < cost2.data + cost2.computeEffort
|
||||
end
|
||||
|
||||
function zero(type::Type{CDCost})
|
||||
return (data = 0.0, computeEffort = 00.0, computeIntensity = 0.0)::CDCost
|
||||
end
|
||||
|
||||
function typemax(type::Type{CDCost})
|
||||
return (data = Inf, computeEffort = Inf, computeIntensity = 0.0)::CDCost
|
||||
end
|
||||
|
||||
struct GlobalMetricEstimator <: AbstractEstimator end
|
||||
|
||||
function cost_type(estimator::GlobalMetricEstimator)::Type{CDCost}
|
||||
return CDCost
|
||||
end
|
||||
|
||||
function graph_cost(estimator::GlobalMetricEstimator, graph::DAG)
|
||||
properties = get_properties(graph)
|
||||
return (
|
||||
data = properties.data,
|
||||
computeEffort = properties.computeEffort,
|
||||
computeIntensity = properties.computeIntensity,
|
||||
)::CDCost
|
||||
end
|
||||
|
||||
function operation_effect(estimator::GlobalMetricEstimator, graph::DAG, operation::NodeFusion)
|
||||
return (data = -data(operation.input[2].task), computeEffort = 0.0, computeIntensity = 0.0)::CDCost
|
||||
end
|
||||
|
||||
function operation_effect(estimator::GlobalMetricEstimator, graph::DAG, operation::NodeReduction)
|
||||
s = length(operation.input) - 1
|
||||
return (
|
||||
data = s * -data(task(operation.input[1])),
|
||||
computeEffort = s * -compute_effort(task(operation.input[1])),
|
||||
computeIntensity = typeof(operation.input) <: DataTaskNode ? 0.0 : Inf,
|
||||
)::CDCost
|
||||
end
|
||||
|
||||
function operation_effect(estimator::GlobalMetricEstimator, graph::DAG, operation::NodeSplit)
|
||||
s::Float64 = length(parents(operation.input)) - 1
|
||||
d::Float64 = s * data(task(operation.input))
|
||||
ce::Float64 = s * compute_effort(task(operation.input))
|
||||
return (data = d, computeEffort = ce, computeIntensity = ce / d)::CDCost
|
||||
end
|
@ -1,44 +0,0 @@
|
||||
|
||||
"""
|
||||
AbstractEstimator
|
||||
|
||||
Abstract base type for an estimator. An estimator estimates the cost of a graph or the difference an operation applied to a graph will make to its cost.
|
||||
|
||||
Interface functions are
|
||||
- [`graph_cost`](@ref)
|
||||
- [`operation_effect`](@ref)
|
||||
"""
|
||||
abstract type AbstractEstimator end
|
||||
|
||||
"""
|
||||
cost_type(estimator::AbstractEstimator)
|
||||
|
||||
Interface function returning a specific estimator's cost type, i.e., the type returned by its implementation of [`graph_cost`](@ref) and [`operation_effect`](@ref).
|
||||
"""
|
||||
function cost_type end
|
||||
|
||||
"""
|
||||
graph_cost(estimator::AbstractEstimator, graph::DAG)
|
||||
|
||||
Get the total estimated cost of the graph. The cost's data type can be chosen by the implementation, but must have a usable lessthan comparison operator (<), basic math operators (+, -) and an implementation of `zero()` and `typemax()`.
|
||||
"""
|
||||
function graph_cost end
|
||||
|
||||
"""
|
||||
operation_effect(estimator::AbstractEstimator, graph::DAG, operation::Operation)
|
||||
|
||||
Get the estimated effect on the cost of the graph, such that `graph_cost(estimator, graph) + operation_effect(estimator, graph, operation) ~= graph_cost(estimator, graph_with_operation_applied)`. There is no hard requirement for this, but the better the estimate, the better an optimization algorithm will be.
|
||||
|
||||
!!! note
|
||||
There is a default implementation of this function, applying the operation, calling [`graph_cost`](@ref), then popping the operation again.
|
||||
|
||||
It can be much faster to overload this function for a specific estimator and directly compute the effects from the operation if possible.
|
||||
"""
|
||||
function operation_effect(estimator::AbstractEstimator, graph::DAG, operation::Operation)
|
||||
# This is currently not stably working, see issue #16
|
||||
cost = graph_cost(estimator, graph)
|
||||
push_operation!(graph, operation)
|
||||
cost_after = graph_cost(estimator, graph)
|
||||
pop_operation!(graph)
|
||||
return cost_after - cost
|
||||
end
|
@ -17,5 +17,21 @@ function in(edge::Edge, graph::DAG)
|
||||
return false
|
||||
end
|
||||
|
||||
return n1 in children(n2)
|
||||
return n1 in n2.children
|
||||
end
|
||||
|
||||
"""
|
||||
==(n1::Node, n2::Node, g::DAG)
|
||||
|
||||
Check equality of two nodes in a graph.
|
||||
"""
|
||||
function ==(n1::Node, n2::Node, g::DAG)
|
||||
if typeof(n1) != typeof(n2)
|
||||
return false
|
||||
end
|
||||
if !(n1 in g) || !(n2 in g)
|
||||
return false
|
||||
end
|
||||
|
||||
return n1.task == n2.task && children(n1) == children(n2)
|
||||
end
|
||||
|
@ -46,7 +46,7 @@ Insert the edge between node1 (child) and node2 (parent) into the graph.
|
||||
See also: [`insert_node!`](@ref), [`remove_node!`](@ref), [`remove_edge!`](@ref)
|
||||
"""
|
||||
function insert_edge!(graph::DAG, node1::Node, node2::Node; track = true, invalidate_cache = true)
|
||||
#@assert (node2 ∉ parents(node1)) && (node1 ∉ children(node2)) "Edge to insert already exists"
|
||||
@assert (node2 ∉ node1.parents) && (node1 ∉ node2.children) "Edge to insert already exists"
|
||||
|
||||
# 1: mute
|
||||
# edge points from child to parent
|
||||
@ -85,7 +85,7 @@ Remove the node from the graph.
|
||||
See also: [`insert_node!`](@ref), [`insert_edge!`](@ref), [`remove_edge!`](@ref)
|
||||
"""
|
||||
function remove_node!(graph::DAG, node::Node; track = true, invalidate_cache = true)
|
||||
#@assert node in graph.nodes "Trying to remove a node that's not in the graph"
|
||||
@assert node in graph.nodes "Trying to remove a node that's not in the graph"
|
||||
|
||||
# 1: mute
|
||||
delete!(graph.nodes, node)
|
||||
@ -124,29 +124,18 @@ function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invali
|
||||
pre_length2 = length(node2.children)
|
||||
|
||||
#TODO: filter is very slow
|
||||
for i in eachindex(node1.parents)
|
||||
if (node1.parents[i] == node2)
|
||||
splice!(node1.parents, i)
|
||||
break
|
||||
end
|
||||
end
|
||||
filter!(x -> x != node2, node1.parents)
|
||||
filter!(x -> x != node1, node2.children)
|
||||
|
||||
for i in eachindex(node2.children)
|
||||
if (node2.children[i] == node1)
|
||||
splice!(node2.children, i)
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
#=@assert begin
|
||||
@assert begin
|
||||
removed = pre_length1 - length(node1.parents)
|
||||
removed <= 1
|
||||
end "removed more than one node from node1's parents"=#
|
||||
end "removed more than one node from node1's parents"
|
||||
|
||||
#=@assert begin
|
||||
removed = pre_length2 - length(children(node2))
|
||||
@assert begin
|
||||
removed = pre_length2 - length(node2.children)
|
||||
removed <= 1
|
||||
end "removed more than one node from node2's children"=#
|
||||
end "removed more than one node from node2's children"
|
||||
|
||||
# 2: keep track
|
||||
if (track)
|
||||
@ -174,7 +163,7 @@ function replace_children!(task::FusedComputeTask, before, after)
|
||||
replacedIn1 = length(findall(x -> x == before, task.t1_inputs))
|
||||
replacedIn2 = length(findall(x -> x == before, task.t2_inputs))
|
||||
|
||||
#@assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
|
||||
@assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
|
||||
|
||||
replace!(task.t1_inputs, before => after)
|
||||
replace!(task.t2_inputs, before => after)
|
||||
@ -196,33 +185,33 @@ end
|
||||
|
||||
function update_child!(graph::DAG, n::Node, child_before::Symbol, child_after::Symbol; track = true)
|
||||
# only need to update fused compute tasks
|
||||
if !(typeof(task(n)) <: FusedComputeTask)
|
||||
if !(typeof(n.task) <: FusedComputeTask)
|
||||
return nothing
|
||||
end
|
||||
|
||||
taskBefore = copy(task(n))
|
||||
taskBefore = copy(n.task)
|
||||
|
||||
#=if !((child_before in task(n).t1_inputs) || (child_before in task(n).t2_inputs))
|
||||
if !((child_before in n.task.t1_inputs) || (child_before in n.task.t2_inputs))
|
||||
println("------------------ Nothing to replace!! ------------------")
|
||||
child_ids = Vector{String}()
|
||||
for child in children(n)
|
||||
for child in n.children
|
||||
push!(child_ids, "$(child.id)")
|
||||
end
|
||||
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
|
||||
@assert false
|
||||
end=#
|
||||
end
|
||||
|
||||
replace_children!(task(n), child_before, child_after)
|
||||
replace_children!(n.task, child_before, child_after)
|
||||
|
||||
#=if !((child_after in task(n).t1_inputs) || (child_after in task(n).t2_inputs))
|
||||
if !((child_after in n.task.t1_inputs) || (child_after in n.task.t2_inputs))
|
||||
println("------------------ Did not replace anything!! ------------------")
|
||||
child_ids = Vector{String}()
|
||||
for child in children(n)
|
||||
for child in n.children
|
||||
push!(child_ids, "$(child.id)")
|
||||
end
|
||||
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
|
||||
@assert false
|
||||
end=#
|
||||
end
|
||||
|
||||
# keep track
|
||||
if (track)
|
||||
@ -253,14 +242,8 @@ function invalidate_caches!(graph::DAG, operation::NodeFusion)
|
||||
|
||||
# delete the operation from all caches of nodes involved in the operation
|
||||
# TODO: filter is very slow
|
||||
for n in [1, 3]
|
||||
for i in eachindex(operation.input[n].nodeFusions)
|
||||
if operation == operation.input[n].nodeFusions[i]
|
||||
splice!(operation.input[n].nodeFusions, i)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
filter!(!=(operation), operation.input[1].nodeFusions)
|
||||
filter!(!=(operation), operation.input[3].nodeFusions)
|
||||
|
||||
operation.input[2].nodeFusion = missing
|
||||
|
||||
|
@ -30,10 +30,10 @@ function show(io::IO, graph::DAG)
|
||||
nodeDict = Dict{Type, Int64}()
|
||||
noEdges = 0
|
||||
for node in graph.nodes
|
||||
if haskey(nodeDict, typeof(task(node)))
|
||||
nodeDict[typeof(task(node))] = nodeDict[typeof(task(node))] + 1
|
||||
if haskey(nodeDict, typeof(node.task))
|
||||
nodeDict[typeof(node.task)] = nodeDict[typeof(node.task)] + 1
|
||||
else
|
||||
nodeDict[typeof(task(node))] = 1
|
||||
nodeDict[typeof(node.task)] = 1
|
||||
end
|
||||
noEdges += length(parents(node))
|
||||
end
|
||||
|
@ -43,12 +43,3 @@ function get_entry_nodes(graph::DAG)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
"""
|
||||
operation_stack_length(graph::DAG)
|
||||
|
||||
Return the number of operations applied to the graph.
|
||||
"""
|
||||
function operation_stack_length(graph::DAG)
|
||||
return length(graph.appliedOperations) + length(graph.operationsToApply)
|
||||
end
|
||||
|
@ -24,7 +24,7 @@ To get the set of possible operations, use [`get_operations`](@ref).
|
||||
The members of the object should not be manually accessed, instead always use the provided interface functions.
|
||||
"""
|
||||
mutable struct DAG
|
||||
nodes::Set{Union{DataTaskNode, ComputeTaskNode}}
|
||||
nodes::Set{Node}
|
||||
|
||||
# The operations currently applied to the set of nodes
|
||||
appliedOperations::Stack{AppliedOperation}
|
||||
@ -36,7 +36,7 @@ mutable struct DAG
|
||||
possibleOperations::PossibleOperations
|
||||
|
||||
# The set of nodes whose possible operations need to be reevaluated
|
||||
dirtyNodes::Set{Union{DataTaskNode, ComputeTaskNode}}
|
||||
dirtyNodes::Set{Node}
|
||||
|
||||
# "snapshot" system: keep track of added/removed nodes/edges since last snapshot
|
||||
# these are muted in insert_node! etc.
|
||||
|
@ -7,7 +7,7 @@ Return the particle and value as is.
|
||||
|
||||
0 FLOP.
|
||||
"""
|
||||
function compute(::ComputeTaskP, data::ParticleValue{P})::ParticleValue{P} where {P <: ABCParticle}
|
||||
function compute(::ComputeTaskP, data::ParticleValue)
|
||||
return data
|
||||
end
|
||||
|
||||
@ -18,7 +18,7 @@ Compute an outer edge. Return the particle value with the same particle and the
|
||||
|
||||
1 FLOP.
|
||||
"""
|
||||
function compute(::ComputeTaskU, data::ParticleValue{P})::ParticleValue{P} where {P <: ABCParticle}
|
||||
function compute(::ComputeTaskU, data::ParticleValue)
|
||||
return ParticleValue(data.p, data.v * outer_edge(data.p))
|
||||
end
|
||||
|
||||
@ -29,11 +29,7 @@ Compute a vertex. Preserve momentum and particle types (AB->C etc.) to create re
|
||||
|
||||
6 FLOP.
|
||||
"""
|
||||
function compute(
|
||||
::ComputeTaskV,
|
||||
data1::ParticleValue{P1},
|
||||
data2::ParticleValue{P2},
|
||||
)::ParticleValue where {P1 <: ABCParticle, P2 <: ABCParticle}
|
||||
function compute(::ComputeTaskV, data1::ParticleValue, data2::ParticleValue)
|
||||
p3 = preserve_momentum(data1.p, data2.p)
|
||||
dataOut = ParticleValue(p3, data1.v * vertex() * data2.v)
|
||||
return dataOut
|
||||
@ -48,15 +44,14 @@ For valid inputs, both input particles should have the same momenta at this poin
|
||||
|
||||
12 FLOP.
|
||||
"""
|
||||
function compute(::ComputeTaskS2, data1::ParticleValue{P}, data2::ParticleValue{P})::Float64 where {P <: ABCParticle}
|
||||
function compute(::ComputeTaskS2, data1::ParticleValue, data2::ParticleValue)
|
||||
#=
|
||||
@assert isapprox(abs(data1.p.momentum.E), abs(data2.p.momentum.E), rtol = 0.001, atol = sqrt(eps())) "E: $(data1.p.momentum.E) vs. $(data2.p.momentum.E)"
|
||||
@assert isapprox(data1.p.momentum.px, -data2.p.momentum.px, rtol = 0.001, atol = sqrt(eps())) "px: $(data1.p.momentum.px) vs. $(data2.p.momentum.px)"
|
||||
@assert isapprox(data1.p.momentum.py, -data2.p.momentum.py, rtol = 0.001, atol = sqrt(eps())) "py: $(data1.p.momentum.py) vs. $(data2.p.momentum.py)"
|
||||
@assert isapprox(data1.p.momentum.pz, -data2.p.momentum.pz, rtol = 0.001, atol = sqrt(eps())) "pz: $(data1.p.momentum.pz) vs. $(data2.p.momentum.pz)"
|
||||
=#
|
||||
inner = inner_edge(data1.p)
|
||||
return data1.v * inner * data2.v
|
||||
return data1.v * inner_edge(data1.p) * data2.v
|
||||
end
|
||||
|
||||
"""
|
||||
@ -66,7 +61,7 @@ Compute inner edge (1 input particle, 1 output particle).
|
||||
|
||||
11 FLOP.
|
||||
"""
|
||||
function compute(::ComputeTaskS1, data::ParticleValue{P})::ParticleValue{P} where {P <: ABCParticle}
|
||||
function compute(::ComputeTaskS1, data::ParticleValue)
|
||||
return ParticleValue(data.p, data.v * inner_edge(data.p))
|
||||
end
|
||||
|
||||
@ -77,7 +72,7 @@ Compute a sum over the vector. Use an algorithm that accounts for accumulated er
|
||||
|
||||
Linearly many FLOP with growing data.
|
||||
"""
|
||||
function compute(::ComputeTaskSum, data::Vector{Float64})::Float64
|
||||
function compute(::ComputeTaskSum, data::Vector{Float64})
|
||||
return sum_kbn(data)
|
||||
end
|
||||
|
||||
|
@ -181,7 +181,7 @@ function parse_dag(filename::AbstractString, model::ABCModel, verbose::Bool = fa
|
||||
insert_edge!(graph, compute_S2, data_out, track = false, invalidate_cache = false)
|
||||
|
||||
insert_edge!(graph, data_out, sum_node, track = false, invalidate_cache = false)
|
||||
add_child!(task(sum_node))
|
||||
add_child!(sum_node.task)
|
||||
elseif occursin(regex_plus, node)
|
||||
if (verbose)
|
||||
println("\rReading Nodes Complete ")
|
||||
|
@ -1,7 +1,5 @@
|
||||
using QEDbase
|
||||
|
||||
import QEDbase.mass
|
||||
|
||||
"""
|
||||
ABCModel <: AbstractPhysicsModel
|
||||
|
||||
@ -89,9 +87,9 @@ For 2 given (non-equal) particle types, return the third of ABC.
|
||||
"""
|
||||
function interaction_result(t1::Type{T1}, t2::Type{T2}) where {T1 <: ABCParticle, T2 <: ABCParticle}
|
||||
@assert t1 != t2
|
||||
if t1 != ParticleA && t2 != ParticleA
|
||||
if t1 != Type{ParticleA} && t2 != Type{ParticleA}
|
||||
return ParticleA
|
||||
elseif t1 != ParticleB && t2 != ParticleB
|
||||
elseif t1 != Type{ParticleB} && t2 != Type{ParticleB}
|
||||
return ParticleB
|
||||
else
|
||||
return ParticleC
|
||||
@ -163,6 +161,7 @@ Takes 4 effective FLOP.
|
||||
function preserve_momentum(p1::ABCParticle, p2::ABCParticle)
|
||||
t3 = interaction_result(typeof(p1), typeof(p2))
|
||||
p3 = t3(p1.momentum + p2.momentum)
|
||||
|
||||
return p3
|
||||
end
|
||||
|
||||
|
@ -3,35 +3,35 @@
|
||||
|
||||
Return the compute effort of an S1 task.
|
||||
"""
|
||||
compute_effort(t::ComputeTaskS1)::Float64 = 11.0
|
||||
compute_effort(t::ComputeTaskS1) = 11
|
||||
|
||||
"""
|
||||
compute_effort(t::ComputeTaskS2)
|
||||
|
||||
Return the compute effort of an S2 task.
|
||||
"""
|
||||
compute_effort(t::ComputeTaskS2)::Float64 = 12.0
|
||||
compute_effort(t::ComputeTaskS2) = 12
|
||||
|
||||
"""
|
||||
compute_effort(t::ComputeTaskU)
|
||||
|
||||
Return the compute effort of a U task.
|
||||
"""
|
||||
compute_effort(t::ComputeTaskU)::Float64 = 1.0
|
||||
compute_effort(t::ComputeTaskU) = 1
|
||||
|
||||
"""
|
||||
compute_effort(t::ComputeTaskV)
|
||||
|
||||
Return the compute effort of a V task.
|
||||
"""
|
||||
compute_effort(t::ComputeTaskV)::Float64 = 6.0
|
||||
compute_effort(t::ComputeTaskV) = 6
|
||||
|
||||
"""
|
||||
compute_effort(t::ComputeTaskP)
|
||||
|
||||
Return the compute effort of a P task.
|
||||
"""
|
||||
compute_effort(t::ComputeTaskP)::Float64 = 0.0
|
||||
compute_effort(t::ComputeTaskP) = 0
|
||||
|
||||
"""
|
||||
compute_effort(t::ComputeTaskSum)
|
||||
@ -41,7 +41,7 @@ Return the compute effort of a Sum task.
|
||||
Note: This is a constant compute effort, even though sum scales with the number of its inputs. Since there is only ever a single sum node in a graph generated from the ABC-Model,
|
||||
this doesn't matter.
|
||||
"""
|
||||
compute_effort(t::ComputeTaskSum)::Float64 = 1.0
|
||||
compute_effort(t::ComputeTaskSum) = 1
|
||||
|
||||
"""
|
||||
show(io::IO, t::DataTask)
|
||||
|
@ -4,7 +4,7 @@
|
||||
Task representing a specific data transfer in the ABC Model.
|
||||
"""
|
||||
struct DataTask <: AbstractDataTask
|
||||
data::Float64
|
||||
data::UInt64
|
||||
end
|
||||
|
||||
"""
|
||||
|
@ -21,7 +21,7 @@ end
|
||||
|
||||
Equality comparison between two [`ComputeTaskNode`](@ref)s.
|
||||
"""
|
||||
function ==(n1::ComputeTaskNode{TaskType}, n2::ComputeTaskNode{TaskType}) where {TaskType <: AbstractComputeTask}
|
||||
function ==(n1::ComputeTaskNode, n2::ComputeTaskNode)
|
||||
return n1.id == n2.id
|
||||
end
|
||||
|
||||
@ -30,6 +30,6 @@ end
|
||||
|
||||
Equality comparison between two [`DataTaskNode`](@ref)s.
|
||||
"""
|
||||
function ==(n1::DataTaskNode{TaskType}, n2::DataTaskNode{TaskType}) where {TaskType <: AbstractDataTask}
|
||||
function ==(n1::DataTaskNode, n2::DataTaskNode)
|
||||
return n1.id == n2.id
|
||||
end
|
||||
|
@ -13,8 +13,8 @@ ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
|
||||
)
|
||||
|
||||
copy(m::Missing) = missing
|
||||
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(task(n)))
|
||||
copy(n::DataTaskNode) = DataTaskNode(copy(task(n)), n.name)
|
||||
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task))
|
||||
copy(n::DataTaskNode) = DataTaskNode(copy(n.task), n.name)
|
||||
|
||||
"""
|
||||
make_node(t::AbstractTask)
|
||||
|
@ -4,7 +4,7 @@
|
||||
Print a short string representation of the node to io.
|
||||
"""
|
||||
function show(io::IO, n::Node)
|
||||
return print(io, "Node(", task(n), ")")
|
||||
return print(io, "Node(", n.task, ")")
|
||||
end
|
||||
|
||||
"""
|
||||
|
@ -3,27 +3,25 @@
|
||||
|
||||
Return whether this node is an entry node in its graph, i.e., it has no children.
|
||||
"""
|
||||
is_entry_node(node::Node) = length(children(node)) == 0
|
||||
is_entry_node(node::Node) = length(node.children) == 0
|
||||
|
||||
"""
|
||||
is_exit_node(node::Node)
|
||||
|
||||
Return whether this node is an exit node of its graph, i.e., it has no parents.
|
||||
"""
|
||||
is_exit_node(node::Node)::Bool = length(parents(node)) == 0
|
||||
is_exit_node(node::Node) = length(node.parents) == 0
|
||||
|
||||
"""
|
||||
task(node::Node)
|
||||
data(edge::Edge)
|
||||
|
||||
Return the node's task.
|
||||
Return the data transfered by this edge, i.e., 0 if the child is a [`ComputeTaskNode`](@ref), otherwise the child's `data()`.
|
||||
"""
|
||||
function task(node::DataTaskNode{TaskType})::TaskType where {TaskType <: Union{AbstractDataTask, AbstractComputeTask}}
|
||||
return node.task
|
||||
end
|
||||
function task(
|
||||
node::ComputeTaskNode{TaskType},
|
||||
)::TaskType where {TaskType <: Union{AbstractDataTask, AbstractComputeTask}}
|
||||
return node.task
|
||||
function data(edge::Edge)
|
||||
if typeof(edge.edge[1]) <: DataTaskNode
|
||||
return data(edge.edge[1].task)
|
||||
end
|
||||
return 0.0
|
||||
end
|
||||
|
||||
"""
|
||||
@ -33,11 +31,8 @@ Return a copy of the node's children so it can safely be muted without changing
|
||||
|
||||
A node's children are its prerequisite nodes, nodes that need to execute before the task of this node.
|
||||
"""
|
||||
function children(node::DataTaskNode)::Vector{ComputeTaskNode}
|
||||
return node.children
|
||||
end
|
||||
function children(node::ComputeTaskNode)::Vector{DataTaskNode}
|
||||
return node.children
|
||||
function children(node::Node)
|
||||
return copy(node.children)
|
||||
end
|
||||
|
||||
"""
|
||||
@ -47,11 +42,8 @@ Return a copy of the node's parents so it can safely be muted without changing t
|
||||
|
||||
A node's parents are its subsequent nodes, nodes that need this node to execute.
|
||||
"""
|
||||
function parents(node::DataTaskNode)::Vector{ComputeTaskNode}
|
||||
return node.parents
|
||||
end
|
||||
function parents(node::ComputeTaskNode)::Vector{DataTaskNode}
|
||||
return node.parents
|
||||
function parents(node::Node)
|
||||
return copy(node.parents)
|
||||
end
|
||||
|
||||
"""
|
||||
@ -61,11 +53,11 @@ Return a vector of all siblings of this node.
|
||||
|
||||
A node's siblings are all children of any of its parents. The result contains no duplicates and includes the node itself.
|
||||
"""
|
||||
function siblings(node::Node)::Set{Node}
|
||||
function siblings(node::Node)
|
||||
result = Set{Node}()
|
||||
push!(result, node)
|
||||
for parent in parents(node)
|
||||
union!(result, children(parent))
|
||||
for parent in node.parents
|
||||
union!(result, parent.children)
|
||||
end
|
||||
|
||||
return result
|
||||
@ -81,11 +73,11 @@ A node's partners are all parents of any of its children. The result contains no
|
||||
Note: This is very slow when there are multiple children with many parents.
|
||||
This is less of a problem in [`siblings(node::Node)`](@ref) because (depending on the model) there are no nodes with a large number of children, or only a single one.
|
||||
"""
|
||||
function partners(node::Node)::Set{Node}
|
||||
function partners(node::Node)
|
||||
result = Set{Node}()
|
||||
push!(result, node)
|
||||
for child in children(node)
|
||||
union!(result, parents(child))
|
||||
for child in node.children
|
||||
union!(result, child.parents)
|
||||
end
|
||||
|
||||
return result
|
||||
@ -98,8 +90,8 @@ Alternative version to [`partners(node::Node)`](@ref), avoiding allocation of a
|
||||
"""
|
||||
function partners(node::Node, set::Set{Node})
|
||||
push!(set, node)
|
||||
for child in children(node)
|
||||
union!(set, parents(child))
|
||||
for child in node.children
|
||||
union!(set, child.parents)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
@ -109,8 +101,8 @@ end
|
||||
|
||||
Return whether the `potential_parent` is a parent of `node`.
|
||||
"""
|
||||
function is_parent(potential_parent::Node, node::Node)::Bool
|
||||
return potential_parent in parents(node)
|
||||
function is_parent(potential_parent::Node, node::Node)
|
||||
return potential_parent in node.parents
|
||||
end
|
||||
|
||||
"""
|
||||
@ -118,6 +110,6 @@ end
|
||||
|
||||
Return whether the `potential_child` is a child of `node`.
|
||||
"""
|
||||
function is_child(potential_child::Node, node::Node)::Bool
|
||||
return potential_child in children(node)
|
||||
function is_child(potential_child::Node, node::Node)
|
||||
return potential_child in node.children
|
||||
end
|
||||
|
@ -33,8 +33,8 @@ Any node that transfers data and does no computation.
|
||||
`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.\\
|
||||
`.name`: The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\
|
||||
"""
|
||||
mutable struct DataTaskNode{TaskType <: AbstractDataTask} <: Node
|
||||
task::TaskType
|
||||
mutable struct DataTaskNode <: Node
|
||||
task::AbstractDataTask
|
||||
|
||||
# use vectors as sets have way too much memory overhead
|
||||
parents::Vector{Node}
|
||||
@ -73,8 +73,8 @@ Any node that computes a result from inputs using an [`AbstractComputeTask`](@re
|
||||
`.nodeFusions`: A vector of this node's [`NodeFusion`](@ref)s. For a `ComputeTaskNode` there can be any number of these, unlike the [`DataTaskNode`](@ref)s.\\
|
||||
`.device`: The Device this node has been scheduled on by a [`Scheduler`](@ref).
|
||||
"""
|
||||
mutable struct ComputeTaskNode{TaskType <: AbstractComputeTask} <: Node
|
||||
task::TaskType
|
||||
mutable struct ComputeTaskNode <: Node
|
||||
task::AbstractComputeTask
|
||||
parents::Vector{Node}
|
||||
children::Vector{Node}
|
||||
id::Base.UUID
|
||||
@ -83,7 +83,7 @@ mutable struct ComputeTaskNode{TaskType <: AbstractComputeTask} <: Node
|
||||
nodeSplit::Union{Operation, Missing}
|
||||
|
||||
# for ComputeTasks there can be multiple fusions, unlike the DataTasks
|
||||
nodeFusions::Vector{<:Operation}
|
||||
nodeFusions::Vector{Operation}
|
||||
|
||||
# the device this node is assigned to execute on
|
||||
device::Union{AbstractDevice, Missing}
|
||||
|
@ -29,7 +29,7 @@ function is_valid_node(graph::DAG, node::Node)
|
||||
@assert is_valid(graph, node.nodeSplit)
|
||||
end=#
|
||||
|
||||
if !(typeof(task(node)) <: FusedComputeTask)
|
||||
if !(typeof(node.task) <: FusedComputeTask)
|
||||
# the remaining checks are only necessary for fused compute tasks
|
||||
return true
|
||||
end
|
||||
@ -37,7 +37,7 @@ function is_valid_node(graph::DAG, node::Node)
|
||||
# every child must be in some input of the task
|
||||
for child in node.children
|
||||
str = Symbol(to_var_name(child.id))
|
||||
@assert (str in task(node).t1_inputs) || (str in task(node).t2_inputs) "$str was not in any of the tasks' inputs\nt1_inputs: $(task(node).t1_inputs)\nt2_inputs: $(task(node).t2_inputs)"
|
||||
@assert (str in node.task.t1_inputs) || (str in node.task.t2_inputs) "$str was not in any of the tasks' inputs\nt1_inputs: $(node.task.t1_inputs)\nt2_inputs: $(node.task.t2_inputs)"
|
||||
end
|
||||
|
||||
return true
|
||||
|
@ -132,11 +132,11 @@ function revert_diff!(graph::DAG, diff::Diff)
|
||||
insert_edge!(graph, edge.edge[1], edge.edge[2], track = false)
|
||||
end
|
||||
|
||||
for (node, t) in diff.updatedChildren
|
||||
for (node, task) in diff.updatedChildren
|
||||
# node must be fused compute task at this point
|
||||
@assert typeof(task(node)) <: FusedComputeTask
|
||||
@assert typeof(node.task) <: FusedComputeTask
|
||||
|
||||
node.task = t
|
||||
node.task = task
|
||||
end
|
||||
|
||||
graph.properties -= GraphProperties(diff)
|
||||
@ -158,11 +158,11 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
# save children and parents
|
||||
n1Children = copy(children(n1))
|
||||
n3Parents = copy(parents(n3))
|
||||
n1Children = children(n1)
|
||||
n3Parents = parents(n3)
|
||||
|
||||
n1Task = copy(task(n1))
|
||||
n3Task = copy(task(n3))
|
||||
n1Task = copy(n1.task)
|
||||
n3Task = copy(n3.task)
|
||||
|
||||
# assemble the input node vectors of n1 and n3 to save into the FusedComputeTask
|
||||
n1Inputs = Vector{Symbol}()
|
||||
@ -177,7 +177,7 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com
|
||||
remove_node!(graph, n2)
|
||||
|
||||
# get n3's children now so it automatically excludes n2
|
||||
n3Children = copy(children(n3))
|
||||
n3Children = children(n3)
|
||||
|
||||
n3Inputs = Vector{Symbol}()
|
||||
for child in n3Children
|
||||
@ -228,7 +228,7 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
n1 = nodes[1]
|
||||
n1Children = copy(children(n1))
|
||||
n1Children = children(n1)
|
||||
|
||||
n1Parents = Set(n1.parents)
|
||||
|
||||
@ -245,7 +245,7 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
remove_edge!(graph, child, n)
|
||||
end
|
||||
|
||||
for parent in copy(parents(n))
|
||||
for parent in parents(n)
|
||||
remove_edge!(graph, n, parent)
|
||||
|
||||
# collect all parents
|
||||
@ -278,17 +278,14 @@ Split the given node into one node per parent, return the applied difference to
|
||||
|
||||
For details see [`NodeSplit`](@ref).
|
||||
"""
|
||||
function node_split!(
|
||||
graph::DAG,
|
||||
n1::Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}},
|
||||
) where {TaskType <: AbstractTask}
|
||||
function node_split!(graph::DAG, n1::Node)
|
||||
@assert is_valid_node_split_input(graph, n1)
|
||||
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
n1Parents = copy(parents(n1))
|
||||
n1Children = copy(children(n1))
|
||||
n1Parents = parents(n1)
|
||||
n1Children = children(n1)
|
||||
|
||||
for parent in n1Parents
|
||||
remove_edge!(graph, n1, parent)
|
||||
|
@ -13,18 +13,18 @@ function find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
return nothing
|
||||
end
|
||||
|
||||
if length(parents(node)) != 1 || length(children(node)) != 1
|
||||
if length(node.parents) != 1 || length(node.children) != 1
|
||||
return nothing
|
||||
end
|
||||
|
||||
child_node = first(children(node))
|
||||
parent_node = first(parents(node))
|
||||
child_node = first(node.children)
|
||||
parent_node = first(node.parents)
|
||||
|
||||
if !(child_node in graph) || !(parent_node in graph)
|
||||
error("Parents/Children that are not in the graph!!!")
|
||||
end
|
||||
|
||||
if length(parents(child_node)) != 1
|
||||
if length(child_node.parents) != 1
|
||||
return nothing
|
||||
end
|
||||
|
||||
@ -44,11 +44,11 @@ Find node fusions involving the given compute node. The function pushes the foun
|
||||
"""
|
||||
function find_fusions!(graph::DAG, node::ComputeTaskNode)
|
||||
# just find fusions in neighbouring DataTaskNodes
|
||||
for child in children(node)
|
||||
for child in node.children
|
||||
find_fusions!(graph, child)
|
||||
end
|
||||
|
||||
for parent in parents(node)
|
||||
for parent in node.parents
|
||||
find_fusions!(graph, parent)
|
||||
end
|
||||
|
||||
@ -123,10 +123,7 @@ end
|
||||
|
||||
Sort this node's parent and child sets, then find fusions, reductions and splits involving it. Needs to be called after the node was changed in some way.
|
||||
"""
|
||||
function clean_node!(
|
||||
graph::DAG,
|
||||
node::Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}},
|
||||
) where {TaskType <: AbstractTask}
|
||||
function clean_node!(graph::DAG, node::Node)
|
||||
sort_node!(node)
|
||||
|
||||
find_fusions!(graph, node)
|
||||
|
@ -203,18 +203,18 @@ function generate_operations(graph::DAG)
|
||||
# --- find possible node fusions ---
|
||||
@threads for node in nodeArray
|
||||
if (typeof(node) <: DataTaskNode)
|
||||
if length(parents(node)) != 1
|
||||
if length(node.parents) != 1
|
||||
# data node can only have a single parent
|
||||
continue
|
||||
end
|
||||
parent_node = first(parents(node))
|
||||
parent_node = first(node.parents)
|
||||
|
||||
if length(children(node)) != 1
|
||||
if length(node.children) != 1
|
||||
# this node is an entry node or has multiple children which should not be possible
|
||||
continue
|
||||
end
|
||||
child_node = first(children(node))
|
||||
if (length(parents(child_node)) != 1)
|
||||
child_node = first(node.children)
|
||||
if (length(child_node.parents) != 1)
|
||||
continue
|
||||
end
|
||||
|
||||
|
@ -14,7 +14,9 @@ function get_operations(graph::DAG)
|
||||
generate_operations(graph)
|
||||
end
|
||||
|
||||
clean_node!.(Ref(graph), graph.dirtyNodes)
|
||||
for node in graph.dirtyNodes
|
||||
clean_node!(graph, node)
|
||||
end
|
||||
empty!(graph.dirtyNodes)
|
||||
|
||||
return graph.possibleOperations
|
||||
|
@ -1,39 +0,0 @@
|
||||
import Base.iterate
|
||||
|
||||
const _POSSIBLE_OPERATIONS_FIELDS = fieldnames(PossibleOperations)
|
||||
|
||||
_POIteratorStateType =
|
||||
NamedTuple{(:result, :state), Tuple{Union{NodeFusion, NodeReduction, NodeSplit}, Tuple{Symbol, Int64}}}
|
||||
|
||||
@inline function iterate(possibleOperations::PossibleOperations)::Union{Nothing, _POIteratorStateType}
|
||||
for fieldname in _POSSIBLE_OPERATIONS_FIELDS
|
||||
iterator = iterate(getfield(possibleOperations, fieldname))
|
||||
if (!isnothing(iterator))
|
||||
return (result = iterator[1], state = (fieldname, iterator[2]))
|
||||
end
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
@inline function iterate(possibleOperations::PossibleOperations, state)::Union{Nothing, _POIteratorStateType}
|
||||
newStateSym = state[1]
|
||||
newStateIt = iterate(getfield(possibleOperations, newStateSym), state[2])
|
||||
if !isnothing(newStateIt)
|
||||
return (result = newStateIt[1], state = (newStateSym, newStateIt[2]))
|
||||
end
|
||||
|
||||
# cycle to next field
|
||||
index = findfirst(x -> x == newStateSym, _POSSIBLE_OPERATIONS_FIELDS) + 1
|
||||
|
||||
while index <= length(_POSSIBLE_OPERATIONS_FIELDS)
|
||||
newStateSym = _POSSIBLE_OPERATIONS_FIELDS[index]
|
||||
newStateIt = iterate(getfield(possibleOperations, newStateSym))
|
||||
if !isnothing(newStateIt)
|
||||
return (result = newStateIt[1], state = (newStateSym, newStateIt[2]))
|
||||
end
|
||||
index += 1
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
@ -30,7 +30,7 @@ function show(io::IO, op::NodeReduction)
|
||||
print(io, "NR: ")
|
||||
print(io, length(op.input))
|
||||
print(io, "x")
|
||||
return print(io, task(op.input[1]))
|
||||
return print(io, op.input[1].task)
|
||||
end
|
||||
|
||||
"""
|
||||
@ -40,7 +40,7 @@ Print a string representation of the node split to io.
|
||||
"""
|
||||
function show(io::IO, op::NodeSplit)
|
||||
print(io, "NS: ")
|
||||
return print(io, task(op.input))
|
||||
return print(io, op.input.task)
|
||||
end
|
||||
|
||||
"""
|
||||
@ -50,9 +50,9 @@ Print a string representation of the node fusion to io.
|
||||
"""
|
||||
function show(io::IO, op::NodeFusion)
|
||||
print(io, "NF: ")
|
||||
print(io, task(op.input[1]))
|
||||
print(io, op.input[1].task)
|
||||
print(io, "->")
|
||||
print(io, task(op.input[2]))
|
||||
print(io, op.input[2].task)
|
||||
print(io, "->")
|
||||
return print(io, task(op.input[3]))
|
||||
return print(io, op.input[3].task)
|
||||
end
|
||||
|
@ -40,9 +40,8 @@ A chain of (n1, n2, n3) can be fused if:
|
||||
|
||||
See also: [`can_fuse`](@ref)
|
||||
"""
|
||||
struct NodeFusion{TaskType1 <: AbstractComputeTask, TaskType2 <: AbstractDataTask, TaskType3 <: AbstractComputeTask} <:
|
||||
Operation
|
||||
input::Tuple{ComputeTaskNode{TaskType1}, DataTaskNode{TaskType2}, ComputeTaskNode{TaskType3}}
|
||||
struct NodeFusion <: Operation
|
||||
input::Tuple{ComputeTaskNode, DataTaskNode, ComputeTaskNode}
|
||||
end
|
||||
|
||||
"""
|
||||
@ -50,12 +49,8 @@ end
|
||||
|
||||
The applied version of the [`NodeFusion`](@ref).
|
||||
"""
|
||||
struct AppliedNodeFusion{
|
||||
TaskType1 <: AbstractComputeTask,
|
||||
TaskType2 <: AbstractDataTask,
|
||||
TaskType3 <: AbstractComputeTask,
|
||||
} <: AppliedOperation
|
||||
operation::NodeFusion{TaskType1, TaskType2, TaskType3}
|
||||
struct AppliedNodeFusion <: AppliedOperation
|
||||
operation::NodeFusion
|
||||
diff::Diff
|
||||
end
|
||||
|
||||
@ -78,8 +73,8 @@ A vector of nodes can be reduced if:
|
||||
|
||||
See also: [`can_reduce`](@ref)
|
||||
"""
|
||||
struct NodeReduction{NodeType <: Node} <: Operation
|
||||
input::Vector{NodeType}
|
||||
struct NodeReduction <: Operation
|
||||
input::Vector{Node}
|
||||
end
|
||||
|
||||
"""
|
||||
@ -87,8 +82,8 @@ end
|
||||
|
||||
The applied version of the [`NodeReduction`](@ref).
|
||||
"""
|
||||
struct AppliedNodeReduction{NodeType <: Node} <: AppliedOperation
|
||||
operation::NodeReduction{NodeType}
|
||||
struct AppliedNodeReduction <: AppliedOperation
|
||||
operation::NodeReduction
|
||||
diff::Diff
|
||||
end
|
||||
|
||||
@ -107,8 +102,8 @@ A node can be split if:
|
||||
|
||||
See also: [`can_split`](@ref)
|
||||
"""
|
||||
struct NodeSplit{NodeType <: Node} <: Operation
|
||||
input::NodeType
|
||||
struct NodeSplit <: Operation
|
||||
input::Node
|
||||
end
|
||||
|
||||
"""
|
||||
@ -116,7 +111,7 @@ end
|
||||
|
||||
The applied version of the [`NodeSplit`](@ref).
|
||||
"""
|
||||
struct AppliedNodeSplit{NodeType <: Node} <: AppliedOperation
|
||||
operation::NodeSplit{NodeType}
|
||||
struct AppliedNodeSplit <: AppliedOperation
|
||||
operation::NodeSplit
|
||||
diff::Diff
|
||||
end
|
||||
|
@ -61,7 +61,7 @@ function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
return false
|
||||
end
|
||||
|
||||
if length(parents(n2)) != 1 || length(children(n2)) != 1 || length(parents(n1)) != 1
|
||||
if length(n2.parents) != 1 || length(n2.children) != 1 || length(n1.parents) != 1
|
||||
return false
|
||||
end
|
||||
|
||||
@ -74,15 +74,12 @@ end
|
||||
Return whether the given two nodes can be reduced. See [`NodeReduction`](@ref) for the requirements.
|
||||
"""
|
||||
function can_reduce(n1::Node, n2::Node)
|
||||
return false
|
||||
end
|
||||
if (n1.task != n2.task)
|
||||
return false
|
||||
end
|
||||
|
||||
function can_reduce(
|
||||
n1::NodeType,
|
||||
n2::NodeType,
|
||||
) where {TaskType <: AbstractTask, NodeType <: Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}}}
|
||||
n1_length = length(children(n1))
|
||||
n2_length = length(children(n2))
|
||||
n1_length = length(n1.children)
|
||||
n2_length = length(n2.children)
|
||||
|
||||
if (n1_length != n2_length)
|
||||
return false
|
||||
@ -91,19 +88,19 @@ function can_reduce(
|
||||
# this seems to be the most common case so do this first
|
||||
# doing it manually is a lot faster than using the sets for a general solution
|
||||
if (n1_length == 2)
|
||||
if (children(n1)[1] != children(n2)[1])
|
||||
if (children(n1)[1] != children(n2)[2])
|
||||
if (n1.children[1] != n2.children[1])
|
||||
if (n1.children[1] != n2.children[2])
|
||||
return false
|
||||
end
|
||||
# 1_1 == 2_2
|
||||
if (children(n1)[2] != children(n2)[1])
|
||||
if (n1.children[2] != n2.children[1])
|
||||
return false
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
# 1_1 == 2_1
|
||||
if (children(n1)[2] != children(n2)[2])
|
||||
if (n1.children[2] != n2.children[2])
|
||||
return false
|
||||
end
|
||||
return true
|
||||
@ -111,11 +108,11 @@ function can_reduce(
|
||||
|
||||
# this is simple
|
||||
if (n1_length == 1)
|
||||
return children(n1)[1] == children(n2)[1]
|
||||
return n1.children[1] == n2.children[1]
|
||||
end
|
||||
|
||||
# this takes a long time
|
||||
return Set(children(n1)) == Set(children(n2))
|
||||
return Set(n1.children) == Set(n2.children)
|
||||
end
|
||||
|
||||
"""
|
||||
@ -141,14 +138,7 @@ end
|
||||
|
||||
Equality comparison between two node fusions. Two node fusions are considered equal if they have the same inputs.
|
||||
"""
|
||||
function ==(
|
||||
op1::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2},
|
||||
op2::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2},
|
||||
) where {
|
||||
ComputeTaskType1 <: AbstractComputeTask,
|
||||
DataTaskType <: AbstractDataTask,
|
||||
ComputeTaskType2 <: AbstractComputeTask,
|
||||
}
|
||||
function ==(op1::NodeFusion, op2::NodeFusion)
|
||||
# there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same
|
||||
return op1.input[2] == op2.input[2]
|
||||
end
|
||||
|
@ -54,9 +54,9 @@ function is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node})
|
||||
@assert is_valid(graph, n)
|
||||
end
|
||||
|
||||
t = typeof(task(nodes[1]))
|
||||
t = typeof(nodes[1].task)
|
||||
for n in nodes
|
||||
if typeof(task(n)) != t
|
||||
if typeof(n.task) != t
|
||||
throw(AssertionError("[Node Reduction] The given nodes are not of the same type"))
|
||||
end
|
||||
|
||||
@ -115,7 +115,7 @@ Intended for use with `@assert` or `@test`.
|
||||
"""
|
||||
function is_valid(graph::DAG, nr::NodeReduction)
|
||||
@assert is_valid_node_reduction_input(graph, nr.input)
|
||||
#@assert nr in graph.possibleOperations.nodeReductions "NodeReduction is not part of the graph's possible operations!"
|
||||
@assert nr in graph.possibleOperations.nodeReductions "NodeReduction is not part of the graph's possible operations!"
|
||||
return true
|
||||
end
|
||||
|
||||
@ -128,7 +128,7 @@ Intended for use with `@assert` or `@test`.
|
||||
"""
|
||||
function is_valid(graph::DAG, ns::NodeSplit)
|
||||
@assert is_valid_node_split_input(graph, ns.input)
|
||||
#@assert ns in graph.possibleOperations.nodeSplits "NodeSplit is not part of the graph's possible operations!"
|
||||
@assert ns in graph.possibleOperations.nodeSplits "NodeSplit is not part of the graph's possible operations!"
|
||||
return true
|
||||
end
|
||||
|
||||
@ -141,6 +141,6 @@ Intended for use with `@assert` or `@test`.
|
||||
"""
|
||||
function is_valid(graph::DAG, nf::NodeFusion)
|
||||
@assert is_valid_node_fusion_input(graph, nf.input[1], nf.input[2], nf.input[3])
|
||||
#@assert nf in graph.possibleOperations.nodeFusions "NodeFusion is not part of the graph's possible operations!"
|
||||
@assert nf in graph.possibleOperations.nodeFusions "NodeFusion is not part of the graph's possible operations!"
|
||||
return true
|
||||
end
|
||||
|
@ -1,73 +0,0 @@
|
||||
"""
|
||||
GreedyOptimizer
|
||||
|
||||
An implementation of the greedy optimization algorithm, simply choosing the best next option evaluated with the given estimator.
|
||||
|
||||
The fixpoint is reached when any leftover operation would increase the graph's total cost according to the given estimator.
|
||||
"""
|
||||
struct GreedyOptimizer{EstimatorType <: AbstractEstimator} <: AbstractOptimizer
|
||||
estimator::EstimatorType
|
||||
end
|
||||
|
||||
function optimize_step!(optimizer::GreedyOptimizer, graph::DAG)
|
||||
# generate all options
|
||||
operations = get_operations(graph)
|
||||
if isempty(operations)
|
||||
return false
|
||||
end
|
||||
|
||||
result = nothing
|
||||
|
||||
lowestCost = reduce(
|
||||
(acc, op) -> begin
|
||||
op_cost = operation_effect(optimizer.estimator, graph, op)
|
||||
if op_cost < acc
|
||||
result = op
|
||||
return op_cost
|
||||
end
|
||||
return acc
|
||||
end,
|
||||
operations;
|
||||
init = typemax(cost_type(optimizer.estimator)),
|
||||
)
|
||||
|
||||
if lowestCost > zero(cost_type(optimizer.estimator))
|
||||
return false
|
||||
end
|
||||
|
||||
push_operation!(graph, result)
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function fixpoint_reached(optimizer::GreedyOptimizer, graph::DAG)
|
||||
# generate all options
|
||||
operations = get_operations(graph)
|
||||
if isempty(operations)
|
||||
return true
|
||||
end
|
||||
|
||||
lowestCost = reduce(
|
||||
(acc, op) -> begin
|
||||
op_cost = operation_effect(optimizer.estimator, graph, op)
|
||||
if op_cost < acc
|
||||
return op_cost
|
||||
end
|
||||
return acc
|
||||
end,
|
||||
operations;
|
||||
init = typemax(cost_type(optimizer.estimator)),
|
||||
)
|
||||
|
||||
if lowestCost > zero(cost_type(optimizer.estimator))
|
||||
return true
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
function optimize_to_fixpoint!(optimizer::GreedyOptimizer, graph::DAG)
|
||||
while optimize_step!(optimizer, graph)
|
||||
end
|
||||
return nothing
|
||||
end
|
@ -1,60 +0,0 @@
|
||||
|
||||
"""
|
||||
AbstractOptimizer
|
||||
|
||||
Abstract base type for optimizer implementations.
|
||||
"""
|
||||
abstract type AbstractOptimizer end
|
||||
|
||||
"""
|
||||
optimize_step!(optimizer::AbstractOptimizer, graph::DAG)
|
||||
|
||||
Interface function that must be implemented by implementations of [`AbstractOptimizer`](@ref). Returns `true` if an operations has been applied, `false` if not, usually when a fixpoint of the algorithm has been reached.
|
||||
|
||||
It should do one smallest logical step on the given [`DAG`](@ref), muting the graph and, if necessary, the optimizer's state.
|
||||
"""
|
||||
function optimize_step! end
|
||||
|
||||
"""
|
||||
optimize!(optimizer::AbstractOptimizer, graph::DAG, n::Int)
|
||||
|
||||
Function calling the given optimizer `n` times, muting the graph. Returns `true` if the requested number of operations has been applied, `false` if not, usually when a fixpoint of the algorithm has been reached.
|
||||
|
||||
If a more efficient method exists, this can be overloaded for a specific optimizer.
|
||||
"""
|
||||
function optimize!(optimizer::AbstractOptimizer, graph::DAG, n::Int)
|
||||
for i in 1:n
|
||||
if !optimize_step!(optimizer, graph)
|
||||
return false
|
||||
end
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
fixpoint_reached(optimizer::AbstractOptimizer, graph::DAG)
|
||||
|
||||
Interface function that can be implemented by optimization algorithms that can reach a fixpoint, returning as a `Bool` whether it has been reached. The default implementation returns `false`.
|
||||
|
||||
See also: [`optimize_to_fixpoint!`](@ref)
|
||||
"""
|
||||
function fixpoint_reached(optimizer::AbstractOptimizer, graph::DAG)
|
||||
return false
|
||||
end
|
||||
|
||||
"""
|
||||
optimize_to_fixpoint!(optimizer::AbstractOptimizer, graph::DAG)
|
||||
|
||||
Interface function that can be implemented by optimization algorithms that can reach a fixpoint. The algorithm will be run until that fixpoint is reached, at which point [`fixpoint_reached`](@ref) should return true.
|
||||
|
||||
A usual implementation might look like this:
|
||||
```julia
|
||||
function optimize_to_fixpoint!(optimizer::MyOptimizer, graph::DAG)
|
||||
while !fixpoint_reached(optimizer, graph)
|
||||
optimize_step!(optimizer, graph)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
```
|
||||
"""
|
||||
function optimize_to_fixpoint! end
|
@ -1,49 +0,0 @@
|
||||
using Random
|
||||
|
||||
"""
|
||||
RandomWalkOptimizer
|
||||
|
||||
An optimizer that randomly pushes or pops operations. It doesn't optimize in any direction and is useful mainly for testing purposes.
|
||||
|
||||
This algorithm never reaches a fixpoint, so it does not implement [`optimize_to_fixpoint`](@ref).
|
||||
"""
|
||||
struct RandomWalkOptimizer <: AbstractOptimizer
|
||||
rng::AbstractRNG
|
||||
end
|
||||
|
||||
function optimize_step!(optimizer::RandomWalkOptimizer, graph::DAG)
|
||||
operations = get_operations(graph)
|
||||
|
||||
if sum(length(operations)) == 0 && length(graph.appliedOperations) + length(graph.operationsToApply) == 0
|
||||
# in case there are zero operations possible at all on the graph
|
||||
return false
|
||||
end
|
||||
|
||||
r = optimizer.rng
|
||||
# try until something was applied or popped
|
||||
while true
|
||||
# choose push or pop
|
||||
if rand(r, Bool)
|
||||
# push
|
||||
|
||||
# choose one of fuse/split/reduce
|
||||
option = rand(r, 1:3)
|
||||
if option == 1 && !isempty(operations.nodeFusions)
|
||||
push_operation!(graph, rand(r, collect(operations.nodeFusions)))
|
||||
return true
|
||||
elseif option == 2 && !isempty(operations.nodeReductions)
|
||||
push_operation!(graph, rand(r, collect(operations.nodeReductions)))
|
||||
return true
|
||||
elseif option == 3 && !isempty(operations.nodeSplits)
|
||||
push_operation!(graph, rand(r, collect(operations.nodeSplits)))
|
||||
return true
|
||||
end
|
||||
else
|
||||
# pop
|
||||
if (can_pop(graph))
|
||||
pop_operation!(graph)
|
||||
return true
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
@ -1,30 +0,0 @@
|
||||
"""
|
||||
ReductionOptimizer
|
||||
|
||||
An optimizer that simply applies an available [`NodeReduction`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeReduction`](@ref)s in the graph.
|
||||
"""
|
||||
struct ReductionOptimizer <: AbstractOptimizer end
|
||||
|
||||
function optimize_step!(optimizer::ReductionOptimizer, graph::DAG)
|
||||
# generate all options
|
||||
operations = get_operations(graph)
|
||||
if fixpoint_reached(optimizer, graph)
|
||||
return false
|
||||
end
|
||||
|
||||
push_operation!(graph, first(operations.nodeReductions))
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function fixpoint_reached(optimizer::ReductionOptimizer, graph::DAG)
|
||||
operations = get_operations(graph)
|
||||
return isempty(operations.nodeReductions)
|
||||
end
|
||||
|
||||
function optimize_to_fixpoint!(optimizer::ReductionOptimizer, graph::DAG)
|
||||
while !fixpoint_reached(optimizer, graph)
|
||||
optimize_step!(optimizer, graph)
|
||||
end
|
||||
return nothing
|
||||
end
|
@ -4,18 +4,14 @@
|
||||
Create an empty [`GraphProperties`](@ref) object.
|
||||
"""
|
||||
function GraphProperties()
|
||||
return (data = 0.0, computeEffort = 0.0, computeIntensity = 0.0, noNodes = 0, noEdges = 0)::GraphProperties
|
||||
end
|
||||
|
||||
@inline function _props(
|
||||
node::DataTaskNode{TaskType},
|
||||
)::Tuple{Float64, Float64, Int64} where {TaskType <: AbstractDataTask}
|
||||
return (data(task(node)) * length(parents(node)), 0.0, length(parents(node)))
|
||||
end
|
||||
@inline function _props(
|
||||
node::ComputeTaskNode{TaskType},
|
||||
)::Tuple{Float64, Float64, Int64} where {TaskType <: AbstractComputeTask}
|
||||
return (0.0, compute_effort(task(node)), length(parents(node)))
|
||||
return (
|
||||
data = 0.0,
|
||||
computeEffort = 0.0,
|
||||
computeIntensity = 0.0,
|
||||
cost = 0.0,
|
||||
noNodes = 0,
|
||||
noEdges = 0,
|
||||
)::GraphProperties
|
||||
end
|
||||
|
||||
"""
|
||||
@ -31,16 +27,16 @@ function GraphProperties(graph::DAG)
|
||||
ce = 0.0
|
||||
ed = 0
|
||||
for node in graph.nodes
|
||||
props = _props(node)
|
||||
d += props[1]
|
||||
ce += props[2]
|
||||
ed += props[3]
|
||||
d += data(node.task) * length(node.parents)
|
||||
ce += compute_effort(node.task)
|
||||
ed += length(node.parents)
|
||||
end
|
||||
|
||||
return (
|
||||
data = d,
|
||||
computeEffort = ce,
|
||||
computeIntensity = (d == 0) ? 0.0 : ce / d,
|
||||
cost = 0.0, # TODO
|
||||
noNodes = length(graph.nodes),
|
||||
noEdges = ed,
|
||||
)::GraphProperties
|
||||
@ -54,18 +50,23 @@ The graph's properties after applying the [`Diff`](@ref) will be `get_properties
|
||||
For reverting a diff, it's `get_properties(graph) - GraphProperties(diff)`.
|
||||
"""
|
||||
function GraphProperties(diff::Diff)
|
||||
d = 0.0
|
||||
ce = 0.0
|
||||
c = 0.0 # TODO
|
||||
|
||||
ce =
|
||||
reduce(+, compute_effort(task(n)) for n in diff.addedNodes; init = 0.0) -
|
||||
reduce(+, compute_effort(task(n)) for n in diff.removedNodes; init = 0.0)
|
||||
reduce(+, compute_effort(n.task) for n in diff.addedNodes; init = 0.0) -
|
||||
reduce(+, compute_effort(n.task) for n in diff.removedNodes; init = 0.0)
|
||||
|
||||
d =
|
||||
reduce(+, data(task(n)) for n in diff.addedNodes; init = 0.0) -
|
||||
reduce(+, data(task(n)) for n in diff.removedNodes; init = 0.0)
|
||||
reduce(+, data(e) for e in diff.addedEdges; init = 0.0) -
|
||||
reduce(+, data(e) for e in diff.removedEdges; init = 0.0)
|
||||
|
||||
return (
|
||||
data = d,
|
||||
computeEffort = ce,
|
||||
computeIntensity = (d == 0) ? 0.0 : ce / d,
|
||||
cost = c,
|
||||
noNodes = length(diff.addedNodes) - length(diff.removedNodes),
|
||||
noEdges = length(diff.addedEdges) - length(diff.removedEdges),
|
||||
)::GraphProperties
|
||||
|
@ -7,10 +7,11 @@ Representation of a [`DAG`](@ref)'s properties.
|
||||
`.data`: The total data transfer.\\
|
||||
`.computeEffort`: The total compute effort.\\
|
||||
`.computeIntensity`: The compute intensity, will always equal `.computeEffort / .data`.\\
|
||||
`.cost`: The estimated cost.\\
|
||||
`.noNodes`: Number of [`Node`](@ref)s.\\
|
||||
`.noEdges`: Number of [`Edge`](@ref)s.
|
||||
"""
|
||||
const GraphProperties = NamedTuple{
|
||||
(:data, :computeEffort, :computeIntensity, :noNodes, :noEdges),
|
||||
Tuple{Float64, Float64, Float64, Int, Int},
|
||||
(:data, :computeEffort, :computeIntensity, :cost, :noNodes, :noEdges),
|
||||
Tuple{Float64, Float64, Float64, Float64, Int, Int},
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ function -(prop1::GraphProperties, prop2::GraphProperties)
|
||||
else
|
||||
(prop1.computeEffort - prop2.computeEffort) / (prop1.data - prop2.data)
|
||||
end,
|
||||
cost = prop1.cost - prop2.cost,
|
||||
noNodes = prop1.noNodes - prop2.noNodes,
|
||||
noEdges = prop1.noEdges - prop2.noEdges,
|
||||
)::GraphProperties
|
||||
@ -33,6 +34,7 @@ function +(prop1::GraphProperties, prop2::GraphProperties)
|
||||
else
|
||||
(prop1.computeEffort + prop2.computeEffort) / (prop1.data + prop2.data)
|
||||
end,
|
||||
cost = prop1.cost + prop2.cost,
|
||||
noNodes = prop1.noNodes + prop2.noNodes,
|
||||
noEdges = prop1.noEdges + prop2.noEdges,
|
||||
)::GraphProperties
|
||||
@ -48,6 +50,7 @@ function -(prop::GraphProperties)
|
||||
data = -prop.data,
|
||||
computeEffort = -prop.computeEffort,
|
||||
computeIntensity = prop.computeIntensity, # no negation here!
|
||||
cost = -prop.cost,
|
||||
noNodes = -prop.noNodes,
|
||||
noEdges = -prop.noEdges,
|
||||
)::GraphProperties
|
||||
|
@ -32,14 +32,14 @@ function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine)
|
||||
if (isa(node, ComputeTaskNode))
|
||||
lowestDevice = peek(deviceAccCost)[1]
|
||||
node.device = lowestDevice
|
||||
deviceAccCost[lowestDevice] = compute_effort(task(node))
|
||||
deviceAccCost[lowestDevice] = compute_effort(node.task)
|
||||
end
|
||||
|
||||
push!(schedule, node)
|
||||
for parent in parents(node)
|
||||
for parent in node.parents
|
||||
# reduce the priority of all parents by one
|
||||
if (!haskey(nodeQueue, parent))
|
||||
enqueue!(nodeQueue, parent => length(children(parent)) - 1)
|
||||
enqueue!(nodeQueue, parent => length(parent.children) - 1)
|
||||
else
|
||||
nodeQueue[parent] = nodeQueue[parent] - 1
|
||||
end
|
||||
|
@ -41,16 +41,16 @@ end
|
||||
Generate and return code for a given [`ComputeTaskNode`](@ref).
|
||||
"""
|
||||
function get_expression(node::ComputeTaskNode)
|
||||
@assert length(children(node)) <= children(task(node)) "Node $(node) has too many children for its task: node has $(length(node.children)) versus task has $(children(task(node)))\nNode's children: $(getfield.(node.children, :children))"
|
||||
@assert length(node.children) <= children(node.task) "Node $(node) has too many children for its task: node has $(length(node.children)) versus task has $(children(node.task))\nNode's children: $(getfield.(node.children, :children))"
|
||||
@assert !ismissing(node.device) "Trying to get expression for an unscheduled ComputeTaskNode\nNode: $(node)"
|
||||
|
||||
inExprs = Vector()
|
||||
for id in getfield.(children(node), :id)
|
||||
for id in getfield.(node.children, :id)
|
||||
push!(inExprs, gen_access_expr(node.device, Symbol(to_var_name(id))))
|
||||
end
|
||||
outExpr = gen_access_expr(node.device, Symbol(to_var_name(node.id)))
|
||||
|
||||
return get_expression(task(node), node.device, inExprs, outExpr)
|
||||
return get_expression(node.task, node.device, inExprs, outExpr)
|
||||
end
|
||||
|
||||
"""
|
||||
@ -59,11 +59,11 @@ end
|
||||
Generate and return code for a given [`DataTaskNode`](@ref).
|
||||
"""
|
||||
function get_expression(node::DataTaskNode)
|
||||
@assert length(children(node)) == 1 "Trying to call get_expression on a data task node that has $(length(node.children)) children instead of 1"
|
||||
@assert length(node.children) == 1 "Trying to call get_expression on a data task node that has $(length(node.children)) children instead of 1"
|
||||
|
||||
# TODO: dispatch to device implementations generating the copy commands
|
||||
|
||||
child = children(node)[1]
|
||||
child = node.children[1]
|
||||
inExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(child.id))))
|
||||
outExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(node.id))))
|
||||
dataTransportExp = Meta.parse("$outExpr = $inExpr")
|
||||
@ -79,7 +79,7 @@ Generate and return code for the initial input reading expression for [`DataTask
|
||||
See also: [`get_entry_nodes`](@ref)
|
||||
"""
|
||||
function get_init_expression(node::DataTaskNode, device::AbstractDevice)
|
||||
@assert isempty(children(node)) "Trying to call get_init_expression on a data task node that is not an entry node."
|
||||
@assert isempty(node.children) "Trying to call get_init_expression on a data task node that is not an entry node."
|
||||
|
||||
inExpr = eval(gen_access_expr(device, Symbol("$(to_var_name(node.id))_in")))
|
||||
outExpr = eval(gen_access_expr(device, Symbol(to_var_name(node.id))))
|
||||
|
@ -17,16 +17,15 @@ copy(t::AbstractComputeTask) = typeof(t)()
|
||||
|
||||
Return a copy of th egiven [`FusedComputeTask`](@ref).
|
||||
"""
|
||||
function copy(t::FusedComputeTask)
|
||||
return FusedComputeTask(copy(t.first_task), copy(t.second_task), copy(t.t1_inputs), t.t1_output, copy(t.t2_inputs))
|
||||
function copy(t::FusedComputeTask{T1, T2}) where {T1, T2}
|
||||
return FusedComputeTask{T1, T2}(
|
||||
copy(t.first_task),
|
||||
copy(t.second_task),
|
||||
copy(t.t1_inputs),
|
||||
t.t1_output,
|
||||
copy(t.t2_inputs),
|
||||
)
|
||||
end
|
||||
|
||||
function FusedComputeTask(
|
||||
T1::Type{<:AbstractComputeTask},
|
||||
T2::Type{<:AbstractComputeTask},
|
||||
t1_inputs::Vector{String},
|
||||
t1_output::String,
|
||||
t2_inputs::Vector{String},
|
||||
)
|
||||
return FusedComputeTask(T1(), T2(), t1_inputs, t1_output, t2_inputs)
|
||||
end
|
||||
FusedComputeTask{T1, T2}(t1_inputs::Vector{String}, t1_output::String, t2_inputs::Vector{String}) where {T1, T2} =
|
||||
FusedComputeTask{T1, T2}(T1(), T2(), t1_inputs, t1_output, t2_inputs)
|
||||
|
@ -30,7 +30,7 @@ compute(t::AbstractDataTask; data...) = data
|
||||
|
||||
Fallback implementation of the compute effort of a task, throwing an error.
|
||||
"""
|
||||
function compute_effort(t::AbstractTask)::Float64
|
||||
function compute_effort(t::AbstractTask)
|
||||
# default implementation using compute
|
||||
return error("Need to implement compute_effort()")
|
||||
end
|
||||
@ -40,7 +40,7 @@ end
|
||||
|
||||
Fallback implementation of the data of a task, throwing an error.
|
||||
"""
|
||||
function data(t::AbstractTask)::Float64
|
||||
function data(t::AbstractTask)
|
||||
return error("Need to implement data()")
|
||||
end
|
||||
|
||||
@ -49,28 +49,28 @@ end
|
||||
|
||||
Return the compute effort of a data task, always zero, regardless of the specific task.
|
||||
"""
|
||||
compute_effort(t::AbstractDataTask)::Float64 = 0.0
|
||||
compute_effort(t::AbstractDataTask) = 0
|
||||
|
||||
"""
|
||||
data(t::AbstractDataTask)
|
||||
|
||||
Return the data of a data task. Given by the task's `.data` field.
|
||||
"""
|
||||
data(t::AbstractDataTask)::Float64 = getfield(t, :data)
|
||||
data(t::AbstractDataTask) = getfield(t, :data)
|
||||
|
||||
"""
|
||||
data(t::AbstractComputeTask)
|
||||
|
||||
Return the data of a compute task, always zero, regardless of the specific task.
|
||||
"""
|
||||
data(t::AbstractComputeTask)::Float64 = 0.0
|
||||
data(t::AbstractComputeTask) = 0
|
||||
|
||||
"""
|
||||
compute_effort(t::FusedComputeTask)
|
||||
|
||||
Return the compute effort of a fused compute task.
|
||||
"""
|
||||
function compute_effort(t::FusedComputeTask)::Float64
|
||||
function compute_effort(t::FusedComputeTask)
|
||||
return compute_effort(t.first_task) + compute_effort(t.second_task)
|
||||
end
|
||||
|
||||
@ -79,4 +79,4 @@ end
|
||||
|
||||
Return a tuple of a the fused compute task's components' types.
|
||||
"""
|
||||
get_types(t::FusedComputeTask) = (typeof(t.first_task), typeof(t.second_task))
|
||||
get_types(::FusedComputeTask{T1, T2}) where {T1, T2} = (T1, T2)
|
||||
|
@ -26,9 +26,9 @@ A fused compute task made up of the computation of first `T1` and then `T2`.
|
||||
|
||||
Also see: [`get_types`](@ref).
|
||||
"""
|
||||
struct FusedComputeTask <: AbstractComputeTask
|
||||
first_task::AbstractComputeTask
|
||||
second_task::AbstractComputeTask
|
||||
struct FusedComputeTask{T1 <: AbstractComputeTask, T2 <: AbstractComputeTask} <: AbstractComputeTask
|
||||
first_task::T1
|
||||
second_task::T2
|
||||
# the names of the inputs for T1
|
||||
t1_inputs::Vector{Symbol}
|
||||
# output name of T1
|
||||
|
32
src/trie.jl
32
src/trie.jl
@ -3,9 +3,9 @@
|
||||
|
||||
Helper struct for [`NodeTrie`](@ref). After the Trie's first level, every Trie level contains the vector of nodes that had children up to that level, and the TrieNode's children by UUID of the node's children.
|
||||
"""
|
||||
mutable struct NodeIdTrie{NodeType <: Node}
|
||||
value::Vector{NodeType}
|
||||
children::Dict{UUID, NodeIdTrie{NodeType}}
|
||||
mutable struct NodeIdTrie
|
||||
value::Vector{Node}
|
||||
children::Dict{UUID, NodeIdTrie}
|
||||
end
|
||||
|
||||
"""
|
||||
@ -35,8 +35,8 @@ end
|
||||
|
||||
Constructor for an empty [`NodeIdTrie`](@ref).
|
||||
"""
|
||||
function NodeIdTrie{NodeType}() where {NodeType <: Node}
|
||||
return NodeIdTrie(Vector{NodeType}(), Dict{UUID, NodeIdTrie{NodeType}}())
|
||||
function NodeIdTrie()
|
||||
return NodeIdTrie(Vector{Node}(), Dict{UUID, NodeIdTrie}())
|
||||
end
|
||||
|
||||
"""
|
||||
@ -44,12 +44,8 @@ end
|
||||
|
||||
Insert the given node into the trie. The depth is used to iterate through the trie layers, while the function calls itself recursively until it ran through all children of the node.
|
||||
"""
|
||||
function insert_helper!(
|
||||
trie::NodeIdTrie{NodeType},
|
||||
node::NodeType,
|
||||
depth::Int,
|
||||
) where {TaskType <: AbstractTask, NodeType <: Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}}}
|
||||
if (length(children(node)) == depth)
|
||||
function insert_helper!(trie::NodeIdTrie, node::Node, depth::Int)
|
||||
if (length(node.children) == depth)
|
||||
push!(trie.value, node)
|
||||
return nothing
|
||||
end
|
||||
@ -58,7 +54,7 @@ function insert_helper!(
|
||||
id = node.children[depth].id
|
||||
|
||||
if (!haskey(trie.children, id))
|
||||
trie.children[id] = NodeIdTrie{NodeType}()
|
||||
trie.children[id] = NodeIdTrie()
|
||||
end
|
||||
return insert_helper!(trie.children[id], node, depth)
|
||||
end
|
||||
@ -68,14 +64,12 @@ end
|
||||
|
||||
Insert the given node into the trie. It's sorted by its type in the first layer, then by its children in the following layers.
|
||||
"""
|
||||
function insert!(
|
||||
trie::NodeTrie,
|
||||
node::NodeType,
|
||||
) where {TaskType <: AbstractTask, NodeType <: Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}}}
|
||||
if (!haskey(trie.children, NodeType))
|
||||
trie.children[NodeType] = NodeIdTrie{NodeType}()
|
||||
function insert!(trie::NodeTrie, node::Node)
|
||||
t = typeof(node.task)
|
||||
if (!haskey(trie.children, t))
|
||||
trie.children[t] = NodeIdTrie()
|
||||
end
|
||||
return insert_helper!(trie.children[NodeType], node, 0)
|
||||
return insert_helper!(trie.children[typeof(node.task)], node, 0)
|
||||
end
|
||||
|
||||
"""
|
||||
|
@ -36,8 +36,8 @@ Sort the nodes' parents and children vectors. The vectors are mostly very short
|
||||
Sorted nodes are required to make the finding of [`NodeReduction`](@ref)s a lot faster using the [`NodeTrie`](@ref) data structure.
|
||||
"""
|
||||
function sort_node!(node::Node)
|
||||
sort!(children(node), lt = lt_nodes)
|
||||
return sort!(parents(node), lt = lt_nodes)
|
||||
sort!(node.children, lt = lt_nodes)
|
||||
return sort!(node.parents, lt = lt_nodes)
|
||||
end
|
||||
|
||||
"""
|
||||
|
@ -1,5 +1,4 @@
|
||||
[deps]
|
||||
AccurateArithmetic = "22286c92-06ac-501d-9306-4abd417d9753"
|
||||
QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
|
||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
|
@ -6,12 +6,9 @@ using Test
|
||||
include("unit_tests_tasks.jl")
|
||||
include("unit_tests_nodes.jl")
|
||||
include("unit_tests_properties.jl")
|
||||
include("unit_tests_estimator.jl")
|
||||
include("unit_tests_abcmodel.jl")
|
||||
include("node_reduction.jl")
|
||||
include("unit_tests_graph.jl")
|
||||
include("unit_tests_execution.jl")
|
||||
include("unit_tests_optimization.jl")
|
||||
|
||||
include("known_graphs.jl")
|
||||
end
|
||||
|
@ -1,26 +0,0 @@
|
||||
using MetagraphOptimization
|
||||
using QEDbase
|
||||
|
||||
import MetagraphOptimization.interaction_result
|
||||
|
||||
def_momentum = SFourMomentum(1.0, 0.0, 0.0, 0.0)
|
||||
|
||||
testparticleTypes = [ParticleA, ParticleB, ParticleC]
|
||||
testparticles = [ParticleA(def_momentum), ParticleB(def_momentum), ParticleC(def_momentum)]
|
||||
|
||||
@testset "Unit Tests ABC-Model" begin
|
||||
@testset "Interaction Result" begin
|
||||
for p1 in testparticleTypes, p2 in testparticleTypes
|
||||
if (p1 == p2)
|
||||
@test_throws AssertionError interaction_result(p1, p2)
|
||||
else
|
||||
@test interaction_result(p1, p2) == setdiff(testparticleTypes, [p1, p2])[1]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Vertex" begin
|
||||
@test isapprox(MetagraphOptimization.vertex(), 1 / 137.0)
|
||||
end
|
||||
end
|
||||
println("ABC-Model Unit Tests Complete!")
|
@ -1,92 +0,0 @@
|
||||
function test_op_specific(estimator, graph, nf::NodeFusion)
|
||||
estimate = operation_effect(estimator, graph, nf)
|
||||
data_reduce = data(nf.input[2].task)
|
||||
|
||||
@test isapprox(estimate.data, -data_reduce)
|
||||
@test isapprox(estimate.computeEffort, 0; atol = eps(Float64))
|
||||
@test isapprox(estimate.computeIntensity, 0; atol = eps(Float64))
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function test_op_specific(estimator, graph, nr::NodeReduction)
|
||||
estimate = operation_effect(estimator, graph, nr)
|
||||
|
||||
data_reduce = data(nr.input[1].task) * (length(nr.input) - 1)
|
||||
compute_effort_reduce = compute_effort(nr.input[1].task) * (length(nr.input) - 1)
|
||||
|
||||
@test isapprox(estimate.data, -data_reduce; atol = eps(Float64))
|
||||
@test isapprox(estimate.computeEffort, -compute_effort_reduce)
|
||||
@test isapprox(estimate.computeIntensity, compute_effort_reduce / data_reduce)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function test_op_specific(estimator, graph, ns::NodeSplit)
|
||||
estimate = operation_effect(estimator, graph, ns)
|
||||
|
||||
copies = length(ns.input.parents) - 1
|
||||
|
||||
data_increase = data(ns.input.task) * copies
|
||||
compute_effort_increase = compute_effort(ns.input.task) * copies
|
||||
|
||||
@test isapprox(estimate.data, data_increase; atol = eps(Float64))
|
||||
@test isapprox(estimate.computeEffort, compute_effort_increase)
|
||||
@test isapprox(estimate.computeIntensity, compute_effort_increase / data_increase)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function test_op(estimator, graph, op)
|
||||
estimate_before = graph_cost(estimator, graph)
|
||||
|
||||
estimate = operation_effect(estimator, graph, op)
|
||||
|
||||
push_operation!(graph, op)
|
||||
estimate_after_apply = graph_cost(estimator, graph)
|
||||
reset_graph!(graph)
|
||||
|
||||
@test isapprox((estimate_before + estimate).data, estimate_after_apply.data)
|
||||
@test isapprox((estimate_before + estimate).computeEffort, estimate_after_apply.computeEffort)
|
||||
@test isapprox((estimate_before + estimate).computeIntensity, estimate_after_apply.computeIntensity)
|
||||
|
||||
test_op_specific(estimator, graph, op)
|
||||
return nothing
|
||||
end
|
||||
|
||||
@testset "Unit Tests Estimator" begin
|
||||
@testset "Global Metric Estimator" for (graph_string, exp_data, exp_computeEffort) in
|
||||
zip(["AB->AB", "AB->ABBB"], [976, 10944], [53, 1075])
|
||||
estimator = GlobalMetricEstimator()
|
||||
|
||||
@test cost_type(estimator) == CDCost
|
||||
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "$(graph_string).txt"), ABCModel())
|
||||
|
||||
@testset "Graph Cost" begin
|
||||
estimate = graph_cost(estimator, graph)
|
||||
|
||||
@test estimate.data == exp_data
|
||||
@test estimate.computeEffort == exp_computeEffort
|
||||
@test isapprox(estimate.computeIntensity, exp_computeEffort / exp_data)
|
||||
end
|
||||
|
||||
@testset "Operation Cost" begin
|
||||
ops = get_operations(graph)
|
||||
nfs = copy(ops.nodeFusions)
|
||||
nrs = copy(ops.nodeReductions)
|
||||
nss = copy(ops.nodeSplits)
|
||||
|
||||
for nf in nfs
|
||||
test_op(estimator, graph, nf)
|
||||
end
|
||||
for nr in nrs
|
||||
test_op(estimator, graph, nr)
|
||||
end
|
||||
for ns in nss
|
||||
test_op(estimator, graph, ns)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
println("Estimator Unit Tests Complete!")
|
@ -1,50 +1,8 @@
|
||||
import MetagraphOptimization.ABCParticle
|
||||
import MetagraphOptimization.interaction_result
|
||||
|
||||
using QEDbase
|
||||
using AccurateArithmetic
|
||||
using Random
|
||||
|
||||
const RTOL = sqrt(eps(Float64))
|
||||
RNG = Random.default_rng()
|
||||
|
||||
function check_particle_reverse_moment(p1::SFourMomentum, p2::SFourMomentum)
|
||||
@test isapprox(abs(p1.E), abs(p2.E))
|
||||
@test isapprox(p1.px, -p2.px)
|
||||
@test isapprox(p1.py, -p2.py)
|
||||
@test isapprox(p1.pz, -p2.pz)
|
||||
return nothing
|
||||
end
|
||||
|
||||
function ground_truth_graph_result(input::ABCProcessInput)
|
||||
# formula for one diagram:
|
||||
# u_Bp * iλ * u_Ap * S_C * u_B * iλ * u_A
|
||||
# for the second diagram:
|
||||
# u_B * iλ * u_Ap * S_C * u_Bp * iλ * u_Ap
|
||||
# the "u"s are all 1, we ignore the i, λ is 1/137.
|
||||
|
||||
constant = (1 / 137.0)^2
|
||||
|
||||
# calculate particle C in diagram 1
|
||||
diagram1_C = ParticleC(input.inParticles[1].momentum + input.inParticles[2].momentum)
|
||||
diagram2_C = ParticleC(input.inParticles[1].momentum + input.outParticles[2].momentum)
|
||||
|
||||
diagram1_Cp = ParticleC(input.outParticles[1].momentum + input.outParticles[2].momentum)
|
||||
diagram2_Cp = ParticleC(input.outParticles[1].momentum + input.inParticles[2].momentum)
|
||||
|
||||
check_particle_reverse_moment(diagram1_Cp.momentum, diagram1_C.momentum)
|
||||
check_particle_reverse_moment(diagram2_Cp.momentum, diagram2_C.momentum)
|
||||
@test isapprox(getMass2(diagram1_C.momentum), getMass2(diagram1_Cp.momentum))
|
||||
@test isapprox(getMass2(diagram2_C.momentum), getMass2(diagram2_Cp.momentum))
|
||||
|
||||
inner1 = MetagraphOptimization.inner_edge(diagram1_C)
|
||||
inner2 = MetagraphOptimization.inner_edge(diagram2_C)
|
||||
|
||||
diagram1_result = inner1 * constant
|
||||
diagram2_result = inner2 * constant
|
||||
|
||||
return sum_kbn([diagram1_result, diagram2_result])
|
||||
end
|
||||
include("../examples/profiling_utilities.jl")
|
||||
|
||||
@testset "Unit Tests Execution" begin
|
||||
machine = get_machine_info()
|
||||
@ -65,29 +23,29 @@ end
|
||||
ParticleB(SFourMomentum(0.823648, 0.835061, 0.474802, -0.277915)),
|
||||
],
|
||||
)
|
||||
expected_result = ground_truth_graph_result(particles_2_2)
|
||||
expected_result = 0.00013916495566048735
|
||||
|
||||
@testset "AB->AB no optimization" begin
|
||||
for _ in 1:10 # test in a loop because graph layout should not change the result
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001)
|
||||
|
||||
# graph should be fully scheduled after being executed
|
||||
@test is_scheduled(graph)
|
||||
|
||||
func = get_compute_function(graph, process_2_2, machine)
|
||||
@test isapprox(func(particles_2_2), expected_result; rtol = RTOL)
|
||||
@test isapprox(func(particles_2_2), expected_result; rtol = 0.001)
|
||||
end
|
||||
end
|
||||
|
||||
@testset "AB->AB after random walk" begin
|
||||
for i in 1:200
|
||||
for i in 1:1000
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
|
||||
optimize!(RandomWalkOptimizer(RNG), graph, 50)
|
||||
random_walk!(graph, 50)
|
||||
|
||||
@test is_valid(graph)
|
||||
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001)
|
||||
|
||||
# graph should be fully scheduled after being executed
|
||||
@test is_scheduled(graph)
|
||||
@ -105,20 +63,20 @@ end
|
||||
@testset "AB->ABBB no optimization" begin
|
||||
for _ in 1:5 # test in a loop because graph layout should not change the result
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"), ABCModel())
|
||||
@test isapprox(execute(graph, process_2_4, machine, particles_2_4), expected_result; rtol = RTOL)
|
||||
@test isapprox(execute(graph, process_2_4, machine, particles_2_4), expected_result; rtol = 0.001)
|
||||
|
||||
func = get_compute_function(graph, process_2_4, machine)
|
||||
@test isapprox(func(particles_2_4), expected_result; rtol = RTOL)
|
||||
@test isapprox(func(particles_2_4), expected_result; rtol = 0.001)
|
||||
end
|
||||
end
|
||||
|
||||
@testset "AB->ABBB after random walk" begin
|
||||
for i in 1:50
|
||||
for i in 1:200
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"), ABCModel())
|
||||
optimize!(RandomWalkOptimizer(RNG), graph, 100)
|
||||
random_walk!(graph, 100)
|
||||
@test is_valid(graph)
|
||||
|
||||
@test isapprox(execute(graph, process_2_4, machine, particles_2_4), expected_result; rtol = RTOL)
|
||||
@test isapprox(execute(graph, process_2_4, machine, particles_2_4), expected_result; rtol = 0.001)
|
||||
end
|
||||
end
|
||||
|
||||
@ -147,8 +105,8 @@ end
|
||||
|
||||
# try execute
|
||||
@test is_valid(graph)
|
||||
expected_result = ground_truth_graph_result(particles_2_2)
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
|
||||
expected_result = 0.00013916495566048735
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001)
|
||||
end
|
||||
|
||||
|
||||
@ -177,8 +135,8 @@ end
|
||||
|
||||
# try execute
|
||||
@test is_valid(graph)
|
||||
expected_result = ground_truth_graph_result(particles_2_2)
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
|
||||
expected_result = 0.00013916495566048735
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001)
|
||||
end
|
||||
|
||||
@testset "AB->AB fusion edge case" for _ in 1:20
|
||||
@ -211,8 +169,8 @@ end
|
||||
|
||||
# try execute
|
||||
@test is_valid(graph)
|
||||
expected_result = ground_truth_graph_result(particles_2_2)
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
|
||||
expected_result = 0.00013916495566048735
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001)
|
||||
end
|
||||
|
||||
end
|
||||
|
@ -135,12 +135,6 @@ import MetagraphOptimization.partners
|
||||
@test length(operations) == (nodeFusions = 10, nodeReductions = 0, nodeSplits = 0)
|
||||
@test length(graph.dirtyNodes) == 0
|
||||
|
||||
i = 0
|
||||
for op in operations
|
||||
i += 1
|
||||
end
|
||||
@test i == 10
|
||||
|
||||
@test operations == get_operations(graph)
|
||||
nf = first(operations.nodeFusions)
|
||||
|
||||
|
@ -1,42 +0,0 @@
|
||||
using Random
|
||||
|
||||
RNG = Random.default_rng()
|
||||
|
||||
@testset "Unit Tests Optimization" begin
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"), ABCModel())
|
||||
|
||||
# create the optimizers
|
||||
FIXPOINT_OPTIMIZERS = [GreedyOptimizer(GlobalMetricEstimator()), ReductionOptimizer()]
|
||||
NO_FIXPOINT_OPTIMIZERS = [RandomWalkOptimizer(RNG)]
|
||||
|
||||
@testset "Optimizer $optimizer" for optimizer in vcat(NO_FIXPOINT_OPTIMIZERS, FIXPOINT_OPTIMIZERS)
|
||||
@test operation_stack_length(graph) == 0
|
||||
@test optimize_step!(optimizer, graph)
|
||||
|
||||
@test !fixpoint_reached(optimizer, graph)
|
||||
@test operation_stack_length(graph) == 1
|
||||
|
||||
@test optimize!(optimizer, graph, 10)
|
||||
|
||||
@test !fixpoint_reached(optimizer, graph)
|
||||
|
||||
reset_graph!(graph)
|
||||
end
|
||||
|
||||
@testset "Fixpoint optimizer $optimizer" for optimizer in FIXPOINT_OPTIMIZERS
|
||||
@test operation_stack_length(graph) == 0
|
||||
|
||||
optimize_to_fixpoint!(optimizer, graph)
|
||||
|
||||
@test fixpoint_reached(optimizer, graph)
|
||||
@test !optimize_step!(optimizer, graph)
|
||||
@test !optimize!(optimizer, graph, 10)
|
||||
|
||||
reset_graph!(graph)
|
||||
end
|
||||
|
||||
@testset "No fixpoint optimizer $optimizer" for optimizer in NO_FIXPOINT_OPTIMIZERS
|
||||
@test_throws MethodError optimize_to_fixpoint!(optimizer, graph)
|
||||
end
|
||||
end
|
||||
println("Optimization Unit Tests Complete!")
|
@ -5,10 +5,18 @@
|
||||
@test prop.data == 0.0
|
||||
@test prop.computeEffort == 0.0
|
||||
@test prop.computeIntensity == 0.0
|
||||
@test prop.cost == 0.0
|
||||
@test prop.noNodes == 0.0
|
||||
@test prop.noEdges == 0.0
|
||||
|
||||
prop2 = (data = 5.0, computeEffort = 6.0, computeIntensity = 6.0 / 5.0, noNodes = 2, noEdges = 3)::GraphProperties
|
||||
prop2 = (
|
||||
data = 5.0,
|
||||
computeEffort = 6.0,
|
||||
computeIntensity = 6.0 / 5.0,
|
||||
cost = 0.0,
|
||||
noNodes = 2,
|
||||
noEdges = 3,
|
||||
)::GraphProperties
|
||||
|
||||
@test prop + prop2 == prop2
|
||||
@test prop2 - prop == prop2
|
||||
@ -17,18 +25,27 @@
|
||||
@test negProp.data == -5.0
|
||||
@test negProp.computeEffort == -6.0
|
||||
@test negProp.computeIntensity == 6.0 / 5.0
|
||||
@test negProp.cost == 0.0
|
||||
@test negProp.noNodes == -2
|
||||
@test negProp.noEdges == -3
|
||||
|
||||
@test negProp + prop2 == GraphProperties()
|
||||
|
||||
prop3 = (data = 7.0, computeEffort = 3.0, computeIntensity = 7.0 / 3.0, noNodes = -3, noEdges = 2)::GraphProperties
|
||||
prop3 = (
|
||||
data = 7.0,
|
||||
computeEffort = 3.0,
|
||||
computeIntensity = 7.0 / 3.0,
|
||||
cost = 0.0,
|
||||
noNodes = -3,
|
||||
noEdges = 2,
|
||||
)::GraphProperties
|
||||
|
||||
propSum = prop2 + prop3
|
||||
|
||||
@test propSum.data == 12.0
|
||||
@test propSum.computeEffort == 9.0
|
||||
@test propSum.computeIntensity == 9.0 / 12.0
|
||||
@test propSum.cost == 0.0
|
||||
@test propSum.noNodes == -1
|
||||
@test propSum.noEdges == 5
|
||||
end
|
||||
|
Reference in New Issue
Block a user