diff --git a/src/MetagraphOptimization.jl b/src/MetagraphOptimization.jl index 76172d9..966a567 100644 --- a/src/MetagraphOptimization.jl +++ b/src/MetagraphOptimization.jl @@ -14,7 +14,6 @@ export Edge export ComputeTaskNode export DataTaskNode export AbstractTask -export AbstractComputeTask export AbstractDataTask export DataTask export FusedComputeTask @@ -24,8 +23,8 @@ export GraphProperties # graph functions export make_node export make_edge -export insert_node -export insert_edge +export insert_node! +export insert_edge! export is_entry_node export is_exit_node export parents diff --git a/src/graph/compare.jl b/src/graph/compare.jl index 7b4f206..b847f77 100644 --- a/src/graph/compare.jl +++ b/src/graph/compare.jl @@ -1,3 +1,5 @@ +using DataStructures + """ in(node::Node, graph::DAG) @@ -19,3 +21,54 @@ function in(edge::Edge, graph::DAG) return n1 in children(n2) end + +""" + is_dependent(graph::DAG, n1::Node, n2::Node) + +Returns whether `n1` is dependent on `n2` in the given `graph`. +""" +function is_dependent(graph::DAG, n1::Node, n2::Node) + if !(n1 in graph) || !(n2 in graph) + return false + end + if n1 == n2 + return false + end + + queue = Deque{Node}() + push!(queue, n1) + + # TODO: this is probably not the best way to do this, deduplication in the queue would help + while !isempty(queue) + current = popfirst!(queue) + if current == n2 + return true + end + + for c in current.children + push!(queue, c) + end + end + + return false +end + +""" + pairwise_independent(graph::DAG, nodes::Vector{Node}) + +Returns true if all given `nodes` are independent of each other, i.e., no path exists from any member of `nodes` to any other member. + +See [`is_dependent`](@ref) +""" +function pairwise_independent(graph::DAG, nodes::Vector{<:Node}) + checked_set = Vector{Node}() + for n in nodes + for c in checked_set + if is_dependent(graph, c, n) || is_dependent(graph, n, c) + return false + end + end + push!(checked_set, n) + end + return true +end diff --git a/src/graph/mute.jl b/src/graph/mute.jl index d5ad22c..bdd1431 100644 --- a/src/graph/mute.jl +++ b/src/graph/mute.jl @@ -46,12 +46,12 @@ Insert the edge between node1 (child) and node2 (parent) into the graph. See also: [`insert_node!`](@ref), [`remove_node!`](@ref), [`remove_edge!`](@ref) """ function insert_edge!(graph::DAG, node1::Node, node2::Node; track = true, invalidate_cache = true) - #@assert (node2 ∉ parents(node1)) && (node1 ∉ children(node2)) "Edge to insert already exists" + @assert (node2 ∉ parents(node1)) && (node1 ∉ children(node2)) "Edge to insert already exists" # 1: mute # edge points from child to parent - push!(node1.parents, node2) - push!(node2.children, node1) + add_parent!(node1, node2) + add_child!(node2, node1) # 2: keep track if (track) @@ -85,7 +85,7 @@ Remove the node from the graph. See also: [`insert_node!`](@ref), [`insert_edge!`](@ref), [`remove_edge!`](@ref) """ function remove_node!(graph::DAG, node::Node; track = true, invalidate_cache = true) - #@assert node in graph.nodes "Trying to remove a node that's not in the graph" + @assert node in graph.nodes "Trying to remove a node that's not in the graph" # 1: mute delete!(graph.nodes, node) @@ -123,19 +123,8 @@ function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invali pre_length1 = length(node1.parents) pre_length2 = length(node2.children) - for i in eachindex(node1.parents) - if (node1.parents[i] == node2) - splice!(node1.parents, i) - break - end - end - - for i in eachindex(node2.children) - if (node2.children[i] == node1) - splice!(node2.children, i) - break - end - end + remove_parent!(node1, node2) + remove_child!(node2, node1) #=@assert begin removed = pre_length1 - length(node1.parents) @@ -173,17 +162,17 @@ function replace_children!(task::FusedComputeTask, before, after) replacedIn1 = length(findall(x -> x == before, task.t1_inputs)) replacedIn2 = length(findall(x -> x == before, task.t2_inputs)) - #@assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)" + @assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)" replace!(task.t1_inputs, before => after) replace!(task.t2_inputs, before => after) # recursively descend down the tree, but only in the tasks where we're replacing things if replacedIn1 > 0 - replace_children!(task.first_task, before, after) + replace_children!(task.first_func, before, after) end if replacedIn2 > 0 - replace_children!(task.second_task, before, after) + replace_children!(task.second_func, before, after) end return nothing diff --git a/src/graph/type.jl b/src/graph/type.jl index e895b36..148b7e5 100644 --- a/src/graph/type.jl +++ b/src/graph/type.jl @@ -10,6 +10,7 @@ mutable struct PossibleOperations nodeFusions::Set{NodeFusion} nodeReductions::Set{NodeReduction} nodeSplits::Set{NodeSplit} + nodeVectorizations::Set{NodeVectorization} end """ @@ -52,7 +53,7 @@ end Construct and return an empty [`PossibleOperations`](@ref) object. """ function PossibleOperations() - return PossibleOperations(Set{NodeFusion}(), Set{NodeReduction}(), Set{NodeSplit}()) + return PossibleOperations(Set{NodeFusion}(), Set{NodeReduction}(), Set{NodeSplit}(), Set{NodeVectorization}()) end """ diff --git a/src/models/abc/parse.jl b/src/models/abc/parse.jl index c16554c..c2865ec 100644 --- a/src/models/abc/parse.jl +++ b/src/models/abc/parse.jl @@ -181,7 +181,7 @@ function parse_dag(filename::AbstractString, model::ABCModel, verbose::Bool = fa insert_edge!(graph, compute_S2, data_out, track = false, invalidate_cache = false) insert_edge!(graph, data_out, sum_node, track = false, invalidate_cache = false) - add_child!(task(sum_node)) + add_child!(task(sum_node).func) elseif occursin(regex_plus, node) if (verbose) println("\rReading Nodes Complete ") diff --git a/src/models/abc/properties.jl b/src/models/abc/properties.jl index d772cd9..25cccfd 100644 --- a/src/models/abc/properties.jl +++ b/src/models/abc/properties.jl @@ -78,7 +78,6 @@ Return the number of children of a ComputeTaskABC_V (always 2). """ children(::ComputeTaskABC_V) = 2 - """ children(::ComputeTaskABC_Sum) diff --git a/src/models/abc/types.jl b/src/models/abc/types.jl index cff7c72..9655b0d 100644 --- a/src/models/abc/types.jl +++ b/src/models/abc/types.jl @@ -1,44 +1,44 @@ """ - ComputeTaskABC_S1 <: AbstractComputeTask + ComputeTaskABC_S1 <: AbstractTaskFunction S task with a single child. """ -struct ComputeTaskABC_S1 <: AbstractComputeTask end +struct ComputeTaskABC_S1 <: AbstractTaskFunction end """ - ComputeTaskABC_S2 <: AbstractComputeTask + ComputeTaskABC_S2 <: AbstractTaskFunction S task with two children. """ -struct ComputeTaskABC_S2 <: AbstractComputeTask end +struct ComputeTaskABC_S2 <: AbstractTaskFunction end """ - ComputeTaskABC_P <: AbstractComputeTask + ComputeTaskABC_P <: AbstractTaskFunction P task with no children. """ -struct ComputeTaskABC_P <: AbstractComputeTask end +struct ComputeTaskABC_P <: AbstractTaskFunction end """ - ComputeTaskABC_V <: AbstractComputeTask + ComputeTaskABC_V <: AbstractTaskFunction v task with two children. """ -struct ComputeTaskABC_V <: AbstractComputeTask end +struct ComputeTaskABC_V <: AbstractTaskFunction end """ - ComputeTaskABC_U <: AbstractComputeTask + ComputeTaskABC_U <: AbstractTaskFunction u task with a single child. """ -struct ComputeTaskABC_U <: AbstractComputeTask end +struct ComputeTaskABC_U <: AbstractTaskFunction end """ - ComputeTaskABC_Sum <: AbstractComputeTask + ComputeTaskABC_Sum <: AbstractTaskFunction Task that sums all its inputs, n children. """ -mutable struct ComputeTaskABC_Sum <: AbstractComputeTask +mutable struct ComputeTaskABC_Sum <: AbstractTaskFunction children_number::Int end diff --git a/src/models/qed/compute.jl b/src/models/qed/compute.jl index 726f9d1..0c43410 100644 --- a/src/models/qed/compute.jl +++ b/src/models/qed/compute.jl @@ -66,7 +66,7 @@ function compute( data1::ParticleValue{P1}, data2::ParticleValue{P2}, ) where {P1 <: Union{AntiFermionStateful, FermionStateful}, P2 <: Union{AntiFermionStateful, FermionStateful}} - #@assert isapprox(data1.p.momentum, data2.p.momentum, rtol = sqrt(eps()), atol = sqrt(eps())) "$(data1.p.momentum) vs. $(data2.p.momentum)" + @assert isapprox(data1.p.momentum, data2.p.momentum, rtol = sqrt(eps()), atol = sqrt(eps())) "$(data1.p.momentum) vs. $(data2.p.momentum)" inner = QED_inner_edge(propagation_result(P1)(momentum(data1.p))) diff --git a/src/models/qed/diagrams.jl b/src/models/qed/diagrams.jl index 073fe4e..ad05215 100644 --- a/src/models/qed/diagrams.jl +++ b/src/models/qed/diagrams.jl @@ -188,7 +188,7 @@ end Apply a [`FeynmanVertex`](@ref) to the given vector of [`FeynmanParticle`](@ref)s. """ function apply_vertex!(particles::Vector{FeynmanParticle}, vertex::FeynmanVertex) - #@assert can_apply_vertex(particles, vertex) + @assert can_apply_vertex(particles, vertex) length_before = length(particles) filter!(x -> x != vertex.in1 && x != vertex.in2, particles) push!(particles, vertex.out) diff --git a/src/models/qed/types.jl b/src/models/qed/types.jl index 9923014..2943121 100644 --- a/src/models/qed/types.jl +++ b/src/models/qed/types.jl @@ -1,44 +1,44 @@ """ - ComputeTaskQED_S1 <: AbstractComputeTask + ComputeTaskQED_S1 <: AbstractTaskFunction S task with a single child. """ -struct ComputeTaskQED_S1 <: AbstractComputeTask end +struct ComputeTaskQED_S1 <: AbstractTaskFunction end """ - ComputeTaskQED_S2 <: AbstractComputeTask + ComputeTaskQED_S2 <: AbstractTaskFunction S task with two children. """ -struct ComputeTaskQED_S2 <: AbstractComputeTask end +struct ComputeTaskQED_S2 <: AbstractTaskFunction end """ - ComputeTaskQED_P <: AbstractComputeTask + ComputeTaskQED_P <: AbstractTaskFunction P task with no children. """ -struct ComputeTaskQED_P <: AbstractComputeTask end +struct ComputeTaskQED_P <: AbstractTaskFunction end """ - ComputeTaskQED_V <: AbstractComputeTask + ComputeTaskQED_V <: AbstractTaskFunction v task with two children. """ -struct ComputeTaskQED_V <: AbstractComputeTask end +struct ComputeTaskQED_V <: AbstractTaskFunction end """ - ComputeTaskQED_U <: AbstractComputeTask + ComputeTaskQED_U <: AbstractTaskFunction u task with a single child. """ -struct ComputeTaskQED_U <: AbstractComputeTask end +struct ComputeTaskQED_U <: AbstractTaskFunction end """ - ComputeTaskQED_Sum <: AbstractComputeTask + ComputeTaskQED_Sum <: AbstractTaskFunction Task that sums all its inputs, n children. """ -mutable struct ComputeTaskQED_Sum <: AbstractComputeTask +mutable struct ComputeTaskQED_Sum <: AbstractTaskFunction children_number::Int end diff --git a/src/node/create.jl b/src/node/create.jl index 0b69885..98af941 100644 --- a/src/node/create.jl +++ b/src/node/create.jl @@ -1,12 +1,22 @@ -DataTaskNode(t::AbstractDataTask, name = "") = - DataTaskNode(t, Vector{Node}(), Vector{Node}(), UUIDs.uuid1(rng[threadid()]), missing, missing, missing, name) +DataTaskNode(t::AbstractDataTask, name = "") = DataTaskNode( + t, + Vector{Node}(), + Vector{Node}(), + UUIDs.uuid1(rng[threadid()]), + missing, + missing, + missing, + missing, + name, +) ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode( t, # task Vector{Node}(), # parents Vector{Node}(), # children UUIDs.uuid1(rng[threadid()]), # id missing, # node reduction + missing, # node vectorization missing, # node split Vector{NodeFusion}(), # node fusions missing, # device @@ -39,9 +49,14 @@ end Construct and return a new [`ComputeTaskNode`](@ref) with the given task. """ -function make_node(t::AbstractComputeTask) - return ComputeTaskNode(t) -end +make_node(t::AbstractComputeTask) = ComputeTaskNode(t) + +""" + make_node(t::AbstractTaskFunction) + +Construct and return a new [`ComputeTaskNode`](@ref) with a default [`ComputeTask`](@ref) with the given task function. +""" +make_node(t::AbstractTaskFunction) = ComputeTaskNode(ComputeTask(t)) """ make_edge(n1::Node, n2::Node) diff --git a/src/node/properties.jl b/src/node/properties.jl index e2de923..56bb2a1 100644 --- a/src/node/properties.jl +++ b/src/node/properties.jl @@ -121,3 +121,157 @@ Return whether the `potential_child` is a child of `node`. function is_child(potential_child::Node, node::Node)::Bool return potential_child in children(node) end + +function add_child!(n::DataTaskNode{<:AbstractDataTask}, child::Node) + push!(n.children, child) + return nothing +end + +function add_parent!(n::DataTaskNode{<:AbstractDataTask}, parent::Node) + push!(n.parents, parent) + return nothing +end + +function remove_child!(n::DataTaskNode{<:AbstractDataTask}, child::Node) + for i in eachindex(n.children) + if (n.children[i] == child) + splice!(n.children, i) + break + end + end + return nothing +end + +function remove_parent!(n::DataTaskNode{<:AbstractDataTask}, parent::Node) + for i in eachindex(n.parents) + if (n.parents[i] == parent) + splice!(n.parents, i) + break + end + end + return nothing +end + +""" + add_child(n::ComputeTaskNode{<:ComputeTask}, child::Node) + +Add a child to the compute node. +""" +function add_child!(n::ComputeTaskNode{<:ComputeTask}, child::Node) + push!(n.children, child) + push!(n.task.arguments, Argument(child.id)) + return nothing +end + +function add_parent!(n::ComputeTaskNode{<:ComputeTask}, parent::Node) + push!(n.parents, parent) + return nothing +end + +function remove_child!(n::ComputeTaskNode{<:ComputeTask}, child::Node) + for i in eachindex(n.children) + if (n.children[i] == child) + splice!(n.children, i) + @assert n.task.func.arguments[i].id == child.id + splice!(n.task.func.arguments, i) + break + end + end + return nothing +end + +function remove_parent!(n::ComputeTaskNode{<:ComputeTask}, parent::Node) + for i in eachindex(n.parents) + if (n.parents[i] == parent) + splice!(n.parents, i) + break + end + end + return nothing +end + + + +function add_child!(n::ComputeTaskNode{<:FusedComputeTask}, child::Node, which::Symbol = :first) + push!(n.children, child) + + if which == :first || which == :both + push!(n.task.first_func_arguments, child.id) + if which == :second || which == :both + push!(n.task.second_func_arguments, child.id) + end + + if which != :first && which != :second && which != :both + @assert false "Tried to add child to symbol $(which) of fused compute task, but only :both, :first, and :second are allowed" + end + return nothing +end + +function add_parent!(n::ComputeTaskNode{<:FusedComputeTask}, parent::Node) + push!(n.parents, parent) + return nothing +end + +function remove_child!(n::ComputeTaskNode{<:FusedComputeTask}, child::Node) + for i in eachindex(n.children) + if (n.children[i] == child) + splice!(n.children, i) + break + end + end + + for field in (:first_func_arguments, :second_func_arguments) + for i in eachindex(getfield(n.task, field)) + if (getfield(n.task, field)[i].id == child.id) + splice!(getfield(n.task, field), i) + break + end + end + end + return nothing +end + +function remove_parent!(n::ComputeTaskNode{<:FusedComputeTask}, parent::Node) + for i in eachindex(n.parents) + if (n.parents[i] == parent) + splice!(n.parents, i) + break + end + end + return nothing +end + + + +function add_child!(n::ComputeTaskNode{<:VectorizedComputeTask}, child::Node) + push!(n.children, child) + push!(n.task.arguments, Argument(child.id)) + return nothing +end + +function add_parent!(n::ComputeTaskNode{<:VectorizedComputeTask}, parent::Node) + push!(n.parents, parent) + return nothing +end + +function remove_child!(n::ComputeTaskNode{<:VectorizedComputeTask}, child::Node) + for i in eachindex(n.children) + if (n.children[i] == child) + splice!(n.children, i) + @assert n.task.func.arguments[i].id == child.id + splice!(n.task.func.arguments, i) + break + end + end + return nothing +end + +function remove_parent!(n::ComputeTaskNode{<:VectorizedComputeTask}, parent::Node) + for i in eachindex(n.parents) + if (n.parents[i] == parent) + splice!(n.parents, i) + break + end + end + return nothing +end diff --git a/src/node/type.jl b/src/node/type.jl index 39283d0..3f5faba 100644 --- a/src/node/type.jl +++ b/src/node/type.jl @@ -24,14 +24,15 @@ abstract type Operation end Any node that transfers data and does no computation. # Fields -`.task`: The node's data task type. Usually [`DataTask`](@ref).\\ -`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\ -`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\ -`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\ -`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\ -`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\ -`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.\\ -`.name`: The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\ +`.task`: The node's data task type. Usually [`DataTask`](@ref).\\ +`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\ +`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\ +`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\ +`.nodeVectorization`: Always `missing`, since a data node can't be vectorized.\\ +`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\ +`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\ +`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.\\ +`.name`: The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\ """ mutable struct DataTaskNode{TaskType <: AbstractDataTask} <: Node task::TaskType @@ -48,6 +49,9 @@ mutable struct DataTaskNode{TaskType <: AbstractDataTask} <: Node # Can't use the NodeReduction type here because it's not yet defined nodeReduction::Union{Operation, Missing} + # data nodes can't be vectorized + nodeVectorization::Missing + # the NodeSplit involving this node, if it exists nodeSplit::Union{Operation, Missing} @@ -64,22 +68,25 @@ end Any node that computes a result from inputs using an [`AbstractComputeTask`](@ref). # Fields -`.task`: The node's compute task type. A concrete subtype of [`AbstractComputeTask`](@ref).\\ -`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\ -`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\ -`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\ -`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\ -`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\ -`.nodeFusions`: A vector of this node's [`NodeFusion`](@ref)s. For a `ComputeTaskNode` there can be any number of these, unlike the [`DataTaskNode`](@ref)s.\\ -`.device`: The Device this node has been scheduled on by a [`Scheduler`](@ref). +`.task`: The node's compute task type. A concrete subtype of [`AbstractComputeTask`](@ref).\\ +`.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\ +`.children`: A vector of the node's children (i.e. nodes that this one depends on).\\ +`.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\ +`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\ +`.nodeVectorization`: Either this node's [`NodeVectorization`](@ref) or `missing`, if none. There can only be at most one.\\ +`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\ +`.nodeFusions`: A vector of this node's [`NodeFusion`](@ref)s. For a `ComputeTaskNode` there can be any number of these, unlike the [`DataTaskNode`](@ref)s.\\ +`.device`: The Device this node has been scheduled on by a [`Scheduler`](@ref). """ mutable struct ComputeTaskNode{TaskType <: AbstractComputeTask} <: Node task::TaskType + parents::Vector{Node} children::Vector{Node} id::Base.UUID nodeReduction::Union{Operation, Missing} + nodeVectorization::Union{Operation, Missing} nodeSplit::Union{Operation, Missing} # for ComputeTasks there can be multiple fusions, unlike the DataTasks diff --git a/src/operation/apply.jl b/src/operation/apply.jl index 164b67d..d945f3c 100644 --- a/src/operation/apply.jl +++ b/src/operation/apply.jl @@ -187,7 +187,7 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com remove_node!(graph, n3) # create new node with the fused compute task - newNode = ComputeTaskNode(FusedComputeTask(n1Task, n3Task, n1Inputs, Symbol(to_var_name(n2.id)), n3Inputs)) + newNode = ComputeTaskNode(FusedComputeTask(n1Task.func, n3Task.func, Symbol(to_var_name(n2.id)))) insert_node!(graph, newNode) for child in n1Children diff --git a/src/operation/find.jl b/src/operation/find.jl index 141443b..c254152 100644 --- a/src/operation/find.jl +++ b/src/operation/find.jl @@ -239,6 +239,31 @@ function generate_operations(graph::DAG) empty!(graph.dirtyNodes) + # find node vectorizations + # assume that relevant tasks will be at the same depth in the graph + """ + nodes = Queue{Tuple{<:Node, Int64}}() + node_set = Set{Node}() + currentDepth = 0 + push!(nodes, (get_exit_node(graph), 0)) + while !isempty(nodes) + (node, depth) = popfirst!(nodes) + pushback!.(node.parents, Ref(depth + 1)) + + if depth == currentDepth + push!(node_set, node) + continue + end + + # collected all nodes of the same depth, sort into compute task types and create vectorizations + #TODO + + # reset node_set + node_set = Set{Node}() + push!(node_set, node) + end + """ + wait(nr_task) wait(nf_task) wait(ns_task) diff --git a/src/operation/print.jl b/src/operation/print.jl index d4a1acd..96ce7a5 100644 --- a/src/operation/print.jl +++ b/src/operation/print.jl @@ -56,3 +56,15 @@ function show(io::IO, op::NodeFusion) print(io, "->") return print(io, task(op.input[3])) end + +""" + show(io::IO, op::NodeVectorization) + +Print a string representation of the node vectorization to io. +""" +function show(io::IO, op::NodeVectorization) + print(io, "NV: ") + print(io, length(op.input)) + print(io, "x") + return print(io, task(op.input[1])) +end diff --git a/src/operation/type.jl b/src/operation/type.jl index 606b101..0662014 100644 --- a/src/operation/type.jl +++ b/src/operation/type.jl @@ -120,3 +120,12 @@ struct AppliedNodeSplit{NodeType <: Node} <: AppliedOperation operation::NodeSplit{NodeType} diff::Diff end + +struct NodeVectorization{TaskType <: AbstractComputeTask} <: Operation + input::Vector{ComputeTaskNode{TaskType}} +end + +struct AppliedNodeVectorization{TaskType <: AbstractComputeTask} <: AppliedOperation + operation::NodeVectorization{TaskType} + diff::Diff +end diff --git a/src/operation/utility.jl b/src/operation/utility.jl index 0ccafb3..de09004 100644 --- a/src/operation/utility.jl +++ b/src/operation/utility.jl @@ -171,3 +171,13 @@ Equality comparison between two node splits. Two node splits are considered equa function ==(op1::NodeSplit, op2::NodeSplit) return op1.input == op2.input end + +""" + ==(op1::NodeVectorization, op2::NodeVectorization) + +Equality comparison between two node vectorizations. Two node vectorizations are considered equal when they have the same inputs. +""" +function ==(op1::NodeVectorization, op2::NodeVectorization) + # node vectorizations are equal exactly if their first input is the same + return op1.input[1].id == op2.input[1].id +end diff --git a/src/operation/validate.jl b/src/operation/validate.jl index ede35f3..412cbb2 100644 --- a/src/operation/validate.jl +++ b/src/operation/validate.jl @@ -17,7 +17,7 @@ function is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTas if !is_child(n1, n2) || !is_child(n2, n3) || !is_parent(n3, n2) || !is_parent(n2, n1) throw( AssertionError( - "[Node Fusion] The given nodes are not connected by edges which is required for node fusion", + "[Node Fusion] The given nodes are not connected by edges which is required for node fusion\n[Node Fusion] n1 ($(n1)) parents: $(parents(n1))\n[Node Fusion] n2 ($(n2)) children: $(children(n2)), n2 parents: $(parents(n2))\n[Node Fusion] n3 ($(n3)) children: $(children(n3))", ), ) end @@ -106,6 +106,35 @@ function is_valid_node_split_input(graph::DAG, n1::Node) return true end +""" + is_valid_node_vectorization_input(graph::DAG, nodes::Vector{Node}) + +Assert for a gven node vectorization input whether the nodes can be vectorized. For the requirements of a node vectorization see [`NodeVectorization`](@ref). + +Intended for use with `@assert` or `@test`. +""" +function is_valid_node_vectorization_input(graph::DAG, nodes::Vector{Node}) + for n in nodes + if n ∉ graph + throw(AssertionError("[Node Vectorization] The given nodes are not part of the given graph")) + end + @assert is_valid(graph, n) + end + + t = typeof(task(nodes[1])) + for n in nodes + if typeof(task(n)) != t + throw(AssertionError("[Node Vectorization] The given nodes are not of the same type")) + end + end + + if !pairwise_independent(graph, nodes) + throw(AssertionError("[Node Vectorization] The given nodes are not pairwise independent of one another")) + end + + return true +end + """ is_valid(graph::DAG, nr::NodeReduction) @@ -144,3 +173,15 @@ function is_valid(graph::DAG, nf::NodeFusion) #@assert nf in graph.possibleOperations.nodeFusions "NodeFusion is not part of the graph's possible operations!" return true end + +""" + is_valid(graph::DAG, nr::NodeVectorization) + +Assert for a given [`NodeVectorization`](@ref) whether it is a valid operation in the graph. + +Intended for use with `@assert` or `@test`. +""" +function is_valid(graph::DAG, nv::NodeVectorization) + @assert is_valid_node_vectorization_input(graph, nv.input) + return true +end diff --git a/src/task/compare.jl b/src/task/compare.jl index e960550..3298b4f 100644 --- a/src/task/compare.jl +++ b/src/task/compare.jl @@ -8,11 +8,11 @@ function ==(t1::AbstractTask, t2::AbstractTask) end """ - ==(t1::AbstractComputeTask, t2::AbstractComputeTask) + ==(t1::AbstractTaskFunction, t2::AbstractTaskFunction) Equality comparison between two compute tasks. """ -function ==(t1::AbstractComputeTask, t2::AbstractComputeTask) +function ==(t1::AbstractTaskFunction, t2::AbstractTaskFunction) return typeof(t1) == typeof(t2) end diff --git a/src/task/compute.jl b/src/task/compute.jl index c1ed265..1f429e6 100644 --- a/src/task/compute.jl +++ b/src/task/compute.jl @@ -6,8 +6,7 @@ using StaticArrays Compute a [`FusedComputeTask`](@ref). This simply asserts false and should not be called. Fused Compute Tasks generate their expressions directly through the other tasks instead. """ function compute(t::FusedComputeTask, data...) - inter = compute(t.first_task) - return compute(t.second_task, inter, data2...) + @assert false end """ @@ -21,8 +20,8 @@ For ordinary compute or data tasks the vector will contain exactly one element, function get_function_call(t::FusedComputeTask, device::AbstractDevice, inSymbols::AbstractVector, outSymbol::Symbol) # sort out the symbols to the correct tasks return [ - get_function_call(t.first_task, device, t.t1_inputs, t.t1_output)..., - get_function_call(t.second_task, device, [t.t2_inputs..., t.t1_output], outSymbol)..., + get_function_call(t.first_func, device, t.t1_inputs, t.t1_output)..., + get_function_call(t.second_func, device, [t.t2_inputs..., t.t1_output], outSymbol)..., ] end diff --git a/src/task/create.jl b/src/task/create.jl index 147bfc1..7ba99b8 100644 --- a/src/task/create.jl +++ b/src/task/create.jl @@ -1,3 +1,5 @@ +Argument(id::Base.UUID) = Argument(id, 0) + """ copy(t::AbstractDataTask) @@ -10,7 +12,12 @@ copy(t::AbstractDataTask) = error("Need to implement copying for your data tasks Return a copy of the given compute task. """ -copy(t::AbstractComputeTask) = typeof(t)() +copy(t::AbstractComputeTask) = typeof(t)(copy(t.func), copy(t.arguments)) +copy(t::AbstractTaskFunction) = typeof(t)() + +function ComputeTask(t::AbstractTaskFunction) + return ComputeTask(t, Vector{Argument}()) +end """ copy(t::FusedComputeTask) @@ -18,15 +25,9 @@ copy(t::AbstractComputeTask) = typeof(t)() Return a copy of th egiven [`FusedComputeTask`](@ref). """ function copy(t::FusedComputeTask) - return FusedComputeTask(copy(t.first_task), copy(t.second_task), copy(t.t1_inputs), t.t1_output, copy(t.t2_inputs)) + return FusedComputeTask(copy(t.first_func), copy(t.second_func), copy(t.t1_inputs), t.t1_output, copy(t.t2_inputs)) end -function FusedComputeTask( - T1::Type{<:AbstractComputeTask}, - T2::Type{<:AbstractComputeTask}, - t1_inputs::Vector{String}, - t1_output::String, - t2_inputs::Vector{String}, -) - return FusedComputeTask(T1(), T2(), t1_inputs, t1_output, t2_inputs) +function FusedComputeTask(T1::Type{<:AbstractComputeTask}, T2::Type{<:AbstractComputeTask}, t1_output::String) + return FusedComputeTask(T1(), T2(), t1_output) end diff --git a/src/task/properties.jl b/src/task/properties.jl index c608843..2850faf 100644 --- a/src/task/properties.jl +++ b/src/task/properties.jl @@ -3,19 +3,21 @@ Fallback implementation of the compute function of a compute task, throwing an error. """ -function compute(t::AbstractTask, data...) +function compute(t::AbstractTaskFunction, data...) return error("Need to implement compute()") end +compute(t::ComputeTask, data...) = compute(t.func, data...) # TODO: is this correct? """ compute_effort(t::AbstractTask) Fallback implementation of the compute effort of a task, throwing an error. """ -function compute_effort(t::AbstractTask)::Float64 +function compute_effort(t::AbstractTaskFunction)::Float64 # default implementation using compute return error("Need to implement compute_effort()") end +compute_effort(t::AbstractComputeTask)::Float64 = compute_effort(t.func) """ data(t::AbstractTask) @@ -60,7 +62,7 @@ children(::DataTask) = 1 Return the number of children of a FusedComputeTask. """ function children(t::FusedComputeTask) - return length(union(Set(t.t1_inputs), Set(t.t2_inputs))) + return length(union(Set(t.first_func.arguments), Set(t.second_func.arguments))) - 1 end """ @@ -69,6 +71,7 @@ end Return the data of a compute task, always zero, regardless of the specific task. """ data(t::AbstractComputeTask)::Float64 = 0.0 +data(t::AbstractTaskFunction)::Float64 = 0.0 """ compute_effort(t::FusedComputeTask) @@ -76,7 +79,7 @@ data(t::AbstractComputeTask)::Float64 = 0.0 Return the compute effort of a fused compute task. """ function compute_effort(t::FusedComputeTask)::Float64 - return compute_effort(t.first_task) + compute_effort(t.second_task) + return compute_effort(t.first_func) + compute_effort(t.second_func) end """ @@ -84,4 +87,4 @@ end Return a tuple of a the fused compute task's components' types. """ -get_types(t::FusedComputeTask) = (typeof(t.first_task), typeof(t.second_task)) +get_types(t::FusedComputeTask) = (typeof(t.first_func), typeof(t.second_func)) diff --git a/src/task/type.jl b/src/task/type.jl index fb12b5b..f9f985f 100644 --- a/src/task/type.jl +++ b/src/task/type.jl @@ -1,3 +1,15 @@ +""" + Argument + +Representation of one of the arguments of a [`AbstractComputeTask`](@ref) in a Node. +Has a member `id::UUID`, which is the id of the compute task's parent the argument refers to, and an `index`, which is the index in the symbol's object. If `index` is 0, the whole object is the argument. +A compute task can have multiple arguments. +""" +struct Argument + id::Base.UUID + index::Int64 +end + """ AbstractTask @@ -6,11 +18,33 @@ The shared base type for any task. abstract type AbstractTask end """ - AbstractComputeTask <: AbstractTask + AbstractComputeTask -The shared base type for any compute task. +Base type for all compute tasks. + +!!! note + A subtype of this *must* be templated with its [`AbstractTaskFunction`](@ref)s! + +See also: [`ComputeTask`](@ref), [`FusedComputeTask`](@ref), [`VectorizedComputeTask`](@ref) """ -abstract type AbstractComputeTask <: AbstractTask end +abstract type AbstractComputeTask end + +""" + AbstractTaskFunction + +The base type for task functions used to dispatch computes. +""" +abstract type AbstractTaskFunction end + +""" + ComputeTask <: AbstractTask + +A compute task. Has a member [`AbstractTaskFunction`](@ref), and a list of [`Argument`](@ref)s. +""" +struct ComputeTask{TaskFunction <: AbstractTaskFunction} <: AbstractComputeTask + func::TaskFunction + arguments::Vector{Argument} +end """ AbstractDataTask <: AbstractTask @@ -29,19 +63,26 @@ struct DataTask <: AbstractDataTask end """ - FusedComputeTask{T1 <: AbstractComputeTask, T2 <: AbstractComputeTask} <: AbstractComputeTask + FusedComputeTask <: AbstractComputeTask A fused compute task made up of the computation of first `T1` and then `T2`. Also see: [`get_types`](@ref). """ -struct FusedComputeTask <: AbstractComputeTask - first_task::AbstractComputeTask - second_task::AbstractComputeTask - # the names of the inputs for T1 - t1_inputs::Vector{Symbol} - # output name of T1 +struct FusedComputeTask{TaskFunction1 <: AbstractTaskFunction, TaskFunction2 <: AbstractTaskFunction} <: + AbstractComputeTask + first_func::TaskFunction1 + second_func::TaskFunction2 t1_output::Symbol - # t2_inputs doesn't include the output of t1, that's implicit - t2_inputs::Vector{Symbol} + first_func_arguments::Vector{Argument} + second_func_arguments::Vector{Argument} +end + +""" + + +""" +struct VectorizedComputeTask{TaskFunction <: AbstractTaskFunction} <: AbstractComputeTask + task_type::TaskFunction + arguments::Vector{Argument} end diff --git a/src/trie.jl b/src/trie.jl index b3babca..1ac50bc 100644 --- a/src/trie.jl +++ b/src/trie.jl @@ -48,7 +48,7 @@ function insert_helper!( trie::NodeIdTrie{NodeType}, node::NodeType, depth::Int, -) where {TaskType <: AbstractTask, NodeType <: Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}}} +) where {CompTask <: AbstractComputeTask, NodeType <: Union{DataTaskNode{AbstractDataTask}, ComputeTaskNode{CompTask}}} if (length(children(node)) == depth) push!(trie.value, node) return nothing @@ -71,7 +71,7 @@ Insert the given node into the trie. It's sorted by its type in the first layer, function insert!( trie::NodeTrie, node::NodeType, -) where {TaskType <: AbstractTask, NodeType <: Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}}} +) where {CompTask <: AbstractComputeTask, NodeType <: Union{DataTaskNode{AbstractDataTask}, ComputeTaskNode{CompTask}}} if (!haskey(trie.children, NodeType)) trie.children[NodeType] = NodeIdTrie{NodeType}() end diff --git a/test/unit_tests_graph.jl b/test/unit_tests_graph.jl index c6e59ff..23dbbe4 100644 --- a/test/unit_tests_graph.jl +++ b/test/unit_tests_graph.jl @@ -5,6 +5,8 @@ import MetagraphOptimization.insert_edge! import MetagraphOptimization.make_node import MetagraphOptimization.siblings import MetagraphOptimization.partners +import MetagraphOptimization.is_dependent +import MetagraphOptimization.pairwise_independent graph = MetagraphOptimization.DAG() @@ -108,6 +110,20 @@ insert_edge!(graph, s0, d_exit, track = false) @test length(graph.dirtyNodes) == 26 @test length(graph.diff) == (addedNodes = 0, removedNodes = 0, addedEdges = 0, removedEdges = 0) +@test !is_dependent(graph, d_exit, d_exit) +@test is_dependent(graph, d_v0_s0, v0) +@test is_dependent(graph, d_v1_s0, v1) +@test !is_dependent(graph, d_v0_s0, v1) +@test is_dependent(graph, s0, d_PB) +@test !is_dependent(graph, v0, uBp) +@test !is_dependent(graph, PB, PA) + +@test pairwise_independent(graph, [PB, PA, PBp, PAp]) +@test pairwise_independent(graph, [uB, uA, d_uBp_v1]) +@test !pairwise_independent(graph, [PB, PA, PBp, PAp, s0]) +@test !pairwise_independent(graph, [d_uB_v0, v0]) +@test !pairwise_independent(graph, [v0, d_v0_s0, s0, uB]) + @test is_valid(graph) @test is_entry_node(d_PB)