Compare commits

...

1 Commits

Author SHA1 Message Date
e5d214a6fc WIP 2024-02-28 13:52:46 +01:00
26 changed files with 484 additions and 110 deletions

View File

@ -14,7 +14,6 @@ export Edge
export ComputeTaskNode
export DataTaskNode
export AbstractTask
export AbstractComputeTask
export AbstractDataTask
export DataTask
export FusedComputeTask
@ -24,8 +23,8 @@ export GraphProperties
# graph functions
export make_node
export make_edge
export insert_node
export insert_edge
export insert_node!
export insert_edge!
export is_entry_node
export is_exit_node
export parents

View File

@ -1,3 +1,5 @@
using DataStructures
"""
in(node::Node, graph::DAG)
@ -19,3 +21,54 @@ function in(edge::Edge, graph::DAG)
return n1 in children(n2)
end
"""
is_dependent(graph::DAG, n1::Node, n2::Node)
Returns whether `n1` is dependent on `n2` in the given `graph`.
"""
function is_dependent(graph::DAG, n1::Node, n2::Node)
if !(n1 in graph) || !(n2 in graph)
return false
end
if n1 == n2
return false
end
queue = Deque{Node}()
push!(queue, n1)
# TODO: this is probably not the best way to do this, deduplication in the queue would help
while !isempty(queue)
current = popfirst!(queue)
if current == n2
return true
end
for c in current.children
push!(queue, c)
end
end
return false
end
"""
pairwise_independent(graph::DAG, nodes::Vector{Node})
Returns true if all given `nodes` are independent of each other, i.e., no path exists from any member of `nodes` to any other member.
See [`is_dependent`](@ref)
"""
function pairwise_independent(graph::DAG, nodes::Vector{<:Node})
checked_set = Vector{Node}()
for n in nodes
for c in checked_set
if is_dependent(graph, c, n) || is_dependent(graph, n, c)
return false
end
end
push!(checked_set, n)
end
return true
end

View File

@ -46,12 +46,12 @@ Insert the edge between node1 (child) and node2 (parent) into the graph.
See also: [`insert_node!`](@ref), [`remove_node!`](@ref), [`remove_edge!`](@ref)
"""
function insert_edge!(graph::DAG, node1::Node, node2::Node; track = true, invalidate_cache = true)
#@assert (node2 ∉ parents(node1)) && (node1 ∉ children(node2)) "Edge to insert already exists"
@assert (node2 parents(node1)) && (node1 children(node2)) "Edge to insert already exists"
# 1: mute
# edge points from child to parent
push!(node1.parents, node2)
push!(node2.children, node1)
add_parent!(node1, node2)
add_child!(node2, node1)
# 2: keep track
if (track)
@ -85,7 +85,7 @@ Remove the node from the graph.
See also: [`insert_node!`](@ref), [`insert_edge!`](@ref), [`remove_edge!`](@ref)
"""
function remove_node!(graph::DAG, node::Node; track = true, invalidate_cache = true)
#@assert node in graph.nodes "Trying to remove a node that's not in the graph"
@assert node in graph.nodes "Trying to remove a node that's not in the graph"
# 1: mute
delete!(graph.nodes, node)
@ -123,19 +123,8 @@ function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invali
pre_length1 = length(node1.parents)
pre_length2 = length(node2.children)
for i in eachindex(node1.parents)
if (node1.parents[i] == node2)
splice!(node1.parents, i)
break
end
end
for i in eachindex(node2.children)
if (node2.children[i] == node1)
splice!(node2.children, i)
break
end
end
remove_parent!(node1, node2)
remove_child!(node2, node1)
#=@assert begin
removed = pre_length1 - length(node1.parents)
@ -173,17 +162,17 @@ function replace_children!(task::FusedComputeTask, before, after)
replacedIn1 = length(findall(x -> x == before, task.t1_inputs))
replacedIn2 = length(findall(x -> x == before, task.t2_inputs))
#@assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
@assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
replace!(task.t1_inputs, before => after)
replace!(task.t2_inputs, before => after)
# recursively descend down the tree, but only in the tasks where we're replacing things
if replacedIn1 > 0
replace_children!(task.first_task, before, after)
replace_children!(task.first_func, before, after)
end
if replacedIn2 > 0
replace_children!(task.second_task, before, after)
replace_children!(task.second_func, before, after)
end
return nothing

View File

@ -10,6 +10,7 @@ mutable struct PossibleOperations
nodeFusions::Set{NodeFusion}
nodeReductions::Set{NodeReduction}
nodeSplits::Set{NodeSplit}
nodeVectorizations::Set{NodeVectorization}
end
"""
@ -52,7 +53,7 @@ end
Construct and return an empty [`PossibleOperations`](@ref) object.
"""
function PossibleOperations()
return PossibleOperations(Set{NodeFusion}(), Set{NodeReduction}(), Set{NodeSplit}())
return PossibleOperations(Set{NodeFusion}(), Set{NodeReduction}(), Set{NodeSplit}(), Set{NodeVectorization}())
end
"""

View File

@ -181,7 +181,7 @@ function parse_dag(filename::AbstractString, model::ABCModel, verbose::Bool = fa
insert_edge!(graph, compute_S2, data_out, track = false, invalidate_cache = false)
insert_edge!(graph, data_out, sum_node, track = false, invalidate_cache = false)
add_child!(task(sum_node))
add_child!(task(sum_node).func)
elseif occursin(regex_plus, node)
if (verbose)
println("\rReading Nodes Complete ")

View File

@ -78,7 +78,6 @@ Return the number of children of a ComputeTaskABC_V (always 2).
"""
children(::ComputeTaskABC_V) = 2
"""
children(::ComputeTaskABC_Sum)

View File

@ -1,44 +1,44 @@
"""
ComputeTaskABC_S1 <: AbstractComputeTask
ComputeTaskABC_S1 <: AbstractTaskFunction
S task with a single child.
"""
struct ComputeTaskABC_S1 <: AbstractComputeTask end
struct ComputeTaskABC_S1 <: AbstractTaskFunction end
"""
ComputeTaskABC_S2 <: AbstractComputeTask
ComputeTaskABC_S2 <: AbstractTaskFunction
S task with two children.
"""
struct ComputeTaskABC_S2 <: AbstractComputeTask end
struct ComputeTaskABC_S2 <: AbstractTaskFunction end
"""
ComputeTaskABC_P <: AbstractComputeTask
ComputeTaskABC_P <: AbstractTaskFunction
P task with no children.
"""
struct ComputeTaskABC_P <: AbstractComputeTask end
struct ComputeTaskABC_P <: AbstractTaskFunction end
"""
ComputeTaskABC_V <: AbstractComputeTask
ComputeTaskABC_V <: AbstractTaskFunction
v task with two children.
"""
struct ComputeTaskABC_V <: AbstractComputeTask end
struct ComputeTaskABC_V <: AbstractTaskFunction end
"""
ComputeTaskABC_U <: AbstractComputeTask
ComputeTaskABC_U <: AbstractTaskFunction
u task with a single child.
"""
struct ComputeTaskABC_U <: AbstractComputeTask end
struct ComputeTaskABC_U <: AbstractTaskFunction end
"""
ComputeTaskABC_Sum <: AbstractComputeTask
ComputeTaskABC_Sum <: AbstractTaskFunction
Task that sums all its inputs, n children.
"""
mutable struct ComputeTaskABC_Sum <: AbstractComputeTask
mutable struct ComputeTaskABC_Sum <: AbstractTaskFunction
children_number::Int
end

View File

@ -66,7 +66,7 @@ function compute(
data1::ParticleValue{P1},
data2::ParticleValue{P2},
) where {P1 <: Union{AntiFermionStateful, FermionStateful}, P2 <: Union{AntiFermionStateful, FermionStateful}}
#@assert isapprox(data1.p.momentum, data2.p.momentum, rtol = sqrt(eps()), atol = sqrt(eps())) "$(data1.p.momentum) vs. $(data2.p.momentum)"
@assert isapprox(data1.p.momentum, data2.p.momentum, rtol = sqrt(eps()), atol = sqrt(eps())) "$(data1.p.momentum) vs. $(data2.p.momentum)"
inner = QED_inner_edge(propagation_result(P1)(momentum(data1.p)))

View File

@ -188,7 +188,7 @@ end
Apply a [`FeynmanVertex`](@ref) to the given vector of [`FeynmanParticle`](@ref)s.
"""
function apply_vertex!(particles::Vector{FeynmanParticle}, vertex::FeynmanVertex)
#@assert can_apply_vertex(particles, vertex)
@assert can_apply_vertex(particles, vertex)
length_before = length(particles)
filter!(x -> x != vertex.in1 && x != vertex.in2, particles)
push!(particles, vertex.out)

View File

@ -1,44 +1,44 @@
"""
ComputeTaskQED_S1 <: AbstractComputeTask
ComputeTaskQED_S1 <: AbstractTaskFunction
S task with a single child.
"""
struct ComputeTaskQED_S1 <: AbstractComputeTask end
struct ComputeTaskQED_S1 <: AbstractTaskFunction end
"""
ComputeTaskQED_S2 <: AbstractComputeTask
ComputeTaskQED_S2 <: AbstractTaskFunction
S task with two children.
"""
struct ComputeTaskQED_S2 <: AbstractComputeTask end
struct ComputeTaskQED_S2 <: AbstractTaskFunction end
"""
ComputeTaskQED_P <: AbstractComputeTask
ComputeTaskQED_P <: AbstractTaskFunction
P task with no children.
"""
struct ComputeTaskQED_P <: AbstractComputeTask end
struct ComputeTaskQED_P <: AbstractTaskFunction end
"""
ComputeTaskQED_V <: AbstractComputeTask
ComputeTaskQED_V <: AbstractTaskFunction
v task with two children.
"""
struct ComputeTaskQED_V <: AbstractComputeTask end
struct ComputeTaskQED_V <: AbstractTaskFunction end
"""
ComputeTaskQED_U <: AbstractComputeTask
ComputeTaskQED_U <: AbstractTaskFunction
u task with a single child.
"""
struct ComputeTaskQED_U <: AbstractComputeTask end
struct ComputeTaskQED_U <: AbstractTaskFunction end
"""
ComputeTaskQED_Sum <: AbstractComputeTask
ComputeTaskQED_Sum <: AbstractTaskFunction
Task that sums all its inputs, n children.
"""
mutable struct ComputeTaskQED_Sum <: AbstractComputeTask
mutable struct ComputeTaskQED_Sum <: AbstractTaskFunction
children_number::Int
end

View File

@ -1,12 +1,22 @@
DataTaskNode(t::AbstractDataTask, name = "") =
DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, missing, name)
DataTaskNode(t::AbstractDataTask, name = "") = DataTaskNode(
t,
Vector{Node}(),
Vector{Node}(),
UUIDs.uuid1(rng[threadid()]),
missing,
missing,
missing,
missing,
name,
)
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
t, # task
Vector{Node}(), # parents
Vector{Node}(), # children
UUIDs.uuid1(rng[threadid()]), # id
missing, # node reduction
missing, # node vectorization
missing, # node split
Vector{NodeFusion}(), # node fusions
missing, # device
@ -39,9 +49,14 @@ end
Construct and return a new [`ComputeTaskNode`](@ref) with the given task.
"""
function make_node(t::AbstractComputeTask)
return ComputeTaskNode(t)
end
make_node(t::AbstractComputeTask) = ComputeTaskNode(t)
"""
make_node(t::AbstractTaskFunction)
Construct and return a new [`ComputeTaskNode`](@ref) with a default [`ComputeTask`](@ref) with the given task function.
"""
make_node(t::AbstractTaskFunction) = ComputeTaskNode(ComputeTask(t))
"""
make_edge(n1::Node, n2::Node)

View File

@ -121,3 +121,157 @@ Return whether the `potential_child` is a child of `node`.
function is_child(potential_child::Node, node::Node)::Bool
return potential_child in children(node)
end
function add_child!(n::DataTaskNode{<:AbstractDataTask}, child::Node)
push!(n.children, child)
return nothing
end
function add_parent!(n::DataTaskNode{<:AbstractDataTask}, parent::Node)
push!(n.parents, parent)
return nothing
end
function remove_child!(n::DataTaskNode{<:AbstractDataTask}, child::Node)
for i in eachindex(n.children)
if (n.children[i] == child)
splice!(n.children, i)
break
end
end
return nothing
end
function remove_parent!(n::DataTaskNode{<:AbstractDataTask}, parent::Node)
for i in eachindex(n.parents)
if (n.parents[i] == parent)
splice!(n.parents, i)
break
end
end
return nothing
end
"""
add_child(n::ComputeTaskNode{<:ComputeTask}, child::Node)
Add a child to the compute node.
"""
function add_child!(n::ComputeTaskNode{<:ComputeTask}, child::Node)
push!(n.children, child)
push!(n.task.arguments, Argument(child.id))
return nothing
end
function add_parent!(n::ComputeTaskNode{<:ComputeTask}, parent::Node)
push!(n.parents, parent)
return nothing
end
function remove_child!(n::ComputeTaskNode{<:ComputeTask}, child::Node)
for i in eachindex(n.children)
if (n.children[i] == child)
splice!(n.children, i)
@assert n.task.func.arguments[i].id == child.id
splice!(n.task.func.arguments, i)
break
end
end
return nothing
end
function remove_parent!(n::ComputeTaskNode{<:ComputeTask}, parent::Node)
for i in eachindex(n.parents)
if (n.parents[i] == parent)
splice!(n.parents, i)
break
end
end
return nothing
end
function add_child!(n::ComputeTaskNode{<:FusedComputeTask}, child::Node, which::Symbol = :first)
push!(n.children, child)
if which == :first || which == :both
push!(n.task.first_func_arguments, child.id)
if which == :second || which == :both
push!(n.task.second_func_arguments, child.id)
end
if which != :first && which != :second && which != :both
@assert false "Tried to add child to symbol $(which) of fused compute task, but only :both, :first, and :second are allowed"
end
return nothing
end
function add_parent!(n::ComputeTaskNode{<:FusedComputeTask}, parent::Node)
push!(n.parents, parent)
return nothing
end
function remove_child!(n::ComputeTaskNode{<:FusedComputeTask}, child::Node)
for i in eachindex(n.children)
if (n.children[i] == child)
splice!(n.children, i)
break
end
end
for field in (:first_func_arguments, :second_func_arguments)
for i in eachindex(getfield(n.task, field))
if (getfield(n.task, field)[i].id == child.id)
splice!(getfield(n.task, field), i)
break
end
end
end
return nothing
end
function remove_parent!(n::ComputeTaskNode{<:FusedComputeTask}, parent::Node)
for i in eachindex(n.parents)
if (n.parents[i] == parent)
splice!(n.parents, i)
break
end
end
return nothing
end
function add_child!(n::ComputeTaskNode{<:VectorizedComputeTask}, child::Node)
push!(n.children, child)
push!(n.task.arguments, Argument(child.id))
return nothing
end
function add_parent!(n::ComputeTaskNode{<:VectorizedComputeTask}, parent::Node)
push!(n.parents, parent)
return nothing
end
function remove_child!(n::ComputeTaskNode{<:VectorizedComputeTask}, child::Node)
for i in eachindex(n.children)
if (n.children[i] == child)
splice!(n.children, i)
@assert n.task.func.arguments[i].id == child.id
splice!(n.task.func.arguments, i)
break
end
end
return nothing
end
function remove_parent!(n::ComputeTaskNode{<:VectorizedComputeTask}, parent::Node)
for i in eachindex(n.parents)
if (n.parents[i] == parent)
splice!(n.parents, i)
break
end
end
return nothing
end

View File

@ -24,14 +24,15 @@ abstract type Operation end
Any node that transfers data and does no computation.
# Fields
`.task`: The node's data task type. Usually [`DataTask`](@ref).\\
`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.\\
`.name`: The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\
`.task`: The node's data task type. Usually [`DataTask`](@ref).\\
`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
`.nodeVectorization`: Always `missing`, since a data node can't be vectorized.\\
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.\\
`.name`: The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\
"""
mutable struct DataTaskNode{TaskType <: AbstractDataTask} <: Node
task::TaskType
@ -48,6 +49,9 @@ mutable struct DataTaskNode{TaskType <: AbstractDataTask} <: Node
# Can't use the NodeReduction type here because it's not yet defined
nodeReduction::Union{Operation, Missing}
# data nodes can't be vectorized
nodeVectorization::Missing
# the NodeSplit involving this node, if it exists
nodeSplit::Union{Operation, Missing}
@ -64,22 +68,25 @@ end
Any node that computes a result from inputs using an [`AbstractComputeTask`](@ref).
# Fields
`.task`: The node's compute task type. A concrete subtype of [`AbstractComputeTask`](@ref).\\
`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeFusions`: A vector of this node's [`NodeFusion`](@ref)s. For a `ComputeTaskNode` there can be any number of these, unlike the [`DataTaskNode`](@ref)s.\\
`.device`: The Device this node has been scheduled on by a [`Scheduler`](@ref).
`.task`: The node's compute task type. A concrete subtype of [`AbstractComputeTask`](@ref).\\
`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeVectorization`: Either this node's [`NodeVectorization`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeFusions`: A vector of this node's [`NodeFusion`](@ref)s. For a `ComputeTaskNode` there can be any number of these, unlike the [`DataTaskNode`](@ref)s.\\
`.device`: The Device this node has been scheduled on by a [`Scheduler`](@ref).
"""
mutable struct ComputeTaskNode{TaskType <: AbstractComputeTask} <: Node
task::TaskType
parents::Vector{Node}
children::Vector{Node}
id::Base.UUID
nodeReduction::Union{Operation, Missing}
nodeVectorization::Union{Operation, Missing}
nodeSplit::Union{Operation, Missing}
# for ComputeTasks there can be multiple fusions, unlike the DataTasks

View File

@ -187,7 +187,7 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com
remove_node!(graph, n3)
# create new node with the fused compute task
newNode = ComputeTaskNode(FusedComputeTask(n1Task, n3Task, n1Inputs, Symbol(to_var_name(n2.id)), n3Inputs))
newNode = ComputeTaskNode(FusedComputeTask(n1Task.func, n3Task.func, Symbol(to_var_name(n2.id))))
insert_node!(graph, newNode)
for child in n1Children

View File

@ -239,6 +239,31 @@ function generate_operations(graph::DAG)
empty!(graph.dirtyNodes)
# find node vectorizations
# assume that relevant tasks will be at the same depth in the graph
"""
nodes = Queue{Tuple{<:Node, Int64}}()
node_set = Set{Node}()
currentDepth = 0
push!(nodes, (get_exit_node(graph), 0))
while !isempty(nodes)
(node, depth) = popfirst!(nodes)
pushback!.(node.parents, Ref(depth + 1))
if depth == currentDepth
push!(node_set, node)
continue
end
# collected all nodes of the same depth, sort into compute task types and create vectorizations
#TODO
# reset node_set
node_set = Set{Node}()
push!(node_set, node)
end
"""
wait(nr_task)
wait(nf_task)
wait(ns_task)

View File

@ -56,3 +56,15 @@ function show(io::IO, op::NodeFusion)
print(io, "->")
return print(io, task(op.input[3]))
end
"""
show(io::IO, op::NodeVectorization)
Print a string representation of the node vectorization to io.
"""
function show(io::IO, op::NodeVectorization)
print(io, "NV: ")
print(io, length(op.input))
print(io, "x")
return print(io, task(op.input[1]))
end

View File

@ -120,3 +120,12 @@ struct AppliedNodeSplit{NodeType <: Node} <: AppliedOperation
operation::NodeSplit{NodeType}
diff::Diff
end
struct NodeVectorization{TaskType <: AbstractComputeTask} <: Operation
input::Vector{ComputeTaskNode{TaskType}}
end
struct AppliedNodeVectorization{TaskType <: AbstractComputeTask} <: AppliedOperation
operation::NodeVectorization{TaskType}
diff::Diff
end

View File

@ -171,3 +171,13 @@ Equality comparison between two node splits. Two node splits are considered equa
function ==(op1::NodeSplit, op2::NodeSplit)
return op1.input == op2.input
end
"""
==(op1::NodeVectorization, op2::NodeVectorization)
Equality comparison between two node vectorizations. Two node vectorizations are considered equal when they have the same inputs.
"""
function ==(op1::NodeVectorization, op2::NodeVectorization)
# node vectorizations are equal exactly if their first input is the same
return op1.input[1].id == op2.input[1].id
end

View File

@ -17,7 +17,7 @@ function is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTas
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
throw(
AssertionError(
"[Node Fusion] The given nodes are not connected by edges which is required for node fusion",
"[Node Fusion] The given nodes are not connected by edges which is required for node fusion\n[Node Fusion] n1 ($(n1)) parents: $(parents(n1))\n[Node Fusion] n2 ($(n2)) children: $(children(n2)), n2 parents: $(parents(n2))\n[Node Fusion] n3 ($(n3)) children: $(children(n3))",
),
)
end
@ -106,6 +106,35 @@ function is_valid_node_split_input(graph::DAG, n1::Node)
return true
end
"""
is_valid_node_vectorization_input(graph::DAG, nodes::Vector{Node})
Assert for a gven node vectorization input whether the nodes can be vectorized. For the requirements of a node vectorization see [`NodeVectorization`](@ref).
Intended for use with `@assert` or `@test`.
"""
function is_valid_node_vectorization_input(graph::DAG, nodes::Vector{Node})
for n in nodes
if n graph
throw(AssertionError("[Node Vectorization] The given nodes are not part of the given graph"))
end
@assert is_valid(graph, n)
end
t = typeof(task(nodes[1]))
for n in nodes
if typeof(task(n)) != t
throw(AssertionError("[Node Vectorization] The given nodes are not of the same type"))
end
end
if !pairwise_independent(graph, nodes)
throw(AssertionError("[Node Vectorization] The given nodes are not pairwise independent of one another"))
end
return true
end
"""
is_valid(graph::DAG, nr::NodeReduction)
@ -144,3 +173,15 @@ function is_valid(graph::DAG, nf::NodeFusion)
#@assert nf in graph.possibleOperations.nodeFusions "NodeFusion is not part of the graph's possible operations!"
return true
end
"""
is_valid(graph::DAG, nr::NodeVectorization)
Assert for a given [`NodeVectorization`](@ref) whether it is a valid operation in the graph.
Intended for use with `@assert` or `@test`.
"""
function is_valid(graph::DAG, nv::NodeVectorization)
@assert is_valid_node_vectorization_input(graph, nv.input)
return true
end

View File

@ -8,11 +8,11 @@ function ==(t1::AbstractTask, t2::AbstractTask)
end
"""
==(t1::AbstractComputeTask, t2::AbstractComputeTask)
==(t1::AbstractTaskFunction, t2::AbstractTaskFunction)
Equality comparison between two compute tasks.
"""
function ==(t1::AbstractComputeTask, t2::AbstractComputeTask)
function ==(t1::AbstractTaskFunction, t2::AbstractTaskFunction)
return typeof(t1) == typeof(t2)
end

View File

@ -6,8 +6,7 @@ using StaticArrays
Compute a [`FusedComputeTask`](@ref). This simply asserts false and should not be called. Fused Compute Tasks generate their expressions directly through the other tasks instead.
"""
function compute(t::FusedComputeTask, data...)
inter = compute(t.first_task)
return compute(t.second_task, inter, data2...)
@assert false
end
"""
@ -21,8 +20,8 @@ For ordinary compute or data tasks the vector will contain exactly one element,
function get_function_call(t::FusedComputeTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol)
# sort out the symbols to the correct tasks
return [
get_function_call(t.first_task, device, t.t1_inputs, t.t1_output)...,
get_function_call(t.second_task, device, [t.t2_inputs..., t.t1_output], outSymbol)...,
get_function_call(t.first_func, device, t.t1_inputs, t.t1_output)...,
get_function_call(t.second_func, device, [t.t2_inputs..., t.t1_output], outSymbol)...,
]
end

View File

@ -1,3 +1,5 @@
Argument(id::Base.UUID) = Argument(id, 0)
"""
copy(t::AbstractDataTask)
@ -10,7 +12,12 @@ copy(t::AbstractDataTask) = error("Need to implement copying for your data tasks
Return a copy of the given compute task.
"""
copy(t::AbstractComputeTask) = typeof(t)()
copy(t::AbstractComputeTask) = typeof(t)(copy(t.func), copy(t.arguments))
copy(t::AbstractTaskFunction) = typeof(t)()
function ComputeTask(t::AbstractTaskFunction)
return ComputeTask(t, Vector{Argument}())
end
"""
copy(t::FusedComputeTask)
@ -18,15 +25,9 @@ copy(t::AbstractComputeTask) = typeof(t)()
Return a copy of th egiven [`FusedComputeTask`](@ref).
"""
function copy(t::FusedComputeTask)
return FusedComputeTask(copy(t.first_task), copy(t.second_task), copy(t.t1_inputs), t.t1_output, copy(t.t2_inputs))
return FusedComputeTask(copy(t.first_func), copy(t.second_func), copy(t.t1_inputs), t.t1_output, copy(t.t2_inputs))
end
function FusedComputeTask(
T1::Type{<:AbstractComputeTask},
T2::Type{<:AbstractComputeTask},
t1_inputs::Vector{String},
t1_output::String,
t2_inputs::Vector{String},
)
return FusedComputeTask(T1(), T2(), t1_inputs, t1_output, t2_inputs)
function FusedComputeTask(T1::Type{<:AbstractComputeTask}, T2::Type{<:AbstractComputeTask}, t1_output::String)
return FusedComputeTask(T1(), T2(), t1_output)
end

View File

@ -3,19 +3,21 @@
Fallback implementation of the compute function of a compute task, throwing an error.
"""
function compute(t::AbstractTask, data...)
function compute(t::AbstractTaskFunction, data...)
return error("Need to implement compute()")
end
compute(t::ComputeTask, data...) = compute(t.func, data...) # TODO: is this correct?
"""
compute_effort(t::AbstractTask)
Fallback implementation of the compute effort of a task, throwing an error.
"""
function compute_effort(t::AbstractTask)::Float64
function compute_effort(t::AbstractTaskFunction)::Float64
# default implementation using compute
return error("Need to implement compute_effort()")
end
compute_effort(t::AbstractComputeTask)::Float64 = compute_effort(t.func)
"""
data(t::AbstractTask)
@ -60,7 +62,7 @@ children(::DataTask) = 1
Return the number of children of a FusedComputeTask.
"""
function children(t::FusedComputeTask)
return length(union(Set(t.t1_inputs), Set(t.t2_inputs)))
return length(union(Set(t.first_func.arguments), Set(t.second_func.arguments))) - 1
end
"""
@ -69,6 +71,7 @@ end
Return the data of a compute task, always zero, regardless of the specific task.
"""
data(t::AbstractComputeTask)::Float64 = 0.0
data(t::AbstractTaskFunction)::Float64 = 0.0
"""
compute_effort(t::FusedComputeTask)
@ -76,7 +79,7 @@ data(t::AbstractComputeTask)::Float64 = 0.0
Return the compute effort of a fused compute task.
"""
function compute_effort(t::FusedComputeTask)::Float64
return compute_effort(t.first_task) + compute_effort(t.second_task)
return compute_effort(t.first_func) + compute_effort(t.second_func)
end
"""
@ -84,4 +87,4 @@ end
Return a tuple of a the fused compute task's components' types.
"""
get_types(t::FusedComputeTask) = (typeof(t.first_task), typeof(t.second_task))
get_types(t::FusedComputeTask) = (typeof(t.first_func), typeof(t.second_func))

View File

@ -1,3 +1,15 @@
"""
Argument
Representation of one of the arguments of a [`AbstractComputeTask`](@ref) in a Node.
Has a member `id::UUID`, which is the id of the compute task's parent the argument refers to, and an `index`, which is the index in the symbol's object. If `index` is 0, the whole object is the argument.
A compute task can have multiple arguments.
"""
struct Argument
id::Base.UUID
index::Int64
end
"""
AbstractTask
@ -6,11 +18,33 @@ The shared base type for any task.
abstract type AbstractTask end
"""
AbstractComputeTask <: AbstractTask
AbstractComputeTask
The shared base type for any compute task.
Base type for all compute tasks.
!!! note
A subtype of this *must* be templated with its [`AbstractTaskFunction`](@ref)s!
See also: [`ComputeTask`](@ref), [`FusedComputeTask`](@ref), [`VectorizedComputeTask`](@ref)
"""
abstract type AbstractComputeTask <: AbstractTask end
abstract type AbstractComputeTask end
"""
AbstractTaskFunction
The base type for task functions used to dispatch computes.
"""
abstract type AbstractTaskFunction end
"""
ComputeTask <: AbstractTask
A compute task. Has a member [`AbstractTaskFunction`](@ref), and a list of [`Argument`](@ref)s.
"""
struct ComputeTask{TaskFunction <: AbstractTaskFunction} <: AbstractComputeTask
func::TaskFunction
arguments::Vector{Argument}
end
"""
AbstractDataTask <: AbstractTask
@ -29,19 +63,26 @@ struct DataTask <: AbstractDataTask
end
"""
FusedComputeTask{T1 <: AbstractComputeTask, T2 <: AbstractComputeTask} <: AbstractComputeTask
FusedComputeTask <: AbstractComputeTask
A fused compute task made up of the computation of first `T1` and then `T2`.
Also see: [`get_types`](@ref).
"""
struct FusedComputeTask <: AbstractComputeTask
first_task::AbstractComputeTask
second_task::AbstractComputeTask
# the names of the inputs for T1
t1_inputs::Vector{Symbol}
# output name of T1
struct FusedComputeTask{TaskFunction1 <: AbstractTaskFunction, TaskFunction2 <: AbstractTaskFunction} <:
AbstractComputeTask
first_func::TaskFunction1
second_func::TaskFunction2
t1_output::Symbol
# t2_inputs doesn't include the output of t1, that's implicit
t2_inputs::Vector{Symbol}
first_func_arguments::Vector{Argument}
second_func_arguments::Vector{Argument}
end
"""
"""
struct VectorizedComputeTask{TaskFunction <: AbstractTaskFunction} <: AbstractComputeTask
task_type::TaskFunction
arguments::Vector{Argument}
end

View File

@ -48,7 +48,7 @@ function insert_helper!(
trie::NodeIdTrie{NodeType},
node::NodeType,
depth::Int,
) where {TaskType <: AbstractTask, NodeType <: Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}}}
) where {CompTask <: AbstractComputeTask, NodeType <: Union{DataTaskNode{AbstractDataTask}, ComputeTaskNode{CompTask}}}
if (length(children(node)) == depth)
push!(trie.value, node)
return nothing
@ -71,7 +71,7 @@ Insert the given node into the trie. It's sorted by its type in the first layer,
function insert!(
trie::NodeTrie,
node::NodeType,
) where {TaskType <: AbstractTask, NodeType <: Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}}}
) where {CompTask <: AbstractComputeTask, NodeType <: Union{DataTaskNode{AbstractDataTask}, ComputeTaskNode{CompTask}}}
if (!haskey(trie.children, NodeType))
trie.children[NodeType] = NodeIdTrie{NodeType}()
end

View File

@ -5,6 +5,8 @@ import MetagraphOptimization.insert_edge!
import MetagraphOptimization.make_node
import MetagraphOptimization.siblings
import MetagraphOptimization.partners
import MetagraphOptimization.is_dependent
import MetagraphOptimization.pairwise_independent
graph = MetagraphOptimization.DAG()
@ -108,6 +110,20 @@ insert_edge!(graph, s0, d_exit, track = false)
@test length(graph.dirtyNodes) == 26
@test length(graph.diff) == (addedNodes = 0, removedNodes = 0, addedEdges = 0, removedEdges = 0)
@test !is_dependent(graph, d_exit, d_exit)
@test is_dependent(graph, d_v0_s0, v0)
@test is_dependent(graph, d_v1_s0, v1)
@test !is_dependent(graph, d_v0_s0, v1)
@test is_dependent(graph, s0, d_PB)
@test !is_dependent(graph, v0, uBp)
@test !is_dependent(graph, PB, PA)
@test pairwise_independent(graph, [PB, PA, PBp, PAp])
@test pairwise_independent(graph, [uB, uA, d_uBp_v1])
@test !pairwise_independent(graph, [PB, PA, PBp, PAp, s0])
@test !pairwise_independent(graph, [d_uB_v0, v0])
@test !pairwise_independent(graph, [v0, d_v0_s0, s0, uB])
@test is_valid(graph)
@test is_entry_node(d_PB)