Fun with type stability

This commit is contained in:
Anton Reinhard 2023-11-21 01:48:59 +01:00
parent 9d947a49ce
commit 705bfb30fe
25 changed files with 164 additions and 133 deletions

View File

@ -31,6 +31,7 @@ export children
export compute
export data
export compute_effort
export task
export get_properties
export get_exit_node
export is_valid, is_scheduled

View File

@ -63,15 +63,15 @@ end
function operation_effect(estimator::GlobalMetricEstimator, graph::DAG, operation::NodeReduction)
s = length(operation.input) - 1
return (
data = s * -data(operation.input[1].task),
computeEffort = s * -compute_effort(operation.input[1].task),
data = s * -data(task(operation.input[1])),
computeEffort = s * -compute_effort(task(operation.input[1])),
computeIntensity = typeof(operation.input) <: DataTaskNode ? 0.0 : Inf,
)::CDCost
end
function operation_effect(estimator::GlobalMetricEstimator, graph::DAG, operation::NodeSplit)
s = length(operation.input.parents) - 1
d = s * data(operation.input.task)
ce = s * compute_effort(operation.input.task)
s::Float64 = length(parents(operation.input)) - 1
d::Float64 = s * data(task(operation.input))
ce::Float64 = s * compute_effort(task(operation.input))
return (data = d, computeEffort = ce, computeIntensity = ce / d)::CDCost
end

View File

@ -17,7 +17,7 @@ function in(edge::Edge, graph::DAG)
return false
end
return n1 in n2.children
return n1 in children(n2)
end
"""

View File

@ -46,7 +46,7 @@ Insert the edge between node1 (child) and node2 (parent) into the graph.
See also: [`insert_node!`](@ref), [`remove_node!`](@ref), [`remove_edge!`](@ref)
"""
function insert_edge!(graph::DAG, node1::Node, node2::Node; track = true, invalidate_cache = true)
@assert (node2 node1.parents) && (node1 node2.children) "Edge to insert already exists"
@assert (node2 parents(node1)) && (node1 children(node2)) "Edge to insert already exists"
# 1: mute
# edge points from child to parent
@ -133,7 +133,7 @@ function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invali
end "removed more than one node from node1's parents"
@assert begin
removed = pre_length2 - length(node2.children)
removed = pre_length2 - length(children(node2))
removed <= 1
end "removed more than one node from node2's children"
@ -185,28 +185,28 @@ end
function update_child!(graph::DAG, n::Node, child_before::Symbol, child_after::Symbol; track = true)
# only need to update fused compute tasks
if !(typeof(n.task) <: FusedComputeTask)
if !(typeof(task(n)) <: FusedComputeTask)
return nothing
end
taskBefore = copy(n.task)
taskBefore = copy(task(n))
if !((child_before in n.task.t1_inputs) || (child_before in n.task.t2_inputs))
if !((child_before in task(n).t1_inputs) || (child_before in task(n).t2_inputs))
println("------------------ Nothing to replace!! ------------------")
child_ids = Vector{String}()
for child in n.children
for child in children(n)
push!(child_ids, "$(child.id)")
end
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
@assert false
end
replace_children!(n.task, child_before, child_after)
replace_children!(task(n), child_before, child_after)
if !((child_after in n.task.t1_inputs) || (child_after in n.task.t2_inputs))
if !((child_after in task(n).t1_inputs) || (child_after in task(n).t2_inputs))
println("------------------ Did not replace anything!! ------------------")
child_ids = Vector{String}()
for child in n.children
for child in children(n)
push!(child_ids, "$(child.id)")
end
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")

View File

@ -30,10 +30,10 @@ function show(io::IO, graph::DAG)
nodeDict = Dict{Type, Int64}()
noEdges = 0
for node in graph.nodes
if haskey(nodeDict, typeof(node.task))
nodeDict[typeof(node.task)] = nodeDict[typeof(node.task)] + 1
if haskey(nodeDict, typeof(task(node)))
nodeDict[typeof(task(node))] = nodeDict[typeof(task(node))] + 1
else
nodeDict[typeof(node.task)] = 1
nodeDict[typeof(task(node))] = 1
end
noEdges += length(parents(node))
end

View File

@ -181,7 +181,7 @@ function parse_dag(filename::AbstractString, model::ABCModel, verbose::Bool = fa
insert_edge!(graph, compute_S2, data_out, track = false, invalidate_cache = false)
insert_edge!(graph, data_out, sum_node, track = false, invalidate_cache = false)
add_child!(sum_node.task)
add_child!(task(sum_node))
elseif occursin(regex_plus, node)
if (verbose)
println("\rReading Nodes Complete ")

View File

@ -13,8 +13,8 @@ ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
)
copy(m::Missing) = missing
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task))
copy(n::DataTaskNode) = DataTaskNode(copy(n.task), n.name)
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(task(n)))
copy(n::DataTaskNode) = DataTaskNode(copy(task(n)), n.name)
"""
make_node(t::AbstractTask)

View File

@ -4,7 +4,7 @@
Print a short string representation of the node to io.
"""
function show(io::IO, n::Node)
return print(io, "Node(", n.task, ")")
return print(io, "Node(", task(n), ")")
end
"""

View File

@ -3,14 +3,26 @@
Return whether this node is an entry node in its graph, i.e., it has no children.
"""
is_entry_node(node::Node) = length(node.children) == 0
is_entry_node(node::Node) = length(children(node)) == 0
"""
is_exit_node(node::Node)
Return whether this node is an exit node of its graph, i.e., it has no parents.
"""
is_exit_node(node::Node) = length(node.parents) == 0
is_exit_node(node::Node)::Bool = length(parents(node)) == 0
"""
task(node::Node)
Return the node's task.
"""
function task(node::DataTaskNode{TaskType})::TaskType where {TaskType <: AbstractDataTask}
return node.task
end
function task(node::ComputeTaskNode{TaskType})::TaskType where {TaskType <: AbstractComputeTask}
return node.task
end
"""
children(node::Node)
@ -19,8 +31,11 @@ Return a copy of the node's children so it can safely be muted without changing
A node's children are its prerequisite nodes, nodes that need to execute before the task of this node.
"""
function children(node::Node)
return copy(node.children)
function children(node::DataTaskNode)::Vector{ComputeTaskNode}
return node.children
end
function children(node::ComputeTaskNode)::Vector{DataTaskNode}
return node.children
end
"""
@ -30,8 +45,11 @@ Return a copy of the node's parents so it can safely be muted without changing t
A node's parents are its subsequent nodes, nodes that need this node to execute.
"""
function parents(node::Node)
return copy(node.parents)
function parents(node::DataTaskNode)::Vector{ComputeTaskNode}
return node.parents
end
function parents(node::ComputeTaskNode)::Vector{DataTaskNode}
return node.parents
end
"""
@ -41,11 +59,11 @@ Return a vector of all siblings of this node.
A node's siblings are all children of any of its parents. The result contains no duplicates and includes the node itself.
"""
function siblings(node::Node)
function siblings(node::Node)::Set{Node}
result = Set{Node}()
push!(result, node)
for parent in node.parents
union!(result, parent.children)
for parent in parents(node)
union!(result, children(parent))
end
return result
@ -61,11 +79,11 @@ A node's partners are all parents of any of its children. The result contains no
Note: This is very slow when there are multiple children with many parents.
This is less of a problem in [`siblings(node::Node)`](@ref) because (depending on the model) there are no nodes with a large number of children, or only a single one.
"""
function partners(node::Node)
function partners(node::Node)::Set{Node}
result = Set{Node}()
push!(result, node)
for child in node.children
union!(result, child.parents)
for child in children(node)
union!(result, parents(child))
end
return result
@ -78,8 +96,8 @@ Alternative version to [`partners(node::Node)`](@ref), avoiding allocation of a
"""
function partners(node::Node, set::Set{Node})
push!(set, node)
for child in node.children
union!(set, child.parents)
for child in children(node)
union!(set, parents(child))
end
return nothing
end
@ -89,8 +107,8 @@ end
Return whether the `potential_parent` is a parent of `node`.
"""
function is_parent(potential_parent::Node, node::Node)
return potential_parent in node.parents
function is_parent(potential_parent::Node, node::Node)::Bool
return potential_parent in parents(node)
end
"""
@ -98,6 +116,6 @@ end
Return whether the `potential_child` is a child of `node`.
"""
function is_child(potential_child::Node, node::Node)
return potential_child in node.children
function is_child(potential_child::Node, node::Node)::Bool
return potential_child in children(node)
end

View File

@ -33,8 +33,8 @@ Any node that transfers data and does no computation.
`.nodeFusion`: Either this node's [`NodeFusion`](@ref) or `missing`, if none. There can only be at most one for DataTaskNodes.\\
`.name`: The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\
"""
mutable struct DataTaskNode <: Node
task::AbstractDataTask
mutable struct DataTaskNode{TaskType <: AbstractDataTask} <: Node
task::TaskType
# use vectors as sets have way too much memory overhead
parents::Vector{Node}
@ -73,8 +73,8 @@ Any node that computes a result from inputs using an [`AbstractComputeTask`](@re
`.nodeFusions`: A vector of this node's [`NodeFusion`](@ref)s. For a `ComputeTaskNode` there can be any number of these, unlike the [`DataTaskNode`](@ref)s.\\
`.device`: The Device this node has been scheduled on by a [`Scheduler`](@ref).
"""
mutable struct ComputeTaskNode <: Node
task::AbstractComputeTask
mutable struct ComputeTaskNode{TaskType <: AbstractComputeTask} <: Node
task::TaskType
parents::Vector{Node}
children::Vector{Node}
id::Base.UUID
@ -83,7 +83,7 @@ mutable struct ComputeTaskNode <: Node
nodeSplit::Union{Operation, Missing}
# for ComputeTasks there can be multiple fusions, unlike the DataTasks
nodeFusions::Vector{Operation}
nodeFusions::Vector{<:Operation}
# the device this node is assigned to execute on
device::Union{AbstractDevice, Missing}

View File

@ -29,7 +29,7 @@ function is_valid_node(graph::DAG, node::Node)
@assert is_valid(graph, node.nodeSplit)
end=#
if !(typeof(node.task) <: FusedComputeTask)
if !(typeof(task(node)) <: FusedComputeTask)
# the remaining checks are only necessary for fused compute tasks
return true
end
@ -37,7 +37,7 @@ function is_valid_node(graph::DAG, node::Node)
# every child must be in some input of the task
for child in node.children
str = Symbol(to_var_name(child.id))
@assert (str in node.task.t1_inputs) || (str in node.task.t2_inputs) "$str was not in any of the tasks' inputs\nt1_inputs: $(node.task.t1_inputs)\nt2_inputs: $(node.task.t2_inputs)"
@assert (str in task(node).t1_inputs) || (str in task(node).t2_inputs) "$str was not in any of the tasks' inputs\nt1_inputs: $(task(node).t1_inputs)\nt2_inputs: $(task(node).t2_inputs)"
end
return true

View File

@ -132,11 +132,11 @@ function revert_diff!(graph::DAG, diff::Diff)
insert_edge!(graph, edge.edge[1], edge.edge[2], track = false)
end
for (node, task) in diff.updatedChildren
for (node, t) in diff.updatedChildren
# node must be fused compute task at this point
@assert typeof(node.task) <: FusedComputeTask
@assert typeof(task(node)) <: FusedComputeTask
node.task = task
node.task = t
end
graph.properties -= GraphProperties(diff)
@ -158,11 +158,11 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com
get_snapshot_diff(graph)
# save children and parents
n1Children = children(n1)
n3Parents = parents(n3)
n1Children = copy(children(n1))
n3Parents = copy(parents(n3))
n1Task = copy(n1.task)
n3Task = copy(n3.task)
n1Task = copy(task(n1))
n3Task = copy(task(n3))
# assemble the input node vectors of n1 and n3 to save into the FusedComputeTask
n1Inputs = Vector{Symbol}()
@ -177,7 +177,7 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com
remove_node!(graph, n2)
# get n3's children now so it automatically excludes n2
n3Children = children(n3)
n3Children = copy(children(n3))
n3Inputs = Vector{Symbol}()
for child in n3Children
@ -228,7 +228,7 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
get_snapshot_diff(graph)
n1 = nodes[1]
n1Children = children(n1)
n1Children = copy(children(n1))
n1Parents = Set(n1.parents)
@ -245,7 +245,7 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
remove_edge!(graph, child, n)
end
for parent in parents(n)
for parent in copy(parents(n))
remove_edge!(graph, n, parent)
# collect all parents
@ -278,14 +278,17 @@ Split the given node into one node per parent, return the applied difference to
For details see [`NodeSplit`](@ref).
"""
function node_split!(graph::DAG, n1::Node)
function node_split!(
graph::DAG,
n1::Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}},
) where {TaskType <: AbstractTask}
@assert is_valid_node_split_input(graph, n1)
# clear snapshot
get_snapshot_diff(graph)
n1Parents = parents(n1)
n1Children = children(n1)
n1Parents = copy(parents(n1))
n1Children = copy(children(n1))
for parent in n1Parents
remove_edge!(graph, n1, parent)

View File

@ -13,18 +13,18 @@ function find_fusions!(graph::DAG, node::DataTaskNode)
return nothing
end
if length(node.parents) != 1 || length(node.children) != 1
if length(parents(node)) != 1 || length(children(node)) != 1
return nothing
end
child_node = first(node.children)
parent_node = first(node.parents)
child_node = first(children(node))
parent_node = first(parents(node))
if !(child_node in graph) || !(parent_node in graph)
error("Parents/Children that are not in the graph!!!")
end
if length(child_node.parents) != 1
if length(parents(child_node)) != 1
return nothing
end
@ -44,11 +44,11 @@ Find node fusions involving the given compute node. The function pushes the foun
"""
function find_fusions!(graph::DAG, node::ComputeTaskNode)
# just find fusions in neighbouring DataTaskNodes
for child in node.children
for child in children(node)
find_fusions!(graph, child)
end
for parent in node.parents
for parent in parents(node)
find_fusions!(graph, parent)
end
@ -123,7 +123,10 @@ end
Sort this node's parent and child sets, then find fusions, reductions and splits involving it. Needs to be called after the node was changed in some way.
"""
function clean_node!(graph::DAG, node::Node)
function clean_node!(
graph::DAG,
node::Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}},
) where {TaskType <: AbstractTask}
sort_node!(node)
find_fusions!(graph, node)

View File

@ -203,18 +203,18 @@ function generate_operations(graph::DAG)
# --- find possible node fusions ---
@threads for node in nodeArray
if (typeof(node) <: DataTaskNode)
if length(node.parents) != 1
if length(parents(node)) != 1
# data node can only have a single parent
continue
end
parent_node = first(node.parents)
parent_node = first(parents(node))
if length(node.children) != 1
if length(children(node)) != 1
# this node is an entry node or has multiple children which should not be possible
continue
end
child_node = first(node.children)
if (length(child_node.parents) != 1)
child_node = first(children(node))
if (length(parents(child_node)) != 1)
continue
end

View File

@ -14,9 +14,7 @@ function get_operations(graph::DAG)
generate_operations(graph)
end
for node in graph.dirtyNodes
clean_node!(graph, node)
end
clean_node!.(Ref(graph), graph.dirtyNodes)
empty!(graph.dirtyNodes)
return graph.possibleOperations

View File

@ -30,7 +30,7 @@ function show(io::IO, op::NodeReduction)
print(io, "NR: ")
print(io, length(op.input))
print(io, "x")
return print(io, op.input[1].task)
return print(io, task(op.input[1]))
end
"""
@ -40,7 +40,7 @@ Print a string representation of the node split to io.
"""
function show(io::IO, op::NodeSplit)
print(io, "NS: ")
return print(io, op.input.task)
return print(io, task(op.input))
end
"""
@ -50,9 +50,9 @@ Print a string representation of the node fusion to io.
"""
function show(io::IO, op::NodeFusion)
print(io, "NF: ")
print(io, op.input[1].task)
print(io, task(op.input[1]))
print(io, "->")
print(io, op.input[2].task)
print(io, task(op.input[2]))
print(io, "->")
return print(io, op.input[3].task)
return print(io, task(op.input[3]))
end

View File

@ -40,8 +40,9 @@ A chain of (n1, n2, n3) can be fused if:
See also: [`can_fuse`](@ref)
"""
struct NodeFusion <: Operation
input::Tuple{ComputeTaskNode, DataTaskNode, ComputeTaskNode}
struct NodeFusion{TaskType1 <: AbstractComputeTask, TaskType2 <: AbstractDataTask, TaskType3 <: AbstractComputeTask} <:
Operation
input::Tuple{ComputeTaskNode{TaskType1}, DataTaskNode{TaskType2}, ComputeTaskNode{TaskType3}}
end
"""
@ -49,8 +50,12 @@ end
The applied version of the [`NodeFusion`](@ref).
"""
struct AppliedNodeFusion <: AppliedOperation
operation::NodeFusion
struct AppliedNodeFusion{
TaskType1 <: AbstractComputeTask,
TaskType2 <: AbstractDataTask,
TaskType3 <: AbstractComputeTask,
} <: AppliedOperation
operation::NodeFusion{TaskType1, TaskType2, TaskType3}
diff::Diff
end
@ -73,8 +78,8 @@ A vector of nodes can be reduced if:
See also: [`can_reduce`](@ref)
"""
struct NodeReduction <: Operation
input::Vector{Node}
struct NodeReduction{NodeType <: Node} <: Operation
input::Vector{NodeType}
end
"""
@ -82,8 +87,8 @@ end
The applied version of the [`NodeReduction`](@ref).
"""
struct AppliedNodeReduction <: AppliedOperation
operation::NodeReduction
struct AppliedNodeReduction{NodeType <: Node} <: AppliedOperation
operation::NodeReduction{NodeType}
diff::Diff
end
@ -102,8 +107,8 @@ A node can be split if:
See also: [`can_split`](@ref)
"""
struct NodeSplit <: Operation
input::Node
struct NodeSplit{NodeType <: Node} <: Operation
input::NodeType
end
"""
@ -111,7 +116,7 @@ end
The applied version of the [`NodeSplit`](@ref).
"""
struct AppliedNodeSplit <: AppliedOperation
operation::NodeSplit
struct AppliedNodeSplit{NodeType <: Node} <: AppliedOperation
operation::NodeSplit{NodeType}
diff::Diff
end

View File

@ -61,7 +61,7 @@ function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
return false
end
if length(n2.parents) != 1 || length(n2.children) != 1 || length(n1.parents) != 1
if length(parents(n2)) != 1 || length(children(n2)) != 1 || length(parents(n1)) != 1
return false
end
@ -74,12 +74,15 @@ end
Return whether the given two nodes can be reduced. See [`NodeReduction`](@ref) for the requirements.
"""
function can_reduce(n1::Node, n2::Node)
if (n1.task != n2.task)
return false
end
return false
end
n1_length = length(n1.children)
n2_length = length(n2.children)
function can_reduce(
n1::NodeType,
n2::NodeType,
) where {TaskType <: AbstractTask, NodeType <: Union{DataTaskNode{TaskType}, ComputeTaskNode{TaskType}}}
n1_length = length(children(n1))
n2_length = length(children(n2))
if (n1_length != n2_length)
return false
@ -88,19 +91,19 @@ function can_reduce(n1::Node, n2::Node)
# this seems to be the most common case so do this first
# doing it manually is a lot faster than using the sets for a general solution
if (n1_length == 2)
if (n1.children[1] != n2.children[1])
if (n1.children[1] != n2.children[2])
if (children(n1)[1] != children(n2)[1])
if (children(n1)[1] != children(n2)[2])
return false
end
# 1_1 == 2_2
if (n1.children[2] != n2.children[1])
if (children(n1)[2] != children(n2)[1])
return false
end
return true
end
# 1_1 == 2_1
if (n1.children[2] != n2.children[2])
if (children(n1)[2] != children(n2)[2])
return false
end
return true
@ -108,11 +111,11 @@ function can_reduce(n1::Node, n2::Node)
# this is simple
if (n1_length == 1)
return n1.children[1] == n2.children[1]
return children(n1)[1] == children(n2)[1]
end
# this takes a long time
return Set(n1.children) == Set(n2.children)
return Set(children(n1)) == Set(children(n2))
end
"""

View File

@ -54,9 +54,9 @@ function is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node})
@assert is_valid(graph, n)
end
t = typeof(nodes[1].task)
t = typeof(task(nodes[1]))
for n in nodes
if typeof(n.task) != t
if typeof(task(n)) != t
throw(AssertionError("[Node Reduction] The given nodes are not of the same type"))
end

View File

@ -20,8 +20,8 @@ function GraphProperties(graph::DAG)
ce = 0.0
ed = 0
for node in graph.nodes
d += data(node.task) * length(node.parents)
ce += compute_effort(node.task)
d += data(task(node)) * length(node.parents)
ce += compute_effort(task(node))
ed += length(node.parents)
end
@ -43,12 +43,12 @@ For reverting a diff, it's `get_properties(graph) - GraphProperties(diff)`.
"""
function GraphProperties(diff::Diff)
ce =
reduce(+, compute_effort(n.task) for n in diff.addedNodes; init = 0.0) -
reduce(+, compute_effort(n.task) for n in diff.removedNodes; init = 0.0)
reduce(+, compute_effort(task(n)) for n in diff.addedNodes; init = 0.0) -
reduce(+, compute_effort(task(n)) for n in diff.removedNodes; init = 0.0)
d =
reduce(+, data(n.task) for n in diff.addedNodes; init = 0.0) -
reduce(+, data(n.task) for n in diff.removedNodes; init = 0.0)
reduce(+, data(task(n)) for n in diff.addedNodes; init = 0.0) -
reduce(+, data(task(n)) for n in diff.removedNodes; init = 0.0)
return (
data = d,

View File

@ -32,14 +32,14 @@ function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine)
if (isa(node, ComputeTaskNode))
lowestDevice = peek(deviceAccCost)[1]
node.device = lowestDevice
deviceAccCost[lowestDevice] = compute_effort(node.task)
deviceAccCost[lowestDevice] = compute_effort(task(node))
end
push!(schedule, node)
for parent in node.parents
for parent in parents(node)
# reduce the priority of all parents by one
if (!haskey(nodeQueue, parent))
enqueue!(nodeQueue, parent => length(parent.children) - 1)
enqueue!(nodeQueue, parent => length(children(parent)) - 1)
else
nodeQueue[parent] = nodeQueue[parent] - 1
end

View File

@ -41,16 +41,16 @@ end
Generate and return code for a given [`ComputeTaskNode`](@ref).
"""
function get_expression(node::ComputeTaskNode)
@assert length(node.children) <= children(node.task) "Node $(node) has too many children for its task: node has $(length(node.children)) versus task has $(children(node.task))\nNode's children: $(getfield.(node.children, :children))"
@assert length(children(node)) <= children(task(node)) "Node $(node) has too many children for its task: node has $(length(node.children)) versus task has $(children(task(node)))\nNode's children: $(getfield.(node.children, :children))"
@assert !ismissing(node.device) "Trying to get expression for an unscheduled ComputeTaskNode\nNode: $(node)"
inExprs = Vector()
for id in getfield.(node.children, :id)
for id in getfield.(children(node), :id)
push!(inExprs, gen_access_expr(node.device, Symbol(to_var_name(id))))
end
outExpr = gen_access_expr(node.device, Symbol(to_var_name(node.id)))
return get_expression(node.task, node.device, inExprs, outExpr)
return get_expression(task(node), node.device, inExprs, outExpr)
end
"""
@ -59,11 +59,11 @@ end
Generate and return code for a given [`DataTaskNode`](@ref).
"""
function get_expression(node::DataTaskNode)
@assert length(node.children) == 1 "Trying to call get_expression on a data task node that has $(length(node.children)) children instead of 1"
@assert length(children(node)) == 1 "Trying to call get_expression on a data task node that has $(length(node.children)) children instead of 1"
# TODO: dispatch to device implementations generating the copy commands
child = node.children[1]
child = children(node)[1]
inExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(child.id))))
outExpr = eval(gen_access_expr(child.device, Symbol(to_var_name(node.id))))
dataTransportExp = Meta.parse("$outExpr = $inExpr")
@ -79,7 +79,7 @@ Generate and return code for the initial input reading expression for [`DataTask
See also: [`get_entry_nodes`](@ref)
"""
function get_init_expression(node::DataTaskNode, device::AbstractDevice)
@assert isempty(node.children) "Trying to call get_init_expression on a data task node that is not an entry node."
@assert isempty(children(node)) "Trying to call get_init_expression on a data task node that is not an entry node."
inExpr = eval(gen_access_expr(device, Symbol("$(to_var_name(node.id))_in")))
outExpr = eval(gen_access_expr(device, Symbol(to_var_name(node.id))))

View File

@ -30,7 +30,7 @@ compute(t::AbstractDataTask; data...) = data
Fallback implementation of the compute effort of a task, throwing an error.
"""
function compute_effort(t::AbstractTask)
function compute_effort(t::AbstractTask)::Float64
# default implementation using compute
return error("Need to implement compute_effort()")
end
@ -40,7 +40,7 @@ end
Fallback implementation of the data of a task, throwing an error.
"""
function data(t::AbstractTask)
function data(t::AbstractTask)::Float64
return error("Need to implement data()")
end
@ -49,28 +49,28 @@ end
Return the compute effort of a data task, always zero, regardless of the specific task.
"""
compute_effort(t::AbstractDataTask) = 0.0
compute_effort(t::AbstractDataTask)::Float64 = 0.0
"""
data(t::AbstractDataTask)
Return the data of a data task. Given by the task's `.data` field.
"""
data(t::AbstractDataTask) = getfield(t, :data)
data(t::AbstractDataTask)::Float64 = getfield(t, :data)
"""
data(t::AbstractComputeTask)
Return the data of a compute task, always zero, regardless of the specific task.
"""
data(t::AbstractComputeTask) = 0.0
data(t::AbstractComputeTask)::Float64 = 0.0
"""
compute_effort(t::FusedComputeTask)
Return the compute effort of a fused compute task.
"""
function compute_effort(t::FusedComputeTask)
function compute_effort(t::FusedComputeTask)::Float64
return compute_effort(t.first_task) + compute_effort(t.second_task)
end

View File

@ -45,13 +45,13 @@ end
Insert the given node into the trie. The depth is used to iterate through the trie layers, while the function calls itself recursively until it ran through all children of the node.
"""
function insert_helper!(trie::NodeIdTrie, node::Node, depth::Int)
if (length(node.children) == depth)
if (length(children(node)) == depth)
push!(trie.value, node)
return nothing
end
depth = depth + 1
id = node.children[depth].id
id = children(node)[depth].id
if (!haskey(trie.children, id))
trie.children[id] = NodeIdTrie()
@ -65,11 +65,11 @@ end
Insert the given node into the trie. It's sorted by its type in the first layer, then by its children in the following layers.
"""
function insert!(trie::NodeTrie, node::Node)
t = typeof(node.task)
t = typeof(task(node))
if (!haskey(trie.children, t))
trie.children[t] = NodeIdTrie()
end
return insert_helper!(trie.children[typeof(node.task)], node, 0)
return insert_helper!(trie.children[typeof(task(node))], node, 0)
end
"""

View File

@ -36,8 +36,8 @@ Sort the nodes' parents and children vectors. The vectors are mostly very short
Sorted nodes are required to make the finding of [`NodeReduction`](@ref)s a lot faster using the [`NodeTrie`](@ref) data structure.
"""
function sort_node!(node::Node)
sort!(node.children, lt = lt_nodes)
return sort!(node.parents, lt = lt_nodes)
sort!(children(node), lt = lt_nodes)
return sort!(parents(node), lt = lt_nodes)
end
"""