Compare commits
1 Commits
main
...
heterogene
Author | SHA1 | Date | |
---|---|---|---|
e5d214a6fc |
@ -14,7 +14,6 @@ export Edge
|
||||
export ComputeTaskNode
|
||||
export DataTaskNode
|
||||
export AbstractTask
|
||||
export AbstractComputeTask
|
||||
export AbstractDataTask
|
||||
export DataTask
|
||||
export FusedComputeTask
|
||||
@ -24,8 +23,8 @@ export GraphProperties
|
||||
# graph functions
|
||||
export make_node
|
||||
export make_edge
|
||||
export insert_node
|
||||
export insert_edge
|
||||
export insert_node!
|
||||
export insert_edge!
|
||||
export is_entry_node
|
||||
export is_exit_node
|
||||
export parents
|
||||
|
@ -1,3 +1,5 @@
|
||||
using DataStructures
|
||||
|
||||
"""
|
||||
in(node::Node, graph::DAG)
|
||||
|
||||
@ -19,3 +21,54 @@ function in(edge::Edge, graph::DAG)
|
||||
|
||||
return n1 in children(n2)
|
||||
end
|
||||
|
||||
"""
|
||||
is_dependent(graph::DAG, n1::Node, n2::Node)
|
||||
|
||||
Returns whether `n1` is dependent on `n2` in the given `graph`.
|
||||
"""
|
||||
function is_dependent(graph::DAG, n1::Node, n2::Node)
|
||||
if !(n1 in graph) || !(n2 in graph)
|
||||
return false
|
||||
end
|
||||
if n1 == n2
|
||||
return false
|
||||
end
|
||||
|
||||
queue = Deque{Node}()
|
||||
push!(queue, n1)
|
||||
|
||||
# TODO: this is probably not the best way to do this, deduplication in the queue would help
|
||||
while !isempty(queue)
|
||||
current = popfirst!(queue)
|
||||
if current == n2
|
||||
return true
|
||||
end
|
||||
|
||||
for c in current.children
|
||||
push!(queue, c)
|
||||
end
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
"""
|
||||
pairwise_independent(graph::DAG, nodes::Vector{Node})
|
||||
|
||||
Returns true if all given `nodes` are independent of each other, i.e., no path exists from any member of `nodes` to any other member.
|
||||
|
||||
See [`is_dependent`](@ref)
|
||||
"""
|
||||
function pairwise_independent(graph::DAG, nodes::Vector{<:Node})
|
||||
checked_set = Vector{Node}()
|
||||
for n in nodes
|
||||
for c in checked_set
|
||||
if is_dependent(graph, c, n) || is_dependent(graph, n, c)
|
||||
return false
|
||||
end
|
||||
end
|
||||
push!(checked_set, n)
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
@ -46,12 +46,12 @@ Insert the edge between node1 (child) and node2 (parent) into the graph.
|
||||
See also: [`insert_node!`](@ref), [`remove_node!`](@ref), [`remove_edge!`](@ref)
|
||||
"""
|
||||
function insert_edge!(graph::DAG, node1::Node, node2::Node; track = true, invalidate_cache = true)
|
||||
#@assert (node2 ∉ parents(node1)) && (node1 ∉ children(node2)) "Edge to insert already exists"
|
||||
@assert (node2 ∉ parents(node1)) && (node1 ∉ children(node2)) "Edge to insert already exists"
|
||||
|
||||
# 1: mute
|
||||
# edge points from child to parent
|
||||
push!(node1.parents, node2)
|
||||
push!(node2.children, node1)
|
||||
add_parent!(node1, node2)
|
||||
add_child!(node2, node1)
|
||||
|
||||
# 2: keep track
|
||||
if (track)
|
||||
@ -85,7 +85,7 @@ Remove the node from the graph.
|
||||
See also: [`insert_node!`](@ref), [`insert_edge!`](@ref), [`remove_edge!`](@ref)
|
||||
"""
|
||||
function remove_node!(graph::DAG, node::Node; track = true, invalidate_cache = true)
|
||||
#@assert node in graph.nodes "Trying to remove a node that's not in the graph"
|
||||
@assert node in graph.nodes "Trying to remove a node that's not in the graph"
|
||||
|
||||
# 1: mute
|
||||
delete!(graph.nodes, node)
|
||||
@ -123,19 +123,8 @@ function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invali
|
||||
pre_length1 = length(node1.parents)
|
||||
pre_length2 = length(node2.children)
|
||||
|
||||
for i in eachindex(node1.parents)
|
||||
if (node1.parents[i] == node2)
|
||||
splice!(node1.parents, i)
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
for i in eachindex(node2.children)
|
||||
if (node2.children[i] == node1)
|
||||
splice!(node2.children, i)
|
||||
break
|
||||
end
|
||||
end
|
||||
remove_parent!(node1, node2)
|
||||
remove_child!(node2, node1)
|
||||
|
||||
#=@assert begin
|
||||
removed = pre_length1 - length(node1.parents)
|
||||
@ -173,17 +162,17 @@ function replace_children!(task::FusedComputeTask, before, after)
|
||||
replacedIn1 = length(findall(x -> x == before, task.t1_inputs))
|
||||
replacedIn2 = length(findall(x -> x == before, task.t2_inputs))
|
||||
|
||||
#@assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
|
||||
@assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
|
||||
|
||||
replace!(task.t1_inputs, before => after)
|
||||
replace!(task.t2_inputs, before => after)
|
||||
|
||||
# recursively descend down the tree, but only in the tasks where we're replacing things
|
||||
if replacedIn1 > 0
|
||||
replace_children!(task.first_task, before, after)
|
||||
replace_children!(task.first_func, before, after)
|
||||
end
|
||||
if replacedIn2 > 0
|
||||
replace_children!(task.second_task, before, after)
|
||||
replace_children!(task.second_func, before, after)
|
||||
end
|
||||
|
||||
return nothing
|
||||
|
@ -10,6 +10,7 @@ mutable struct PossibleOperations
|
||||
nodeFusions::Set{NodeFusion}
|
||||
nodeReductions::Set{NodeReduction}
|
||||
nodeSplits::Set{NodeSplit}
|
||||
nodeVectorizations::Set{NodeVectorization}
|
||||
end
|
||||
|
||||
"""
|
||||
@ -52,7 +53,7 @@ end
|
||||
Construct and return an empty [`PossibleOperations`](@ref) object.
|
||||
"""
|
||||
function PossibleOperations()
|
||||
return PossibleOperations(Set{NodeFusion}(), Set{NodeReduction}(), Set{NodeSplit}())
|
||||
return PossibleOperations(Set{NodeFusion}(), Set{NodeReduction}(), Set{NodeSplit}(), Set{NodeVectorization}())
|
||||
end
|
||||
|
||||
"""
|
||||
|
@ -181,7 +181,7 @@ function parse_dag(filename::AbstractString, model::ABCModel, verbose::Bool = fa
|
||||
insert_edge!(graph, compute_S2, data_out, track = false, invalidate_cache = false)
|
||||
|
||||
insert_edge!(graph, data_out, sum_node, track = false, invalidate_cache = false)
|
||||
add_child!(task(sum_node))
|
||||
add_child!(task(sum_node).func)
|
||||
elseif occursin(regex_plus, node)
|
||||
if (verbose)
|
||||
println("\rReading Nodes Complete ")
|
||||
|
@ -78,7 +78,6 @@ Return the number of children of a ComputeTaskABC_V (always 2).
|
||||
"""
|
||||
children(::ComputeTaskABC_V) = 2
|
||||
|
||||
|
||||
"""
|
||||
children(::ComputeTaskABC_Sum)
|
||||
|
||||
|
@ -1,44 +1,44 @@
|
||||
"""
|
||||
ComputeTaskABC_S1 <: AbstractComputeTask
|
||||
ComputeTaskABC_S1 <: AbstractTaskFunction
|
||||
|
||||
S task with a single child.
|
||||
"""
|
||||
struct ComputeTaskABC_S1 <: AbstractComputeTask end
|
||||
struct ComputeTaskABC_S1 <: AbstractTaskFunction end
|
||||
|
||||
"""
|
||||
ComputeTaskABC_S2 <: AbstractComputeTask
|
||||
ComputeTaskABC_S2 <: AbstractTaskFunction
|
||||
|
||||
S task with two children.
|
||||
"""
|
||||
struct ComputeTaskABC_S2 <: AbstractComputeTask end
|
||||
struct ComputeTaskABC_S2 <: AbstractTaskFunction end
|
||||
|
||||
"""
|
||||
ComputeTaskABC_P <: AbstractComputeTask
|
||||
ComputeTaskABC_P <: AbstractTaskFunction
|
||||
|
||||
P task with no children.
|
||||
"""
|
||||
struct ComputeTaskABC_P <: AbstractComputeTask end
|
||||
struct ComputeTaskABC_P <: AbstractTaskFunction end
|
||||
|
||||
"""
|
||||
ComputeTaskABC_V <: AbstractComputeTask
|
||||
ComputeTaskABC_V <: AbstractTaskFunction
|
||||
|
||||
v task with two children.
|
||||
"""
|
||||
struct ComputeTaskABC_V <: AbstractComputeTask end
|
||||
struct ComputeTaskABC_V <: AbstractTaskFunction end
|
||||
|
||||
"""
|
||||
ComputeTaskABC_U <: AbstractComputeTask
|
||||
ComputeTaskABC_U <: AbstractTaskFunction
|
||||
|
||||
u task with a single child.
|
||||
"""
|
||||
struct ComputeTaskABC_U <: AbstractComputeTask end
|
||||
struct ComputeTaskABC_U <: AbstractTaskFunction end
|
||||
|
||||
"""
|
||||
ComputeTaskABC_Sum <: AbstractComputeTask
|
||||
ComputeTaskABC_Sum <: AbstractTaskFunction
|
||||
|
||||
Task that sums all its inputs, n children.
|
||||
"""
|
||||
mutable struct ComputeTaskABC_Sum <: AbstractComputeTask
|
||||
mutable struct ComputeTaskABC_Sum <: AbstractTaskFunction
|
||||
children_number::Int
|
||||
end
|
||||
|
||||
|
@ -66,7 +66,7 @@ function compute(
|
||||
data1::ParticleValue{P1},
|
||||
data2::ParticleValue{P2},
|
||||
) where {P1 <: Union{AntiFermionStateful, FermionStateful}, P2 <: Union{AntiFermionStateful, FermionStateful}}
|
||||
#@assert isapprox(data1.p.momentum, data2.p.momentum, rtol = sqrt(eps()), atol = sqrt(eps())) "$(data1.p.momentum) vs. $(data2.p.momentum)"
|
||||
@assert isapprox(data1.p.momentum, data2.p.momentum, rtol = sqrt(eps()), atol = sqrt(eps())) "$(data1.p.momentum) vs. $(data2.p.momentum)"
|
||||
|
||||
inner = QED_inner_edge(propagation_result(P1)(momentum(data1.p)))
|
||||
|
||||
|
@ -188,7 +188,7 @@ end
|
||||
Apply a [`FeynmanVertex`](@ref) to the given vector of [`FeynmanParticle`](@ref)s.
|
||||
"""
|
||||
function apply_vertex!(particles::Vector{FeynmanParticle}, vertex::FeynmanVertex)
|
||||
#@assert can_apply_vertex(particles, vertex)
|
||||
@assert can_apply_vertex(particles, vertex)
|
||||
length_before = length(particles)
|
||||
filter!(x -> x != vertex.in1 && x != vertex.in2, particles)
|
||||
push!(particles, vertex.out)
|
||||
|
@ -1,44 +1,44 @@
|
||||
"""
|
||||
ComputeTaskQED_S1 <: AbstractComputeTask
|
||||
ComputeTaskQED_S1 <: AbstractTaskFunction
|
||||
|
||||
S task with a single child.
|
||||
"""
|
||||
struct ComputeTaskQED_S1 <: AbstractComputeTask end
|
||||
struct ComputeTaskQED_S1 <: AbstractTaskFunction end
|
||||
|
||||
"""
|
||||
ComputeTaskQED_S2 <: AbstractComputeTask
|
||||
ComputeTaskQED_S2 <: AbstractTaskFunction
|
||||
|
||||
S task with two children.
|
||||
"""
|
||||
struct ComputeTaskQED_S2 <: AbstractComputeTask end
|
||||
struct ComputeTaskQED_S2 <: AbstractTaskFunction end
|
||||
|
||||
"""
|
||||
ComputeTaskQED_P <: AbstractComputeTask
|
||||
ComputeTaskQED_P <: AbstractTaskFunction
|
||||
|
||||
P task with no children.
|
||||
"""
|
||||
struct ComputeTaskQED_P <: AbstractComputeTask end
|
||||
struct ComputeTaskQED_P <: AbstractTaskFunction end
|
||||
|
||||
"""
|
||||
ComputeTaskQED_V <: AbstractComputeTask
|
||||
ComputeTaskQED_V <: AbstractTaskFunction
|
||||
|
||||
v task with two children.
|
||||
"""
|
||||
struct ComputeTaskQED_V <: AbstractComputeTask end
|
||||
struct ComputeTaskQED_V <: AbstractTaskFunction end
|
||||
|
||||
"""
|
||||
ComputeTaskQED_U <: AbstractComputeTask
|
||||
ComputeTaskQED_U <: AbstractTaskFunction
|
||||
|
||||
u task with a single child.
|
||||
"""
|
||||
struct ComputeTaskQED_U <: AbstractComputeTask end
|
||||
struct ComputeTaskQED_U <: AbstractTaskFunction end
|
||||
|
||||
"""
|
||||
ComputeTaskQED_Sum <: AbstractComputeTask
|
||||
ComputeTaskQED_Sum <: AbstractTaskFunction
|
||||
|
||||
Task that sums all its inputs, n children.
|
||||
"""
|
||||
mutable struct ComputeTaskQED_Sum <: AbstractComputeTask
|
||||
mutable struct ComputeTaskQED_Sum <: AbstractTaskFunction
|
||||
children_number::Int
|
||||
end
|
||||
|
||||
|
@ -1,12 +1,22 @@
|
||||
|
||||
DataTaskNode(t::AbstractDataTask, name = "") =
|
||||
DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, missing, name)
|
||||
DataTaskNode(t::AbstractDataTask, name = "") = DataTaskNode(
|
||||
t,
|
||||
Vector{Node}(),
|
||||
Vector{Node}(),
|
||||
UUIDs.uuid1(rng[threadid()]),
|
||||
missing,
|
||||
missing,
|
||||
missing,
|
||||
missing,
|
||||
name,
|
||||
)
|
||||
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
|
||||
t, # task
|
||||
Vector{Node}(), # parents
|
||||
Vector{Node}(), # children
|
||||
UUIDs.uuid1(rng[threadid()]), # id
|
||||
missing, # node reduction
|
||||
missing, # node vectorization
|
||||
missing, # node split
|
||||
Vector{NodeFusion}(), # node fusions
|
||||
missing, # device
|
||||
@ -39,9 +49,14 @@ end
|
||||
|
||||
Construct and return a new [`ComputeTaskNode`](@ref) with the given task.
|
||||
"""
|
||||
function make_node(t::AbstractComputeTask)
|
||||
return ComputeTaskNode(t)
|
||||
end
|
||||
make_node(t::AbstractComputeTask) = ComputeTaskNode(t)
|
||||
|
||||
"""
|
||||
make_node(t::AbstractTaskFunction)
|
||||
|
||||
Construct and return a new [`ComputeTaskNode`](@ref) with a default [`ComputeTask`](@ref) with the given task function.
|
||||
"""
|
||||
make_node(t::AbstractTaskFunction) = ComputeTaskNode(ComputeTask(t))
|
||||
|
||||
"""
|
||||
make_edge(n1::Node, n2::Node)
|
||||
|
@ -121,3 +121,157 @@ Return whether the `potential_child` is a child of `node`.
|
||||
function is_child(potential_child::Node, node::Node)::Bool
|
||||
return potential_child in children(node)
|
||||
end
|
||||
|
||||
function add_child!(n::DataTaskNode{<:AbstractDataTask}, child::Node)
|
||||
push!(n.children, child)
|
||||
return nothing
|
||||
end
|
||||
|
||||
function add_parent!(n::DataTaskNode{<:AbstractDataTask}, parent::Node)
|
||||
push!(n.parents, parent)
|
||||
return nothing
|
||||
end
|
||||
|
||||
function remove_child!(n::DataTaskNode{<:AbstractDataTask}, child::Node)
|
||||
for i in eachindex(n.children)
|
||||
if (n.children[i] == child)
|
||||
splice!(n.children, i)
|
||||
break
|
||||
end
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function remove_parent!(n::DataTaskNode{<:AbstractDataTask}, parent::Node)
|
||||
for i in eachindex(n.parents)
|
||||
if (n.parents[i] == parent)
|
||||
splice!(n.parents, i)
|
||||
break
|
||||
end
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
add_child(n::ComputeTaskNode{<:ComputeTask}, child::Node)
|
||||
|
||||
Add a child to the compute node.
|
||||
"""
|
||||
function add_child!(n::ComputeTaskNode{<:ComputeTask}, child::Node)
|
||||
push!(n.children, child)
|
||||
push!(n.task.arguments, Argument(child.id))
|
||||
return nothing
|
||||
end
|
||||
|
||||
function add_parent!(n::ComputeTaskNode{<:ComputeTask}, parent::Node)
|
||||
push!(n.parents, parent)
|
||||
return nothing
|
||||
end
|
||||
|
||||
function remove_child!(n::ComputeTaskNode{<:ComputeTask}, child::Node)
|
||||
for i in eachindex(n.children)
|
||||
if (n.children[i] == child)
|
||||
splice!(n.children, i)
|
||||
@assert n.task.func.arguments[i].id == child.id
|
||||
splice!(n.task.func.arguments, i)
|
||||
break
|
||||
end
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function remove_parent!(n::ComputeTaskNode{<:ComputeTask}, parent::Node)
|
||||
for i in eachindex(n.parents)
|
||||
if (n.parents[i] == parent)
|
||||
splice!(n.parents, i)
|
||||
break
|
||||
end
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
|
||||
|
||||
function add_child!(n::ComputeTaskNode{<:FusedComputeTask}, child::Node, which::Symbol = :first)
|
||||
push!(n.children, child)
|
||||
|
||||
if which == :first || which == :both
|
||||
push!(n.task.first_func_arguments, child.id)
|
||||
if which == :second || which == :both
|
||||
push!(n.task.second_func_arguments, child.id)
|
||||
end
|
||||
|
||||
if which != :first && which != :second && which != :both
|
||||
@assert false "Tried to add child to symbol $(which) of fused compute task, but only :both, :first, and :second are allowed"
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function add_parent!(n::ComputeTaskNode{<:FusedComputeTask}, parent::Node)
|
||||
push!(n.parents, parent)
|
||||
return nothing
|
||||
end
|
||||
|
||||
function remove_child!(n::ComputeTaskNode{<:FusedComputeTask}, child::Node)
|
||||
for i in eachindex(n.children)
|
||||
if (n.children[i] == child)
|
||||
splice!(n.children, i)
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
for field in (:first_func_arguments, :second_func_arguments)
|
||||
for i in eachindex(getfield(n.task, field))
|
||||
if (getfield(n.task, field)[i].id == child.id)
|
||||
splice!(getfield(n.task, field), i)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function remove_parent!(n::ComputeTaskNode{<:FusedComputeTask}, parent::Node)
|
||||
for i in eachindex(n.parents)
|
||||
if (n.parents[i] == parent)
|
||||
splice!(n.parents, i)
|
||||
break
|
||||
end
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
|
||||
|
||||
function add_child!(n::ComputeTaskNode{<:VectorizedComputeTask}, child::Node)
|
||||
push!(n.children, child)
|
||||
push!(n.task.arguments, Argument(child.id))
|
||||
return nothing
|
||||
end
|
||||
|
||||
function add_parent!(n::ComputeTaskNode{<:VectorizedComputeTask}, parent::Node)
|
||||
push!(n.parents, parent)
|
||||
return nothing
|
||||
end
|
||||
|
||||
function remove_child!(n::ComputeTaskNode{<:VectorizedComputeTask}, child::Node)
|
||||
for i in eachindex(n.children)
|
||||
if (n.children[i] == child)
|
||||
splice!(n.children, i)
|
||||
@assert n.task.func.arguments[i].id == child.id
|
||||
splice!(n.task.func.arguments, i)
|
||||
break
|
||||
end
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function remove_parent!(n::ComputeTaskNode{<:VectorizedComputeTask}, parent::Node)
|
||||
for i in eachindex(n.parents)
|
||||
if (n.parents[i] == parent)
|
||||
splice!(n.parents, i)
|
||||
break
|
||||
end
|
||||
end
|
||||
return nothing
|
||||
end
|
||||
|
@ -24,14 +24,15 @@ abstract type Operation end
|
||||
Any node that transfers data and does no computation.
|
||||
|
||||
# Fields
|
||||
`.task`: The node's data task type. Usually [`DataTask`](@ref).\\
|
||||
`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
|
||||
`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
|
||||
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
|
||||
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.\\
|
||||
`.name`: The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\
|
||||
`.task`: The node's data task type. Usually [`DataTask`](@ref).\\
|
||||
`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
|
||||
`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
|
||||
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
|
||||
`.nodeVectorization`: Always `missing`, since a data node can't be vectorized.\\
|
||||
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.\\
|
||||
`.name`: The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\
|
||||
"""
|
||||
mutable struct DataTaskNode{TaskType <: AbstractDataTask} <: Node
|
||||
task::TaskType
|
||||
@ -48,6 +49,9 @@ mutable struct DataTaskNode{TaskType <: AbstractDataTask} <: Node
|
||||
# Can't use the NodeReduction type here because it's not yet defined
|
||||
nodeReduction::Union{Operation, Missing}
|
||||
|
||||
# data nodes can't be vectorized
|
||||
nodeVectorization::Missing
|
||||
|
||||
# the NodeSplit involving this node, if it exists
|
||||
nodeSplit::Union{Operation, Missing}
|
||||
|
||||
@ -64,22 +68,25 @@ end
|
||||
Any node that computes a result from inputs using an [`AbstractComputeTask`](@ref).
|
||||
|
||||
# Fields
|
||||
`.task`: The node's compute task type. A concrete subtype of [`AbstractComputeTask`](@ref).\\
|
||||
`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
|
||||
`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
|
||||
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
|
||||
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeFusions`: A vector of this node's [`NodeFusion`](@ref)s. For a `ComputeTaskNode` there can be any number of these, unlike the [`DataTaskNode`](@ref)s.\\
|
||||
`.device`: The Device this node has been scheduled on by a [`Scheduler`](@ref).
|
||||
`.task`: The node's compute task type. A concrete subtype of [`AbstractComputeTask`](@ref).\\
|
||||
`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
|
||||
`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
|
||||
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
|
||||
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeVectorization`: Either this node's [`NodeVectorization`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
|
||||
`.nodeFusions`: A vector of this node's [`NodeFusion`](@ref)s. For a `ComputeTaskNode` there can be any number of these, unlike the [`DataTaskNode`](@ref)s.\\
|
||||
`.device`: The Device this node has been scheduled on by a [`Scheduler`](@ref).
|
||||
"""
|
||||
mutable struct ComputeTaskNode{TaskType <: AbstractComputeTask} <: Node
|
||||
task::TaskType
|
||||
|
||||
parents::Vector{Node}
|
||||
children::Vector{Node}
|
||||
id::Base.UUID
|
||||
|
||||
nodeReduction::Union{Operation, Missing}
|
||||
nodeVectorization::Union{Operation, Missing}
|
||||
nodeSplit::Union{Operation, Missing}
|
||||
|
||||
# for ComputeTasks there can be multiple fusions, unlike the DataTasks
|
||||
|
@ -187,7 +187,7 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com
|
||||
remove_node!(graph, n3)
|
||||
|
||||
# create new node with the fused compute task
|
||||
newNode = ComputeTaskNode(FusedComputeTask(n1Task, n3Task, n1Inputs, Symbol(to_var_name(n2.id)), n3Inputs))
|
||||
newNode = ComputeTaskNode(FusedComputeTask(n1Task.func, n3Task.func, Symbol(to_var_name(n2.id))))
|
||||
insert_node!(graph, newNode)
|
||||
|
||||
for child in n1Children
|
||||
|
@ -239,6 +239,31 @@ function generate_operations(graph::DAG)
|
||||
|
||||
empty!(graph.dirtyNodes)
|
||||
|
||||
# find node vectorizations
|
||||
# assume that relevant tasks will be at the same depth in the graph
|
||||
"""
|
||||
nodes = Queue{Tuple{<:Node, Int64}}()
|
||||
node_set = Set{Node}()
|
||||
currentDepth = 0
|
||||
push!(nodes, (get_exit_node(graph), 0))
|
||||
while !isempty(nodes)
|
||||
(node, depth) = popfirst!(nodes)
|
||||
pushback!.(node.parents, Ref(depth + 1))
|
||||
|
||||
if depth == currentDepth
|
||||
push!(node_set, node)
|
||||
continue
|
||||
end
|
||||
|
||||
# collected all nodes of the same depth, sort into compute task types and create vectorizations
|
||||
#TODO
|
||||
|
||||
# reset node_set
|
||||
node_set = Set{Node}()
|
||||
push!(node_set, node)
|
||||
end
|
||||
"""
|
||||
|
||||
wait(nr_task)
|
||||
wait(nf_task)
|
||||
wait(ns_task)
|
||||
|
@ -56,3 +56,15 @@ function show(io::IO, op::NodeFusion)
|
||||
print(io, "->")
|
||||
return print(io, task(op.input[3]))
|
||||
end
|
||||
|
||||
"""
|
||||
show(io::IO, op::NodeVectorization)
|
||||
|
||||
Print a string representation of the node vectorization to io.
|
||||
"""
|
||||
function show(io::IO, op::NodeVectorization)
|
||||
print(io, "NV: ")
|
||||
print(io, length(op.input))
|
||||
print(io, "x")
|
||||
return print(io, task(op.input[1]))
|
||||
end
|
||||
|
@ -120,3 +120,12 @@ struct AppliedNodeSplit{NodeType <: Node} <: AppliedOperation
|
||||
operation::NodeSplit{NodeType}
|
||||
diff::Diff
|
||||
end
|
||||
|
||||
struct NodeVectorization{TaskType <: AbstractComputeTask} <: Operation
|
||||
input::Vector{ComputeTaskNode{TaskType}}
|
||||
end
|
||||
|
||||
struct AppliedNodeVectorization{TaskType <: AbstractComputeTask} <: AppliedOperation
|
||||
operation::NodeVectorization{TaskType}
|
||||
diff::Diff
|
||||
end
|
||||
|
@ -171,3 +171,13 @@ Equality comparison between two node splits. Two node splits are considered equa
|
||||
function ==(op1::NodeSplit, op2::NodeSplit)
|
||||
return op1.input == op2.input
|
||||
end
|
||||
|
||||
"""
|
||||
==(op1::NodeVectorization, op2::NodeVectorization)
|
||||
|
||||
Equality comparison between two node vectorizations. Two node vectorizations are considered equal when they have the same inputs.
|
||||
"""
|
||||
function ==(op1::NodeVectorization, op2::NodeVectorization)
|
||||
# node vectorizations are equal exactly if their first input is the same
|
||||
return op1.input[1].id == op2.input[1].id
|
||||
end
|
||||
|
@ -17,7 +17,7 @@ function is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTas
|
||||
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
|
||||
throw(
|
||||
AssertionError(
|
||||
"[Node Fusion] The given nodes are not connected by edges which is required for node fusion",
|
||||
"[Node Fusion] The given nodes are not connected by edges which is required for node fusion\n[Node Fusion] n1 ($(n1)) parents: $(parents(n1))\n[Node Fusion] n2 ($(n2)) children: $(children(n2)), n2 parents: $(parents(n2))\n[Node Fusion] n3 ($(n3)) children: $(children(n3))",
|
||||
),
|
||||
)
|
||||
end
|
||||
@ -106,6 +106,35 @@ function is_valid_node_split_input(graph::DAG, n1::Node)
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
is_valid_node_vectorization_input(graph::DAG, nodes::Vector{Node})
|
||||
|
||||
Assert for a gven node vectorization input whether the nodes can be vectorized. For the requirements of a node vectorization see [`NodeVectorization`](@ref).
|
||||
|
||||
Intended for use with `@assert` or `@test`.
|
||||
"""
|
||||
function is_valid_node_vectorization_input(graph::DAG, nodes::Vector{Node})
|
||||
for n in nodes
|
||||
if n ∉ graph
|
||||
throw(AssertionError("[Node Vectorization] The given nodes are not part of the given graph"))
|
||||
end
|
||||
@assert is_valid(graph, n)
|
||||
end
|
||||
|
||||
t = typeof(task(nodes[1]))
|
||||
for n in nodes
|
||||
if typeof(task(n)) != t
|
||||
throw(AssertionError("[Node Vectorization] The given nodes are not of the same type"))
|
||||
end
|
||||
end
|
||||
|
||||
if !pairwise_independent(graph, nodes)
|
||||
throw(AssertionError("[Node Vectorization] The given nodes are not pairwise independent of one another"))
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
is_valid(graph::DAG, nr::NodeReduction)
|
||||
|
||||
@ -144,3 +173,15 @@ function is_valid(graph::DAG, nf::NodeFusion)
|
||||
#@assert nf in graph.possibleOperations.nodeFusions "NodeFusion is not part of the graph's possible operations!"
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
is_valid(graph::DAG, nr::NodeVectorization)
|
||||
|
||||
Assert for a given [`NodeVectorization`](@ref) whether it is a valid operation in the graph.
|
||||
|
||||
Intended for use with `@assert` or `@test`.
|
||||
"""
|
||||
function is_valid(graph::DAG, nv::NodeVectorization)
|
||||
@assert is_valid_node_vectorization_input(graph, nv.input)
|
||||
return true
|
||||
end
|
||||
|
@ -8,11 +8,11 @@ function ==(t1::AbstractTask, t2::AbstractTask)
|
||||
end
|
||||
|
||||
"""
|
||||
==(t1::AbstractComputeTask, t2::AbstractComputeTask)
|
||||
==(t1::AbstractTaskFunction, t2::AbstractTaskFunction)
|
||||
|
||||
Equality comparison between two compute tasks.
|
||||
"""
|
||||
function ==(t1::AbstractComputeTask, t2::AbstractComputeTask)
|
||||
function ==(t1::AbstractTaskFunction, t2::AbstractTaskFunction)
|
||||
return typeof(t1) == typeof(t2)
|
||||
end
|
||||
|
||||
|
@ -6,8 +6,7 @@ using StaticArrays
|
||||
Compute a [`FusedComputeTask`](@ref). This simply asserts false and should not be called. Fused Compute Tasks generate their expressions directly through the other tasks instead.
|
||||
"""
|
||||
function compute(t::FusedComputeTask, data...)
|
||||
inter = compute(t.first_task)
|
||||
return compute(t.second_task, inter, data2...)
|
||||
@assert false
|
||||
end
|
||||
|
||||
"""
|
||||
@ -21,8 +20,8 @@ For ordinary compute or data tasks the vector will contain exactly one element,
|
||||
function get_function_call(t::FusedComputeTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol)
|
||||
# sort out the symbols to the correct tasks
|
||||
return [
|
||||
get_function_call(t.first_task, device, t.t1_inputs, t.t1_output)...,
|
||||
get_function_call(t.second_task, device, [t.t2_inputs..., t.t1_output], outSymbol)...,
|
||||
get_function_call(t.first_func, device, t.t1_inputs, t.t1_output)...,
|
||||
get_function_call(t.second_func, device, [t.t2_inputs..., t.t1_output], outSymbol)...,
|
||||
]
|
||||
end
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
Argument(id::Base.UUID) = Argument(id, 0)
|
||||
|
||||
"""
|
||||
copy(t::AbstractDataTask)
|
||||
|
||||
@ -10,7 +12,12 @@ copy(t::AbstractDataTask) = error("Need to implement copying for your data tasks
|
||||
|
||||
Return a copy of the given compute task.
|
||||
"""
|
||||
copy(t::AbstractComputeTask) = typeof(t)()
|
||||
copy(t::AbstractComputeTask) = typeof(t)(copy(t.func), copy(t.arguments))
|
||||
copy(t::AbstractTaskFunction) = typeof(t)()
|
||||
|
||||
function ComputeTask(t::AbstractTaskFunction)
|
||||
return ComputeTask(t, Vector{Argument}())
|
||||
end
|
||||
|
||||
"""
|
||||
copy(t::FusedComputeTask)
|
||||
@ -18,15 +25,9 @@ copy(t::AbstractComputeTask) = typeof(t)()
|
||||
Return a copy of th egiven [`FusedComputeTask`](@ref).
|
||||
"""
|
||||
function copy(t::FusedComputeTask)
|
||||
return FusedComputeTask(copy(t.first_task), copy(t.second_task), copy(t.t1_inputs), t.t1_output, copy(t.t2_inputs))
|
||||
return FusedComputeTask(copy(t.first_func), copy(t.second_func), copy(t.t1_inputs), t.t1_output, copy(t.t2_inputs))
|
||||
end
|
||||
|
||||
function FusedComputeTask(
|
||||
T1::Type{<:AbstractComputeTask},
|
||||
T2::Type{<:AbstractComputeTask},
|
||||
t1_inputs::Vector{String},
|
||||
t1_output::String,
|
||||
t2_inputs::Vector{String},
|
||||
)
|
||||
return FusedComputeTask(T1(), T2(), t1_inputs, t1_output, t2_inputs)
|
||||
function FusedComputeTask(T1::Type{<:AbstractComputeTask}, T2::Type{<:AbstractComputeTask}, t1_output::String)
|
||||
return FusedComputeTask(T1(), T2(), t1_output)
|
||||
end
|
||||
|
@ -3,19 +3,21 @@
|
||||
|
||||
Fallback implementation of the compute function of a compute task, throwing an error.
|
||||
"""
|
||||
function compute(t::AbstractTask, data...)
|
||||
function compute(t::AbstractTaskFunction, data...)
|
||||
return error("Need to implement compute()")
|
||||
end
|
||||
compute(t::ComputeTask, data...) = compute(t.func, data...) # TODO: is this correct?
|
||||
|
||||
"""
|
||||
compute_effort(t::AbstractTask)
|
||||
|
||||
Fallback implementation of the compute effort of a task, throwing an error.
|
||||
"""
|
||||
function compute_effort(t::AbstractTask)::Float64
|
||||
function compute_effort(t::AbstractTaskFunction)::Float64
|
||||
# default implementation using compute
|
||||
return error("Need to implement compute_effort()")
|
||||
end
|
||||
compute_effort(t::AbstractComputeTask)::Float64 = compute_effort(t.func)
|
||||
|
||||
"""
|
||||
data(t::AbstractTask)
|
||||
@ -60,7 +62,7 @@ children(::DataTask) = 1
|
||||
Return the number of children of a FusedComputeTask.
|
||||
"""
|
||||
function children(t::FusedComputeTask)
|
||||
return length(union(Set(t.t1_inputs), Set(t.t2_inputs)))
|
||||
return length(union(Set(t.first_func.arguments), Set(t.second_func.arguments))) - 1
|
||||
end
|
||||
|
||||
"""
|
||||
@ -69,6 +71,7 @@ end
|
||||
Return the data of a compute task, always zero, regardless of the specific task.
|
||||
"""
|
||||
data(t::AbstractComputeTask)::Float64 = 0.0
|
||||
data(t::AbstractTaskFunction)::Float64 = 0.0
|
||||
|
||||
"""
|
||||
compute_effort(t::FusedComputeTask)
|
||||
@ -76,7 +79,7 @@ data(t::AbstractComputeTask)::Float64 = 0.0
|
||||
Return the compute effort of a fused compute task.
|
||||
"""
|
||||
function compute_effort(t::FusedComputeTask)::Float64
|
||||
return compute_effort(t.first_task) + compute_effort(t.second_task)
|
||||
return compute_effort(t.first_func) + compute_effort(t.second_func)
|
||||
end
|
||||
|
||||
"""
|
||||
@ -84,4 +87,4 @@ end
|
||||
|
||||
Return a tuple of a the fused compute task's components' types.
|
||||
"""
|
||||
get_types(t::FusedComputeTask) = (typeof(t.first_task), typeof(t.second_task))
|
||||
get_types(t::FusedComputeTask) = (typeof(t.first_func), typeof(t.second_func))
|
||||
|
@ -1,3 +1,15 @@
|
||||
"""
|
||||
Argument
|
||||
|
||||
Representation of one of the arguments of a [`AbstractComputeTask`](@ref) in a Node.
|
||||
Has a member `id::UUID`, which is the id of the compute task's parent the argument refers to, and an `index`, which is the index in the symbol's object. If `index` is 0, the whole object is the argument.
|
||||
A compute task can have multiple arguments.
|
||||
"""
|
||||
struct Argument
|
||||
id::Base.UUID
|
||||
index::Int64
|
||||
end
|
||||
|
||||
"""
|
||||
AbstractTask
|
||||
|
||||
@ -6,11 +18,33 @@ The shared base type for any task.
|
||||
abstract type AbstractTask end
|
||||
|
||||
"""
|
||||
AbstractComputeTask <: AbstractTask
|
||||
AbstractComputeTask
|
||||
|
||||
The shared base type for any compute task.
|
||||
Base type for all compute tasks.
|
||||
|
||||
!!! note
|
||||
A subtype of this *must* be templated with its [`AbstractTaskFunction`](@ref)s!
|
||||
|
||||
See also: [`ComputeTask`](@ref), [`FusedComputeTask`](@ref), [`VectorizedComputeTask`](@ref)
|
||||
"""
|
||||
abstract type AbstractComputeTask <: AbstractTask end
|
||||
abstract type AbstractComputeTask end
|
||||
|
||||
"""
|
||||
AbstractTaskFunction
|
||||
|
||||
The base type for task functions used to dispatch computes.
|
||||
"""
|
||||
abstract type AbstractTaskFunction end
|
||||
|
||||
"""
|
||||
ComputeTask <: AbstractTask
|
||||
|
||||
A compute task. Has a member [`AbstractTaskFunction`](@ref), and a list of [`Argument`](@ref)s.
|
||||
"""
|
||||
struct ComputeTask{TaskFunction <: AbstractTaskFunction} <: AbstractComputeTask
|
||||
func::TaskFunction
|
||||
arguments::Vector{Argument}
|
||||
end
|
||||
|
||||
"""
|
||||
AbstractDataTask <: AbstractTask
|
||||
@ -29,19 +63,26 @@ struct DataTask <: AbstractDataTask
|
||||
end
|
||||
|
||||
"""
|
||||
FusedComputeTask{T1 <: AbstractComputeTask, T2 <: AbstractComputeTask} <: AbstractComputeTask
|
||||
FusedComputeTask <: AbstractComputeTask
|
||||
|
||||
A fused compute task made up of the computation of first `T1` and then `T2`.
|
||||
|
||||
Also see: [`get_types`](@ref).
|
||||
"""
|
||||
struct FusedComputeTask <: AbstractComputeTask
|
||||
first_task::AbstractComputeTask
|
||||
second_task::AbstractComputeTask
|
||||
# the names of the inputs for T1
|
||||
t1_inputs::Vector{Symbol}
|
||||
# output name of T1
|
||||
struct FusedComputeTask{TaskFunction1 <: AbstractTaskFunction, TaskFunction2 <: AbstractTaskFunction} <:
|
||||
AbstractComputeTask
|
||||
first_func::TaskFunction1
|
||||
second_func::TaskFunction2
|
||||
t1_output::Symbol
|
||||
# t2_inputs doesn't include the output of t1, that's implicit
|
||||
t2_inputs::Vector{Symbol}
|
||||
first_func_arguments::Vector{Argument}
|
||||
second_func_arguments::Vector{Argument}
|
||||
end
|
||||
|
||||
"""
|
||||
|
||||
|
||||
"""
|
||||
struct VectorizedComputeTask{TaskFunction <: AbstractTaskFunction} <: AbstractComputeTask
|
||||
task_type::TaskFunction
|
||||
arguments::Vector{Argument}
|
||||
end
|
||||
|
@ -48,7 +48,7 @@ function insert_helper!(
|
||||
trie::NodeIdTrie{NodeType},
|
||||
node::NodeType,
|
||||
depth::Int,
|
||||
) where {TaskType <: AbstractTask, NodeType <: Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}}}
|
||||
) where {CompTask <: AbstractComputeTask, NodeType <: Union{DataTaskNode{AbstractDataTask}, ComputeTaskNode{CompTask}}}
|
||||
if (length(children(node)) == depth)
|
||||
push!(trie.value, node)
|
||||
return nothing
|
||||
@ -71,7 +71,7 @@ Insert the given node into the trie. It's sorted by its type in the first layer,
|
||||
function insert!(
|
||||
trie::NodeTrie,
|
||||
node::NodeType,
|
||||
) where {TaskType <: AbstractTask, NodeType <: Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}}}
|
||||
) where {CompTask <: AbstractComputeTask, NodeType <: Union{DataTaskNode{AbstractDataTask}, ComputeTaskNode{CompTask}}}
|
||||
if (!haskey(trie.children, NodeType))
|
||||
trie.children[NodeType] = NodeIdTrie{NodeType}()
|
||||
end
|
||||
|
@ -5,6 +5,8 @@ import MetagraphOptimization.insert_edge!
|
||||
import MetagraphOptimization.make_node
|
||||
import MetagraphOptimization.siblings
|
||||
import MetagraphOptimization.partners
|
||||
import MetagraphOptimization.is_dependent
|
||||
import MetagraphOptimization.pairwise_independent
|
||||
|
||||
graph = MetagraphOptimization.DAG()
|
||||
|
||||
@ -108,6 +110,20 @@ insert_edge!(graph, s0, d_exit, track = false)
|
||||
@test length(graph.dirtyNodes) == 26
|
||||
@test length(graph.diff) == (addedNodes = 0, removedNodes = 0, addedEdges = 0, removedEdges = 0)
|
||||
|
||||
@test !is_dependent(graph, d_exit, d_exit)
|
||||
@test is_dependent(graph, d_v0_s0, v0)
|
||||
@test is_dependent(graph, d_v1_s0, v1)
|
||||
@test !is_dependent(graph, d_v0_s0, v1)
|
||||
@test is_dependent(graph, s0, d_PB)
|
||||
@test !is_dependent(graph, v0, uBp)
|
||||
@test !is_dependent(graph, PB, PA)
|
||||
|
||||
@test pairwise_independent(graph, [PB, PA, PBp, PAp])
|
||||
@test pairwise_independent(graph, [uB, uA, d_uBp_v1])
|
||||
@test !pairwise_independent(graph, [PB, PA, PBp, PAp, s0])
|
||||
@test !pairwise_independent(graph, [d_uB_v0, v0])
|
||||
@test !pairwise_independent(graph, [v0, d_v0_s0, s0, uB])
|
||||
|
||||
@test is_valid(graph)
|
||||
|
||||
@test is_entry_node(d_PB)
|
||||
|
Loading…
x
Reference in New Issue
Block a user