Make FusedComputeTasks usable in execution

This commit is contained in:
Anton Reinhard 2023-09-25 08:39:59 +02:00 committed by Anton Reinhard
parent f8a591991c
commit c428613c80
16 changed files with 548 additions and 283 deletions

View File

@ -4,7 +4,6 @@ authors = ["Anton Reinhard <anton.reinhard@proton.me>"]
version = "0.1.0"
[deps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
AccurateArithmetic = "22286c92-06ac-501d-9306-4abd417d9753"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"

View File

@ -1,6 +1,6 @@
function test_random_walk(g::DAG, n::Int64)
# the purpose here is to do "random" operations and reverse them again and validate that the graph stays the same and doesn't diverge
function random_walk!(g::DAG, n::Int64)
# the purpose here is to do "random" operations on the graph to simulate an optimizer
reset_graph!(g)
properties = get_properties(g)
@ -32,7 +32,7 @@ function test_random_walk(g::DAG, n::Int64)
end
end
return reset_graph!(g)
return nothing
end
function reduce_all!(g::DAG)

View File

@ -4,8 +4,14 @@
A named tuple representing a difference of added and removed nodes and edges on a [`DAG`](@ref).
"""
const Diff = NamedTuple{
(:addedNodes, :removedNodes, :addedEdges, :removedEdges),
Tuple{Vector{Node}, Vector{Node}, Vector{Edge}, Vector{Edge}},
(:addedNodes, :removedNodes, :addedEdges, :removedEdges, :updatedChildren),
Tuple{
Vector{Node},
Vector{Node},
Vector{Edge},
Vector{Edge},
Vector{Tuple{Node, String, String}},
},
}
function Diff()
@ -14,5 +20,8 @@ function Diff()
removedNodes = Vector{Node}(),
addedEdges = Vector{Edge}(),
removedEdges = Vector{Edge}(),
# children were updated from updatedChildren[2] to updatedChildren[3] in node updatedChildren[1]
updatedChildren = Vector{Tuple{Node, String, String}}(),
)::Diff
end

View File

@ -17,7 +17,7 @@ See also: [`remove_node!`](@ref), [`insert_edge!`](@ref), [`remove_edge!`](@ref)
"""
function insert_node!(
graph::DAG,
node::Node,
node::Node;
track = true,
invalidate_cache = true,
)
@ -53,7 +53,7 @@ See also: [`insert_node!`](@ref), [`remove_node!`](@ref), [`remove_edge!`](@ref)
function insert_edge!(
graph::DAG,
node1::Node,
node2::Node,
node2::Node;
track = true,
invalidate_cache = true,
)
@ -97,7 +97,7 @@ See also: [`insert_node!`](@ref), [`insert_edge!`](@ref), [`remove_edge!`](@ref)
"""
function remove_node!(
graph::DAG,
node::Node,
node::Node;
track = true,
invalidate_cache = true,
)
@ -137,7 +137,7 @@ See also: [`insert_node!`](@ref), [`remove_node!`](@ref), [`insert_edge!`](@ref)
function remove_edge!(
graph::DAG,
node1::Node,
node2::Node,
node2::Node;
track = true,
invalidate_cache = true,
)
@ -181,6 +181,27 @@ function remove_edge!(
return nothing
end
function update_child!(
graph::DAG,
n::Node,
child_before::String,
child_after::String;
track = true,
)
# only need to update fused compute tasks
if !(typeof(n.task) <: FusedComputeTask)
return nothing
end
replace!(n.task.t1_inputs, child_before => child_after)
replace!(n.task.t2_inputs, child_before => child_after)
# keep track
if (track)
push!(graph.diff.updatedChildren, (n, child_before, child_after))
end
end
"""
get_snapshot_diff(graph::DAG)

View File

@ -80,125 +80,113 @@ function compute(t::FusedComputeTask, data)
end
"""
get_expression(::ComputeTaskP, inSymbol::Symbol, outSymbol::Symbol)
get_expression(::ComputeTaskP, inExprs::Vector{String}, outExpr::String)
Generate and return code evaluating [`ComputeTaskP`](@ref) on `inSymbol`, providing the output on `outSymbol`.
Generate and return code evaluating [`ComputeTaskP`](@ref) on `inExpr`, providing the output on `outExpr`.
"""
function get_expression(::ComputeTaskP, inSymbol::Symbol, outSymbol::Symbol)
return Meta.parse("$outSymbol = compute(ComputeTaskP(), $inSymbol)")
function get_expression(
::ComputeTaskP,
inExprs::Vector{String},
outExpr::String,
)
return Meta.parse("$outExpr = compute(ComputeTaskP(), $(inExprs[1]))")
end
"""
get_expression(::ComputeTaskU, inSymbol::Symbol, outSymbol::Symbol)
get_expression(::ComputeTaskU, inExprs::Vector{String}, outExpr::String)
Generate code evaluating [`ComputeTaskU`](@ref) on `inSymbol`, providing the output on `outSymbol`.
`inSymbol` should be of type [`ParticleValue`](@ref), `outSymbol` will be of type [`ParticleValue`](@ref).
Generate code evaluating [`ComputeTaskU`](@ref) on `inExpr`, providing the output on `outExpr`.
`inExpr` should be of type [`ParticleValue`](@ref), `outExpr` will be of type [`ParticleValue`](@ref).
"""
function get_expression(::ComputeTaskU, inSymbol::Symbol, outSymbol::Symbol)
return Meta.parse("$outSymbol = compute(ComputeTaskU(), $inSymbol)")
function get_expression(
::ComputeTaskU,
inExprs::Vector{String},
outExpr::String,
)
return Meta.parse("$outExpr = compute(ComputeTaskU(), $(inExprs[1]))")
end
"""
get_expression(::ComputeTaskV, inSymbol1::Symbol, inSymbol2::Symbol, outSymbol::Symbol)
get_expression(::ComputeTaskV, inExprs::Vector{String}, outExpr::String)
Generate code evaluating [`ComputeTaskV`](@ref) on `inSymbol1` and `inSymbol2`, providing the output on `outSymbol`.
`inSymbol1` and `inSymbol2` should be of type [`ParticleValue`](@ref), `outSymbol` will be of type [`ParticleValue`](@ref).
Generate code evaluating [`ComputeTaskV`](@ref) on `inExpr1` and `inExpr2`, providing the output on `outExpr`.
`inExpr1` and `inExpr2` should be of type [`ParticleValue`](@ref), `outExpr` will be of type [`ParticleValue`](@ref).
"""
function get_expression(
::ComputeTaskV,
inSymbol1::Symbol,
inSymbol2::Symbol,
outSymbol::Symbol,
inExprs::Vector{String},
outExpr::String,
)
return Meta.parse(
"$outSymbol = compute(ComputeTaskV(), $inSymbol1, $inSymbol2)",
"$outExpr = compute(ComputeTaskV(), $(inExprs[1]), $(inExprs[2]))",
)
end
"""
get_expression(::ComputeTaskS2, inSymbol1::Symbol, inSymbol2::Symbol, outSymbol::Symbol)
get_expression(::ComputeTaskS2, inExprs::Vector{String}, outExpr::String)
Generate code evaluating [`ComputeTaskS2`](@ref) on `inSymbol1` and `inSymbol2`, providing the output on `outSymbol`.
`inSymbol1` and `inSymbol2` should be of type [`ParticleValue`](@ref), `outSymbol` will be of type `Float64`.
Generate code evaluating [`ComputeTaskS2`](@ref) on `inExpr1` and `inExpr2`, providing the output on `outExpr`.
`inExpr1` and `inExpr2` should be of type [`ParticleValue`](@ref), `outExpr` will be of type `Float64`.
"""
function get_expression(
::ComputeTaskS2,
inSymbol1::Symbol,
inSymbol2::Symbol,
outSymbol::Symbol,
inExprs::Vector{String},
outExpr::String,
)
return Meta.parse(
"$outSymbol = compute(ComputeTaskS2(), $inSymbol1, $inSymbol2)",
"$outExpr = compute(ComputeTaskS2(), $(inExprs[1]), $(inExprs[2]))",
)
end
"""
get_expression(::ComputeTaskS1, inSymbol::Symbol, outSymbol::Symbol)
get_expression(::ComputeTaskS1, inExprs::Vector{String}, outExpr::String)
Generate code evaluating [`ComputeTaskS1`](@ref) on `inSymbol`, providing the output on `outSymbol`.
`inSymbol` should be of type [`ParticleValue`](@ref), `outSymbol` will be of type [`ParticleValue`](@ref).
Generate code evaluating [`ComputeTaskS1`](@ref) on `inExpr`, providing the output on `outExpr`.
`inExpr` should be of type [`ParticleValue`](@ref), `outExpr` will be of type [`ParticleValue`](@ref).
"""
function get_expression(::ComputeTaskS1, inSymbol::Symbol, outSymbol::Symbol)
return Meta.parse("$outSymbol = compute(ComputeTaskS1(), $inSymbol)")
function get_expression(
::ComputeTaskS1,
inExprs::Vector{String},
outExpr::String,
)
return Meta.parse("$outExpr = compute(ComputeTaskS1(), $(inExprs[1]))")
end
"""
get_expression(::ComputeTaskSum, inSymbols::Vector{Symbol}, outSymbol::Symbol)
get_expression(::ComputeTaskSum, inExprs::Vector{String}, outExpr::String)
Generate code evaluating [`ComputeTaskSum`](@ref) on `inSymbols`, providing the output on `outSymbol`.
`inSymbols` should be of type [`Float64`], `outSymbol` will be of type [`Float64`].
Generate code evaluating [`ComputeTaskSum`](@ref) on `inExprs`, providing the output on `outExpr`.
`inExprs` should be of type [`Float64`], `outExpr` will be of type [`Float64`].
"""
function get_expression(
::ComputeTaskSum,
inSymbols::Vector{Symbol},
outSymbol::Symbol,
inExprs::Vector{String},
outExpr::String,
)
return quote
$outSymbol = compute(ComputeTaskSum(), [$(inSymbols...)])
end
return Meta.parse(
"$outExpr = compute(ComputeTaskSum(), [$(unroll_string_vector(inExprs))])",
)
end
"""
get_expression(t::FusedComputeTask, inSymbols::Vector{Symbol}, outSymbol::Symbol)
get_expression(t::FusedComputeTask, inExprs::Vector{String}, outExpr::String)
Generate code evaluating a [`FusedComputeTask`](@ref) on `inSymbols`, providing the output on `outSymbol`.
`inSymbols` should be of the correct types and may be heterogeneous. `outSymbol` will be of the type of the output of `T2` of t.
Generate code evaluating a [`FusedComputeTask`](@ref) on `inExprs`, providing the output on `outExpr`.
`inExprs` should be of the correct types and may be heterogeneous. `outExpr` will be of the type of the output of `T2` of t.
"""
function get_expression(
t::FusedComputeTask,
inSymbols::Vector{Symbol},
outSymbol::Symbol,
inExprs::Vector{String},
outExpr::String,
)
(T1, T2) = get_types(t)
c1 = children(T1())
c2 = children(T2())
c1 = length(t.t1_inputs)
c2 = length(t.t2_inputs) + 1
expr1 = nothing
expr2 = nothing
# TODO need to figure out how to know which inputs belong to which subtask
# since we order the vectors with the child nodes we can't just split
if (c1 == 1)
expr1 = get_expression(T1(), inSymbols[begin], :intermediate)
elseif (c1 == 2)
expr1 =
get_expression(T1(), inSymbols[begin], inSymbols[2], :intermediate)
else
expr1 = get_expression(T1(), inSymbols[begin:c1], :intermediate)
end
if (c2 == 1)
expr2 = get_expression(T2(), :intermediate, outSymbol)
elseif c2 == 2
expr2 =
get_expression(T2(), :intermediate, inSymbols[c1 + 1], outSymbol)
else
expr2 = get_expression(
T2(),
:intermediate * inSymbols[(c1 + 1):end],
outSymbol,
)
end
expr1 = get_expression(t.first_task, t.t1_inputs, t.t1_output)
expr2 =
get_expression(t.second_task, [t.t2_inputs..., t.t1_output], outExpr)
return Expr(:block, expr1, expr2)
end
@ -210,24 +198,27 @@ Generate and return code for a given [`ComputeTaskNode`](@ref).
"""
function get_expression(node::ComputeTaskNode)
t = typeof(node.task)
@assert length(node.children) == children(node.task) || t <: ComputeTaskSum
@assert length(node.children) == children(node.task) || t <: ComputeTaskSum "Node $(node) has inconsistent number of children"
if (t <: ComputeTaskU || t <: ComputeTaskP || t <: ComputeTaskS1) # single input
symbolIn = Symbol("data_$(to_var_name(node.children[1].id))")
symbolOut = Symbol("data_$(to_var_name(node.id))")
return get_expression(t(), symbolIn, symbolOut)
symbolIn = "data_$(to_var_name(node.children[1].id))"
symbolOut = "data_$(to_var_name(node.id))"
return get_expression(node.task, [symbolIn], symbolOut)
elseif (t <: ComputeTaskS2 || t <: ComputeTaskV) # double input
symbolIn1 = Symbol("data_$(to_var_name(node.children[1].id))")
symbolIn2 = Symbol("data_$(to_var_name(node.children[2].id))")
symbolOut = Symbol("data_$(to_var_name(node.id))")
return get_expression(t(), symbolIn1, symbolIn2, symbolOut)
elseif (t <: ComputeTaskSum || t <: FusedComputeTask) # vector input
inSymbols = Vector{Symbol}()
symbolIn1 = "data_$(to_var_name(node.children[1].id))"
symbolIn2 = "data_$(to_var_name(node.children[2].id))"
symbolOut = "data_$(to_var_name(node.id))"
return get_expression(node.task, [symbolIn1, symbolIn2], symbolOut)
elseif (t <: ComputeTaskSum) # vector input
inExprs = Vector{String}()
for child in node.children
push!(inSymbols, Symbol("data_$(to_var_name(child.id))"))
push!(inExprs, "data_$(to_var_name(child.id))")
end
outSymbol = Symbol("data_$(to_var_name(node.id))")
return get_expression(t(), inSymbols, outSymbol)
outExpr = "data_$(to_var_name(node.id))"
return get_expression(node.task, inExprs, outExpr)
elseif t <: FusedComputeTask # fused compute task knows its inputs
outExpr = "data_$(to_var_name(node.id))"
return get_expression(node.task, Vector{String}(), outExpr)
else
error("Unknown compute task")
end
@ -242,15 +233,15 @@ function get_expression(node::DataTaskNode)
# TODO: do things to transport data from/to gpu, between numa nodes, etc.
@assert length(node.children) <= 1
inSymbol = nothing
inExpr = nothing
if (length(node.children) == 1)
inSymbol = Symbol("data_$(to_var_name(node.children[1].id))")
inExpr = "data_$(to_var_name(node.children[1].id))"
else
inSymbol = Symbol("data_$(to_var_name(node.id))_in")
inExpr = "data_$(to_var_name(node.id))_in"
end
outSymbol = Symbol("data_$(to_var_name(node.id))")
outExpr = "data_$(to_var_name(node.id))"
dataTransportExp = Meta.parse("$outSymbol = $inSymbol")
dataTransportExp = Meta.parse("$outExpr = $inExpr")
return dataTransportExp
end

View File

@ -63,10 +63,25 @@ function parse_abc(filename::String, verbose::Bool = false)
end
sizehint!(graph.nodes, estimate_no_nodes)
sum_node = insert_node!(graph, make_node(ComputeTaskSum()), false, false)
global_data_out =
insert_node!(graph, make_node(DataTask(FLOAT_SIZE)), false, false)
insert_edge!(graph, sum_node, global_data_out, false, false)
sum_node = insert_node!(
graph,
make_node(ComputeTaskSum()),
track = false,
invalidate_cache = false,
)
global_data_out = insert_node!(
graph,
make_node(DataTask(FLOAT_SIZE)),
track = false,
invalidate_cache = false,
)
insert_edge!(
graph,
sum_node,
global_data_out,
track = false,
invalidate_cache = false,
)
# remember the data out nodes for connection
dataOutNodes = Dict()
@ -93,30 +108,62 @@ function parse_abc(filename::String, verbose::Bool = false)
data_in = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE), string(node)),
false,
false,
track = false,
invalidate_cache = false,
) # read particle data node
compute_P =
insert_node!(graph, make_node(ComputeTaskP()), false, false) # compute P node
compute_P = insert_node!(
graph,
make_node(ComputeTaskP()),
track = false,
invalidate_cache = false,
) # compute P node
data_Pu = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE)),
false,
false,
track = false,
invalidate_cache = false,
) # transfer data from P to u (one ParticleValue object)
compute_u =
insert_node!(graph, make_node(ComputeTaskU()), false, false) # compute U node
compute_u = insert_node!(
graph,
make_node(ComputeTaskU()),
track = false,
invalidate_cache = false,
) # compute U node
data_out = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE)),
false,
false,
track = false,
invalidate_cache = false,
) # transfer data out from u (one ParticleValue object)
insert_edge!(graph, data_in, compute_P, false, false)
insert_edge!(graph, compute_P, data_Pu, false, false)
insert_edge!(graph, data_Pu, compute_u, false, false)
insert_edge!(graph, compute_u, data_out, false, false)
insert_edge!(
graph,
data_in,
compute_P,
track = false,
invalidate_cache = false,
)
insert_edge!(
graph,
compute_P,
data_Pu,
track = false,
invalidate_cache = false,
)
insert_edge!(
graph,
data_Pu,
compute_u,
track = false,
invalidate_cache = false,
)
insert_edge!(
graph,
compute_u,
data_out,
track = false,
invalidate_cache = false,
)
# remember the data_out node for future edges
dataOutNodes[node] = data_out
@ -126,13 +173,17 @@ function parse_abc(filename::String, verbose::Bool = false)
in1 = capt.captures[1]
in2 = capt.captures[2]
compute_v =
insert_node!(graph, make_node(ComputeTaskV()), false, false)
compute_v = insert_node!(
graph,
make_node(ComputeTaskV()),
track = false,
invalidate_cache = false,
)
data_out = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE)),
false,
false,
track = false,
invalidate_cache = false,
)
if (occursin(regex_c, in1))
@ -140,22 +191,46 @@ function parse_abc(filename::String, verbose::Bool = false)
compute_S = insert_node!(
graph,
make_node(ComputeTaskS1()),
false,
false,
track = false,
invalidate_cache = false,
)
data_S_v = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE)),
false,
false,
track = false,
invalidate_cache = false,
)
insert_edge!(graph, dataOutNodes[in1], compute_S, false, false)
insert_edge!(graph, compute_S, data_S_v, false, false)
insert_edge!(
graph,
dataOutNodes[in1],
compute_S,
track = false,
invalidate_cache = false,
)
insert_edge!(
graph,
compute_S,
data_S_v,
track = false,
invalidate_cache = false,
)
insert_edge!(graph, data_S_v, compute_v, false, false)
insert_edge!(
graph,
data_S_v,
compute_v,
track = false,
invalidate_cache = false,
)
else
insert_edge!(graph, dataOutNodes[in1], compute_v, false, false)
insert_edge!(
graph,
dataOutNodes[in1],
compute_v,
track = false,
invalidate_cache = false,
)
end
if (occursin(regex_c, in2))
@ -164,25 +239,55 @@ function parse_abc(filename::String, verbose::Bool = false)
compute_S = insert_node!(
graph,
make_node(ComputeTaskS1()),
false,
false,
track = false,
invalidate_cache = false,
)
data_S_v = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE)),
false,
false,
track = false,
invalidate_cache = false,
)
insert_edge!(graph, dataOutNodes[in2], compute_S, false, false)
insert_edge!(graph, compute_S, data_S_v, false, false)
insert_edge!(
graph,
dataOutNodes[in2],
compute_S,
track = false,
invalidate_cache = false,
)
insert_edge!(
graph,
compute_S,
data_S_v,
track = false,
invalidate_cache = false,
)
insert_edge!(graph, data_S_v, compute_v, false, false)
insert_edge!(
graph,
data_S_v,
compute_v,
track = false,
invalidate_cache = false,
)
else
insert_edge!(graph, dataOutNodes[in2], compute_v, false, false)
insert_edge!(
graph,
dataOutNodes[in2],
compute_v,
track = false,
invalidate_cache = false,
)
end
insert_edge!(graph, compute_v, data_out, false, false)
insert_edge!(
graph,
compute_v,
data_out,
track = false,
invalidate_cache = false,
)
dataOutNodes[node] = data_out
elseif occursin(regex_m, node)
@ -193,34 +298,84 @@ function parse_abc(filename::String, verbose::Bool = false)
in3 = capt.captures[3]
# in2 + in3 with a v
compute_v =
insert_node!(graph, make_node(ComputeTaskV()), false, false)
compute_v = insert_node!(
graph,
make_node(ComputeTaskV()),
track = false,
invalidate_cache = false,
)
data_v = insert_node!(
graph,
make_node(DataTask(PARTICLE_VALUE_SIZE)),
false,
false,
track = false,
invalidate_cache = false,
)
insert_edge!(graph, dataOutNodes[in2], compute_v, false, false)
insert_edge!(graph, dataOutNodes[in3], compute_v, false, false)
insert_edge!(graph, compute_v, data_v, false, false)
insert_edge!(
graph,
dataOutNodes[in2],
compute_v,
track = false,
invalidate_cache = false,
)
insert_edge!(
graph,
dataOutNodes[in3],
compute_v,
track = false,
invalidate_cache = false,
)
insert_edge!(
graph,
compute_v,
data_v,
track = false,
invalidate_cache = false,
)
# combine with the v of the combined other input
compute_S2 =
insert_node!(graph, make_node(ComputeTaskS2()), false, false)
compute_S2 = insert_node!(
graph,
make_node(ComputeTaskS2()),
track = false,
invalidate_cache = false,
)
data_out = insert_node!(
graph,
make_node(DataTask(FLOAT_SIZE)),
false,
false,
track = false,
invalidate_cache = false,
) # output of a S2 task is only a float
insert_edge!(graph, data_v, compute_S2, false, false)
insert_edge!(graph, dataOutNodes[in1], compute_S2, false, false)
insert_edge!(graph, compute_S2, data_out, false, false)
insert_edge!(
graph,
data_v,
compute_S2,
track = false,
invalidate_cache = false,
)
insert_edge!(
graph,
dataOutNodes[in1],
compute_S2,
track = false,
invalidate_cache = false,
)
insert_edge!(
graph,
compute_S2,
data_out,
track = false,
invalidate_cache = false,
)
insert_edge!(graph, data_out, sum_node, false, false)
insert_edge!(
graph,
data_out,
sum_node,
track = false,
invalidate_cache = false,
)
elseif occursin(regex_plus, node)
if (verbose)
println("\rReading Nodes Complete ")

View File

@ -21,6 +21,7 @@ A struct describing a particle of the ABC-Model. It has the 4 momentum parts P0.
`sizeof(Particle())` = 40 Byte
"""
struct Particle
# SFourMomentum
P0::Float64
P1::Float64
P2::Float64

View File

@ -157,9 +157,8 @@ children(::ComputeTaskSum) = -1
"""
children(t::FusedComputeTask)
Return the number of children of a FusedComputeTask. It's the sum of the children of both tasks minus one.
Return the number of children of a FusedComputeTask.
"""
function children(t::FusedComputeTask)
(T1, T2) = get_types(t)
return children(T1()) + children(T2()) - 1 # one of the inputs is the output of T1 and thus not a child of the node
return length(union(Set(t.t1_inputs), Set(t.t2_inputs)))
end

View File

@ -124,17 +124,25 @@ function revert_diff!(graph::DAG, diff::Diff)
# add removed nodes, remove added nodes, same for edges
# note the order
for edge in diff.addedEdges
remove_edge!(graph, edge.edge[1], edge.edge[2], false)
remove_edge!(graph, edge.edge[1], edge.edge[2], track = false)
end
for node in diff.addedNodes
remove_node!(graph, node, false)
remove_node!(graph, node, track = false)
end
for node in diff.removedNodes
insert_node!(graph, node, false)
insert_node!(graph, node, track = false)
end
for edge in diff.removedEdges
insert_edge!(graph, edge.edge[1], edge.edge[2], false)
insert_edge!(graph, edge.edge[1], edge.edge[2], track = false)
end
for (node, before, after) in diff.updatedChildren
# node must be fused compute task at this point
@assert typeof(node.task) <: FusedComputeTask
replace!(node.task.t1_inputs, after => before)
replace!(node.task.t2_inputs, after => before)
end
graph.properties -= GraphProperties(diff)
@ -175,9 +183,27 @@ function node_fusion!(
n3_children = children(n3)
remove_node!(graph, n3)
# assemble the input node vectors of n1 and n3 to save into the FusedComputeTask
n1_inputs = Vector{String}()
for child in n1_children
push!(n1_inputs, "data_$(to_var_name(child.id))")
end
n3_inputs = Vector{String}()
for child in n3_children
push!(n3_inputs, "data_$(to_var_name(child.id))")
end
# create new node with the fused compute task
new_node =
ComputeTaskNode(FusedComputeTask{typeof(n1.task), typeof(n3.task)}())
new_node = ComputeTaskNode(
FusedComputeTask(
n1.task,
n3.task,
n1_inputs,
"data_$(to_var_name(n2.id))",
n3_inputs,
),
)
insert_node!(graph, new_node)
for child in n1_children
@ -195,6 +221,15 @@ function node_fusion!(
for parent in n3_parents
remove_edge!(graph, n3, parent)
insert_edge!(graph, new_node, parent)
# important! update the parent node's child names in case they are fused compute tasks
# needed for compute generation so the fused compute task can correctly match inputs to its component tasks
update_child!(
graph,
parent,
"data_$(to_var_name(n3.id))",
"data_$(to_var_name(new_node.id))",
)
end
return get_snapshot_diff(graph)
@ -217,7 +252,9 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
n1_children = children(n1)
n1_parents = Set(n1.parents)
new_parents = Set{Node}()
# set of the new parents of n1, together with the names of the previous children that n1 now replaces
new_parents = Set{Tuple{Node, String}}()
# remove all of the nodes' parents and children and the nodes themselves (except for first node)
for i in 2:length(nodes)
@ -230,7 +267,7 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
remove_edge!(graph, n, parent)
# collect all parents
push!(new_parents, parent)
push!(new_parents, (parent, "data_$(to_var_name(n.id))"))
end
remove_node!(graph, n)
@ -238,9 +275,11 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
setdiff!(new_parents, n1_parents)
for parent in new_parents
for (parent, prev_child) in new_parents
# now add parents of all input nodes to n1 without duplicates
insert_edge!(graph, n1, parent)
update_child!(graph, parent, prev_child, "data_$(to_var_name(n1.id))")
end
return get_snapshot_diff(graph)
@ -275,6 +314,13 @@ function node_split!(graph::DAG, n1::Node)
insert_node!(graph, n_copy)
insert_edge!(graph, n_copy, parent)
update_child!(
graph,
parent,
"data_$(to_var_name(n1.id))",
"data_$(to_var_name(n_copy.id))",
)
for child in n1_children
insert_edge!(graph, child, n_copy)
end

View File

@ -12,3 +12,25 @@ copy(t::AbstractDataTask) =
Return a copy of the given compute task.
"""
copy(t::AbstractComputeTask) = typeof(t)()
"""
copy(t::FusedComputeTask)
Return a copy of th egiven [`FusedComputeTask`](@ref).
"""
function copy(t::FusedComputeTask{T1, T2}) where {T1, T2}
return FusedComputeTask{T1, T2}(
t.first_task,
t.second_task,
t.t1_inputs,
t.t1_output,
t.t2_inputs,
)
end
FusedComputeTask{T1, T2}(
t1_inputs::Vector{String},
t1_output::String,
t2_inputs::Vector{String},
) where {T1, T2} =
FusedComputeTask{T1, T2}(T1(), T2(), t1_inputs, t1_output, t2_inputs)

View File

@ -71,8 +71,7 @@ data(t::AbstractComputeTask) = 0
Return the compute effort of a fused compute task.
"""
function compute_effort(t::FusedComputeTask)
(T1, T2) = collect(typeof(t).parameters)
return compute_effort(T1()) + compute_effort(T2())
return compute_effort(t.first_task) + compute_effort(t.second_task)
end
"""
@ -81,30 +80,3 @@ end
Return a tuple of a the fused compute task's components' types.
"""
get_types(::FusedComputeTask{T1, T2}) where {T1, T2} = (T1, T2)
"""
get_expression(t::AbstractTask)
Return an expression evaluating the given task on the :dataIn symbol
"""
function get_expression(t::AbstractTask)
return quote
dataOut = compute($t, dataIn)
end
end
"""
get_expression()
"""
function get_expression(
t::FusedComputeTask,
inSymbol::Symbol,
outSymbol::Symbol,
)
#TODO
computeExp = quote
$outSymbol = compute($t, $inSymbol)
end
return computeExp
end

View File

@ -27,4 +27,13 @@ A fused compute task made up of the computation of first `T1` and then `T2`.
Also see: [`get_types`](@ref).
"""
struct FusedComputeTask{T1 <: AbstractComputeTask, T2 <: AbstractComputeTask} <:
AbstractComputeTask end
AbstractComputeTask
first_task::T1
second_task::T2
# the names of the inputs for T1
t1_inputs::Vector{String}
# output name of T1
t1_output::String
# t2_inputs doesn't include the output of t1, that's implicit
t2_inputs::Vector{String}
end

View File

@ -87,3 +87,19 @@ Return the memory footprint of the node in Byte. Used in [`mem(graph::DAG)`](@re
function mem(node::Node)
return Base.summarysize(node, exclude = Union{Node, Operation})
end
"""
unroll_string_vector(vec::Vector{String})
Return the given vector as single String without quotation marks or brackets.
"""
function unroll_string_vector(vec::Vector{String})
result = ""
for s in vec
if (result != "")
result *= ", "
end
result *= s
end
return result
end

View File

@ -5,51 +5,51 @@ import MetagraphOptimization.make_node
@testset "Unit Tests Node Reduction" begin
graph = MetagraphOptimization.DAG()
d_exit = insert_node!(graph, make_node(DataTask(10)), false)
d_exit = insert_node!(graph, make_node(DataTask(10)), track = false)
s0 = insert_node!(graph, make_node(ComputeTaskS2()), false)
s0 = insert_node!(graph, make_node(ComputeTaskS2()), track = false)
ED = insert_node!(graph, make_node(DataTask(3)), false)
FD = insert_node!(graph, make_node(DataTask(3)), false)
ED = insert_node!(graph, make_node(DataTask(3)), track = false)
FD = insert_node!(graph, make_node(DataTask(3)), track = false)
EC = insert_node!(graph, make_node(ComputeTaskV()), false)
FC = insert_node!(graph, make_node(ComputeTaskV()), false)
EC = insert_node!(graph, make_node(ComputeTaskV()), track = false)
FC = insert_node!(graph, make_node(ComputeTaskV()), track = false)
A1D = insert_node!(graph, make_node(DataTask(4)), false)
B1D_1 = insert_node!(graph, make_node(DataTask(4)), false)
B1D_2 = insert_node!(graph, make_node(DataTask(4)), false)
C1D = insert_node!(graph, make_node(DataTask(4)), false)
A1D = insert_node!(graph, make_node(DataTask(4)), track = false)
B1D_1 = insert_node!(graph, make_node(DataTask(4)), track = false)
B1D_2 = insert_node!(graph, make_node(DataTask(4)), track = false)
C1D = insert_node!(graph, make_node(DataTask(4)), track = false)
A1C = insert_node!(graph, make_node(ComputeTaskU()), false)
B1C_1 = insert_node!(graph, make_node(ComputeTaskU()), false)
B1C_2 = insert_node!(graph, make_node(ComputeTaskU()), false)
C1C = insert_node!(graph, make_node(ComputeTaskU()), false)
A1C = insert_node!(graph, make_node(ComputeTaskU()), track = false)
B1C_1 = insert_node!(graph, make_node(ComputeTaskU()), track = false)
B1C_2 = insert_node!(graph, make_node(ComputeTaskU()), track = false)
C1C = insert_node!(graph, make_node(ComputeTaskU()), track = false)
AD = insert_node!(graph, make_node(DataTask(5)), false)
BD = insert_node!(graph, make_node(DataTask(5)), false)
CD = insert_node!(graph, make_node(DataTask(5)), false)
AD = insert_node!(graph, make_node(DataTask(5)), track = false)
BD = insert_node!(graph, make_node(DataTask(5)), track = false)
CD = insert_node!(graph, make_node(DataTask(5)), track = false)
insert_edge!(graph, s0, d_exit, false)
insert_edge!(graph, ED, s0, false)
insert_edge!(graph, FD, s0, false)
insert_edge!(graph, EC, ED, false)
insert_edge!(graph, FC, FD, false)
insert_edge!(graph, s0, d_exit, track = false)
insert_edge!(graph, ED, s0, track = false)
insert_edge!(graph, FD, s0, track = false)
insert_edge!(graph, EC, ED, track = false)
insert_edge!(graph, FC, FD, track = false)
insert_edge!(graph, A1D, EC, false)
insert_edge!(graph, B1D_1, EC, false)
insert_edge!(graph, A1D, EC, track = false)
insert_edge!(graph, B1D_1, EC, track = false)
insert_edge!(graph, B1D_2, FC, false)
insert_edge!(graph, C1D, FC, false)
insert_edge!(graph, B1D_2, FC, track = false)
insert_edge!(graph, C1D, FC, track = false)
insert_edge!(graph, A1C, A1D, false)
insert_edge!(graph, B1C_1, B1D_1, false)
insert_edge!(graph, B1C_2, B1D_2, false)
insert_edge!(graph, C1C, C1D, false)
insert_edge!(graph, A1C, A1D, track = false)
insert_edge!(graph, B1C_1, B1D_1, track = false)
insert_edge!(graph, B1C_2, B1D_2, track = false)
insert_edge!(graph, C1C, C1D, track = false)
insert_edge!(graph, AD, A1C, false)
insert_edge!(graph, BD, B1C_1, false)
insert_edge!(graph, BD, B1C_2, false)
insert_edge!(graph, CD, C1C, false)
insert_edge!(graph, AD, A1C, track = false)
insert_edge!(graph, BD, B1C_1, track = false)
insert_edge!(graph, BD, B1C_2, track = false)
insert_edge!(graph, CD, C1C, track = false)
@test is_valid(graph)

View File

@ -2,7 +2,9 @@ import MetagraphOptimization.A
import MetagraphOptimization.B
import MetagraphOptimization.ParticleType
@testset "Unit Tests Graph" begin
include("../examples/profiling_utilities.jl")
@testset "Unit Tests Execution" begin
particles = Dict{ParticleType, Vector{Particle}}(
(
A => [
@ -20,12 +22,35 @@ import MetagraphOptimization.ParticleType
expected_result = 5.5320567694746876e-5
for _ in 1:10 # test in a loop because graph layout should not change the result
graph = parse_abc(joinpath(@__DIR__, "..", "input", "AB->AB.txt"))
@test isapprox(execute(graph, particles), expected_result; rtol = 0.001)
@testset "AB->AB no optimization" begin
for _ in 1:10 # test in a loop because graph layout should not change the result
graph = parse_abc(joinpath(@__DIR__, "..", "input", "AB->AB.txt"))
@test isapprox(
execute(graph, particles),
expected_result;
rtol = 0.001,
)
code = MetagraphOptimization.gen_code(graph)
@test isapprox(execute(code, particles), expected_result; rtol = 0.001)
code = MetagraphOptimization.gen_code(graph)
@test isapprox(
execute(code, particles),
expected_result;
rtol = 0.001,
)
end
end
@testset "AB->AB after random walk" begin
for _ in 1:20
graph = parse_abc(joinpath(@__DIR__, "..", "input", "AB->AB.txt"))
random_walk!(graph, 40)
@test isapprox(
execute(graph, particles),
expected_result;
rtol = 0.001,
)
end
end
end
println("Execution Unit Tests Complete!")

View File

@ -17,91 +17,91 @@ import MetagraphOptimization.partners
(nodeFusions = 0, nodeReductions = 0, nodeSplits = 0)
# s to output (exit node)
d_exit = insert_node!(graph, make_node(DataTask(10)), false)
d_exit = insert_node!(graph, make_node(DataTask(10)), track = false)
@test length(graph.nodes) == 1
@test length(graph.dirtyNodes) == 1
# final s compute
s0 = insert_node!(graph, make_node(ComputeTaskS2()), false)
s0 = insert_node!(graph, make_node(ComputeTaskS2()), track = false)
@test length(graph.nodes) == 2
@test length(graph.dirtyNodes) == 2
# data from v0 and v1 to s0
d_v0_s0 = insert_node!(graph, make_node(DataTask(5)), false)
d_v1_s0 = insert_node!(graph, make_node(DataTask(5)), false)
d_v0_s0 = insert_node!(graph, make_node(DataTask(5)), track = false)
d_v1_s0 = insert_node!(graph, make_node(DataTask(5)), track = false)
# v0 and v1 compute
v0 = insert_node!(graph, make_node(ComputeTaskV()), false)
v1 = insert_node!(graph, make_node(ComputeTaskV()), false)
v0 = insert_node!(graph, make_node(ComputeTaskV()), track = false)
v1 = insert_node!(graph, make_node(ComputeTaskV()), track = false)
# data from uB, uA, uBp and uAp to v0 and v1
d_uB_v0 = insert_node!(graph, make_node(DataTask(3)), false)
d_uA_v0 = insert_node!(graph, make_node(DataTask(3)), false)
d_uBp_v1 = insert_node!(graph, make_node(DataTask(3)), false)
d_uAp_v1 = insert_node!(graph, make_node(DataTask(3)), false)
d_uB_v0 = insert_node!(graph, make_node(DataTask(3)), track = false)
d_uA_v0 = insert_node!(graph, make_node(DataTask(3)), track = false)
d_uBp_v1 = insert_node!(graph, make_node(DataTask(3)), track = false)
d_uAp_v1 = insert_node!(graph, make_node(DataTask(3)), track = false)
# uB, uA, uBp and uAp computes
uB = insert_node!(graph, make_node(ComputeTaskU()), false)
uA = insert_node!(graph, make_node(ComputeTaskU()), false)
uBp = insert_node!(graph, make_node(ComputeTaskU()), false)
uAp = insert_node!(graph, make_node(ComputeTaskU()), false)
uB = insert_node!(graph, make_node(ComputeTaskU()), track = false)
uA = insert_node!(graph, make_node(ComputeTaskU()), track = false)
uBp = insert_node!(graph, make_node(ComputeTaskU()), track = false)
uAp = insert_node!(graph, make_node(ComputeTaskU()), track = false)
# data from PB, PA, PBp and PAp to uB, uA, uBp and uAp
d_PB_uB = insert_node!(graph, make_node(DataTask(6)), false)
d_PA_uA = insert_node!(graph, make_node(DataTask(6)), false)
d_PBp_uBp = insert_node!(graph, make_node(DataTask(6)), false)
d_PAp_uAp = insert_node!(graph, make_node(DataTask(6)), false)
d_PB_uB = insert_node!(graph, make_node(DataTask(6)), track = false)
d_PA_uA = insert_node!(graph, make_node(DataTask(6)), track = false)
d_PBp_uBp = insert_node!(graph, make_node(DataTask(6)), track = false)
d_PAp_uAp = insert_node!(graph, make_node(DataTask(6)), track = false)
# P computes PB, PA, PBp and PAp
PB = insert_node!(graph, make_node(ComputeTaskP()), false)
PA = insert_node!(graph, make_node(ComputeTaskP()), false)
PBp = insert_node!(graph, make_node(ComputeTaskP()), false)
PAp = insert_node!(graph, make_node(ComputeTaskP()), false)
PB = insert_node!(graph, make_node(ComputeTaskP()), track = false)
PA = insert_node!(graph, make_node(ComputeTaskP()), track = false)
PBp = insert_node!(graph, make_node(ComputeTaskP()), track = false)
PAp = insert_node!(graph, make_node(ComputeTaskP()), track = false)
# entry nodes getting data for P computes
d_PB = insert_node!(graph, make_node(DataTask(4)), false)
d_PA = insert_node!(graph, make_node(DataTask(4)), false)
d_PBp = insert_node!(graph, make_node(DataTask(4)), false)
d_PAp = insert_node!(graph, make_node(DataTask(4)), false)
d_PB = insert_node!(graph, make_node(DataTask(4)), track = false)
d_PA = insert_node!(graph, make_node(DataTask(4)), track = false)
d_PBp = insert_node!(graph, make_node(DataTask(4)), track = false)
d_PAp = insert_node!(graph, make_node(DataTask(4)), track = false)
@test length(graph.nodes) == 26
@test length(graph.dirtyNodes) == 26
# now for all the edges
insert_edge!(graph, d_PB, PB, false)
insert_edge!(graph, d_PA, PA, false)
insert_edge!(graph, d_PBp, PBp, false)
insert_edge!(graph, d_PAp, PAp, false)
insert_edge!(graph, d_PB, PB, track = false)
insert_edge!(graph, d_PA, PA, track = false)
insert_edge!(graph, d_PBp, PBp, track = false)
insert_edge!(graph, d_PAp, PAp, track = false)
insert_edge!(graph, PB, d_PB_uB, false)
insert_edge!(graph, PA, d_PA_uA, false)
insert_edge!(graph, PBp, d_PBp_uBp, false)
insert_edge!(graph, PAp, d_PAp_uAp, false)
insert_edge!(graph, PB, d_PB_uB, track = false)
insert_edge!(graph, PA, d_PA_uA, track = false)
insert_edge!(graph, PBp, d_PBp_uBp, track = false)
insert_edge!(graph, PAp, d_PAp_uAp, track = false)
insert_edge!(graph, d_PB_uB, uB, false)
insert_edge!(graph, d_PA_uA, uA, false)
insert_edge!(graph, d_PBp_uBp, uBp, false)
insert_edge!(graph, d_PAp_uAp, uAp, false)
insert_edge!(graph, d_PB_uB, uB, track = false)
insert_edge!(graph, d_PA_uA, uA, track = false)
insert_edge!(graph, d_PBp_uBp, uBp, track = false)
insert_edge!(graph, d_PAp_uAp, uAp, track = false)
insert_edge!(graph, uB, d_uB_v0, false)
insert_edge!(graph, uA, d_uA_v0, false)
insert_edge!(graph, uBp, d_uBp_v1, false)
insert_edge!(graph, uAp, d_uAp_v1, false)
insert_edge!(graph, uB, d_uB_v0, track = false)
insert_edge!(graph, uA, d_uA_v0, track = false)
insert_edge!(graph, uBp, d_uBp_v1, track = false)
insert_edge!(graph, uAp, d_uAp_v1, track = false)
insert_edge!(graph, d_uB_v0, v0, false)
insert_edge!(graph, d_uA_v0, v0, false)
insert_edge!(graph, d_uBp_v1, v1, false)
insert_edge!(graph, d_uAp_v1, v1, false)
insert_edge!(graph, d_uB_v0, v0, track = false)
insert_edge!(graph, d_uA_v0, v0, track = false)
insert_edge!(graph, d_uBp_v1, v1, track = false)
insert_edge!(graph, d_uAp_v1, v1, track = false)
insert_edge!(graph, v0, d_v0_s0, false)
insert_edge!(graph, v1, d_v1_s0, false)
insert_edge!(graph, v0, d_v0_s0, track = false)
insert_edge!(graph, v1, d_v1_s0, track = false)
insert_edge!(graph, d_v0_s0, s0, false)
insert_edge!(graph, d_v1_s0, s0, false)
insert_edge!(graph, d_v0_s0, s0, track = false)
insert_edge!(graph, d_v1_s0, s0, track = false)
insert_edge!(graph, s0, d_exit, false)
insert_edge!(graph, s0, d_exit, track = false)
@test length(graph.nodes) == 26
@test length(graph.appliedOperations) == 0