2 Commits

Author SHA1 Message Date
4592b76ec5 WIP 2024-02-02 03:03:08 +01:00
7bdc01b72a WIP add tape machine implementation 2024-01-31 23:00:51 +01:00
27 changed files with 111 additions and 485 deletions

View File

@ -14,6 +14,7 @@ export Edge
export ComputeTaskNode
export DataTaskNode
export AbstractTask
export AbstractComputeTask
export AbstractDataTask
export DataTask
export FusedComputeTask
@ -23,8 +24,8 @@ export GraphProperties
# graph functions
export make_node
export make_edge
export insert_node!
export insert_edge!
export insert_node
export insert_edge
export is_entry_node
export is_exit_node
export parents

View File

@ -1,5 +1,3 @@
using DataStructures
"""
in(node::Node, graph::DAG)
@ -21,54 +19,3 @@ function in(edge::Edge, graph::DAG)
return n1 in children(n2)
end
"""
is_dependent(graph::DAG, n1::Node, n2::Node)
Returns whether `n1` is dependent on `n2` in the given `graph`.
"""
function is_dependent(graph::DAG, n1::Node, n2::Node)
if !(n1 in graph) || !(n2 in graph)
return false
end
if n1 == n2
return false
end
queue = Deque{Node}()
push!(queue, n1)
# TODO: this is probably not the best way to do this, deduplication in the queue would help
while !isempty(queue)
current = popfirst!(queue)
if current == n2
return true
end
for c in current.children
push!(queue, c)
end
end
return false
end
"""
pairwise_independent(graph::DAG, nodes::Vector{Node})
Returns true if all given `nodes` are independent of each other, i.e., no path exists from any member of `nodes` to any other member.
See [`is_dependent`](@ref)
"""
function pairwise_independent(graph::DAG, nodes::Vector{<:Node})
checked_set = Vector{Node}()
for n in nodes
for c in checked_set
if is_dependent(graph, c, n) || is_dependent(graph, n, c)
return false
end
end
push!(checked_set, n)
end
return true
end

View File

@ -46,12 +46,12 @@ Insert the edge between node1 (child) and node2 (parent) into the graph.
See also: [`insert_node!`](@ref), [`remove_node!`](@ref), [`remove_edge!`](@ref)
"""
function insert_edge!(graph::DAG, node1::Node, node2::Node; track = true, invalidate_cache = true)
@assert (node2 parents(node1)) && (node1 children(node2)) "Edge to insert already exists"
#@assert (node2 ∉ parents(node1)) && (node1 ∉ children(node2)) "Edge to insert already exists"
# 1: mute
# edge points from child to parent
add_parent!(node1, node2)
add_child!(node2, node1)
push!(node1.parents, node2)
push!(node2.children, node1)
# 2: keep track
if (track)
@ -85,7 +85,7 @@ Remove the node from the graph.
See also: [`insert_node!`](@ref), [`insert_edge!`](@ref), [`remove_edge!`](@ref)
"""
function remove_node!(graph::DAG, node::Node; track = true, invalidate_cache = true)
@assert node in graph.nodes "Trying to remove a node that's not in the graph"
#@assert node in graph.nodes "Trying to remove a node that's not in the graph"
# 1: mute
delete!(graph.nodes, node)
@ -123,8 +123,19 @@ function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invali
pre_length1 = length(node1.parents)
pre_length2 = length(node2.children)
remove_parent!(node1, node2)
remove_child!(node2, node1)
for i in eachindex(node1.parents)
if (node1.parents[i] == node2)
splice!(node1.parents, i)
break
end
end
for i in eachindex(node2.children)
if (node2.children[i] == node1)
splice!(node2.children, i)
break
end
end
#=@assert begin
removed = pre_length1 - length(node1.parents)
@ -162,17 +173,17 @@ function replace_children!(task::FusedComputeTask, before, after)
replacedIn1 = length(findall(x -> x == before, task.t1_inputs))
replacedIn2 = length(findall(x -> x == before, task.t2_inputs))
@assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
#@assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
replace!(task.t1_inputs, before => after)
replace!(task.t2_inputs, before => after)
# recursively descend down the tree, but only in the tasks where we're replacing things
if replacedIn1 > 0
replace_children!(task.first_func, before, after)
replace_children!(task.first_task, before, after)
end
if replacedIn2 > 0
replace_children!(task.second_func, before, after)
replace_children!(task.second_task, before, after)
end
return nothing

View File

@ -10,7 +10,6 @@ mutable struct PossibleOperations
nodeFusions::Set{NodeFusion}
nodeReductions::Set{NodeReduction}
nodeSplits::Set{NodeSplit}
nodeVectorizations::Set{NodeVectorization}
end
"""
@ -53,7 +52,7 @@ end
Construct and return an empty [`PossibleOperations`](@ref) object.
"""
function PossibleOperations()
return PossibleOperations(Set{NodeFusion}(), Set{NodeReduction}(), Set{NodeSplit}(), Set{NodeVectorization}())
return PossibleOperations(Set{NodeFusion}(), Set{NodeReduction}(), Set{NodeSplit}())
end
"""

View File

@ -181,7 +181,7 @@ function parse_dag(filename::AbstractString, model::ABCModel, verbose::Bool = fa
insert_edge!(graph, compute_S2, data_out, track = false, invalidate_cache = false)
insert_edge!(graph, data_out, sum_node, track = false, invalidate_cache = false)
add_child!(task(sum_node).func)
add_child!(task(sum_node))
elseif occursin(regex_plus, node)
if (verbose)
println("\rReading Nodes Complete ")

View File

@ -78,6 +78,7 @@ Return the number of children of a ComputeTaskABC_V (always 2).
"""
children(::ComputeTaskABC_V) = 2
"""
children(::ComputeTaskABC_Sum)

View File

@ -1,44 +1,44 @@
"""
ComputeTaskABC_S1 <: AbstractTaskFunction
ComputeTaskABC_S1 <: AbstractComputeTask
S task with a single child.
"""
struct ComputeTaskABC_S1 <: AbstractTaskFunction end
struct ComputeTaskABC_S1 <: AbstractComputeTask end
"""
ComputeTaskABC_S2 <: AbstractTaskFunction
ComputeTaskABC_S2 <: AbstractComputeTask
S task with two children.
"""
struct ComputeTaskABC_S2 <: AbstractTaskFunction end
struct ComputeTaskABC_S2 <: AbstractComputeTask end
"""
ComputeTaskABC_P <: AbstractTaskFunction
ComputeTaskABC_P <: AbstractComputeTask
P task with no children.
"""
struct ComputeTaskABC_P <: AbstractTaskFunction end
struct ComputeTaskABC_P <: AbstractComputeTask end
"""
ComputeTaskABC_V <: AbstractTaskFunction
ComputeTaskABC_V <: AbstractComputeTask
v task with two children.
"""
struct ComputeTaskABC_V <: AbstractTaskFunction end
struct ComputeTaskABC_V <: AbstractComputeTask end
"""
ComputeTaskABC_U <: AbstractTaskFunction
ComputeTaskABC_U <: AbstractComputeTask
u task with a single child.
"""
struct ComputeTaskABC_U <: AbstractTaskFunction end
struct ComputeTaskABC_U <: AbstractComputeTask end
"""
ComputeTaskABC_Sum <: AbstractTaskFunction
ComputeTaskABC_Sum <: AbstractComputeTask
Task that sums all its inputs, n children.
"""
mutable struct ComputeTaskABC_Sum <: AbstractTaskFunction
mutable struct ComputeTaskABC_Sum <: AbstractComputeTask
children_number::Int
end

View File

@ -66,7 +66,7 @@ function compute(
data1::ParticleValue{P1},
data2::ParticleValue{P2},
) where {P1 <: Union{AntiFermionStateful, FermionStateful}, P2 <: Union{AntiFermionStateful, FermionStateful}}
@assert isapprox(data1.p.momentum, data2.p.momentum, rtol = sqrt(eps()), atol = sqrt(eps())) "$(data1.p.momentum) vs. $(data2.p.momentum)"
#@assert isapprox(data1.p.momentum, data2.p.momentum, rtol = sqrt(eps()), atol = sqrt(eps())) "$(data1.p.momentum) vs. $(data2.p.momentum)"
inner = QED_inner_edge(propagation_result(P1)(momentum(data1.p)))

View File

@ -188,7 +188,7 @@ end
Apply a [`FeynmanVertex`](@ref) to the given vector of [`FeynmanParticle`](@ref)s.
"""
function apply_vertex!(particles::Vector{FeynmanParticle}, vertex::FeynmanVertex)
@assert can_apply_vertex(particles, vertex)
#@assert can_apply_vertex(particles, vertex)
length_before = length(particles)
filter!(x -> x != vertex.in1 && x != vertex.in2, particles)
push!(particles, vertex.out)

View File

@ -1,44 +1,44 @@
"""
ComputeTaskQED_S1 <: AbstractTaskFunction
ComputeTaskQED_S1 <: AbstractComputeTask
S task with a single child.
"""
struct ComputeTaskQED_S1 <: AbstractTaskFunction end
struct ComputeTaskQED_S1 <: AbstractComputeTask end
"""
ComputeTaskQED_S2 <: AbstractTaskFunction
ComputeTaskQED_S2 <: AbstractComputeTask
S task with two children.
"""
struct ComputeTaskQED_S2 <: AbstractTaskFunction end
struct ComputeTaskQED_S2 <: AbstractComputeTask end
"""
ComputeTaskQED_P <: AbstractTaskFunction
ComputeTaskQED_P <: AbstractComputeTask
P task with no children.
"""
struct ComputeTaskQED_P <: AbstractTaskFunction end
struct ComputeTaskQED_P <: AbstractComputeTask end
"""
ComputeTaskQED_V <: AbstractTaskFunction
ComputeTaskQED_V <: AbstractComputeTask
v task with two children.
"""
struct ComputeTaskQED_V <: AbstractTaskFunction end
struct ComputeTaskQED_V <: AbstractComputeTask end
"""
ComputeTaskQED_U <: AbstractTaskFunction
ComputeTaskQED_U <: AbstractComputeTask
u task with a single child.
"""
struct ComputeTaskQED_U <: AbstractTaskFunction end
struct ComputeTaskQED_U <: AbstractComputeTask end
"""
ComputeTaskQED_Sum <: AbstractTaskFunction
ComputeTaskQED_Sum <: AbstractComputeTask
Task that sums all its inputs, n children.
"""
mutable struct ComputeTaskQED_Sum <: AbstractTaskFunction
mutable struct ComputeTaskQED_Sum <: AbstractComputeTask
children_number::Int
end

View File

@ -1,22 +1,12 @@
DataTaskNode(t::AbstractDataTask, name = "") = DataTaskNode(
t,
Vector{Node}(),
Vector{Node}(),
UUIDs.uuid1(rng[threadid()]),
missing,
missing,
missing,
missing,
name,
)
DataTaskNode(t::AbstractDataTask, name = "") =
DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, missing, name)
ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
t, # task
Vector{Node}(), # parents
Vector{Node}(), # children
UUIDs.uuid1(rng[threadid()]), # id
missing, # node reduction
missing, # node vectorization
missing, # node split
Vector{NodeFusion}(), # node fusions
missing, # device
@ -49,14 +39,9 @@ end
Construct and return a new [`ComputeTaskNode`](@ref) with the given task.
"""
make_node(t::AbstractComputeTask) = ComputeTaskNode(t)
"""
make_node(t::AbstractTaskFunction)
Construct and return a new [`ComputeTaskNode`](@ref) with a default [`ComputeTask`](@ref) with the given task function.
"""
make_node(t::AbstractTaskFunction) = ComputeTaskNode(ComputeTask(t))
function make_node(t::AbstractComputeTask)
return ComputeTaskNode(t)
end
"""
make_edge(n1::Node, n2::Node)

View File

@ -121,157 +121,3 @@ Return whether the `potential_child` is a child of `node`.
function is_child(potential_child::Node, node::Node)::Bool
return potential_child in children(node)
end
function add_child!(n::DataTaskNode{<:AbstractDataTask}, child::Node)
push!(n.children, child)
return nothing
end
function add_parent!(n::DataTaskNode{<:AbstractDataTask}, parent::Node)
push!(n.parents, parent)
return nothing
end
function remove_child!(n::DataTaskNode{<:AbstractDataTask}, child::Node)
for i in eachindex(n.children)
if (n.children[i] == child)
splice!(n.children, i)
break
end
end
return nothing
end
function remove_parent!(n::DataTaskNode{<:AbstractDataTask}, parent::Node)
for i in eachindex(n.parents)
if (n.parents[i] == parent)
splice!(n.parents, i)
break
end
end
return nothing
end
"""
add_child(n::ComputeTaskNode{<:ComputeTask}, child::Node)
Add a child to the compute node.
"""
function add_child!(n::ComputeTaskNode{<:ComputeTask}, child::Node)
push!(n.children, child)
push!(n.task.arguments, Argument(child.id))
return nothing
end
function add_parent!(n::ComputeTaskNode{<:ComputeTask}, parent::Node)
push!(n.parents, parent)
return nothing
end
function remove_child!(n::ComputeTaskNode{<:ComputeTask}, child::Node)
for i in eachindex(n.children)
if (n.children[i] == child)
splice!(n.children, i)
@assert n.task.func.arguments[i].id == child.id
splice!(n.task.func.arguments, i)
break
end
end
return nothing
end
function remove_parent!(n::ComputeTaskNode{<:ComputeTask}, parent::Node)
for i in eachindex(n.parents)
if (n.parents[i] == parent)
splice!(n.parents, i)
break
end
end
return nothing
end
function add_child!(n::ComputeTaskNode{<:FusedComputeTask}, child::Node, which::Symbol = :first)
push!(n.children, child)
if which == :first || which == :both
push!(n.task.first_func_arguments, child.id)
if which == :second || which == :both
push!(n.task.second_func_arguments, child.id)
end
if which != :first && which != :second && which != :both
@assert false "Tried to add child to symbol $(which) of fused compute task, but only :both, :first, and :second are allowed"
end
return nothing
end
function add_parent!(n::ComputeTaskNode{<:FusedComputeTask}, parent::Node)
push!(n.parents, parent)
return nothing
end
function remove_child!(n::ComputeTaskNode{<:FusedComputeTask}, child::Node)
for i in eachindex(n.children)
if (n.children[i] == child)
splice!(n.children, i)
break
end
end
for field in (:first_func_arguments, :second_func_arguments)
for i in eachindex(getfield(n.task, field))
if (getfield(n.task, field)[i].id == child.id)
splice!(getfield(n.task, field), i)
break
end
end
end
return nothing
end
function remove_parent!(n::ComputeTaskNode{<:FusedComputeTask}, parent::Node)
for i in eachindex(n.parents)
if (n.parents[i] == parent)
splice!(n.parents, i)
break
end
end
return nothing
end
function add_child!(n::ComputeTaskNode{<:VectorizedComputeTask}, child::Node)
push!(n.children, child)
push!(n.task.arguments, Argument(child.id))
return nothing
end
function add_parent!(n::ComputeTaskNode{<:VectorizedComputeTask}, parent::Node)
push!(n.parents, parent)
return nothing
end
function remove_child!(n::ComputeTaskNode{<:VectorizedComputeTask}, child::Node)
for i in eachindex(n.children)
if (n.children[i] == child)
splice!(n.children, i)
@assert n.task.func.arguments[i].id == child.id
splice!(n.task.func.arguments, i)
break
end
end
return nothing
end
function remove_parent!(n::ComputeTaskNode{<:VectorizedComputeTask}, parent::Node)
for i in eachindex(n.parents)
if (n.parents[i] == parent)
splice!(n.parents, i)
break
end
end
return nothing
end

View File

@ -24,15 +24,14 @@ abstract type Operation end
Any node that transfers data and does no computation.
# Fields
`.task`: The node's data task type. Usually [`DataTask`](@ref).\\
`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
`.nodeVectorization`: Always `missing`, since a data node can't be vectorized.\\
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.\\
`.name`: The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\
`.task`: The node's data task type. Usually [`DataTask`](@ref).\\
`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.\\
`.name`: The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\
"""
mutable struct DataTaskNode{TaskType <: AbstractDataTask} <: Node
task::TaskType
@ -49,9 +48,6 @@ mutable struct DataTaskNode{TaskType <: AbstractDataTask} <: Node
# Can't use the NodeReduction type here because it's not yet defined
nodeReduction::Union{Operation, Missing}
# data nodes can't be vectorized
nodeVectorization::Missing
# the NodeSplit involving this node, if it exists
nodeSplit::Union{Operation, Missing}
@ -68,25 +64,22 @@ end
Any node that computes a result from inputs using an [`AbstractComputeTask`](@ref).
# Fields
`.task`: The node's compute task type. A concrete subtype of [`AbstractComputeTask`](@ref).\\
`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeVectorization`: Either this node's [`NodeVectorization`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeFusions`: A vector of this node's [`NodeFusion`](@ref)s. For a `ComputeTaskNode` there can be any number of these, unlike the [`DataTaskNode`](@ref)s.\\
`.device`: The Device this node has been scheduled on by a [`Scheduler`](@ref).
`.task`: The node's compute task type. A concrete subtype of [`AbstractComputeTask`](@ref).\\
`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\
`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\
`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\
`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\
`.nodeFusions`: A vector of this node's [`NodeFusion`](@ref)s. For a `ComputeTaskNode` there can be any number of these, unlike the [`DataTaskNode`](@ref)s.\\
`.device`: The Device this node has been scheduled on by a [`Scheduler`](@ref).
"""
mutable struct ComputeTaskNode{TaskType <: AbstractComputeTask} <: Node
task::TaskType
parents::Vector{Node}
children::Vector{Node}
id::Base.UUID
nodeReduction::Union{Operation, Missing}
nodeVectorization::Union{Operation, Missing}
nodeSplit::Union{Operation, Missing}
# for ComputeTasks there can be multiple fusions, unlike the DataTasks

View File

@ -187,7 +187,7 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com
remove_node!(graph, n3)
# create new node with the fused compute task
newNode = ComputeTaskNode(FusedComputeTask(n1Task.func, n3Task.func, Symbol(to_var_name(n2.id))))
newNode = ComputeTaskNode(FusedComputeTask(n1Task, n3Task, n1Inputs, Symbol(to_var_name(n2.id)), n3Inputs))
insert_node!(graph, newNode)
for child in n1Children

View File

@ -239,31 +239,6 @@ function generate_operations(graph::DAG)
empty!(graph.dirtyNodes)
# find node vectorizations
# assume that relevant tasks will be at the same depth in the graph
"""
nodes = Queue{Tuple{<:Node, Int64}}()
node_set = Set{Node}()
currentDepth = 0
push!(nodes, (get_exit_node(graph), 0))
while !isempty(nodes)
(node, depth) = popfirst!(nodes)
pushback!.(node.parents, Ref(depth + 1))
if depth == currentDepth
push!(node_set, node)
continue
end
# collected all nodes of the same depth, sort into compute task types and create vectorizations
#TODO
# reset node_set
node_set = Set{Node}()
push!(node_set, node)
end
"""
wait(nr_task)
wait(nf_task)
wait(ns_task)

View File

@ -56,15 +56,3 @@ function show(io::IO, op::NodeFusion)
print(io, "->")
return print(io, task(op.input[3]))
end
"""
show(io::IO, op::NodeVectorization)
Print a string representation of the node vectorization to io.
"""
function show(io::IO, op::NodeVectorization)
print(io, "NV: ")
print(io, length(op.input))
print(io, "x")
return print(io, task(op.input[1]))
end

View File

@ -120,12 +120,3 @@ struct AppliedNodeSplit{NodeType <: Node} <: AppliedOperation
operation::NodeSplit{NodeType}
diff::Diff
end
struct NodeVectorization{TaskType <: AbstractComputeTask} <: Operation
input::Vector{ComputeTaskNode{TaskType}}
end
struct AppliedNodeVectorization{TaskType <: AbstractComputeTask} <: AppliedOperation
operation::NodeVectorization{TaskType}
diff::Diff
end

View File

@ -171,13 +171,3 @@ Equality comparison between two node splits. Two node splits are considered equa
function ==(op1::NodeSplit, op2::NodeSplit)
return op1.input == op2.input
end
"""
==(op1::NodeVectorization, op2::NodeVectorization)
Equality comparison between two node vectorizations. Two node vectorizations are considered equal when they have the same inputs.
"""
function ==(op1::NodeVectorization, op2::NodeVectorization)
# node vectorizations are equal exactly if their first input is the same
return op1.input[1].id == op2.input[1].id
end

View File

@ -17,7 +17,7 @@ function is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTas
if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1)
throw(
AssertionError(
"[Node Fusion] The given nodes are not connected by edges which is required for node fusion\n[Node Fusion] n1 ($(n1)) parents: $(parents(n1))\n[Node Fusion] n2 ($(n2)) children: $(children(n2)), n2 parents: $(parents(n2))\n[Node Fusion] n3 ($(n3)) children: $(children(n3))",
"[Node Fusion] The given nodes are not connected by edges which is required for node fusion",
),
)
end
@ -106,35 +106,6 @@ function is_valid_node_split_input(graph::DAG, n1::Node)
return true
end
"""
is_valid_node_vectorization_input(graph::DAG, nodes::Vector{Node})
Assert for a gven node vectorization input whether the nodes can be vectorized. For the requirements of a node vectorization see [`NodeVectorization`](@ref).
Intended for use with `@assert` or `@test`.
"""
function is_valid_node_vectorization_input(graph::DAG, nodes::Vector{Node})
for n in nodes
if n graph
throw(AssertionError("[Node Vectorization] The given nodes are not part of the given graph"))
end
@assert is_valid(graph, n)
end
t = typeof(task(nodes[1]))
for n in nodes
if typeof(task(n)) != t
throw(AssertionError("[Node Vectorization] The given nodes are not of the same type"))
end
end
if !pairwise_independent(graph, nodes)
throw(AssertionError("[Node Vectorization] The given nodes are not pairwise independent of one another"))
end
return true
end
"""
is_valid(graph::DAG, nr::NodeReduction)
@ -173,15 +144,3 @@ function is_valid(graph::DAG, nf::NodeFusion)
#@assert nf in graph.possibleOperations.nodeFusions "NodeFusion is not part of the graph's possible operations!"
return true
end
"""
is_valid(graph::DAG, nr::NodeVectorization)
Assert for a given [`NodeVectorization`](@ref) whether it is a valid operation in the graph.
Intended for use with `@assert` or `@test`.
"""
function is_valid(graph::DAG, nv::NodeVectorization)
@assert is_valid_node_vectorization_input(graph, nv.input)
return true
end

View File

@ -8,11 +8,11 @@ function ==(t1::AbstractTask, t2::AbstractTask)
end
"""
==(t1::AbstractTaskFunction, t2::AbstractTaskFunction)
==(t1::AbstractComputeTask, t2::AbstractComputeTask)
Equality comparison between two compute tasks.
"""
function ==(t1::AbstractTaskFunction, t2::AbstractTaskFunction)
function ==(t1::AbstractComputeTask, t2::AbstractComputeTask)
return typeof(t1) == typeof(t2)
end

View File

@ -6,7 +6,8 @@ using StaticArrays
Compute a [`FusedComputeTask`](@ref). This simply asserts false and should not be called. Fused Compute Tasks generate their expressions directly through the other tasks instead.
"""
function compute(t::FusedComputeTask, data...)
@assert false
inter = compute(t.first_task)
return compute(t.second_task, inter, data2...)
end
"""
@ -20,8 +21,8 @@ For ordinary compute or data tasks the vector will contain exactly one element,
function get_function_call(t::FusedComputeTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol)
# sort out the symbols to the correct tasks
return [
get_function_call(t.first_func, device, t.t1_inputs, t.t1_output)...,
get_function_call(t.second_func, device, [t.t2_inputs..., t.t1_output], outSymbol)...,
get_function_call(t.first_task, device, t.t1_inputs, t.t1_output)...,
get_function_call(t.second_task, device, [t.t2_inputs..., t.t1_output], outSymbol)...,
]
end

View File

@ -1,5 +1,3 @@
Argument(id::Base.UUID) = Argument(id, 0)
"""
copy(t::AbstractDataTask)
@ -12,12 +10,7 @@ copy(t::AbstractDataTask) = error("Need to implement copying for your data tasks
Return a copy of the given compute task.
"""
copy(t::AbstractComputeTask) = typeof(t)(copy(t.func), copy(t.arguments))
copy(t::AbstractTaskFunction) = typeof(t)()
function ComputeTask(t::AbstractTaskFunction)
return ComputeTask(t, Vector{Argument}())
end
copy(t::AbstractComputeTask) = typeof(t)()
"""
copy(t::FusedComputeTask)
@ -25,9 +18,15 @@ end
Return a copy of th egiven [`FusedComputeTask`](@ref).
"""
function copy(t::FusedComputeTask)
return FusedComputeTask(copy(t.first_func), copy(t.second_func), copy(t.t1_inputs), t.t1_output, copy(t.t2_inputs))
return FusedComputeTask(copy(t.first_task), copy(t.second_task), copy(t.t1_inputs), t.t1_output, copy(t.t2_inputs))
end
function FusedComputeTask(T1::Type{<:AbstractComputeTask}, T2::Type{<:AbstractComputeTask}, t1_output::String)
return FusedComputeTask(T1(), T2(), t1_output)
function FusedComputeTask(
T1::Type{<:AbstractComputeTask},
T2::Type{<:AbstractComputeTask},
t1_inputs::Vector{String},
t1_output::String,
t2_inputs::Vector{String},
)
return FusedComputeTask(T1(), T2(), t1_inputs, t1_output, t2_inputs)
end

View File

@ -3,21 +3,19 @@
Fallback implementation of the compute function of a compute task, throwing an error.
"""
function compute(t::AbstractTaskFunction, data...)
function compute(t::AbstractTask, data...)
return error("Need to implement compute()")
end
compute(t::ComputeTask, data...) = compute(t.func, data...) # TODO: is this correct?
"""
compute_effort(t::AbstractTask)
Fallback implementation of the compute effort of a task, throwing an error.
"""
function compute_effort(t::AbstractTaskFunction)::Float64
function compute_effort(t::AbstractTask)::Float64
# default implementation using compute
return error("Need to implement compute_effort()")
end
compute_effort(t::AbstractComputeTask)::Float64 = compute_effort(t.func)
"""
data(t::AbstractTask)
@ -62,7 +60,7 @@ children(::DataTask) = 1
Return the number of children of a FusedComputeTask.
"""
function children(t::FusedComputeTask)
return length(union(Set(t.first_func.arguments), Set(t.second_func.arguments))) - 1
return length(union(Set(t.t1_inputs), Set(t.t2_inputs)))
end
"""
@ -71,7 +69,6 @@ end
Return the data of a compute task, always zero, regardless of the specific task.
"""
data(t::AbstractComputeTask)::Float64 = 0.0
data(t::AbstractTaskFunction)::Float64 = 0.0
"""
compute_effort(t::FusedComputeTask)
@ -79,7 +76,7 @@ data(t::AbstractTaskFunction)::Float64 = 0.0
Return the compute effort of a fused compute task.
"""
function compute_effort(t::FusedComputeTask)::Float64
return compute_effort(t.first_func) + compute_effort(t.second_func)
return compute_effort(t.first_task) + compute_effort(t.second_task)
end
"""
@ -87,4 +84,4 @@ end
Return a tuple of a the fused compute task's components' types.
"""
get_types(t::FusedComputeTask) = (typeof(t.first_func), typeof(t.second_func))
get_types(t::FusedComputeTask) = (typeof(t.first_task), typeof(t.second_task))

View File

@ -1,15 +1,3 @@
"""
Argument
Representation of one of the arguments of a [`AbstractComputeTask`](@ref) in a Node.
Has a member `id::UUID`, which is the id of the compute task's parent the argument refers to, and an `index`, which is the index in the symbol's object. If `index` is 0, the whole object is the argument.
A compute task can have multiple arguments.
"""
struct Argument
id::Base.UUID
index::Int64
end
"""
AbstractTask
@ -18,33 +6,11 @@ The shared base type for any task.
abstract type AbstractTask end
"""
AbstractComputeTask
AbstractComputeTask <: AbstractTask
Base type for all compute tasks.
!!! note
A subtype of this *must* be templated with its [`AbstractTaskFunction`](@ref)s!
See also: [`ComputeTask`](@ref), [`FusedComputeTask`](@ref), [`VectorizedComputeTask`](@ref)
The shared base type for any compute task.
"""
abstract type AbstractComputeTask end
"""
AbstractTaskFunction
The base type for task functions used to dispatch computes.
"""
abstract type AbstractTaskFunction end
"""
ComputeTask <: AbstractTask
A compute task. Has a member [`AbstractTaskFunction`](@ref), and a list of [`Argument`](@ref)s.
"""
struct ComputeTask{TaskFunction <: AbstractTaskFunction} <: AbstractComputeTask
func::TaskFunction
arguments::Vector{Argument}
end
abstract type AbstractComputeTask <: AbstractTask end
"""
AbstractDataTask <: AbstractTask
@ -63,26 +29,19 @@ struct DataTask <: AbstractDataTask
end
"""
FusedComputeTask <: AbstractComputeTask
FusedComputeTask{T1 <: AbstractComputeTask, T2 <: AbstractComputeTask} <: AbstractComputeTask
A fused compute task made up of the computation of first `T1` and then `T2`.
Also see: [`get_types`](@ref).
"""
struct FusedComputeTask{TaskFunction1 <: AbstractTaskFunction, TaskFunction2 <: AbstractTaskFunction} <:
AbstractComputeTask
first_func::TaskFunction1
second_func::TaskFunction2
struct FusedComputeTask <: AbstractComputeTask
first_task::AbstractComputeTask
second_task::AbstractComputeTask
# the names of the inputs for T1
t1_inputs::Vector{Symbol}
# output name of T1
t1_output::Symbol
first_func_arguments::Vector{Argument}
second_func_arguments::Vector{Argument}
end
"""
"""
struct VectorizedComputeTask{TaskFunction <: AbstractTaskFunction} <: AbstractComputeTask
task_type::TaskFunction
arguments::Vector{Argument}
# t2_inputs doesn't include the output of t1, that's implicit
t2_inputs::Vector{Symbol}
end

View File

@ -48,7 +48,7 @@ function insert_helper!(
trie::NodeIdTrie{NodeType},
node::NodeType,
depth::Int,
) where {CompTask <: AbstractComputeTask, NodeType <: Union{DataTaskNode{AbstractDataTask}, ComputeTaskNode{CompTask}}}
) where {TaskType <: AbstractTask, NodeType <: Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}}}
if (length(children(node)) == depth)
push!(trie.value, node)
return nothing
@ -71,7 +71,7 @@ Insert the given node into the trie. It's sorted by its type in the first layer,
function insert!(
trie::NodeTrie,
node::NodeType,
) where {CompTask <: AbstractComputeTask, NodeType <: Union{DataTaskNode{AbstractDataTask}, ComputeTaskNode{CompTask}}}
) where {TaskType <: AbstractTask, NodeType <: Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}}}
if (!haskey(trie.children, NodeType))
trie.children[NodeType] = NodeIdTrie{NodeType}()
end

View File

@ -5,8 +5,6 @@ import MetagraphOptimization.insert_edge!
import MetagraphOptimization.make_node
import MetagraphOptimization.siblings
import MetagraphOptimization.partners
import MetagraphOptimization.is_dependent
import MetagraphOptimization.pairwise_independent
graph = MetagraphOptimization.DAG()
@ -110,20 +108,6 @@ insert_edge!(graph, s0, d_exit, track = false)
@test length(graph.dirtyNodes) == 26
@test length(graph.diff) == (addedNodes = 0, removedNodes = 0, addedEdges = 0, removedEdges = 0)
@test !is_dependent(graph, d_exit, d_exit)
@test is_dependent(graph, d_v0_s0, v0)
@test is_dependent(graph, d_v1_s0, v1)
@test !is_dependent(graph, d_v0_s0, v1)
@test is_dependent(graph, s0, d_PB)
@test !is_dependent(graph, v0, uBp)
@test !is_dependent(graph, PB, PA)
@test pairwise_independent(graph, [PB, PA, PBp, PAp])
@test pairwise_independent(graph, [uB, uA, d_uBp_v1])
@test !pairwise_independent(graph, [PB, PA, PBp, PAp, s0])
@test !pairwise_independent(graph, [d_uB_v0, v0])
@test !pairwise_independent(graph, [v0, d_v0_s0, s0, uB])
@test is_valid(graph)
@test is_entry_node(d_PB)

View File

@ -37,7 +37,7 @@ testparticleTypesPropagated = [
function compton_groundtruth(input::QEDProcessInput)
# p1k1 -> p2k2
# formula: (ie)^2 (u(p2) slashed(ε1) S(p2 k1) slashed(ε2) u(p1) + u(p2) slashed(ε2) S(p1 + k1) slashed(ε1) u(p1))
# formula: -(ie)^2 (u(p2) slashed(ε1) S(p2 - k1) slashed(ε2) u(p1) + u(p2) slashed(ε2) S(p1 + k1) slashed(ε1) u(p1))
p1 = input.inFerms[1]
p2 = input.outFerms[1]