Remove occurrences of Fusion/Fuse
This commit is contained in:
parent
8a5e49429b
commit
97ccb3f3fb
@ -5,7 +5,7 @@
|
||||
## Package Features
|
||||
- Read a DAG from a file
|
||||
- Analyze its properties
|
||||
- Mute the graph using the operations NodeFusion, NodeReduction and NodeSplit
|
||||
- Mute the graph using the operations NodeReduction and NodeSplit
|
||||
|
||||
## Coming Soon:
|
||||
- Add Code Generation from finished DAG
|
||||
|
@ -1,60 +0,0 @@
|
||||
using MetagraphOptimization
|
||||
using Plots
|
||||
using Random
|
||||
|
||||
function gen_plot(filepath)
|
||||
name = basename(filepath)
|
||||
name, _ = splitext(name)
|
||||
|
||||
filepath = joinpath(@__DIR__, "../input/", filepath)
|
||||
if !isfile(filepath)
|
||||
println("File ", filepath, " does not exist, skipping")
|
||||
return
|
||||
end
|
||||
|
||||
g = parse_dag(filepath, ABCModel())
|
||||
|
||||
Random.seed!(1)
|
||||
|
||||
println("Random Walking... ")
|
||||
|
||||
x = Vector{Float64}()
|
||||
y = Vector{Float64}()
|
||||
|
||||
for i in 1:30
|
||||
print("\r", i)
|
||||
# push
|
||||
opt = get_operations(g)
|
||||
|
||||
# choose one of fuse/split/reduce
|
||||
option = rand(1:3)
|
||||
if option == 1 && !isempty(opt.nodeFusions)
|
||||
push_operation!(g, rand(collect(opt.nodeFusions)))
|
||||
println("NF")
|
||||
elseif option == 2 && !isempty(opt.nodeReductions)
|
||||
push_operation!(g, rand(collect(opt.nodeReductions)))
|
||||
println("NR")
|
||||
elseif option == 3 && !isempty(opt.nodeSplits)
|
||||
push_operation!(g, rand(collect(opt.nodeSplits)))
|
||||
println("NS")
|
||||
else
|
||||
i = i - 1
|
||||
end
|
||||
|
||||
props = get_properties(g)
|
||||
push!(x, props.data)
|
||||
push!(y, props.computeEffort)
|
||||
end
|
||||
|
||||
println("\rDone.")
|
||||
|
||||
plot([x[1], x[2]], [y[1], y[2]], linestyle = :solid, linewidth = 1, color = :red, legend = false)
|
||||
# Create lines connecting the reference point to each data point
|
||||
for i in 3:length(x)
|
||||
plot!([x[i - 1], x[i]], [y[i - 1], y[i]], linestyle = :solid, linewidth = 1, color = :red)
|
||||
end
|
||||
|
||||
return gui()
|
||||
end
|
||||
|
||||
gen_plot("AB->ABBB.txt")
|
@ -1,96 +0,0 @@
|
||||
using MetagraphOptimization
|
||||
using Plots
|
||||
using Random
|
||||
|
||||
function gen_plot(filepath)
|
||||
name = basename(filepath)
|
||||
name, _ = splitext(name)
|
||||
|
||||
filepath = joinpath(@__DIR__, "../input/", filepath)
|
||||
if !isfile(filepath)
|
||||
println("File ", filepath, " does not exist, skipping")
|
||||
return
|
||||
end
|
||||
|
||||
g = parse_dag(filepath, ABCModel())
|
||||
|
||||
Random.seed!(1)
|
||||
|
||||
println("Random Walking... ")
|
||||
|
||||
for i in 1:30
|
||||
print("\r", i)
|
||||
# push
|
||||
opt = get_operations(g)
|
||||
|
||||
# choose one of fuse/split/reduce
|
||||
option = rand(1:3)
|
||||
if option == 1 && !isempty(opt.nodeFusions)
|
||||
push_operation!(g, rand(collect(opt.nodeFusions)))
|
||||
println("NF")
|
||||
elseif option == 2 && !isempty(opt.nodeReductions)
|
||||
push_operation!(g, rand(collect(opt.nodeReductions)))
|
||||
println("NR")
|
||||
elseif option == 3 && !isempty(opt.nodeSplits)
|
||||
push_operation!(g, rand(collect(opt.nodeSplits)))
|
||||
println("NS")
|
||||
else
|
||||
i = i - 1
|
||||
end
|
||||
end
|
||||
|
||||
println("\rDone.")
|
||||
|
||||
|
||||
|
||||
|
||||
props = get_properties(g)
|
||||
x0 = props.data
|
||||
y0 = props.computeEffort
|
||||
|
||||
x = Vector{Float64}()
|
||||
y = Vector{Float64}()
|
||||
names = Vector{String}()
|
||||
|
||||
opt = get_operations(g)
|
||||
for op in opt.nodeFusions
|
||||
push_operation!(g, op)
|
||||
props = get_properties(g)
|
||||
push!(x, props.data)
|
||||
push!(y, props.computeEffort)
|
||||
pop_operation!(g)
|
||||
|
||||
push!(names, "NF: (" * string(props.data) * ", " * string(props.computeEffort) * ")")
|
||||
end
|
||||
for op in opt.nodeReductions
|
||||
push_operation!(g, op)
|
||||
props = get_properties(g)
|
||||
push!(x, props.data)
|
||||
push!(y, props.computeEffort)
|
||||
pop_operation!(g)
|
||||
|
||||
push!(names, "NR: (" * string(props.data) * ", " * string(props.computeEffort) * ")")
|
||||
end
|
||||
for op in opt.nodeSplits
|
||||
push_operation!(g, op)
|
||||
props = get_properties(g)
|
||||
push!(x, props.data)
|
||||
push!(y, props.computeEffort)
|
||||
pop_operation!(g)
|
||||
|
||||
push!(names, "NS: (" * string(props.data) * ", " * string(props.computeEffort) * ")")
|
||||
end
|
||||
|
||||
plot([x0, x[1]], [y0, y[1]], linestyle = :solid, linewidth = 1, color = :red, legend = false)
|
||||
# Create lines connecting the reference point to each data point
|
||||
for i in 2:length(x)
|
||||
plot!([x0, x[i]], [y0, y[i]], linestyle = :solid, linewidth = 1, color = :red)
|
||||
end
|
||||
#scatter!(x, y, label=names)
|
||||
|
||||
print(names)
|
||||
|
||||
return gui()
|
||||
end
|
||||
|
||||
gen_plot("AB->ABBB.txt")
|
@ -1,5 +1,5 @@
|
||||
# Optimizer Plots
|
||||
|
||||
Plots of FusionOptimizer, ReductionOptimizer, SplitOptimizer, RandomWalkOptimizer, and GreedyOptimizer, executed on a system with 32 threads and an A30 GPU.
|
||||
Plots of FusionOptimizer (deprecated), ReductionOptimizer, SplitOptimizer, RandomWalkOptimizer, and GreedyOptimizer, executed on a system with 32 threads and an A30 GPU.
|
||||
|
||||
Benchmarked using `notebooks/optimizers.ipynb`.
|
||||
|
@ -54,8 +54,6 @@
|
||||
"cu_inputs = CuArray(inputs)\n",
|
||||
"optimizer = RandomWalkOptimizer(MersenneTwister(0))# GreedyOptimizer(GlobalMetricEstimator())\n",
|
||||
"\n",
|
||||
"#done: split, reduce, fuse, greedy\n",
|
||||
"\n",
|
||||
"process_str_short = \"qed_k3\"\n",
|
||||
"optim_str = \"Random Walk Optimization\"\n",
|
||||
"optim_str_short=\"random\"\n",
|
||||
|
@ -240,31 +240,6 @@ function get_snapshot_diff(graph::DAG)
|
||||
return swapfield!(graph, :diff, Diff())
|
||||
end
|
||||
|
||||
"""
|
||||
invalidate_caches!(graph::DAG, operation::NodeFusion)
|
||||
|
||||
Invalidate the operation caches for a given [`NodeFusion`](@ref).
|
||||
|
||||
This deletes the operation from the graph's possible operations and from the involved nodes' own operation caches.
|
||||
"""
|
||||
function invalidate_caches!(graph::DAG, operation::NodeFusion)
|
||||
delete!(graph.possibleOperations, operation)
|
||||
|
||||
# delete the operation from all caches of nodes involved in the operation
|
||||
for n in [1, 3]
|
||||
for i in eachindex(operation.input[n].nodeFusions)
|
||||
if operation == operation.input[n].nodeFusions[i]
|
||||
splice!(operation.input[n].nodeFusions, i)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
operation.input[2].nodeFusion = missing
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
invalidate_caches!(graph::DAG, operation::NodeReduction)
|
||||
|
||||
@ -311,9 +286,6 @@ function invalidate_operation_caches!(graph::DAG, node::ComputeTaskNode)
|
||||
if !ismissing(node.nodeSplit)
|
||||
invalidate_caches!(graph, node.nodeSplit)
|
||||
end
|
||||
while !isempty(node.nodeFusions)
|
||||
invalidate_caches!(graph, pop!(node.nodeFusions))
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
@ -329,8 +301,5 @@ function invalidate_operation_caches!(graph::DAG, node::DataTaskNode)
|
||||
if !ismissing(node.nodeSplit)
|
||||
invalidate_caches!(graph, node.nodeSplit)
|
||||
end
|
||||
if !ismissing(node.nodeFusion)
|
||||
invalidate_caches!(graph, node.nodeFusion)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
@ -7,7 +7,6 @@ A struct storing all possible operations on a [`DAG`](@ref).
|
||||
To get the [`PossibleOperations`](@ref) on a [`DAG`](@ref), use [`get_operations`](@ref).
|
||||
"""
|
||||
mutable struct PossibleOperations
|
||||
nodeFusions::Set{NodeFusion}
|
||||
nodeReductions::Set{NodeReduction}
|
||||
nodeSplits::Set{NodeSplit}
|
||||
end
|
||||
@ -52,7 +51,7 @@ end
|
||||
Construct and return an empty [`PossibleOperations`](@ref) object.
|
||||
"""
|
||||
function PossibleOperations()
|
||||
return PossibleOperations(Set{NodeFusion}(), Set{NodeReduction}(), Set{NodeSplit}())
|
||||
return PossibleOperations(Set{NodeReduction}(), Set{NodeSplit}())
|
||||
end
|
||||
|
||||
"""
|
||||
|
@ -1,6 +1,6 @@
|
||||
|
||||
DataTaskNode(t::AbstractDataTask, name = "") =
|
||||
DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, missing, name)
|
||||
DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, name)
|
||||
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
|
||||
t, # task
|
||||
Vector{Node}(), # parents
|
||||
@ -8,7 +8,6 @@ ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
|
||||
UUIDs.uuid1(rng[threadid()]), # id
|
||||
missing, # node reduction
|
||||
missing, # node split
|
||||
Vector{NodeFusion}(), # node fusions
|
||||
missing, # device
|
||||
)
|
||||
|
||||
|
@ -30,7 +30,6 @@ Any node that transfers data and does no computation.
|
||||
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
|
||||
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.\\
|
||||
`.name`: The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\
|
||||
"""
|
||||
mutable struct DataTaskNode{TaskType <: AbstractDataTask} <: Node
|
||||
@ -51,9 +50,6 @@ mutable struct DataTaskNode{TaskType <: AbstractDataTask} <: Node
|
||||
# the NodeSplit involving this node, if it exists
|
||||
nodeSplit::Union{Operation, Missing}
|
||||
|
||||
# the node fusion involving this node, if it exists
|
||||
nodeFusion::Union{Operation, Missing}
|
||||
|
||||
# for input nodes we need a name for the node to distinguish between them
|
||||
name::String
|
||||
end
|
||||
@ -70,7 +66,6 @@ Any node that computes a result from inputs using an [`AbstractComputeTask`](@re
|
||||
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
|
||||
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeFusions`: A vector of this node's [`NodeFusion`](@ref)s. For a `ComputeTaskNode` there can be any number of these, unlike the [`DataTaskNode`](@ref)s.\\
|
||||
`.device`: The Device this node has been scheduled on by a [`Scheduler`](@ref).
|
||||
"""
|
||||
mutable struct ComputeTaskNode{TaskType <: AbstractComputeTask} <: Node
|
||||
@ -82,9 +77,6 @@ mutable struct ComputeTaskNode{TaskType <: AbstractComputeTask} <: Node
|
||||
nodeReduction::Union{Operation, Missing}
|
||||
nodeSplit::Union{Operation, Missing}
|
||||
|
||||
# for ComputeTasks there can be multiple fusions, unlike the DataTasks
|
||||
nodeFusions::Vector{<:Operation}
|
||||
|
||||
# the device this node is assigned to execute on
|
||||
device::Union{AbstractDevice, Missing}
|
||||
end
|
||||
|
@ -29,11 +29,6 @@ function is_valid_node(graph::DAG, node::Node)
|
||||
@assert is_valid(graph, node.nodeSplit)
|
||||
end=#
|
||||
|
||||
if !(typeof(task(node)) <: FusedComputeTask)
|
||||
# the remaining checks are only necessary for fused compute tasks
|
||||
return true
|
||||
end
|
||||
|
||||
# every child must be in some input of the task
|
||||
for child in node.children
|
||||
str = Symbol(to_var_name(child.id))
|
||||
@ -53,9 +48,6 @@ This also calls [`is_valid_node(graph::DAG, node::Node)`](@ref).
|
||||
function is_valid(graph::DAG, node::ComputeTaskNode)
|
||||
@assert is_valid_node(graph, node)
|
||||
|
||||
#=for nf in node.nodeFusions
|
||||
@assert is_valid(graph, nf)
|
||||
end=#
|
||||
return true
|
||||
end
|
||||
|
||||
@ -69,8 +61,5 @@ This also calls [`is_valid_node(graph::DAG, node::Node)`](@ref).
|
||||
function is_valid(graph::DAG, node::DataTaskNode)
|
||||
@assert is_valid_node(graph, node)
|
||||
|
||||
#=if !ismissing(node.nodeFusion)
|
||||
@assert is_valid(graph, node.nodeFusion)
|
||||
end=#
|
||||
return true
|
||||
end
|
||||
|
@ -26,21 +26,6 @@ function apply_operation!(graph::DAG, operation::Operation)
|
||||
return error("Unknown operation type!")
|
||||
end
|
||||
|
||||
"""
|
||||
apply_operation!(graph::DAG, operation::NodeFusion)
|
||||
|
||||
Apply the given [`NodeFusion`](@ref) to the graph. Generic wrapper around [`node_fusion!`](@ref).
|
||||
|
||||
Return an [`AppliedNodeFusion`](@ref) object generated from the graph's [`Diff`](@ref).
|
||||
"""
|
||||
function apply_operation!(graph::DAG, operation::NodeFusion)
|
||||
diff = node_fusion!(graph, operation.input[1], operation.input[2], operation.input[3])
|
||||
|
||||
graph.properties += GraphProperties(diff)
|
||||
|
||||
return AppliedNodeFusion(operation, diff)
|
||||
end
|
||||
|
||||
"""
|
||||
apply_operation!(graph::DAG, operation::NodeReduction)
|
||||
|
||||
@ -80,20 +65,10 @@ function revert_operation!(graph::DAG, operation::AppliedOperation)
|
||||
return error("Unknown operation type!")
|
||||
end
|
||||
|
||||
"""
|
||||
revert_operation!(graph::DAG, operation::AppliedNodeFusion)
|
||||
|
||||
Revert the applied node fusion on the graph. Return the original [`NodeFusion`](@ref) operation.
|
||||
"""
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeFusion)
|
||||
revert_diff!(graph, operation.diff)
|
||||
return operation.operation
|
||||
end
|
||||
|
||||
"""
|
||||
revert_operation!(graph::DAG, operation::AppliedNodeReduction)
|
||||
|
||||
Revert the applied node fusion on the graph. Return the original [`NodeReduction`](@ref) operation.
|
||||
Revert the applied node reduction on the graph. Return the original [`NodeReduction`](@ref) operation.
|
||||
"""
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeReduction)
|
||||
revert_diff!(graph, operation.diff)
|
||||
@ -103,7 +78,7 @@ end
|
||||
"""
|
||||
revert_operation!(graph::DAG, operation::AppliedNodeSplit)
|
||||
|
||||
Revert the applied node fusion on the graph. Return the original [`NodeSplit`](@ref) operation.
|
||||
Revert the applied node split on the graph. Return the original [`NodeSplit`](@ref) operation.
|
||||
"""
|
||||
function revert_operation!(graph::DAG, operation::AppliedNodeSplit)
|
||||
revert_diff!(graph, operation.diff)
|
||||
@ -132,88 +107,11 @@ function revert_diff!(graph::DAG, diff::Diff)
|
||||
insert_edge!(graph, edge.edge[1], edge.edge[2], track = false)
|
||||
end
|
||||
|
||||
for (node, t) in diff.updatedChildren
|
||||
# node must be fused compute task at this point
|
||||
@assert typeof(task(node)) <: FusedComputeTask
|
||||
|
||||
node.task = t
|
||||
end
|
||||
|
||||
graph.properties -= GraphProperties(diff)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
|
||||
Fuse nodes n1 -> n2 -> n3 together into one node, return the applied difference to the graph.
|
||||
|
||||
For details see [`NodeFusion`](@ref).
|
||||
"""
|
||||
function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
@assert is_valid_node_fusion_input(graph, n1, n2, n3)
|
||||
|
||||
# clear snapshot
|
||||
get_snapshot_diff(graph)
|
||||
|
||||
# save children and parents
|
||||
n1Children = copy(children(n1))
|
||||
n3Parents = copy(parents(n3))
|
||||
|
||||
n1Task = copy(task(n1))
|
||||
n3Task = copy(task(n3))
|
||||
|
||||
# assemble the input node vectors of n1 and n3 to save into the FusedComputeTask
|
||||
n1Inputs = Vector{Symbol}()
|
||||
for child in n1Children
|
||||
push!(n1Inputs, Symbol(to_var_name(child.id)))
|
||||
end
|
||||
|
||||
# remove the edges and nodes that will be replaced by the fused node
|
||||
remove_edge!(graph, n1, n2)
|
||||
remove_edge!(graph, n2, n3)
|
||||
remove_node!(graph, n1)
|
||||
remove_node!(graph, n2)
|
||||
|
||||
# get n3's children now so it automatically excludes n2
|
||||
n3Children = copy(children(n3))
|
||||
|
||||
n3Inputs = Vector{Symbol}()
|
||||
for child in n3Children
|
||||
push!(n3Inputs, Symbol(to_var_name(child.id)))
|
||||
end
|
||||
|
||||
remove_node!(graph, n3)
|
||||
|
||||
# create new node with the fused compute task
|
||||
newNode = ComputeTaskNode(FusedComputeTask(n1Task, n3Task, n1Inputs, Symbol(to_var_name(n2.id)), n3Inputs))
|
||||
insert_node!(graph, newNode)
|
||||
|
||||
for child in n1Children
|
||||
remove_edge!(graph, child, n1)
|
||||
insert_edge!(graph, child, newNode)
|
||||
end
|
||||
|
||||
for child in n3Children
|
||||
remove_edge!(graph, child, n3)
|
||||
if !(child in n1Children)
|
||||
insert_edge!(graph, child, newNode)
|
||||
end
|
||||
end
|
||||
|
||||
for parent in n3Parents
|
||||
remove_edge!(graph, n3, parent)
|
||||
insert_edge!(graph, newNode, parent)
|
||||
|
||||
# important! update the parent node's child names in case they are fused compute tasks
|
||||
# needed for compute generation so the fused compute task can correctly match inputs to its component tasks
|
||||
update_child!(graph, parent, Symbol(to_var_name(n3.id)), Symbol(to_var_name(newNode.id)))
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
end
|
||||
|
||||
"""
|
||||
node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
|
||||
@ -265,7 +163,6 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
|
||||
# this has to be done for all parents, even the ones of n1 because they can be duplicate
|
||||
prevChild = newParentsChildNames[parent]
|
||||
update_child!(graph, parent, prevChild, Symbol(to_var_name(n1.id)))
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
@ -307,8 +204,6 @@ function node_split!(
|
||||
for child in n1Children
|
||||
insert_edge!(graph, child, nCopy)
|
||||
end
|
||||
|
||||
update_child!(graph, parent, Symbol(to_var_name(n1.id)), Symbol(to_var_name(nCopy.id)))
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
|
@ -1,60 +1,5 @@
|
||||
# These are functions for "cleaning" nodes, i.e. regenerating the possible operations for a node
|
||||
|
||||
"""
|
||||
find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
|
||||
Find node fusions involving the given data node. The function pushes the found [`NodeFusion`](@ref) (if any) everywhere it needs to be and returns nothing.
|
||||
|
||||
Does nothing if the node already has a node fusion set. Since it's a data node, only one node fusion can be possible with it.
|
||||
"""
|
||||
function find_fusions!(graph::DAG, node::DataTaskNode)
|
||||
# if there is already a fusion here, skip to avoid duplicates
|
||||
if !ismissing(node.nodeFusion)
|
||||
return nothing
|
||||
end
|
||||
|
||||
if length(parents(node)) != 1 || length(children(node)) != 1
|
||||
return nothing
|
||||
end
|
||||
|
||||
child_node = first(children(node))
|
||||
parent_node = first(parents(node))
|
||||
|
||||
if !(child_node in graph) || !(parent_node in graph)
|
||||
error("Parents/Children that are not in the graph!!!")
|
||||
end
|
||||
|
||||
if length(parents(child_node)) != 1
|
||||
return nothing
|
||||
end
|
||||
|
||||
nf = NodeFusion((child_node, node, parent_node))
|
||||
push!(graph.possibleOperations.nodeFusions, nf)
|
||||
push!(child_node.nodeFusions, nf)
|
||||
node.nodeFusion = nf
|
||||
push!(parent_node.nodeFusions, nf)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
find_fusions!(graph::DAG, node::ComputeTaskNode)
|
||||
|
||||
Find node fusions involving the given compute node. The function pushes the found [`NodeFusion`](@ref)s (if any) everywhere they need to be and returns nothing.
|
||||
"""
|
||||
function find_fusions!(graph::DAG, node::ComputeTaskNode)
|
||||
# just find fusions in neighbouring DataTaskNodes
|
||||
for child in children(node)
|
||||
find_fusions!(graph, child)
|
||||
end
|
||||
|
||||
for parent in parents(node)
|
||||
find_fusions!(graph, parent)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
find_reductions!(graph::DAG, node::Node)
|
||||
|
||||
@ -121,7 +66,7 @@ end
|
||||
"""
|
||||
clean_node!(graph::DAG, node::Node)
|
||||
|
||||
Sort this node's parent and child sets, then find fusions, reductions and splits involving it. Needs to be called after the node was changed in some way.
|
||||
Sort this node's parent and child sets, then find reductions and splits involving it. Needs to be called after the node was changed in some way.
|
||||
"""
|
||||
function clean_node!(
|
||||
graph::DAG,
|
||||
@ -129,7 +74,6 @@ function clean_node!(
|
||||
) where {TaskType <: AbstractTask}
|
||||
sort_node!(node)
|
||||
|
||||
find_fusions!(graph, node)
|
||||
find_reductions!(graph, node)
|
||||
find_splits!(graph, node)
|
||||
|
||||
|
@ -2,26 +2,6 @@
|
||||
|
||||
using Base.Threads
|
||||
|
||||
"""
|
||||
insert_operation!(nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock})
|
||||
|
||||
Insert the given node fusion into its input nodes' operation caches. For the compute nodes, locking via the given `locks` is employed to have safe multi-threading. For a large set of nodes, contention on the locks should be very small.
|
||||
"""
|
||||
function insert_operation!(nf::NodeFusion, locks::Dict{ComputeTaskNode, SpinLock})
|
||||
n1 = nf.input[1]
|
||||
n2 = nf.input[2]
|
||||
n3 = nf.input[3]
|
||||
|
||||
lock(locks[n1]) do
|
||||
return push!(nf.input[1].nodeFusions, nf)
|
||||
end
|
||||
n2.nodeFusion = nf
|
||||
lock(locks[n3]) do
|
||||
return push!(nf.input[3].nodeFusions, nf)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
insert_operation!(nf::NodeReduction)
|
||||
|
||||
@ -72,41 +52,6 @@ function nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Ve
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
nf_insertion!(graph::DAG, operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}})
|
||||
|
||||
Insert the node fusions into the graph and the nodes' caches. Employs multithreading for speedup.
|
||||
"""
|
||||
function nf_insertion!(graph::DAG, operations::PossibleOperations, nodeFusions::Vector{Vector{NodeFusion}})
|
||||
total_len = 0
|
||||
for vec in nodeFusions
|
||||
total_len += length(vec)
|
||||
end
|
||||
sizehint!(operations.nodeFusions, total_len)
|
||||
|
||||
t = @task for vec in nodeFusions
|
||||
union!(operations.nodeFusions, Set(vec))
|
||||
end
|
||||
schedule(t)
|
||||
|
||||
locks = Dict{ComputeTaskNode, SpinLock}()
|
||||
for n in graph.nodes
|
||||
if (typeof(n) <: ComputeTaskNode)
|
||||
locks[n] = SpinLock()
|
||||
end
|
||||
end
|
||||
|
||||
@threads for vec in nodeFusions
|
||||
for op in vec
|
||||
insert_operation!(op, locks)
|
||||
end
|
||||
end
|
||||
|
||||
wait(t)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplits}})
|
||||
|
||||
@ -143,7 +88,6 @@ Generate all possible operations on the graph. Used initially when the graph is
|
||||
Safely inserts all the found operations into the graph and its nodes.
|
||||
"""
|
||||
function generate_operations(graph::DAG)
|
||||
generatedFusions = [Vector{NodeFusion}() for _ in 1:nthreads()]
|
||||
generatedReductions = [Vector{NodeReduction}() for _ in 1:nthreads()]
|
||||
generatedSplits = [Vector{NodeSplit}() for _ in 1:nthreads()]
|
||||
|
||||
@ -199,31 +143,6 @@ function generate_operations(graph::DAG)
|
||||
# remove duplicates
|
||||
nr_task = @spawn nr_insertion!(graph.possibleOperations, generatedReductions)
|
||||
|
||||
# --- find possible node fusions ---
|
||||
@threads for node in nodeArray
|
||||
if (typeof(node) <: DataTaskNode)
|
||||
if length(parents(node)) != 1
|
||||
# data node can only have a single parent
|
||||
continue
|
||||
end
|
||||
parent_node = first(parents(node))
|
||||
|
||||
if length(children(node)) != 1
|
||||
# this node is an entry node or has multiple children which should not be possible
|
||||
continue
|
||||
end
|
||||
child_node = first(children(node))
|
||||
if (length(parents(child_node)) != 1)
|
||||
continue
|
||||
end
|
||||
|
||||
push!(generatedFusions[threadid()], NodeFusion((child_node, node, parent_node)))
|
||||
end
|
||||
end
|
||||
|
||||
# launch thread for node fusion insertion
|
||||
nf_task = @spawn nf_insertion!(graph, graph.possibleOperations, generatedFusions)
|
||||
|
||||
# find possible node splits
|
||||
@threads for node in nodeArray
|
||||
if (can_split(node))
|
||||
@ -237,7 +156,6 @@ function generate_operations(graph::DAG)
|
||||
empty!(graph.dirtyNodes)
|
||||
|
||||
wait(nr_task)
|
||||
wait(nf_task)
|
||||
wait(ns_task)
|
||||
|
||||
return nothing
|
||||
|
@ -2,8 +2,7 @@ import Base.iterate
|
||||
|
||||
const _POSSIBLE_OPERATIONS_FIELDS = fieldnames(PossibleOperations)
|
||||
|
||||
_POIteratorStateType =
|
||||
NamedTuple{(:result, :state), Tuple{Union{NodeFusion, NodeReduction, NodeSplit}, Tuple{Symbol, Int64}}}
|
||||
_POIteratorStateType = NamedTuple{(:result, :state), Tuple{Union{NodeReduction, NodeSplit}, Tuple{Symbol, Int64}}}
|
||||
|
||||
@inline function iterate(possibleOperations::PossibleOperations)::Union{Nothing, _POIteratorStateType}
|
||||
for fieldname in _POSSIBLE_OPERATIONS_FIELDS
|
||||
|
@ -4,11 +4,6 @@
|
||||
Print a string representation of the set of possible operations to io.
|
||||
"""
|
||||
function show(io::IO, ops::PossibleOperations)
|
||||
print(io, length(ops.nodeFusions))
|
||||
println(io, " Node Fusions: ")
|
||||
for nf in ops.nodeFusions
|
||||
println(io, " - ", nf)
|
||||
end
|
||||
print(io, length(ops.nodeReductions))
|
||||
println(io, " Node Reductions: ")
|
||||
for nr in ops.nodeReductions
|
||||
@ -42,17 +37,3 @@ function show(io::IO, op::NodeSplit)
|
||||
print(io, "NS: ")
|
||||
return print(io, task(op.input))
|
||||
end
|
||||
|
||||
"""
|
||||
show(io::IO, op::NodeFusion)
|
||||
|
||||
Print a string representation of the node fusion to io.
|
||||
"""
|
||||
function show(io::IO, op::NodeFusion)
|
||||
print(io, "NF: ")
|
||||
print(io, task(op.input[1]))
|
||||
print(io, "->")
|
||||
print(io, task(op.input[2]))
|
||||
print(io, "->")
|
||||
return print(io, task(op.input[3]))
|
||||
end
|
||||
|
@ -4,7 +4,7 @@
|
||||
Return whether `operations` is empty, i.e. all of its fields are empty.
|
||||
"""
|
||||
function isempty(operations::PossibleOperations)
|
||||
return isempty(operations.nodeFusions) && isempty(operations.nodeReductions) && isempty(operations.nodeSplits)
|
||||
return isempty(operations.nodeReductions) && isempty(operations.nodeSplits)
|
||||
end
|
||||
|
||||
"""
|
||||
@ -13,21 +13,7 @@ end
|
||||
Return a named tuple with the number of each of the operation types as a named tuple. The fields are named the same as the [`PossibleOperations`](@ref)'.
|
||||
"""
|
||||
function length(operations::PossibleOperations)
|
||||
return (
|
||||
nodeFusions = length(operations.nodeFusions),
|
||||
nodeReductions = length(operations.nodeReductions),
|
||||
nodeSplits = length(operations.nodeSplits),
|
||||
)
|
||||
end
|
||||
|
||||
"""
|
||||
delete!(operations::PossibleOperations, op::NodeFusion)
|
||||
|
||||
Delete the given node fusion from the possible operations.
|
||||
"""
|
||||
function delete!(operations::PossibleOperations, op::NodeFusion)
|
||||
delete!(operations.nodeFusions, op)
|
||||
return operations
|
||||
return (nodeReductions = length(operations.nodeReductions), nodeSplits = length(operations.nodeSplits))
|
||||
end
|
||||
|
||||
"""
|
||||
@ -50,24 +36,6 @@ function delete!(operations::PossibleOperations, op::NodeSplit)
|
||||
return operations
|
||||
end
|
||||
|
||||
"""
|
||||
can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
|
||||
Return whether the given nodes can be fused. See [`NodeFusion`](@ref) for the requirements.
|
||||
"""
|
||||
function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
if !is_child(n1, n2) || !is_child(n2, n3)
|
||||
# the checks are redundant but maybe a good sanity check
|
||||
return false
|
||||
end
|
||||
|
||||
if length(parents(n2)) != 1 || length(children(n2)) != 1 || length(parents(n1)) != 1
|
||||
return false
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
can_reduce(n1::Node, n2::Node)
|
||||
|
||||
@ -136,23 +104,6 @@ function ==(op1::Operation, op2::Operation)
|
||||
return false
|
||||
end
|
||||
|
||||
"""
|
||||
==(op1::NodeFusion, op2::NodeFusion)
|
||||
|
||||
Equality comparison between two node fusions. Two node fusions are considered equal if they have the same inputs.
|
||||
"""
|
||||
function ==(
|
||||
op1::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2},
|
||||
op2::NodeFusion{ComputeTaskType1, DataTaskType, ComputeTaskType2},
|
||||
) where {
|
||||
ComputeTaskType1 <: AbstractComputeTask,
|
||||
DataTaskType <: AbstractDataTask,
|
||||
ComputeTaskType2 <: AbstractComputeTask,
|
||||
}
|
||||
# there can only be one node fusion on a given data task, so if the data task is the same, the fusion is the same
|
||||
return op1.input[2] == op2.input[2]
|
||||
end
|
||||
|
||||
"""
|
||||
==(op1::NodeReduction, op2::NodeReduction)
|
||||
|
||||
|
@ -2,43 +2,6 @@
|
||||
# should be called with @assert
|
||||
# the functions throw their own errors though, to still have helpful error messages
|
||||
|
||||
"""
|
||||
is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
|
||||
Assert for a gven node fusion input whether the nodes can be fused. For the requirements of a node fusion see [`NodeFusion`](@ref).
|
||||
|
||||
Intended for use with `@assert` or `@test`.
|
||||
"""
|
||||
function is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
|
||||
if !(n1 in graph) || !(n2 in graph) || !(n3 in graph)
|
||||
throw(AssertionError("[Node Fusion] The given nodes are not part of the given graph"))
|
||||
end
|
||||
|
||||
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
|
||||
throw(
|
||||
AssertionError(
|
||||
"[Node Fusion] The given nodes are not connected by edges which is required for node fusion",
|
||||
),
|
||||
)
|
||||
end
|
||||
|
||||
if length(n2.parents) > 1
|
||||
throw(AssertionError("[Node Fusion] The given data node has more than one parent"))
|
||||
end
|
||||
if length(n2.children) > 1
|
||||
throw(AssertionError("[Node Fusion] The given data node has more than one child"))
|
||||
end
|
||||
if length(n1.parents) > 1
|
||||
throw(AssertionError("[Node Fusion] The given n1 has more than one parent"))
|
||||
end
|
||||
|
||||
@assert is_valid(graph, n1)
|
||||
@assert is_valid(graph, n2)
|
||||
@assert is_valid(graph, n3)
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node})
|
||||
|
||||
@ -131,16 +94,3 @@ function is_valid(graph::DAG, ns::NodeSplit)
|
||||
#@assert ns in graph.possibleOperations.nodeSplits "NodeSplit is not part of the graph's possible operations!"
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
is_valid(graph::DAG, nr::NodeFusion)
|
||||
|
||||
Assert for a given [`NodeFusion`](@ref) whether it is a valid operation in the graph.
|
||||
|
||||
Intended for use with `@assert` or `@test`.
|
||||
"""
|
||||
function is_valid(graph::DAG, nf::NodeFusion)
|
||||
@assert is_valid_node_fusion_input(graph, nf.input[1], nf.input[2], nf.input[3])
|
||||
#@assert nf in graph.possibleOperations.nodeFusions "NodeFusion is not part of the graph's possible operations!"
|
||||
return true
|
||||
end
|
||||
|
@ -1,36 +0,0 @@
|
||||
"""
|
||||
FusionOptimizer
|
||||
|
||||
An optimizer that simply applies an available [`NodeFusion`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeFusion`](@ref)s in the graph.
|
||||
|
||||
See also: [`SplitOptimizer`](@ref), [`ReductionOptimizer`](@ref)
|
||||
"""
|
||||
struct FusionOptimizer <: AbstractOptimizer end
|
||||
|
||||
function optimize_step!(optimizer::FusionOptimizer, graph::DAG)
|
||||
# generate all options
|
||||
operations = get_operations(graph)
|
||||
if fixpoint_reached(optimizer, graph)
|
||||
return false
|
||||
end
|
||||
|
||||
push_operation!(graph, first(operations.nodeFusions))
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function fixpoint_reached(optimizer::FusionOptimizer, graph::DAG)
|
||||
operations = get_operations(graph)
|
||||
return isempty(operations.nodeFusions)
|
||||
end
|
||||
|
||||
function optimize_to_fixpoint!(optimizer::FusionOptimizer, graph::DAG)
|
||||
while !fixpoint_reached(optimizer, graph)
|
||||
optimize_step!(optimizer, graph)
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function String(::FusionOptimizer)
|
||||
return "fusion_optimizer"
|
||||
end
|
@ -26,16 +26,12 @@ function optimize_step!(optimizer::RandomWalkOptimizer, graph::DAG)
|
||||
if rand(r, Bool)
|
||||
# push
|
||||
|
||||
# choose one of fuse/split/reduce
|
||||
# TODO refactor fusions so they actually work
|
||||
option = rand(r, 2:3)
|
||||
if option == 1 && !isempty(operations.nodeFusions)
|
||||
push_operation!(graph, rand(r, collect(operations.nodeFusions)))
|
||||
return true
|
||||
elseif option == 2 && !isempty(operations.nodeReductions)
|
||||
# choose one of split/reduce
|
||||
option = rand(r, 1:2)
|
||||
if option == 1 && !isempty(operations.nodeReductions)
|
||||
push_operation!(graph, rand(r, collect(operations.nodeReductions)))
|
||||
return true
|
||||
elseif option == 3 && !isempty(operations.nodeSplits)
|
||||
elseif option == 2 && !isempty(operations.nodeSplits)
|
||||
push_operation!(graph, rand(r, collect(operations.nodeSplits)))
|
||||
return true
|
||||
end
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
An optimizer that simply applies an available [`NodeReduction`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeReduction`](@ref)s in the graph.
|
||||
|
||||
See also: [`FusionOptimizer`](@ref), [`SplitOptimizer`](@ref)
|
||||
See also: [`SplitOptimizer`](@ref)
|
||||
"""
|
||||
struct ReductionOptimizer <: AbstractOptimizer end
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
An optimizer that simply applies an available [`NodeSplit`](@ref) on each step. It implements [`optimize_to_fixpoint`](@ref). The fixpoint is reached when there are no more possible [`NodeSplit`](@ref)s in the graph.
|
||||
|
||||
See also: [`FusionOptimizer`](@ref), [`ReductionOptimizer`](@ref)
|
||||
See also: [`ReductionOptimizer`](@ref)
|
||||
"""
|
||||
struct SplitOptimizer <: AbstractOptimizer end
|
||||
|
||||
|
@ -1,31 +1,13 @@
|
||||
using StaticArrays
|
||||
|
||||
"""
|
||||
compute(t::FusedComputeTask, data)
|
||||
|
||||
Compute a [`FusedComputeTask`](@ref). This simply asserts false and should not be called. Fused Compute Tasks generate their expressions directly through the other tasks instead.
|
||||
"""
|
||||
function compute(t::FusedComputeTask, data...)
|
||||
inter = compute(t.first_task)
|
||||
return compute(t.second_task, inter, data2...)
|
||||
end
|
||||
|
||||
"""
|
||||
get_function_call(n::Node)
|
||||
get_function_call(t::AbstractTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol)
|
||||
|
||||
For a node or a task together with necessary information, return a vector of [`FunctionCall`](@ref)s for the computation of the node or task.
|
||||
|
||||
For ordinary compute or data tasks the vector will contain exactly one element, for a [`FusedComputeTask`](@ref) there can be any number of tasks greater 1.
|
||||
For ordinary compute or data tasks the vector will contain exactly one element.
|
||||
"""
|
||||
function get_function_call(t::FusedComputeTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol)
|
||||
# sort out the symbols to the correct tasks
|
||||
return [
|
||||
get_function_call(t.first_task, device, t.t1_inputs, t.t1_output)...,
|
||||
get_function_call(t.second_task, device, [t.t2_inputs..., t.t1_output], outSymbol)...,
|
||||
]
|
||||
end
|
||||
|
||||
function get_function_call(
|
||||
t::CompTask,
|
||||
device::AbstractDevice,
|
||||
|
@ -11,22 +11,3 @@ copy(t::AbstractDataTask) = error("Need to implement copying for your data tasks
|
||||
Return a copy of the given compute task.
|
||||
"""
|
||||
copy(t::AbstractComputeTask) = typeof(t)()
|
||||
|
||||
"""
|
||||
copy(t::FusedComputeTask)
|
||||
|
||||
Return a copy of th egiven [`FusedComputeTask`](@ref).
|
||||
"""
|
||||
function copy(t::FusedComputeTask)
|
||||
return FusedComputeTask(copy(t.first_task), copy(t.second_task), copy(t.t1_inputs), t.t1_output, copy(t.t2_inputs))
|
||||
end
|
||||
|
||||
function FusedComputeTask(
|
||||
T1::Type{<:AbstractComputeTask},
|
||||
T2::Type{<:AbstractComputeTask},
|
||||
t1_inputs::Vector{String},
|
||||
t1_output::String,
|
||||
t2_inputs::Vector{String},
|
||||
)
|
||||
return FusedComputeTask(T1(), T2(), t1_inputs, t1_output, t2_inputs)
|
||||
end
|
||||
|
@ -47,34 +47,9 @@ Return the number of children of a data task (always 1).
|
||||
"""
|
||||
children(::DataTask) = 1
|
||||
|
||||
"""
|
||||
children(t::FusedComputeTask)
|
||||
|
||||
Return the number of children of a FusedComputeTask.
|
||||
"""
|
||||
function children(t::FusedComputeTask)
|
||||
return length(union(Set(t.t1_inputs), Set(t.t2_inputs)))
|
||||
end
|
||||
|
||||
"""
|
||||
data(t::AbstractComputeTask)
|
||||
|
||||
Return the data of a compute task, always zero, regardless of the specific task.
|
||||
"""
|
||||
data(t::AbstractComputeTask)::Float64 = 0.0
|
||||
|
||||
"""
|
||||
compute_effort(t::FusedComputeTask)
|
||||
|
||||
Return the compute effort of a fused compute task.
|
||||
"""
|
||||
function compute_effort(t::FusedComputeTask)::Float64
|
||||
return compute_effort(t.first_task) + compute_effort(t.second_task)
|
||||
end
|
||||
|
||||
"""
|
||||
get_types(::FusedComputeTask{T1, T2})
|
||||
|
||||
Return a tuple of a the fused compute task's components' types.
|
||||
"""
|
||||
get_types(t::FusedComputeTask) = (typeof(t.first_task), typeof(t.second_task))
|
||||
|
@ -27,21 +27,3 @@ Task representing a specific data transfer.
|
||||
struct DataTask <: AbstractDataTask
|
||||
data::Float64
|
||||
end
|
||||
|
||||
"""
|
||||
FusedComputeTask{T1 <: AbstractComputeTask, T2 <: AbstractComputeTask} <: AbstractComputeTask
|
||||
|
||||
A fused compute task made up of the computation of first `T1` and then `T2`.
|
||||
|
||||
Also see: [`get_types`](@ref).
|
||||
"""
|
||||
struct FusedComputeTask <: AbstractComputeTask
|
||||
first_task::AbstractComputeTask
|
||||
second_task::AbstractComputeTask
|
||||
# the names of the inputs for T1
|
||||
t1_inputs::Vector{Symbol}
|
||||
# output name of T1
|
||||
t1_output::Symbol
|
||||
# t2_inputs doesn't include the output of t1, that's implicit
|
||||
t2_inputs::Vector{Symbol}
|
||||
end
|
||||
|
@ -3,48 +3,15 @@ using Random
|
||||
|
||||
RNG = Random.MersenneTwister(321)
|
||||
|
||||
function test_known_graph(name::String, n, fusion_test = true)
|
||||
function test_known_graph(name::String, n)
|
||||
@testset "Test $name Graph ($n)" begin
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "$name.txt"), ABCModel())
|
||||
props = get_properties(graph)
|
||||
|
||||
if (fusion_test)
|
||||
test_node_fusion(graph)
|
||||
end
|
||||
test_random_walk(RNG, graph, n)
|
||||
end
|
||||
end
|
||||
|
||||
function test_node_fusion(g::DAG)
|
||||
@testset "Test Node Fusion" begin
|
||||
props = get_properties(g)
|
||||
|
||||
options = get_operations(g)
|
||||
|
||||
nodes_number = length(g.nodes)
|
||||
data = props.data
|
||||
compute_effort = props.computeEffort
|
||||
|
||||
while !isempty(options.nodeFusions)
|
||||
fusion = first(options.nodeFusions)
|
||||
|
||||
@test typeof(fusion) <: NodeFusion
|
||||
|
||||
push_operation!(g, fusion)
|
||||
|
||||
props = get_properties(g)
|
||||
@test props.data < data
|
||||
@test props.computeEffort == compute_effort
|
||||
|
||||
nodes_number = length(g.nodes)
|
||||
data = props.data
|
||||
compute_effort = props.computeEffort
|
||||
|
||||
options = get_operations(g)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function test_random_walk(RNG, g::DAG, n::Int64)
|
||||
@testset "Test Random Walk ($n)" begin
|
||||
# the purpose here is to do "random" operations and reverse them again and validate that the graph stays the same and doesn't diverge
|
||||
@ -60,13 +27,11 @@ function test_random_walk(RNG, g::DAG, n::Int64)
|
||||
# push
|
||||
opt = get_operations(g)
|
||||
|
||||
# choose one of fuse/split/reduce
|
||||
option = rand(RNG, 1:3)
|
||||
if option == 1 && !isempty(opt.nodeFusions)
|
||||
push_operation!(g, rand(RNG, collect(opt.nodeFusions)))
|
||||
elseif option == 2 && !isempty(opt.nodeReductions)
|
||||
# choose one of split/reduce
|
||||
option = rand(RNG, 1:2)
|
||||
if option == 1 && !isempty(opt.nodeReductions)
|
||||
push_operation!(g, rand(RNG, collect(opt.nodeReductions)))
|
||||
elseif option == 3 && !isempty(opt.nodeSplits)
|
||||
elseif option == 2 && !isempty(opt.nodeSplits)
|
||||
push_operation!(g, rand(RNG, collect(opt.nodeSplits)))
|
||||
else
|
||||
i = i - 1
|
||||
@ -91,4 +56,4 @@ end
|
||||
|
||||
test_known_graph("AB->AB", 10000)
|
||||
test_known_graph("AB->ABBB", 10000)
|
||||
test_known_graph("AB->ABBBBB", 1000, false)
|
||||
test_known_graph("AB->ABBBBB", 1000)
|
||||
|
@ -61,17 +61,14 @@ insert_edge!(graph, CD, C1C, track = false)
|
||||
|
||||
opt = get_operations(graph)
|
||||
|
||||
@test length(opt) == (nodeFusions = 6, nodeReductions = 1, nodeSplits = 1)
|
||||
|
||||
#println("Initial State:\n", opt)
|
||||
@test length(opt) == (nodeReductions = 1, nodeSplits = 1)
|
||||
|
||||
nr = first(opt.nodeReductions)
|
||||
@test Set(nr.input) == Set([B1C_1, B1C_2])
|
||||
push_operation!(graph, nr)
|
||||
opt = get_operations(graph)
|
||||
|
||||
@test length(opt) == (nodeFusions = 4, nodeReductions = 1, nodeSplits = 1)
|
||||
#println("After 1 Node Reduction:\n", opt)
|
||||
@test length(opt) == (nodeReductions = 1, nodeSplits = 1)
|
||||
|
||||
nr = first(opt.nodeReductions)
|
||||
@test Set(nr.input) == Set([B1D_1, B1D_2])
|
||||
@ -80,19 +77,16 @@ opt = get_operations(graph)
|
||||
|
||||
@test is_valid(graph)
|
||||
|
||||
@test length(opt) == (nodeFusions = 4, nodeReductions = 0, nodeSplits = 1)
|
||||
#println("After 2 Node Reductions:\n", opt)
|
||||
@test length(opt) == (nodeReductions = 0, nodeSplits = 1)
|
||||
|
||||
pop_operation!(graph)
|
||||
|
||||
opt = get_operations(graph)
|
||||
@test length(opt) == (nodeFusions = 4, nodeReductions = 1, nodeSplits = 1)
|
||||
#println("After reverting the second Node Reduction:\n", opt)
|
||||
@test length(opt) == (nodeReductions = 1, nodeSplits = 1)
|
||||
|
||||
reset_graph!(graph)
|
||||
|
||||
opt = get_operations(graph)
|
||||
@test length(opt) == (nodeFusions = 6, nodeReductions = 1, nodeSplits = 1)
|
||||
#println("After reverting to the initial state:\n", opt)
|
||||
@test length(opt) == (nodeReductions = 1, nodeSplits = 1)
|
||||
|
||||
@test is_valid(graph)
|
||||
|
@ -1,16 +1,5 @@
|
||||
using MetagraphOptimization
|
||||
|
||||
function test_op_specific(estimator, graph, nf::NodeFusion)
|
||||
estimate = operation_effect(estimator, graph, nf)
|
||||
data_reduce = data(nf.input[2].task)
|
||||
|
||||
@test isapprox(estimate.data, -data_reduce)
|
||||
@test isapprox(estimate.computeEffort, 0; atol = eps(Float64))
|
||||
@test isapprox(estimate.computeIntensity, 0; atol = eps(Float64))
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function test_op_specific(estimator, graph, nr::NodeReduction)
|
||||
estimate = operation_effect(estimator, graph, nr)
|
||||
|
||||
@ -74,13 +63,9 @@ end
|
||||
|
||||
@testset "Operation Cost" begin
|
||||
ops = get_operations(graph)
|
||||
nfs = copy(ops.nodeFusions)
|
||||
nrs = copy(ops.nodeReductions)
|
||||
nss = copy(ops.nodeSplits)
|
||||
|
||||
for nf in nfs
|
||||
test_op(estimator, graph, nf)
|
||||
end
|
||||
for nr in nrs
|
||||
test_op(estimator, graph, nr)
|
||||
end
|
||||
|
@ -136,105 +136,6 @@ TODO: fix precision(?) issues
|
||||
end
|
||||
=#
|
||||
|
||||
@testset "AB->AB large sum fusion" begin
|
||||
for _ in 1:20
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
|
||||
|
||||
# push a fusion with the sum node
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[3].task, ComputeTaskABC_Sum)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
# push two more fusions with the fused node
|
||||
for _ in 1:15
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[3].task, FusedComputeTask)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# try execute
|
||||
@test is_valid(graph)
|
||||
expected_result = ground_truth_graph_result(particles_2_2)
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
@testset "AB->AB large sum fusion" begin
|
||||
for _ in 1:20
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
|
||||
|
||||
# push a fusion with the sum node
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[3].task, ComputeTaskABC_Sum)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
# push two more fusions with the fused node
|
||||
for _ in 1:15
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[3].task, FusedComputeTask)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# try execute
|
||||
@test is_valid(graph)
|
||||
expected_result = ground_truth_graph_result(particles_2_2)
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
|
||||
end
|
||||
end
|
||||
|
||||
@testset "AB->AB fusion edge case" begin
|
||||
for _ in 1:20
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
|
||||
|
||||
# push two fusions with ComputeTaskABC_V
|
||||
for _ in 1:2
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[1].task, ComputeTaskABC_V)
|
||||
push_operation!(graph, fusion)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# push fusions until the end
|
||||
cont = true
|
||||
while cont
|
||||
cont = false
|
||||
ops = get_operations(graph)
|
||||
for fusion in ops.nodeFusions
|
||||
if isa(fusion.input[1].task, FusedComputeTask)
|
||||
push_operation!(graph, fusion)
|
||||
cont = true
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# try execute
|
||||
@test is_valid(graph)
|
||||
expected_result = ground_truth_graph_result(particles_2_2)
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = RTOL)
|
||||
end
|
||||
end
|
||||
|
||||
@testset "$(process) after random walk" for process in ["ke->ke", "ke->kke", "ke->kkke"]
|
||||
process = parse_process("ke->kkke", QEDModel())
|
||||
inputs = [gen_process_input(process) for _ in 1:100]
|
||||
|
@ -13,7 +13,7 @@ graph = MetagraphOptimization.DAG()
|
||||
@test length(graph.operationsToApply) == 0
|
||||
@test length(graph.dirtyNodes) == 0
|
||||
@test length(graph.diff) == (addedNodes = 0, removedNodes = 0, addedEdges = 0, removedEdges = 0)
|
||||
@test length(get_operations(graph)) == (nodeFusions = 0, nodeReductions = 0, nodeSplits = 0)
|
||||
@test length(get_operations(graph)) == (nodeReductions = 0, nodeSplits = 0)
|
||||
|
||||
# s to output (exit node)
|
||||
d_exit = insert_node!(graph, make_node(DataTask(10)), track = false)
|
||||
@ -133,13 +133,12 @@ insert_edge!(graph, s0, d_exit, track = false)
|
||||
@test length(siblings(s0)) == 1
|
||||
|
||||
operations = get_operations(graph)
|
||||
@test length(operations) == (nodeFusions = 10, nodeReductions = 0, nodeSplits = 0)
|
||||
@test length(operations) == (nodeReductions = 0, nodeSplits = 0)
|
||||
@test length(graph.dirtyNodes) == 0
|
||||
|
||||
@test sum(length(operations)) == 10
|
||||
|
||||
@test operations == get_operations(graph)
|
||||
nf = first(operations.nodeFusions)
|
||||
|
||||
properties = get_properties(graph)
|
||||
@test properties.computeEffort == 28
|
||||
@ -148,46 +147,11 @@ properties = get_properties(graph)
|
||||
@test properties.noNodes == 26
|
||||
@test properties.noEdges == 25
|
||||
|
||||
push_operation!(graph, nf)
|
||||
# **does not immediately apply the operation**
|
||||
|
||||
@test length(graph.nodes) == 26
|
||||
@test length(graph.appliedOperations) == 0
|
||||
@test length(graph.operationsToApply) == 1
|
||||
@test first(graph.operationsToApply) == nf
|
||||
@test length(graph.dirtyNodes) == 0
|
||||
@test length(graph.diff) == (addedNodes = 0, removedNodes = 0, addedEdges = 0, removedEdges = 0)
|
||||
|
||||
# this applies pending operations
|
||||
properties = get_properties(graph)
|
||||
|
||||
@test length(graph.nodes) == 24
|
||||
@test length(graph.appliedOperations) == 1
|
||||
@test length(graph.operationsToApply) == 0
|
||||
@test length(graph.dirtyNodes) != 0
|
||||
@test properties.noNodes == 24
|
||||
@test properties.noEdges == 23
|
||||
@test properties.computeEffort == 28
|
||||
@test properties.data < 62
|
||||
@test properties.computeIntensity > 28 / 62
|
||||
|
||||
operations = get_operations(graph)
|
||||
@test length(graph.dirtyNodes) == 0
|
||||
|
||||
@test length(operations) == (nodeFusions = 9, nodeReductions = 0, nodeSplits = 0)
|
||||
@test !isempty(operations)
|
||||
|
||||
possibleNF = 9
|
||||
while !isempty(operations.nodeFusions)
|
||||
push_operation!(graph, first(operations.nodeFusions))
|
||||
global operations = get_operations(graph)
|
||||
global possibleNF = possibleNF - 1
|
||||
@test length(operations) == (nodeFusions = possibleNF, nodeReductions = 0, nodeSplits = 0)
|
||||
end
|
||||
|
||||
@test length(operations) == (nodeReductions = 0, nodeSplits = 0)
|
||||
@test isempty(operations)
|
||||
|
||||
@test length(operations) == (nodeFusions = 0, nodeReductions = 0, nodeSplits = 0)
|
||||
@test length(graph.dirtyNodes) == 0
|
||||
@test length(graph.nodes) == 6
|
||||
@test length(graph.appliedOperations) == 10
|
||||
@ -208,6 +172,6 @@ properties = get_properties(graph)
|
||||
@test properties.computeIntensity ≈ 28 / 62
|
||||
|
||||
operations = get_operations(graph)
|
||||
@test length(operations) == (nodeFusions = 10, nodeReductions = 0, nodeSplits = 0)
|
||||
@test length(operations) == (nodeReductions = 0, nodeSplits = 0)
|
||||
|
||||
@test is_valid(graph)
|
||||
|
@ -6,8 +6,7 @@ RNG = Random.MersenneTwister(0)
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->ABBB.txt"), ABCModel())
|
||||
|
||||
# create the optimizers
|
||||
FIXPOINT_OPTIMIZERS =
|
||||
[GreedyOptimizer(GlobalMetricEstimator()), ReductionOptimizer(), SplitOptimizer(), FusionOptimizer()]
|
||||
FIXPOINT_OPTIMIZERS = [GreedyOptimizer(GlobalMetricEstimator()), ReductionOptimizer(), SplitOptimizer()]
|
||||
NO_FIXPOINT_OPTIMIZERS = [RandomWalkOptimizer(RNG)]
|
||||
|
||||
@testset "Optimizer $optimizer" for optimizer in vcat(NO_FIXPOINT_OPTIMIZERS, FIXPOINT_OPTIMIZERS)
|
||||
|
Loading…
x
Reference in New Issue
Block a user