Make FusedComputeTasks usable in execution

This commit is contained in:
2023-09-25 08:39:59 +02:00
committed by Anton Reinhard
parent f8a591991c
commit c428613c80
16 changed files with 548 additions and 283 deletions

View File

@@ -4,8 +4,14 @@
A named tuple representing a difference of added and removed nodes and edges on a [`DAG`](@ref).
"""
const Diff = NamedTuple{
(:addedNodes, :removedNodes, :addedEdges, :removedEdges),
Tuple{Vector{Node}, Vector{Node}, Vector{Edge}, Vector{Edge}},
(:addedNodes, :removedNodes, :addedEdges, :removedEdges, :updatedChildren),
Tuple{
Vector{Node},
Vector{Node},
Vector{Edge},
Vector{Edge},
Vector{Tuple{Node, String, String}},
},
}
function Diff()
@@ -14,5 +20,8 @@ function Diff()
removedNodes = Vector{Node}(),
addedEdges = Vector{Edge}(),
removedEdges = Vector{Edge}(),
# children were updated from updatedChildren[2] to updatedChildren[3] in node updatedChildren[1]
updatedChildren = Vector{Tuple{Node, String, String}}(),
)::Diff
end

View File

@@ -17,7 +17,7 @@ See also: [`remove_node!`](@ref), [`insert_edge!`](@ref), [`remove_edge!`](@ref)
"""
function insert_node!(
graph::DAG,
node::Node,
node::Node;
track = true,
invalidate_cache = true,
)
@@ -53,7 +53,7 @@ See also: [`insert_node!`](@ref), [`remove_node!`](@ref), [`remove_edge!`](@ref)
function insert_edge!(
graph::DAG,
node1::Node,
node2::Node,
node2::Node;
track = true,
invalidate_cache = true,
)
@@ -97,7 +97,7 @@ See also: [`insert_node!`](@ref), [`insert_edge!`](@ref), [`remove_edge!`](@ref)
"""
function remove_node!(
graph::DAG,
node::Node,
node::Node;
track = true,
invalidate_cache = true,
)
@@ -137,7 +137,7 @@ See also: [`insert_node!`](@ref), [`remove_node!`](@ref), [`insert_edge!`](@ref)
function remove_edge!(
graph::DAG,
node1::Node,
node2::Node,
node2::Node;
track = true,
invalidate_cache = true,
)
@@ -181,6 +181,27 @@ function remove_edge!(
return nothing
end
function update_child!(
graph::DAG,
n::Node,
child_before::String,
child_after::String;
track = true,
)
# only need to update fused compute tasks
if !(typeof(n.task) <: FusedComputeTask)
return nothing
end
replace!(n.task.t1_inputs, child_before => child_after)
replace!(n.task.t2_inputs, child_before => child_after)
# keep track
if (track)
push!(graph.diff.updatedChildren, (n, child_before, child_after))
end
end
"""
get_snapshot_diff(graph::DAG)

View File

@@ -80,125 +80,113 @@ function compute(t::FusedComputeTask, data)
end
"""
get_expression(::ComputeTaskP, inSymbol::Symbol, outSymbol::Symbol)
get_expression(::ComputeTaskP, inExprs::Vector{String}, outExpr::String)
Generate and return code evaluating [`ComputeTaskP`](@ref) on `inSymbol`, providing the output on `outSymbol`.
Generate and return code evaluating [`ComputeTaskP`](@ref) on `inExpr`, providing the output on `outExpr`.
"""
function get_expression(::ComputeTaskP, inSymbol::Symbol, outSymbol::Symbol)
return Meta.parse("$outSymbol = compute(ComputeTaskP(), $inSymbol)")
function get_expression(
::ComputeTaskP,
inExprs::Vector{String},
outExpr::String,
)
return Meta.parse("$outExpr = compute(ComputeTaskP(), $(inExprs[1]))")
end
"""
get_expression(::ComputeTaskU, inSymbol::Symbol, outSymbol::Symbol)
get_expression(::ComputeTaskU, inExprs::Vector{String}, outExpr::String)
Generate code evaluating [`ComputeTaskU`](@ref) on `inSymbol`, providing the output on `outSymbol`.
`inSymbol` should be of type [`ParticleValue`](@ref), `outSymbol` will be of type [`ParticleValue`](@ref).
Generate code evaluating [`ComputeTaskU`](@ref) on `inExpr`, providing the output on `outExpr`.
`inExpr` should be of type [`ParticleValue`](@ref), `outExpr` will be of type [`ParticleValue`](@ref).
"""
function get_expression(::ComputeTaskU, inSymbol::Symbol, outSymbol::Symbol)
return Meta.parse("$outSymbol = compute(ComputeTaskU(), $inSymbol)")
function get_expression(
::ComputeTaskU,
inExprs::Vector{String},
outExpr::String,
)
return Meta.parse("$outExpr = compute(ComputeTaskU(), $(inExprs[1]))")
end
"""
get_expression(::ComputeTaskV, inSymbol1::Symbol, inSymbol2::Symbol, outSymbol::Symbol)
get_expression(::ComputeTaskV, inExprs::Vector{String}, outExpr::String)
Generate code evaluating [`ComputeTaskV`](@ref) on `inSymbol1` and `inSymbol2`, providing the output on `outSymbol`.
`inSymbol1` and `inSymbol2` should be of type [`ParticleValue`](@ref), `outSymbol` will be of type [`ParticleValue`](@ref).
Generate code evaluating [`ComputeTaskV`](@ref) on `inExpr1` and `inExpr2`, providing the output on `outExpr`.
`inExpr1` and `inExpr2` should be of type [`ParticleValue`](@ref), `outExpr` will be of type [`ParticleValue`](@ref).
"""
function get_expression(
::ComputeTaskV,
inSymbol1::Symbol,
inSymbol2::Symbol,
outSymbol::Symbol,
inExprs::Vector{String},
outExpr::String,
)
return Meta.parse(
"$outSymbol = compute(ComputeTaskV(), $inSymbol1, $inSymbol2)",
"$outExpr = compute(ComputeTaskV(), $(inExprs[1]), $(inExprs[2]))",
)
end
"""
get_expression(::ComputeTaskS2, inSymbol1::Symbol, inSymbol2::Symbol, outSymbol::Symbol)
get_expression(::ComputeTaskS2, inExprs::Vector{String}, outExpr::String)
Generate code evaluating [`ComputeTaskS2`](@ref) on `inSymbol1` and `inSymbol2`, providing the output on `outSymbol`.
`inSymbol1` and `inSymbol2` should be of type [`ParticleValue`](@ref), `outSymbol` will be of type `Float64`.
Generate code evaluating [`ComputeTaskS2`](@ref) on `inExpr1` and `inExpr2`, providing the output on `outExpr`.
`inExpr1` and `inExpr2` should be of type [`ParticleValue`](@ref), `outExpr` will be of type `Float64`.
"""
function get_expression(
::ComputeTaskS2,
inSymbol1::Symbol,
inSymbol2::Symbol,
outSymbol::Symbol,
inExprs::Vector{String},
outExpr::String,
)
return Meta.parse(
"$outSymbol = compute(ComputeTaskS2(), $inSymbol1, $inSymbol2)",
"$outExpr = compute(ComputeTaskS2(), $(inExprs[1]), $(inExprs[2]))",
)
end
"""
get_expression(::ComputeTaskS1, inSymbol::Symbol, outSymbol::Symbol)
get_expression(::ComputeTaskS1, inExprs::Vector{String}, outExpr::String)
Generate code evaluating [`ComputeTaskS1`](@ref) on `inSymbol`, providing the output on `outSymbol`.
`inSymbol` should be of type [`ParticleValue`](@ref), `outSymbol` will be of type [`ParticleValue`](@ref).
Generate code evaluating [`ComputeTaskS1`](@ref) on `inExpr`, providing the output on `outExpr`.
`inExpr` should be of type [`ParticleValue`](@ref), `outExpr` will be of type [`ParticleValue`](@ref).
"""
function get_expression(::ComputeTaskS1, inSymbol::Symbol, outSymbol::Symbol)
return Meta.parse("$outSymbol = compute(ComputeTaskS1(), $inSymbol)")
function get_expression(
::ComputeTaskS1,
inExprs::Vector{String},
outExpr::String,
)
return Meta.parse("$outExpr = compute(ComputeTaskS1(), $(inExprs[1]))")
end
"""
get_expression(::ComputeTaskSum, inSymbols::Vector{Symbol}, outSymbol::Symbol)
get_expression(::ComputeTaskSum, inExprs::Vector{String}, outExpr::String)
Generate code evaluating [`ComputeTaskSum`](@ref) on `inSymbols`, providing the output on `outSymbol`.
`inSymbols` should be of type [`Float64`], `outSymbol` will be of type [`Float64`].
Generate code evaluating [`ComputeTaskSum`](@ref) on `inExprs`, providing the output on `outExpr`.
`inExprs` should be of type [`Float64`], `outExpr` will be of type [`Float64`].
"""
function get_expression(
::ComputeTaskSum,
inSymbols::Vector{Symbol},
outSymbol::Symbol,
inExprs::Vector{String},
outExpr::String,
)
return quote
$outSymbol = compute(ComputeTaskSum(), [$(inSymbols...)])
end
return Meta.parse(
"$outExpr = compute(ComputeTaskSum(), [$(unroll_string_vector(inExprs))])",
)
end
"""
get_expression(t::FusedComputeTask, inSymbols::Vector{Symbol}, outSymbol::Symbol)
get_expression(t::FusedComputeTask, inExprs::Vector{String}, outExpr::String)
Generate code evaluating a [`FusedComputeTask`](@ref) on `inSymbols`, providing the output on `outSymbol`.
`inSymbols` should be of the correct types and may be heterogeneous. `outSymbol` will be of the type of the output of `T2` of t.
Generate code evaluating a [`FusedComputeTask`](@ref) on `inExprs`, providing the output on `outExpr`.
`inExprs` should be of the correct types and may be heterogeneous. `outExpr` will be of the type of the output of `T2` of t.
"""
function get_expression(
t::FusedComputeTask,
inSymbols::Vector{Symbol},
outSymbol::Symbol,
inExprs::Vector{String},
outExpr::String,
)
(T1, T2) = get_types(t)
c1 = children(T1())
c2 = children(T2())
c1 = length(t.t1_inputs)
c2 = length(t.t2_inputs) + 1
expr1 = nothing
expr2 = nothing
# TODO need to figure out how to know which inputs belong to which subtask
# since we order the vectors with the child nodes we can't just split
if (c1 == 1)
expr1 = get_expression(T1(), inSymbols[begin], :intermediate)
elseif (c1 == 2)
expr1 =
get_expression(T1(), inSymbols[begin], inSymbols[2], :intermediate)
else
expr1 = get_expression(T1(), inSymbols[begin:c1], :intermediate)
end
if (c2 == 1)
expr2 = get_expression(T2(), :intermediate, outSymbol)
elseif c2 == 2
expr2 =
get_expression(T2(), :intermediate, inSymbols[c1 + 1], outSymbol)
else
expr2 = get_expression(
T2(),
:intermediate * inSymbols[(c1 + 1):end],
outSymbol,
)
end
expr1 = get_expression(t.first_task, t.t1_inputs, t.t1_output)
expr2 =
get_expression(t.second_task, [t.t2_inputs..., t.t1_output], outExpr)
return Expr(:block, expr1, expr2)
end
@@ -210,24 +198,27 @@ Generate and return code for a given [`ComputeTaskNode`](@ref).
"""
function get_expression(node::ComputeTaskNode)
t = typeof(node.task)
@assert length(node.children) == children(node.task) || t <: ComputeTaskSum
@assert length(node.children) == children(node.task) || t <: ComputeTaskSum "Node $(node) has inconsistent number of children"
if (t <: ComputeTaskU || t <: ComputeTaskP || t <: ComputeTaskS1) # single input
symbolIn = Symbol("data_$(to_var_name(node.children[1].id))")
symbolOut = Symbol("data_$(to_var_name(node.id))")
return get_expression(t(), symbolIn, symbolOut)
symbolIn = "data_$(to_var_name(node.children[1].id))"
symbolOut = "data_$(to_var_name(node.id))"
return get_expression(node.task, [symbolIn], symbolOut)
elseif (t <: ComputeTaskS2 || t <: ComputeTaskV) # double input
symbolIn1 = Symbol("data_$(to_var_name(node.children[1].id))")
symbolIn2 = Symbol("data_$(to_var_name(node.children[2].id))")
symbolOut = Symbol("data_$(to_var_name(node.id))")
return get_expression(t(), symbolIn1, symbolIn2, symbolOut)
elseif (t <: ComputeTaskSum || t <: FusedComputeTask) # vector input
inSymbols = Vector{Symbol}()
symbolIn1 = "data_$(to_var_name(node.children[1].id))"
symbolIn2 = "data_$(to_var_name(node.children[2].id))"
symbolOut = "data_$(to_var_name(node.id))"
return get_expression(node.task, [symbolIn1, symbolIn2], symbolOut)
elseif (t <: ComputeTaskSum) # vector input
inExprs = Vector{String}()
for child in node.children
push!(inSymbols, Symbol("data_$(to_var_name(child.id))"))
push!(inExprs, "data_$(to_var_name(child.id))")
end
outSymbol = Symbol("data_$(to_var_name(node.id))")
return get_expression(t(), inSymbols, outSymbol)
outExpr = "data_$(to_var_name(node.id))"
return get_expression(node.task, inExprs, outExpr)
elseif t <: FusedComputeTask # fused compute task knows its inputs
outExpr = "data_$(to_var_name(node.id))"
return get_expression(node.task, Vector{String}(), outExpr)
else
error("Unknown compute task")
end
@@ -242,15 +233,15 @@ function get_expression(node::DataTaskNode)
# TODO: do things to transport data from/to gpu, between numa nodes, etc.
@assert length(node.children) <= 1
inSymbol = nothing
inExpr = nothing
if (length(node.children) == 1)
inSymbol = Symbol("data_$(to_var_name(node.children[1].id))")
inExpr = "data_$(to_var_name(node.children[1].id))"
else
inSymbol = Symbol("data_$(to_var_name(node.id))_in")
inExpr = "data_$(to_var_name(node.id))_in"
end
outSymbol = Symbol("data_$(to_var_name(node.id))")
outExpr = "data_$(to_var_name(node.id))"
dataTransportExp = Meta.parse("$outSymbol = $inSymbol")
dataTransportExp = Meta.parse("$outExpr = $inExpr")
return dataTransportExp
end

View File

@@ -63,10 +63,25 @@ function parse_abc(filename::String, verbose::Bool = false)
end
sizehint!(graph.nodes, estimate_no_nodes)
sum_node = insert_node!(graph, make_node(ComputeTaskSum()), false, false)
global_data_out =
insert_node!(graph, make_node(DataTask(FLOAT_SIZE)), false, false)
insert_edge!(graph, sum_node, global_data_out, false, false)
sum_node = insert_node!(
graph,
make_node(ComputeTaskSum()),
track = false,
invalidate_cache = false,
)
global_data_out = insert_node!(
graph,
make_node(DataTask(FLOAT_SIZE)),
track = false,
invalidate_cache = false,
)
insert_edge!(
graph,
sum_node,
global_data_out,
track = false,
invalidate_cache = false,
)
# remember the data out nodes for connection
dataOutNodes = Dict()
@@ -93,30 +108,62 @@ function parse_abc(filename::String, verbose::Bool = false)
data_in = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE), string(node)),
false,
false,
track = false,
invalidate_cache = false,
) # read particle data node
compute_P =
insert_node!(graph, make_node(ComputeTaskP()), false, false) # compute P node
compute_P = insert_node!(
graph,
make_node(ComputeTaskP()),
track = false,
invalidate_cache = false,
) # compute P node
data_Pu = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE)),
false,
false,
track = false,
invalidate_cache = false,
) # transfer data from P to u (one ParticleValue object)
compute_u =
insert_node!(graph, make_node(ComputeTaskU()), false, false) # compute U node
compute_u = insert_node!(
graph,
make_node(ComputeTaskU()),
track = false,
invalidate_cache = false,
) # compute U node
data_out = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE)),
false,
false,
track = false,
invalidate_cache = false,
) # transfer data out from u (one ParticleValue object)
insert_edge!(graph, data_in, compute_P, false, false)
insert_edge!(graph, compute_P, data_Pu, false, false)
insert_edge!(graph, data_Pu, compute_u, false, false)
insert_edge!(graph, compute_u, data_out, false, false)
insert_edge!(
graph,
data_in,
compute_P,
track = false,
invalidate_cache = false,
)
insert_edge!(
graph,
compute_P,
data_Pu,
track = false,
invalidate_cache = false,
)
insert_edge!(
graph,
data_Pu,
compute_u,
track = false,
invalidate_cache = false,
)
insert_edge!(
graph,
compute_u,
data_out,
track = false,
invalidate_cache = false,
)
# remember the data_out node for future edges
dataOutNodes[node] = data_out
@@ -126,13 +173,17 @@ function parse_abc(filename::String, verbose::Bool = false)
in1 = capt.captures[1]
in2 = capt.captures[2]
compute_v =
insert_node!(graph, make_node(ComputeTaskV()), false, false)
compute_v = insert_node!(
graph,
make_node(ComputeTaskV()),
track = false,
invalidate_cache = false,
)
data_out = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE)),
false,
false,
track = false,
invalidate_cache = false,
)
if (occursin(regex_c, in1))
@@ -140,22 +191,46 @@ function parse_abc(filename::String, verbose::Bool = false)
compute_S = insert_node!(
graph,
make_node(ComputeTaskS1()),
false,
false,
track = false,
invalidate_cache = false,
)
data_S_v = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE)),
false,
false,
track = false,
invalidate_cache = false,
)
insert_edge!(graph, dataOutNodes[in1], compute_S, false, false)
insert_edge!(graph, compute_S, data_S_v, false, false)
insert_edge!(
graph,
dataOutNodes[in1],
compute_S,
track = false,
invalidate_cache = false,
)
insert_edge!(
graph,
compute_S,
data_S_v,
track = false,
invalidate_cache = false,
)
insert_edge!(graph, data_S_v, compute_v, false, false)
insert_edge!(
graph,
data_S_v,
compute_v,
track = false,
invalidate_cache = false,
)
else
insert_edge!(graph, dataOutNodes[in1], compute_v, false, false)
insert_edge!(
graph,
dataOutNodes[in1],
compute_v,
track = false,
invalidate_cache = false,
)
end
if (occursin(regex_c, in2))
@@ -164,25 +239,55 @@ function parse_abc(filename::String, verbose::Bool = false)
compute_S = insert_node!(
graph,
make_node(ComputeTaskS1()),
false,
false,
track = false,
invalidate_cache = false,
)
data_S_v = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE)),
false,
false,
track = false,
invalidate_cache = false,
)
insert_edge!(graph, dataOutNodes[in2], compute_S, false, false)
insert_edge!(graph, compute_S, data_S_v, false, false)
insert_edge!(
graph,
dataOutNodes[in2],
compute_S,
track = false,
invalidate_cache = false,
)
insert_edge!(
graph,
compute_S,
data_S_v,
track = false,
invalidate_cache = false,
)
insert_edge!(graph, data_S_v, compute_v, false, false)
insert_edge!(
graph,
data_S_v,
compute_v,
track = false,
invalidate_cache = false,
)
else
insert_edge!(graph, dataOutNodes[in2], compute_v, false, false)
insert_edge!(
graph,
dataOutNodes[in2],
compute_v,
track = false,
invalidate_cache = false,
)
end
insert_edge!(graph, compute_v, data_out, false, false)
insert_edge!(
graph,
compute_v,
data_out,
track = false,
invalidate_cache = false,
)
dataOutNodes[node] = data_out
elseif occursin(regex_m, node)
@@ -193,34 +298,84 @@ function parse_abc(filename::String, verbose::Bool = false)
in3 = capt.captures[3]
# in2 + in3 with a v
compute_v =
insert_node!(graph, make_node(ComputeTaskV()), false, false)
compute_v = insert_node!(
graph,
make_node(ComputeTaskV()),
track = false,
invalidate_cache = false,
)
data_v = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE)),
false,
false,
track = false,
invalidate_cache = false,
)
insert_edge!(graph, dataOutNodes[in2], compute_v, false, false)
insert_edge!(graph, dataOutNodes[in3], compute_v, false, false)
insert_edge!(graph, compute_v, data_v, false, false)
insert_edge!(
graph,
dataOutNodes[in2],
compute_v,
track = false,
invalidate_cache = false,
)
insert_edge!(
graph,
dataOutNodes[in3],
compute_v,
track = false,
invalidate_cache = false,
)
insert_edge!(
graph,
compute_v,
data_v,
track = false,
invalidate_cache = false,
)
# combine with the v of the combined other input
compute_S2 =
insert_node!(graph, make_node(ComputeTaskS2()), false, false)
compute_S2 = insert_node!(
graph,
make_node(ComputeTaskS2()),
track = false,
invalidate_cache = false,
)
data_out = insert_node!(
graph,
make_node(DataTask(FLOAT_SIZE)),
false,
false,
track = false,
invalidate_cache = false,
) # output of a S2 task is only a float
insert_edge!(graph, data_v, compute_S2, false, false)
insert_edge!(graph, dataOutNodes[in1], compute_S2, false, false)
insert_edge!(graph, compute_S2, data_out, false, false)
insert_edge!(
graph,
data_v,
compute_S2,
track = false,
invalidate_cache = false,
)
insert_edge!(
graph,
dataOutNodes[in1],
compute_S2,
track = false,
invalidate_cache = false,
)
insert_edge!(
graph,
compute_S2,
data_out,
track = false,
invalidate_cache = false,
)
insert_edge!(graph, data_out, sum_node, false, false)
insert_edge!(
graph,
data_out,
sum_node,
track = false,
invalidate_cache = false,
)
elseif occursin(regex_plus, node)
if (verbose)
println("\rReading Nodes Complete ")

View File

@@ -21,6 +21,7 @@ A struct describing a particle of the ABC-Model. It has the 4 momentum parts P0.
`sizeof(Particle())` = 40 Byte
"""
struct Particle
# SFourMomentum
P0::Float64
P1::Float64
P2::Float64

View File

@@ -157,9 +157,8 @@ children(::ComputeTaskSum) = -1
"""
children(t::FusedComputeTask)
Return the number of children of a FusedComputeTask. It's the sum of the children of both tasks minus one.
Return the number of children of a FusedComputeTask.
"""
function children(t::FusedComputeTask)
(T1, T2) = get_types(t)
return children(T1()) + children(T2()) - 1 # one of the inputs is the output of T1 and thus not a child of the node
return length(union(Set(t.t1_inputs), Set(t.t2_inputs)))
end

View File

@@ -124,17 +124,25 @@ function revert_diff!(graph::DAG, diff::Diff)
# add removed nodes, remove added nodes, same for edges
# note the order
for edge in diff.addedEdges
remove_edge!(graph, edge.edge[1], edge.edge[2], false)
remove_edge!(graph, edge.edge[1], edge.edge[2], track = false)
end
for node in diff.addedNodes
remove_node!(graph, node, false)
remove_node!(graph, node, track = false)
end
for node in diff.removedNodes
insert_node!(graph, node, false)
insert_node!(graph, node, track = false)
end
for edge in diff.removedEdges
insert_edge!(graph, edge.edge[1], edge.edge[2], false)
insert_edge!(graph, edge.edge[1], edge.edge[2], track = false)
end
for (node, before, after) in diff.updatedChildren
# node must be fused compute task at this point
@assert typeof(node.task) <: FusedComputeTask
replace!(node.task.t1_inputs, after => before)
replace!(node.task.t2_inputs, after => before)
end
graph.properties -= GraphProperties(diff)
@@ -175,9 +183,27 @@ function node_fusion!(
n3_children = children(n3)
remove_node!(graph, n3)
# assemble the input node vectors of n1 and n3 to save into the FusedComputeTask
n1_inputs = Vector{String}()
for child in n1_children
push!(n1_inputs, "data_$(to_var_name(child.id))")
end
n3_inputs = Vector{String}()
for child in n3_children
push!(n3_inputs, "data_$(to_var_name(child.id))")
end
# create new node with the fused compute task
new_node =
ComputeTaskNode(FusedComputeTask{typeof(n1.task), typeof(n3.task)}())
new_node = ComputeTaskNode(
FusedComputeTask(
n1.task,
n3.task,
n1_inputs,
"data_$(to_var_name(n2.id))",
n3_inputs,
),
)
insert_node!(graph, new_node)
for child in n1_children
@@ -195,6 +221,15 @@ function node_fusion!(
for parent in n3_parents
remove_edge!(graph, n3, parent)
insert_edge!(graph, new_node, parent)
# important! update the parent node's child names in case they are fused compute tasks
# needed for compute generation so the fused compute task can correctly match inputs to its component tasks
update_child!(
graph,
parent,
"data_$(to_var_name(n3.id))",
"data_$(to_var_name(new_node.id))",
)
end
return get_snapshot_diff(graph)
@@ -217,7 +252,9 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
n1_children = children(n1)
n1_parents = Set(n1.parents)
new_parents = Set{Node}()
# set of the new parents of n1, together with the names of the previous children that n1 now replaces
new_parents = Set{Tuple{Node, String}}()
# remove all of the nodes' parents and children and the nodes themselves (except for first node)
for i in 2:length(nodes)
@@ -230,7 +267,7 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
remove_edge!(graph, n, parent)
# collect all parents
push!(new_parents, parent)
push!(new_parents, (parent, "data_$(to_var_name(n.id))"))
end
remove_node!(graph, n)
@@ -238,9 +275,11 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
setdiff!(new_parents, n1_parents)
for parent in new_parents
for (parent, prev_child) in new_parents
# now add parents of all input nodes to n1 without duplicates
insert_edge!(graph, n1, parent)
update_child!(graph, parent, prev_child, "data_$(to_var_name(n1.id))")
end
return get_snapshot_diff(graph)
@@ -275,6 +314,13 @@ function node_split!(graph::DAG, n1::Node)
insert_node!(graph, n_copy)
insert_edge!(graph, n_copy, parent)
update_child!(
graph,
parent,
"data_$(to_var_name(n1.id))",
"data_$(to_var_name(n_copy.id))",
)
for child in n1_children
insert_edge!(graph, child, n_copy)
end

View File

@@ -12,3 +12,25 @@ copy(t::AbstractDataTask) =
Return a copy of the given compute task.
"""
copy(t::AbstractComputeTask) = typeof(t)()
"""
copy(t::FusedComputeTask)
Return a copy of th egiven [`FusedComputeTask`](@ref).
"""
function copy(t::FusedComputeTask{T1, T2}) where {T1, T2}
return FusedComputeTask{T1, T2}(
t.first_task,
t.second_task,
t.t1_inputs,
t.t1_output,
t.t2_inputs,
)
end
FusedComputeTask{T1, T2}(
t1_inputs::Vector{String},
t1_output::String,
t2_inputs::Vector{String},
) where {T1, T2} =
FusedComputeTask{T1, T2}(T1(), T2(), t1_inputs, t1_output, t2_inputs)

View File

@@ -71,8 +71,7 @@ data(t::AbstractComputeTask) = 0
Return the compute effort of a fused compute task.
"""
function compute_effort(t::FusedComputeTask)
(T1, T2) = collect(typeof(t).parameters)
return compute_effort(T1()) + compute_effort(T2())
return compute_effort(t.first_task) + compute_effort(t.second_task)
end
"""
@@ -81,30 +80,3 @@ end
Return a tuple of a the fused compute task's components' types.
"""
get_types(::FusedComputeTask{T1, T2}) where {T1, T2} = (T1, T2)
"""
get_expression(t::AbstractTask)
Return an expression evaluating the given task on the :dataIn symbol
"""
function get_expression(t::AbstractTask)
return quote
dataOut = compute($t, dataIn)
end
end
"""
get_expression()
"""
function get_expression(
t::FusedComputeTask,
inSymbol::Symbol,
outSymbol::Symbol,
)
#TODO
computeExp = quote
$outSymbol = compute($t, $inSymbol)
end
return computeExp
end

View File

@@ -27,4 +27,13 @@ A fused compute task made up of the computation of first `T1` and then `T2`.
Also see: [`get_types`](@ref).
"""
struct FusedComputeTask{T1 <: AbstractComputeTask, T2 <: AbstractComputeTask} <:
AbstractComputeTask end
AbstractComputeTask
first_task::T1
second_task::T2
# the names of the inputs for T1
t1_inputs::Vector{String}
# output name of T1
t1_output::String
# t2_inputs doesn't include the output of t1, that's implicit
t2_inputs::Vector{String}
end

View File

@@ -87,3 +87,19 @@ Return the memory footprint of the node in Byte. Used in [`mem(graph::DAG)`](@re
function mem(node::Node)
return Base.summarysize(node, exclude = Union{Node, Operation})
end
"""
unroll_string_vector(vec::Vector{String})
Return the given vector as single String without quotation marks or brackets.
"""
function unroll_string_vector(vec::Vector{String})
result = ""
for s in vec
if (result != "")
result *= ", "
end
result *= s
end
return result
end