Make FusedComputeTasks usable in execution
This commit is contained in:
parent
f8a591991c
commit
c428613c80
@ -4,7 +4,6 @@ authors = ["Anton Reinhard <anton.reinhard@proton.me>"]
|
||||
version = "0.1.0"
|
||||
|
||||
[deps]
|
||||
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
|
||||
AccurateArithmetic = "22286c92-06ac-501d-9306-4abd417d9753"
|
||||
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
|
||||
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
|
||||
|
@ -1,6 +1,6 @@
|
||||
|
||||
function test_random_walk(g::DAG, n::Int64)
|
||||
# the purpose here is to do "random" operations and reverse them again and validate that the graph stays the same and doesn't diverge
|
||||
function random_walk!(g::DAG, n::Int64)
|
||||
# the purpose here is to do "random" operations on the graph to simulate an optimizer
|
||||
reset_graph!(g)
|
||||
|
||||
properties = get_properties(g)
|
||||
@ -32,7 +32,7 @@ function test_random_walk(g::DAG, n::Int64)
|
||||
end
|
||||
end
|
||||
|
||||
return reset_graph!(g)
|
||||
return nothing
|
||||
end
|
||||
|
||||
function reduce_all!(g::DAG)
|
||||
|
@ -4,8 +4,14 @@
|
||||
A named tuple representing a difference of added and removed nodes and edges on a [`DAG`](@ref).
|
||||
"""
|
||||
const Diff = NamedTuple{
|
||||
(:addedNodes, :removedNodes, :addedEdges, :removedEdges),
|
||||
Tuple{Vector{Node}, Vector{Node}, Vector{Edge}, Vector{Edge}},
|
||||
(:addedNodes, :removedNodes, :addedEdges, :removedEdges, :updatedChildren),
|
||||
Tuple{
|
||||
Vector{Node},
|
||||
Vector{Node},
|
||||
Vector{Edge},
|
||||
Vector{Edge},
|
||||
Vector{Tuple{Node, String, String}},
|
||||
},
|
||||
}
|
||||
|
||||
function Diff()
|
||||
@ -14,5 +20,8 @@ function Diff()
|
||||
removedNodes = Vector{Node}(),
|
||||
addedEdges = Vector{Edge}(),
|
||||
removedEdges = Vector{Edge}(),
|
||||
|
||||
# children were updated from updatedChildren[2] to updatedChildren[3] in node updatedChildren[1]
|
||||
updatedChildren = Vector{Tuple{Node, String, String}}(),
|
||||
)::Diff
|
||||
end
|
||||
|
@ -17,7 +17,7 @@ See also: [`remove_node!`](@ref), [`insert_edge!`](@ref), [`remove_edge!`](@ref)
|
||||
"""
|
||||
function insert_node!(
|
||||
graph::DAG,
|
||||
node::Node,
|
||||
node::Node;
|
||||
track = true,
|
||||
invalidate_cache = true,
|
||||
)
|
||||
@ -53,7 +53,7 @@ See also: [`insert_node!`](@ref), [`remove_node!`](@ref), [`remove_edge!`](@ref)
|
||||
function insert_edge!(
|
||||
graph::DAG,
|
||||
node1::Node,
|
||||
node2::Node,
|
||||
node2::Node;
|
||||
track = true,
|
||||
invalidate_cache = true,
|
||||
)
|
||||
@ -97,7 +97,7 @@ See also: [`insert_node!`](@ref), [`insert_edge!`](@ref), [`remove_edge!`](@ref)
|
||||
"""
|
||||
function remove_node!(
|
||||
graph::DAG,
|
||||
node::Node,
|
||||
node::Node;
|
||||
track = true,
|
||||
invalidate_cache = true,
|
||||
)
|
||||
@ -137,7 +137,7 @@ See also: [`insert_node!`](@ref), [`remove_node!`](@ref), [`insert_edge!`](@ref)
|
||||
function remove_edge!(
|
||||
graph::DAG,
|
||||
node1::Node,
|
||||
node2::Node,
|
||||
node2::Node;
|
||||
track = true,
|
||||
invalidate_cache = true,
|
||||
)
|
||||
@ -181,6 +181,27 @@ function remove_edge!(
|
||||
return nothing
|
||||
end
|
||||
|
||||
function update_child!(
|
||||
graph::DAG,
|
||||
n::Node,
|
||||
child_before::String,
|
||||
child_after::String;
|
||||
track = true,
|
||||
)
|
||||
# only need to update fused compute tasks
|
||||
if !(typeof(n.task) <: FusedComputeTask)
|
||||
return nothing
|
||||
end
|
||||
|
||||
replace!(n.task.t1_inputs, child_before => child_after)
|
||||
replace!(n.task.t2_inputs, child_before => child_after)
|
||||
|
||||
# keep track
|
||||
if (track)
|
||||
push!(graph.diff.updatedChildren, (n, child_before, child_after))
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
get_snapshot_diff(graph::DAG)
|
||||
|
||||
|
@ -80,125 +80,113 @@ function compute(t::FusedComputeTask, data)
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(::ComputeTaskP, inSymbol::Symbol, outSymbol::Symbol)
|
||||
get_expression(::ComputeTaskP, inExprs::Vector{String}, outExpr::String)
|
||||
|
||||
Generate and return code evaluating [`ComputeTaskP`](@ref) on `inSymbol`, providing the output on `outSymbol`.
|
||||
Generate and return code evaluating [`ComputeTaskP`](@ref) on `inExpr`, providing the output on `outExpr`.
|
||||
"""
|
||||
function get_expression(::ComputeTaskP, inSymbol::Symbol, outSymbol::Symbol)
|
||||
return Meta.parse("$outSymbol = compute(ComputeTaskP(), $inSymbol)")
|
||||
function get_expression(
|
||||
::ComputeTaskP,
|
||||
inExprs::Vector{String},
|
||||
outExpr::String,
|
||||
)
|
||||
return Meta.parse("$outExpr = compute(ComputeTaskP(), $(inExprs[1]))")
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(::ComputeTaskU, inSymbol::Symbol, outSymbol::Symbol)
|
||||
get_expression(::ComputeTaskU, inExprs::Vector{String}, outExpr::String)
|
||||
|
||||
Generate code evaluating [`ComputeTaskU`](@ref) on `inSymbol`, providing the output on `outSymbol`.
|
||||
`inSymbol` should be of type [`ParticleValue`](@ref), `outSymbol` will be of type [`ParticleValue`](@ref).
|
||||
Generate code evaluating [`ComputeTaskU`](@ref) on `inExpr`, providing the output on `outExpr`.
|
||||
`inExpr` should be of type [`ParticleValue`](@ref), `outExpr` will be of type [`ParticleValue`](@ref).
|
||||
"""
|
||||
function get_expression(::ComputeTaskU, inSymbol::Symbol, outSymbol::Symbol)
|
||||
return Meta.parse("$outSymbol = compute(ComputeTaskU(), $inSymbol)")
|
||||
function get_expression(
|
||||
::ComputeTaskU,
|
||||
inExprs::Vector{String},
|
||||
outExpr::String,
|
||||
)
|
||||
return Meta.parse("$outExpr = compute(ComputeTaskU(), $(inExprs[1]))")
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(::ComputeTaskV, inSymbol1::Symbol, inSymbol2::Symbol, outSymbol::Symbol)
|
||||
get_expression(::ComputeTaskV, inExprs::Vector{String}, outExpr::String)
|
||||
|
||||
Generate code evaluating [`ComputeTaskV`](@ref) on `inSymbol1` and `inSymbol2`, providing the output on `outSymbol`.
|
||||
`inSymbol1` and `inSymbol2` should be of type [`ParticleValue`](@ref), `outSymbol` will be of type [`ParticleValue`](@ref).
|
||||
Generate code evaluating [`ComputeTaskV`](@ref) on `inExpr1` and `inExpr2`, providing the output on `outExpr`.
|
||||
`inExpr1` and `inExpr2` should be of type [`ParticleValue`](@ref), `outExpr` will be of type [`ParticleValue`](@ref).
|
||||
"""
|
||||
function get_expression(
|
||||
::ComputeTaskV,
|
||||
inSymbol1::Symbol,
|
||||
inSymbol2::Symbol,
|
||||
outSymbol::Symbol,
|
||||
inExprs::Vector{String},
|
||||
outExpr::String,
|
||||
)
|
||||
return Meta.parse(
|
||||
"$outSymbol = compute(ComputeTaskV(), $inSymbol1, $inSymbol2)",
|
||||
"$outExpr = compute(ComputeTaskV(), $(inExprs[1]), $(inExprs[2]))",
|
||||
)
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(::ComputeTaskS2, inSymbol1::Symbol, inSymbol2::Symbol, outSymbol::Symbol)
|
||||
get_expression(::ComputeTaskS2, inExprs::Vector{String}, outExpr::String)
|
||||
|
||||
Generate code evaluating [`ComputeTaskS2`](@ref) on `inSymbol1` and `inSymbol2`, providing the output on `outSymbol`.
|
||||
`inSymbol1` and `inSymbol2` should be of type [`ParticleValue`](@ref), `outSymbol` will be of type `Float64`.
|
||||
Generate code evaluating [`ComputeTaskS2`](@ref) on `inExpr1` and `inExpr2`, providing the output on `outExpr`.
|
||||
`inExpr1` and `inExpr2` should be of type [`ParticleValue`](@ref), `outExpr` will be of type `Float64`.
|
||||
"""
|
||||
function get_expression(
|
||||
::ComputeTaskS2,
|
||||
inSymbol1::Symbol,
|
||||
inSymbol2::Symbol,
|
||||
outSymbol::Symbol,
|
||||
inExprs::Vector{String},
|
||||
outExpr::String,
|
||||
)
|
||||
return Meta.parse(
|
||||
"$outSymbol = compute(ComputeTaskS2(), $inSymbol1, $inSymbol2)",
|
||||
"$outExpr = compute(ComputeTaskS2(), $(inExprs[1]), $(inExprs[2]))",
|
||||
)
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(::ComputeTaskS1, inSymbol::Symbol, outSymbol::Symbol)
|
||||
get_expression(::ComputeTaskS1, inExprs::Vector{String}, outExpr::String)
|
||||
|
||||
Generate code evaluating [`ComputeTaskS1`](@ref) on `inSymbol`, providing the output on `outSymbol`.
|
||||
`inSymbol` should be of type [`ParticleValue`](@ref), `outSymbol` will be of type [`ParticleValue`](@ref).
|
||||
Generate code evaluating [`ComputeTaskS1`](@ref) on `inExpr`, providing the output on `outExpr`.
|
||||
`inExpr` should be of type [`ParticleValue`](@ref), `outExpr` will be of type [`ParticleValue`](@ref).
|
||||
"""
|
||||
function get_expression(::ComputeTaskS1, inSymbol::Symbol, outSymbol::Symbol)
|
||||
return Meta.parse("$outSymbol = compute(ComputeTaskS1(), $inSymbol)")
|
||||
function get_expression(
|
||||
::ComputeTaskS1,
|
||||
inExprs::Vector{String},
|
||||
outExpr::String,
|
||||
)
|
||||
return Meta.parse("$outExpr = compute(ComputeTaskS1(), $(inExprs[1]))")
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(::ComputeTaskSum, inSymbols::Vector{Symbol}, outSymbol::Symbol)
|
||||
get_expression(::ComputeTaskSum, inExprs::Vector{String}, outExpr::String)
|
||||
|
||||
Generate code evaluating [`ComputeTaskSum`](@ref) on `inSymbols`, providing the output on `outSymbol`.
|
||||
`inSymbols` should be of type [`Float64`], `outSymbol` will be of type [`Float64`].
|
||||
Generate code evaluating [`ComputeTaskSum`](@ref) on `inExprs`, providing the output on `outExpr`.
|
||||
`inExprs` should be of type [`Float64`], `outExpr` will be of type [`Float64`].
|
||||
"""
|
||||
function get_expression(
|
||||
::ComputeTaskSum,
|
||||
inSymbols::Vector{Symbol},
|
||||
outSymbol::Symbol,
|
||||
inExprs::Vector{String},
|
||||
outExpr::String,
|
||||
)
|
||||
return quote
|
||||
$outSymbol = compute(ComputeTaskSum(), [$(inSymbols...)])
|
||||
end
|
||||
return Meta.parse(
|
||||
"$outExpr = compute(ComputeTaskSum(), [$(unroll_string_vector(inExprs))])",
|
||||
)
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(t::FusedComputeTask, inSymbols::Vector{Symbol}, outSymbol::Symbol)
|
||||
get_expression(t::FusedComputeTask, inExprs::Vector{String}, outExpr::String)
|
||||
|
||||
Generate code evaluating a [`FusedComputeTask`](@ref) on `inSymbols`, providing the output on `outSymbol`.
|
||||
`inSymbols` should be of the correct types and may be heterogeneous. `outSymbol` will be of the type of the output of `T2` of t.
|
||||
Generate code evaluating a [`FusedComputeTask`](@ref) on `inExprs`, providing the output on `outExpr`.
|
||||
`inExprs` should be of the correct types and may be heterogeneous. `outExpr` will be of the type of the output of `T2` of t.
|
||||
"""
|
||||
function get_expression(
|
||||
t::FusedComputeTask,
|
||||
inSymbols::Vector{Symbol},
|
||||
outSymbol::Symbol,
|
||||
inExprs::Vector{String},
|
||||
outExpr::String,
|
||||
)
|
||||
(T1, T2) = get_types(t)
|
||||
c1 = children(T1())
|
||||
c2 = children(T2())
|
||||
|
||||
c1 = length(t.t1_inputs)
|
||||
c2 = length(t.t2_inputs) + 1
|
||||
expr1 = nothing
|
||||
expr2 = nothing
|
||||
|
||||
# TODO need to figure out how to know which inputs belong to which subtask
|
||||
# since we order the vectors with the child nodes we can't just split
|
||||
if (c1 == 1)
|
||||
expr1 = get_expression(T1(), inSymbols[begin], :intermediate)
|
||||
elseif (c1 == 2)
|
||||
expr1 =
|
||||
get_expression(T1(), inSymbols[begin], inSymbols[2], :intermediate)
|
||||
else
|
||||
expr1 = get_expression(T1(), inSymbols[begin:c1], :intermediate)
|
||||
end
|
||||
|
||||
if (c2 == 1)
|
||||
expr2 = get_expression(T2(), :intermediate, outSymbol)
|
||||
elseif c2 == 2
|
||||
expr2 =
|
||||
get_expression(T2(), :intermediate, inSymbols[c1 + 1], outSymbol)
|
||||
else
|
||||
expr2 = get_expression(
|
||||
T2(),
|
||||
:intermediate * inSymbols[(c1 + 1):end],
|
||||
outSymbol,
|
||||
)
|
||||
end
|
||||
expr1 = get_expression(t.first_task, t.t1_inputs, t.t1_output)
|
||||
expr2 =
|
||||
get_expression(t.second_task, [t.t2_inputs..., t.t1_output], outExpr)
|
||||
|
||||
return Expr(:block, expr1, expr2)
|
||||
end
|
||||
@ -210,24 +198,27 @@ Generate and return code for a given [`ComputeTaskNode`](@ref).
|
||||
"""
|
||||
function get_expression(node::ComputeTaskNode)
|
||||
t = typeof(node.task)
|
||||
@assert length(node.children) == children(node.task) || t <: ComputeTaskSum
|
||||
@assert length(node.children) == children(node.task) || t <: ComputeTaskSum "Node $(node) has inconsistent number of children"
|
||||
|
||||
if (t <: ComputeTaskU || t <: ComputeTaskP || t <: ComputeTaskS1) # single input
|
||||
symbolIn = Symbol("data_$(to_var_name(node.children[1].id))")
|
||||
symbolOut = Symbol("data_$(to_var_name(node.id))")
|
||||
return get_expression(t(), symbolIn, symbolOut)
|
||||
symbolIn = "data_$(to_var_name(node.children[1].id))"
|
||||
symbolOut = "data_$(to_var_name(node.id))"
|
||||
return get_expression(node.task, [symbolIn], symbolOut)
|
||||
elseif (t <: ComputeTaskS2 || t <: ComputeTaskV) # double input
|
||||
symbolIn1 = Symbol("data_$(to_var_name(node.children[1].id))")
|
||||
symbolIn2 = Symbol("data_$(to_var_name(node.children[2].id))")
|
||||
symbolOut = Symbol("data_$(to_var_name(node.id))")
|
||||
return get_expression(t(), symbolIn1, symbolIn2, symbolOut)
|
||||
elseif (t <: ComputeTaskSum || t <: FusedComputeTask) # vector input
|
||||
inSymbols = Vector{Symbol}()
|
||||
symbolIn1 = "data_$(to_var_name(node.children[1].id))"
|
||||
symbolIn2 = "data_$(to_var_name(node.children[2].id))"
|
||||
symbolOut = "data_$(to_var_name(node.id))"
|
||||
return get_expression(node.task, [symbolIn1, symbolIn2], symbolOut)
|
||||
elseif (t <: ComputeTaskSum) # vector input
|
||||
inExprs = Vector{String}()
|
||||
for child in node.children
|
||||
push!(inSymbols, Symbol("data_$(to_var_name(child.id))"))
|
||||
push!(inExprs, "data_$(to_var_name(child.id))")
|
||||
end
|
||||
outSymbol = Symbol("data_$(to_var_name(node.id))")
|
||||
return get_expression(t(), inSymbols, outSymbol)
|
||||
outExpr = "data_$(to_var_name(node.id))"
|
||||
return get_expression(node.task, inExprs, outExpr)
|
||||
elseif t <: FusedComputeTask # fused compute task knows its inputs
|
||||
outExpr = "data_$(to_var_name(node.id))"
|
||||
return get_expression(node.task, Vector{String}(), outExpr)
|
||||
else
|
||||
error("Unknown compute task")
|
||||
end
|
||||
@ -242,15 +233,15 @@ function get_expression(node::DataTaskNode)
|
||||
# TODO: do things to transport data from/to gpu, between numa nodes, etc.
|
||||
@assert length(node.children) <= 1
|
||||
|
||||
inSymbol = nothing
|
||||
inExpr = nothing
|
||||
if (length(node.children) == 1)
|
||||
inSymbol = Symbol("data_$(to_var_name(node.children[1].id))")
|
||||
inExpr = "data_$(to_var_name(node.children[1].id))"
|
||||
else
|
||||
inSymbol = Symbol("data_$(to_var_name(node.id))_in")
|
||||
inExpr = "data_$(to_var_name(node.id))_in"
|
||||
end
|
||||
outSymbol = Symbol("data_$(to_var_name(node.id))")
|
||||
outExpr = "data_$(to_var_name(node.id))"
|
||||
|
||||
dataTransportExp = Meta.parse("$outSymbol = $inSymbol")
|
||||
dataTransportExp = Meta.parse("$outExpr = $inExpr")
|
||||
|
||||
return dataTransportExp
|
||||
end
|
||||
|
@ -63,10 +63,25 @@ function parse_abc(filename::String, verbose::Bool = false)
|
||||
end
|
||||
sizehint!(graph.nodes, estimate_no_nodes)
|
||||
|
||||
sum_node = insert_node!(graph, make_node(ComputeTaskSum()), false, false)
|
||||
global_data_out =
|
||||
insert_node!(graph, make_node(DataTask(FLOAT_SIZE)), false, false)
|
||||
insert_edge!(graph, sum_node, global_data_out, false, false)
|
||||
sum_node = insert_node!(
|
||||
graph,
|
||||
make_node(ComputeTaskSum()),
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
global_data_out = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(FLOAT_SIZE)),
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
insert_edge!(
|
||||
graph,
|
||||
sum_node,
|
||||
global_data_out,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
|
||||
# remember the data out nodes for connection
|
||||
dataOutNodes = Dict()
|
||||
@ -93,30 +108,62 @@ function parse_abc(filename::String, verbose::Bool = false)
|
||||
data_in = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE), string(node)),
|
||||
false,
|
||||
false,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
) # read particle data node
|
||||
compute_P =
|
||||
insert_node!(graph, make_node(ComputeTaskP()), false, false) # compute P node
|
||||
compute_P = insert_node!(
|
||||
graph,
|
||||
make_node(ComputeTaskP()),
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
) # compute P node
|
||||
data_Pu = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE)),
|
||||
false,
|
||||
false,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
) # transfer data from P to u (one ParticleValue object)
|
||||
compute_u =
|
||||
insert_node!(graph, make_node(ComputeTaskU()), false, false) # compute U node
|
||||
compute_u = insert_node!(
|
||||
graph,
|
||||
make_node(ComputeTaskU()),
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
) # compute U node
|
||||
data_out = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE)),
|
||||
false,
|
||||
false,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
) # transfer data out from u (one ParticleValue object)
|
||||
|
||||
insert_edge!(graph, data_in, compute_P, false, false)
|
||||
insert_edge!(graph, compute_P, data_Pu, false, false)
|
||||
insert_edge!(graph, data_Pu, compute_u, false, false)
|
||||
insert_edge!(graph, compute_u, data_out, false, false)
|
||||
insert_edge!(
|
||||
graph,
|
||||
data_in,
|
||||
compute_P,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
insert_edge!(
|
||||
graph,
|
||||
compute_P,
|
||||
data_Pu,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
insert_edge!(
|
||||
graph,
|
||||
data_Pu,
|
||||
compute_u,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
insert_edge!(
|
||||
graph,
|
||||
compute_u,
|
||||
data_out,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
|
||||
# remember the data_out node for future edges
|
||||
dataOutNodes[node] = data_out
|
||||
@ -126,13 +173,17 @@ function parse_abc(filename::String, verbose::Bool = false)
|
||||
in1 = capt.captures[1]
|
||||
in2 = capt.captures[2]
|
||||
|
||||
compute_v =
|
||||
insert_node!(graph, make_node(ComputeTaskV()), false, false)
|
||||
compute_v = insert_node!(
|
||||
graph,
|
||||
make_node(ComputeTaskV()),
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
data_out = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE)),
|
||||
false,
|
||||
false,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
|
||||
if (occursin(regex_c, in1))
|
||||
@ -140,22 +191,46 @@ function parse_abc(filename::String, verbose::Bool = false)
|
||||
compute_S = insert_node!(
|
||||
graph,
|
||||
make_node(ComputeTaskS1()),
|
||||
false,
|
||||
false,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
data_S_v = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE)),
|
||||
false,
|
||||
false,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
|
||||
insert_edge!(graph, dataOutNodes[in1], compute_S, false, false)
|
||||
insert_edge!(graph, compute_S, data_S_v, false, false)
|
||||
insert_edge!(
|
||||
graph,
|
||||
dataOutNodes[in1],
|
||||
compute_S,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
insert_edge!(
|
||||
graph,
|
||||
compute_S,
|
||||
data_S_v,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
|
||||
insert_edge!(graph, data_S_v, compute_v, false, false)
|
||||
insert_edge!(
|
||||
graph,
|
||||
data_S_v,
|
||||
compute_v,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
else
|
||||
insert_edge!(graph, dataOutNodes[in1], compute_v, false, false)
|
||||
insert_edge!(
|
||||
graph,
|
||||
dataOutNodes[in1],
|
||||
compute_v,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
end
|
||||
|
||||
if (occursin(regex_c, in2))
|
||||
@ -164,25 +239,55 @@ function parse_abc(filename::String, verbose::Bool = false)
|
||||
compute_S = insert_node!(
|
||||
graph,
|
||||
make_node(ComputeTaskS1()),
|
||||
false,
|
||||
false,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
data_S_v = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE)),
|
||||
false,
|
||||
false,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
|
||||
insert_edge!(graph, dataOutNodes[in2], compute_S, false, false)
|
||||
insert_edge!(graph, compute_S, data_S_v, false, false)
|
||||
insert_edge!(
|
||||
graph,
|
||||
dataOutNodes[in2],
|
||||
compute_S,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
insert_edge!(
|
||||
graph,
|
||||
compute_S,
|
||||
data_S_v,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
|
||||
insert_edge!(graph, data_S_v, compute_v, false, false)
|
||||
insert_edge!(
|
||||
graph,
|
||||
data_S_v,
|
||||
compute_v,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
else
|
||||
insert_edge!(graph, dataOutNodes[in2], compute_v, false, false)
|
||||
insert_edge!(
|
||||
graph,
|
||||
dataOutNodes[in2],
|
||||
compute_v,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
end
|
||||
|
||||
insert_edge!(graph, compute_v, data_out, false, false)
|
||||
insert_edge!(
|
||||
graph,
|
||||
compute_v,
|
||||
data_out,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
dataOutNodes[node] = data_out
|
||||
|
||||
elseif occursin(regex_m, node)
|
||||
@ -193,34 +298,84 @@ function parse_abc(filename::String, verbose::Bool = false)
|
||||
in3 = capt.captures[3]
|
||||
|
||||
# in2 + in3 with a v
|
||||
compute_v =
|
||||
insert_node!(graph, make_node(ComputeTaskV()), false, false)
|
||||
compute_v = insert_node!(
|
||||
graph,
|
||||
make_node(ComputeTaskV()),
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
data_v = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(PARTICLE_VALUE_SIZE)),
|
||||
false,
|
||||
false,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
|
||||
insert_edge!(graph, dataOutNodes[in2], compute_v, false, false)
|
||||
insert_edge!(graph, dataOutNodes[in3], compute_v, false, false)
|
||||
insert_edge!(graph, compute_v, data_v, false, false)
|
||||
insert_edge!(
|
||||
graph,
|
||||
dataOutNodes[in2],
|
||||
compute_v,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
insert_edge!(
|
||||
graph,
|
||||
dataOutNodes[in3],
|
||||
compute_v,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
insert_edge!(
|
||||
graph,
|
||||
compute_v,
|
||||
data_v,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
|
||||
# combine with the v of the combined other input
|
||||
compute_S2 =
|
||||
insert_node!(graph, make_node(ComputeTaskS2()), false, false)
|
||||
compute_S2 = insert_node!(
|
||||
graph,
|
||||
make_node(ComputeTaskS2()),
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
data_out = insert_node!(
|
||||
graph,
|
||||
make_node(DataTask(FLOAT_SIZE)),
|
||||
false,
|
||||
false,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
) # output of a S2 task is only a float
|
||||
|
||||
insert_edge!(graph, data_v, compute_S2, false, false)
|
||||
insert_edge!(graph, dataOutNodes[in1], compute_S2, false, false)
|
||||
insert_edge!(graph, compute_S2, data_out, false, false)
|
||||
insert_edge!(
|
||||
graph,
|
||||
data_v,
|
||||
compute_S2,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
insert_edge!(
|
||||
graph,
|
||||
dataOutNodes[in1],
|
||||
compute_S2,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
insert_edge!(
|
||||
graph,
|
||||
compute_S2,
|
||||
data_out,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
|
||||
insert_edge!(graph, data_out, sum_node, false, false)
|
||||
insert_edge!(
|
||||
graph,
|
||||
data_out,
|
||||
sum_node,
|
||||
track = false,
|
||||
invalidate_cache = false,
|
||||
)
|
||||
elseif occursin(regex_plus, node)
|
||||
if (verbose)
|
||||
println("\rReading Nodes Complete ")
|
||||
|
@ -21,6 +21,7 @@ A struct describing a particle of the ABC-Model. It has the 4 momentum parts P0.
|
||||
`sizeof(Particle())` = 40 Byte
|
||||
"""
|
||||
struct Particle
|
||||
# SFourMomentum
|
||||
P0::Float64
|
||||
P1::Float64
|
||||
P2::Float64
|
||||
|
@ -157,9 +157,8 @@ children(::ComputeTaskSum) = -1
|
||||
"""
|
||||
children(t::FusedComputeTask)
|
||||
|
||||
Return the number of children of a FusedComputeTask. It's the sum of the children of both tasks minus one.
|
||||
Return the number of children of a FusedComputeTask.
|
||||
"""
|
||||
function children(t::FusedComputeTask)
|
||||
(T1, T2) = get_types(t)
|
||||
return children(T1()) + children(T2()) - 1 # one of the inputs is the output of T1 and thus not a child of the node
|
||||
return length(union(Set(t.t1_inputs), Set(t.t2_inputs)))
|
||||
end
|
||||
|
@ -124,17 +124,25 @@ function revert_diff!(graph::DAG, diff::Diff)
|
||||
# add removed nodes, remove added nodes, same for edges
|
||||
# note the order
|
||||
for edge in diff.addedEdges
|
||||
remove_edge!(graph, edge.edge[1], edge.edge[2], false)
|
||||
remove_edge!(graph, edge.edge[1], edge.edge[2], track = false)
|
||||
end
|
||||
for node in diff.addedNodes
|
||||
remove_node!(graph, node, false)
|
||||
remove_node!(graph, node, track = false)
|
||||
end
|
||||
|
||||
for node in diff.removedNodes
|
||||
insert_node!(graph, node, false)
|
||||
insert_node!(graph, node, track = false)
|
||||
end
|
||||
for edge in diff.removedEdges
|
||||
insert_edge!(graph, edge.edge[1], edge.edge[2], false)
|
||||
insert_edge!(graph, edge.edge[1], edge.edge[2], track = false)
|
||||
end
|
||||
|
||||
for (node, before, after) in diff.updatedChildren
|
||||
# node must be fused compute task at this point
|
||||
@assert typeof(node.task) <: FusedComputeTask
|
||||
|
||||
replace!(node.task.t1_inputs, after => before)
|
||||
replace!(node.task.t2_inputs, after => before)
|
||||
end
|
||||
|
||||
graph.properties -= GraphProperties(diff)
|
||||
@ -175,9 +183,27 @@ function node_fusion!(
|
||||
n3_children = children(n3)
|
||||
remove_node!(graph, n3)
|
||||
|
||||
# assemble the input node vectors of n1 and n3 to save into the FusedComputeTask
|
||||
n1_inputs = Vector{String}()
|
||||
for child in n1_children
|
||||
push!(n1_inputs, "data_$(to_var_name(child.id))")
|
||||
end
|
||||
|
||||
n3_inputs = Vector{String}()
|
||||
for child in n3_children
|
||||
push!(n3_inputs, "data_$(to_var_name(child.id))")
|
||||
end
|
||||
|
||||
# create new node with the fused compute task
|
||||
new_node =
|
||||
ComputeTaskNode(FusedComputeTask{typeof(n1.task), typeof(n3.task)}())
|
||||
new_node = ComputeTaskNode(
|
||||
FusedComputeTask(
|
||||
n1.task,
|
||||
n3.task,
|
||||
n1_inputs,
|
||||
"data_$(to_var_name(n2.id))",
|
||||
n3_inputs,
|
||||
),
|
||||
)
|
||||
insert_node!(graph, new_node)
|
||||
|
||||
for child in n1_children
|
||||
@ -195,6 +221,15 @@ function node_fusion!(
|
||||
for parent in n3_parents
|
||||
remove_edge!(graph, n3, parent)
|
||||
insert_edge!(graph, new_node, parent)
|
||||
|
||||
# important! update the parent node's child names in case they are fused compute tasks
|
||||
# needed for compute generation so the fused compute task can correctly match inputs to its component tasks
|
||||
update_child!(
|
||||
graph,
|
||||
parent,
|
||||
"data_$(to_var_name(n3.id))",
|
||||
"data_$(to_var_name(new_node.id))",
|
||||
)
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
@ -217,7 +252,9 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
n1_children = children(n1)
|
||||
|
||||
n1_parents = Set(n1.parents)
|
||||
new_parents = Set{Node}()
|
||||
|
||||
# set of the new parents of n1, together with the names of the previous children that n1 now replaces
|
||||
new_parents = Set{Tuple{Node, String}}()
|
||||
|
||||
# remove all of the nodes' parents and children and the nodes themselves (except for first node)
|
||||
for i in 2:length(nodes)
|
||||
@ -230,7 +267,7 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
remove_edge!(graph, n, parent)
|
||||
|
||||
# collect all parents
|
||||
push!(new_parents, parent)
|
||||
push!(new_parents, (parent, "data_$(to_var_name(n.id))"))
|
||||
end
|
||||
|
||||
remove_node!(graph, n)
|
||||
@ -238,9 +275,11 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
|
||||
setdiff!(new_parents, n1_parents)
|
||||
|
||||
for parent in new_parents
|
||||
for (parent, prev_child) in new_parents
|
||||
# now add parents of all input nodes to n1 without duplicates
|
||||
insert_edge!(graph, n1, parent)
|
||||
|
||||
update_child!(graph, parent, prev_child, "data_$(to_var_name(n1.id))")
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
@ -275,6 +314,13 @@ function node_split!(graph::DAG, n1::Node)
|
||||
insert_node!(graph, n_copy)
|
||||
insert_edge!(graph, n_copy, parent)
|
||||
|
||||
update_child!(
|
||||
graph,
|
||||
parent,
|
||||
"data_$(to_var_name(n1.id))",
|
||||
"data_$(to_var_name(n_copy.id))",
|
||||
)
|
||||
|
||||
for child in n1_children
|
||||
insert_edge!(graph, child, n_copy)
|
||||
end
|
||||
|
@ -12,3 +12,25 @@ copy(t::AbstractDataTask) =
|
||||
Return a copy of the given compute task.
|
||||
"""
|
||||
copy(t::AbstractComputeTask) = typeof(t)()
|
||||
|
||||
"""
|
||||
copy(t::FusedComputeTask)
|
||||
|
||||
Return a copy of th egiven [`FusedComputeTask`](@ref).
|
||||
"""
|
||||
function copy(t::FusedComputeTask{T1, T2}) where {T1, T2}
|
||||
return FusedComputeTask{T1, T2}(
|
||||
t.first_task,
|
||||
t.second_task,
|
||||
t.t1_inputs,
|
||||
t.t1_output,
|
||||
t.t2_inputs,
|
||||
)
|
||||
end
|
||||
|
||||
FusedComputeTask{T1, T2}(
|
||||
t1_inputs::Vector{String},
|
||||
t1_output::String,
|
||||
t2_inputs::Vector{String},
|
||||
) where {T1, T2} =
|
||||
FusedComputeTask{T1, T2}(T1(), T2(), t1_inputs, t1_output, t2_inputs)
|
||||
|
@ -71,8 +71,7 @@ data(t::AbstractComputeTask) = 0
|
||||
Return the compute effort of a fused compute task.
|
||||
"""
|
||||
function compute_effort(t::FusedComputeTask)
|
||||
(T1, T2) = collect(typeof(t).parameters)
|
||||
return compute_effort(T1()) + compute_effort(T2())
|
||||
return compute_effort(t.first_task) + compute_effort(t.second_task)
|
||||
end
|
||||
|
||||
"""
|
||||
@ -81,30 +80,3 @@ end
|
||||
Return a tuple of a the fused compute task's components' types.
|
||||
"""
|
||||
get_types(::FusedComputeTask{T1, T2}) where {T1, T2} = (T1, T2)
|
||||
|
||||
"""
|
||||
get_expression(t::AbstractTask)
|
||||
|
||||
Return an expression evaluating the given task on the :dataIn symbol
|
||||
"""
|
||||
function get_expression(t::AbstractTask)
|
||||
return quote
|
||||
dataOut = compute($t, dataIn)
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression()
|
||||
"""
|
||||
function get_expression(
|
||||
t::FusedComputeTask,
|
||||
inSymbol::Symbol,
|
||||
outSymbol::Symbol,
|
||||
)
|
||||
#TODO
|
||||
computeExp = quote
|
||||
$outSymbol = compute($t, $inSymbol)
|
||||
end
|
||||
|
||||
return computeExp
|
||||
end
|
||||
|
@ -27,4 +27,13 @@ A fused compute task made up of the computation of first `T1` and then `T2`.
|
||||
Also see: [`get_types`](@ref).
|
||||
"""
|
||||
struct FusedComputeTask{T1 <: AbstractComputeTask, T2 <: AbstractComputeTask} <:
|
||||
AbstractComputeTask end
|
||||
AbstractComputeTask
|
||||
first_task::T1
|
||||
second_task::T2
|
||||
# the names of the inputs for T1
|
||||
t1_inputs::Vector{String}
|
||||
# output name of T1
|
||||
t1_output::String
|
||||
# t2_inputs doesn't include the output of t1, that's implicit
|
||||
t2_inputs::Vector{String}
|
||||
end
|
||||
|
@ -87,3 +87,19 @@ Return the memory footprint of the node in Byte. Used in [`mem(graph::DAG)`](@re
|
||||
function mem(node::Node)
|
||||
return Base.summarysize(node, exclude = Union{Node, Operation})
|
||||
end
|
||||
|
||||
"""
|
||||
unroll_string_vector(vec::Vector{String})
|
||||
|
||||
Return the given vector as single String without quotation marks or brackets.
|
||||
"""
|
||||
function unroll_string_vector(vec::Vector{String})
|
||||
result = ""
|
||||
for s in vec
|
||||
if (result != "")
|
||||
result *= ", "
|
||||
end
|
||||
result *= s
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
@ -5,51 +5,51 @@ import MetagraphOptimization.make_node
|
||||
@testset "Unit Tests Node Reduction" begin
|
||||
graph = MetagraphOptimization.DAG()
|
||||
|
||||
d_exit = insert_node!(graph, make_node(DataTask(10)), false)
|
||||
d_exit = insert_node!(graph, make_node(DataTask(10)), track = false)
|
||||
|
||||
s0 = insert_node!(graph, make_node(ComputeTaskS2()), false)
|
||||
s0 = insert_node!(graph, make_node(ComputeTaskS2()), track = false)
|
||||
|
||||
ED = insert_node!(graph, make_node(DataTask(3)), false)
|
||||
FD = insert_node!(graph, make_node(DataTask(3)), false)
|
||||
ED = insert_node!(graph, make_node(DataTask(3)), track = false)
|
||||
FD = insert_node!(graph, make_node(DataTask(3)), track = false)
|
||||
|
||||
EC = insert_node!(graph, make_node(ComputeTaskV()), false)
|
||||
FC = insert_node!(graph, make_node(ComputeTaskV()), false)
|
||||
EC = insert_node!(graph, make_node(ComputeTaskV()), track = false)
|
||||
FC = insert_node!(graph, make_node(ComputeTaskV()), track = false)
|
||||
|
||||
A1D = insert_node!(graph, make_node(DataTask(4)), false)
|
||||
B1D_1 = insert_node!(graph, make_node(DataTask(4)), false)
|
||||
B1D_2 = insert_node!(graph, make_node(DataTask(4)), false)
|
||||
C1D = insert_node!(graph, make_node(DataTask(4)), false)
|
||||
A1D = insert_node!(graph, make_node(DataTask(4)), track = false)
|
||||
B1D_1 = insert_node!(graph, make_node(DataTask(4)), track = false)
|
||||
B1D_2 = insert_node!(graph, make_node(DataTask(4)), track = false)
|
||||
C1D = insert_node!(graph, make_node(DataTask(4)), track = false)
|
||||
|
||||
A1C = insert_node!(graph, make_node(ComputeTaskU()), false)
|
||||
B1C_1 = insert_node!(graph, make_node(ComputeTaskU()), false)
|
||||
B1C_2 = insert_node!(graph, make_node(ComputeTaskU()), false)
|
||||
C1C = insert_node!(graph, make_node(ComputeTaskU()), false)
|
||||
A1C = insert_node!(graph, make_node(ComputeTaskU()), track = false)
|
||||
B1C_1 = insert_node!(graph, make_node(ComputeTaskU()), track = false)
|
||||
B1C_2 = insert_node!(graph, make_node(ComputeTaskU()), track = false)
|
||||
C1C = insert_node!(graph, make_node(ComputeTaskU()), track = false)
|
||||
|
||||
AD = insert_node!(graph, make_node(DataTask(5)), false)
|
||||
BD = insert_node!(graph, make_node(DataTask(5)), false)
|
||||
CD = insert_node!(graph, make_node(DataTask(5)), false)
|
||||
AD = insert_node!(graph, make_node(DataTask(5)), track = false)
|
||||
BD = insert_node!(graph, make_node(DataTask(5)), track = false)
|
||||
CD = insert_node!(graph, make_node(DataTask(5)), track = false)
|
||||
|
||||
insert_edge!(graph, s0, d_exit, false)
|
||||
insert_edge!(graph, ED, s0, false)
|
||||
insert_edge!(graph, FD, s0, false)
|
||||
insert_edge!(graph, EC, ED, false)
|
||||
insert_edge!(graph, FC, FD, false)
|
||||
insert_edge!(graph, s0, d_exit, track = false)
|
||||
insert_edge!(graph, ED, s0, track = false)
|
||||
insert_edge!(graph, FD, s0, track = false)
|
||||
insert_edge!(graph, EC, ED, track = false)
|
||||
insert_edge!(graph, FC, FD, track = false)
|
||||
|
||||
insert_edge!(graph, A1D, EC, false)
|
||||
insert_edge!(graph, B1D_1, EC, false)
|
||||
insert_edge!(graph, A1D, EC, track = false)
|
||||
insert_edge!(graph, B1D_1, EC, track = false)
|
||||
|
||||
insert_edge!(graph, B1D_2, FC, false)
|
||||
insert_edge!(graph, C1D, FC, false)
|
||||
insert_edge!(graph, B1D_2, FC, track = false)
|
||||
insert_edge!(graph, C1D, FC, track = false)
|
||||
|
||||
insert_edge!(graph, A1C, A1D, false)
|
||||
insert_edge!(graph, B1C_1, B1D_1, false)
|
||||
insert_edge!(graph, B1C_2, B1D_2, false)
|
||||
insert_edge!(graph, C1C, C1D, false)
|
||||
insert_edge!(graph, A1C, A1D, track = false)
|
||||
insert_edge!(graph, B1C_1, B1D_1, track = false)
|
||||
insert_edge!(graph, B1C_2, B1D_2, track = false)
|
||||
insert_edge!(graph, C1C, C1D, track = false)
|
||||
|
||||
insert_edge!(graph, AD, A1C, false)
|
||||
insert_edge!(graph, BD, B1C_1, false)
|
||||
insert_edge!(graph, BD, B1C_2, false)
|
||||
insert_edge!(graph, CD, C1C, false)
|
||||
insert_edge!(graph, AD, A1C, track = false)
|
||||
insert_edge!(graph, BD, B1C_1, track = false)
|
||||
insert_edge!(graph, BD, B1C_2, track = false)
|
||||
insert_edge!(graph, CD, C1C, track = false)
|
||||
|
||||
@test is_valid(graph)
|
||||
|
||||
|
@ -2,7 +2,9 @@ import MetagraphOptimization.A
|
||||
import MetagraphOptimization.B
|
||||
import MetagraphOptimization.ParticleType
|
||||
|
||||
@testset "Unit Tests Graph" begin
|
||||
include("../examples/profiling_utilities.jl")
|
||||
|
||||
@testset "Unit Tests Execution" begin
|
||||
particles = Dict{ParticleType, Vector{Particle}}(
|
||||
(
|
||||
A => [
|
||||
@ -20,12 +22,35 @@ import MetagraphOptimization.ParticleType
|
||||
|
||||
expected_result = 5.5320567694746876e-5
|
||||
|
||||
for _ in 1:10 # test in a loop because graph layout should not change the result
|
||||
graph = parse_abc(joinpath(@__DIR__, "..", "input", "AB->AB.txt"))
|
||||
@test isapprox(execute(graph, particles), expected_result; rtol = 0.001)
|
||||
@testset "AB->AB no optimization" begin
|
||||
for _ in 1:10 # test in a loop because graph layout should not change the result
|
||||
graph = parse_abc(joinpath(@__DIR__, "..", "input", "AB->AB.txt"))
|
||||
@test isapprox(
|
||||
execute(graph, particles),
|
||||
expected_result;
|
||||
rtol = 0.001,
|
||||
)
|
||||
|
||||
code = MetagraphOptimization.gen_code(graph)
|
||||
@test isapprox(execute(code, particles), expected_result; rtol = 0.001)
|
||||
code = MetagraphOptimization.gen_code(graph)
|
||||
@test isapprox(
|
||||
execute(code, particles),
|
||||
expected_result;
|
||||
rtol = 0.001,
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
@testset "AB->AB after random walk" begin
|
||||
for _ in 1:20
|
||||
graph = parse_abc(joinpath(@__DIR__, "..", "input", "AB->AB.txt"))
|
||||
random_walk!(graph, 40)
|
||||
|
||||
@test isapprox(
|
||||
execute(graph, particles),
|
||||
expected_result;
|
||||
rtol = 0.001,
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
println("Execution Unit Tests Complete!")
|
||||
|
@ -17,91 +17,91 @@ import MetagraphOptimization.partners
|
||||
(nodeFusions = 0, nodeReductions = 0, nodeSplits = 0)
|
||||
|
||||
# s to output (exit node)
|
||||
d_exit = insert_node!(graph, make_node(DataTask(10)), false)
|
||||
d_exit = insert_node!(graph, make_node(DataTask(10)), track = false)
|
||||
|
||||
@test length(graph.nodes) == 1
|
||||
@test length(graph.dirtyNodes) == 1
|
||||
|
||||
# final s compute
|
||||
s0 = insert_node!(graph, make_node(ComputeTaskS2()), false)
|
||||
s0 = insert_node!(graph, make_node(ComputeTaskS2()), track = false)
|
||||
|
||||
@test length(graph.nodes) == 2
|
||||
@test length(graph.dirtyNodes) == 2
|
||||
|
||||
# data from v0 and v1 to s0
|
||||
d_v0_s0 = insert_node!(graph, make_node(DataTask(5)), false)
|
||||
d_v1_s0 = insert_node!(graph, make_node(DataTask(5)), false)
|
||||
d_v0_s0 = insert_node!(graph, make_node(DataTask(5)), track = false)
|
||||
d_v1_s0 = insert_node!(graph, make_node(DataTask(5)), track = false)
|
||||
|
||||
# v0 and v1 compute
|
||||
v0 = insert_node!(graph, make_node(ComputeTaskV()), false)
|
||||
v1 = insert_node!(graph, make_node(ComputeTaskV()), false)
|
||||
v0 = insert_node!(graph, make_node(ComputeTaskV()), track = false)
|
||||
v1 = insert_node!(graph, make_node(ComputeTaskV()), track = false)
|
||||
|
||||
# data from uB, uA, uBp and uAp to v0 and v1
|
||||
d_uB_v0 = insert_node!(graph, make_node(DataTask(3)), false)
|
||||
d_uA_v0 = insert_node!(graph, make_node(DataTask(3)), false)
|
||||
d_uBp_v1 = insert_node!(graph, make_node(DataTask(3)), false)
|
||||
d_uAp_v1 = insert_node!(graph, make_node(DataTask(3)), false)
|
||||
d_uB_v0 = insert_node!(graph, make_node(DataTask(3)), track = false)
|
||||
d_uA_v0 = insert_node!(graph, make_node(DataTask(3)), track = false)
|
||||
d_uBp_v1 = insert_node!(graph, make_node(DataTask(3)), track = false)
|
||||
d_uAp_v1 = insert_node!(graph, make_node(DataTask(3)), track = false)
|
||||
|
||||
# uB, uA, uBp and uAp computes
|
||||
uB = insert_node!(graph, make_node(ComputeTaskU()), false)
|
||||
uA = insert_node!(graph, make_node(ComputeTaskU()), false)
|
||||
uBp = insert_node!(graph, make_node(ComputeTaskU()), false)
|
||||
uAp = insert_node!(graph, make_node(ComputeTaskU()), false)
|
||||
uB = insert_node!(graph, make_node(ComputeTaskU()), track = false)
|
||||
uA = insert_node!(graph, make_node(ComputeTaskU()), track = false)
|
||||
uBp = insert_node!(graph, make_node(ComputeTaskU()), track = false)
|
||||
uAp = insert_node!(graph, make_node(ComputeTaskU()), track = false)
|
||||
|
||||
# data from PB, PA, PBp and PAp to uB, uA, uBp and uAp
|
||||
d_PB_uB = insert_node!(graph, make_node(DataTask(6)), false)
|
||||
d_PA_uA = insert_node!(graph, make_node(DataTask(6)), false)
|
||||
d_PBp_uBp = insert_node!(graph, make_node(DataTask(6)), false)
|
||||
d_PAp_uAp = insert_node!(graph, make_node(DataTask(6)), false)
|
||||
d_PB_uB = insert_node!(graph, make_node(DataTask(6)), track = false)
|
||||
d_PA_uA = insert_node!(graph, make_node(DataTask(6)), track = false)
|
||||
d_PBp_uBp = insert_node!(graph, make_node(DataTask(6)), track = false)
|
||||
d_PAp_uAp = insert_node!(graph, make_node(DataTask(6)), track = false)
|
||||
|
||||
# P computes PB, PA, PBp and PAp
|
||||
PB = insert_node!(graph, make_node(ComputeTaskP()), false)
|
||||
PA = insert_node!(graph, make_node(ComputeTaskP()), false)
|
||||
PBp = insert_node!(graph, make_node(ComputeTaskP()), false)
|
||||
PAp = insert_node!(graph, make_node(ComputeTaskP()), false)
|
||||
PB = insert_node!(graph, make_node(ComputeTaskP()), track = false)
|
||||
PA = insert_node!(graph, make_node(ComputeTaskP()), track = false)
|
||||
PBp = insert_node!(graph, make_node(ComputeTaskP()), track = false)
|
||||
PAp = insert_node!(graph, make_node(ComputeTaskP()), track = false)
|
||||
|
||||
# entry nodes getting data for P computes
|
||||
d_PB = insert_node!(graph, make_node(DataTask(4)), false)
|
||||
d_PA = insert_node!(graph, make_node(DataTask(4)), false)
|
||||
d_PBp = insert_node!(graph, make_node(DataTask(4)), false)
|
||||
d_PAp = insert_node!(graph, make_node(DataTask(4)), false)
|
||||
d_PB = insert_node!(graph, make_node(DataTask(4)), track = false)
|
||||
d_PA = insert_node!(graph, make_node(DataTask(4)), track = false)
|
||||
d_PBp = insert_node!(graph, make_node(DataTask(4)), track = false)
|
||||
d_PAp = insert_node!(graph, make_node(DataTask(4)), track = false)
|
||||
|
||||
@test length(graph.nodes) == 26
|
||||
@test length(graph.dirtyNodes) == 26
|
||||
|
||||
# now for all the edges
|
||||
insert_edge!(graph, d_PB, PB, false)
|
||||
insert_edge!(graph, d_PA, PA, false)
|
||||
insert_edge!(graph, d_PBp, PBp, false)
|
||||
insert_edge!(graph, d_PAp, PAp, false)
|
||||
insert_edge!(graph, d_PB, PB, track = false)
|
||||
insert_edge!(graph, d_PA, PA, track = false)
|
||||
insert_edge!(graph, d_PBp, PBp, track = false)
|
||||
insert_edge!(graph, d_PAp, PAp, track = false)
|
||||
|
||||
insert_edge!(graph, PB, d_PB_uB, false)
|
||||
insert_edge!(graph, PA, d_PA_uA, false)
|
||||
insert_edge!(graph, PBp, d_PBp_uBp, false)
|
||||
insert_edge!(graph, PAp, d_PAp_uAp, false)
|
||||
insert_edge!(graph, PB, d_PB_uB, track = false)
|
||||
insert_edge!(graph, PA, d_PA_uA, track = false)
|
||||
insert_edge!(graph, PBp, d_PBp_uBp, track = false)
|
||||
insert_edge!(graph, PAp, d_PAp_uAp, track = false)
|
||||
|
||||
insert_edge!(graph, d_PB_uB, uB, false)
|
||||
insert_edge!(graph, d_PA_uA, uA, false)
|
||||
insert_edge!(graph, d_PBp_uBp, uBp, false)
|
||||
insert_edge!(graph, d_PAp_uAp, uAp, false)
|
||||
insert_edge!(graph, d_PB_uB, uB, track = false)
|
||||
insert_edge!(graph, d_PA_uA, uA, track = false)
|
||||
insert_edge!(graph, d_PBp_uBp, uBp, track = false)
|
||||
insert_edge!(graph, d_PAp_uAp, uAp, track = false)
|
||||
|
||||
insert_edge!(graph, uB, d_uB_v0, false)
|
||||
insert_edge!(graph, uA, d_uA_v0, false)
|
||||
insert_edge!(graph, uBp, d_uBp_v1, false)
|
||||
insert_edge!(graph, uAp, d_uAp_v1, false)
|
||||
insert_edge!(graph, uB, d_uB_v0, track = false)
|
||||
insert_edge!(graph, uA, d_uA_v0, track = false)
|
||||
insert_edge!(graph, uBp, d_uBp_v1, track = false)
|
||||
insert_edge!(graph, uAp, d_uAp_v1, track = false)
|
||||
|
||||
insert_edge!(graph, d_uB_v0, v0, false)
|
||||
insert_edge!(graph, d_uA_v0, v0, false)
|
||||
insert_edge!(graph, d_uBp_v1, v1, false)
|
||||
insert_edge!(graph, d_uAp_v1, v1, false)
|
||||
insert_edge!(graph, d_uB_v0, v0, track = false)
|
||||
insert_edge!(graph, d_uA_v0, v0, track = false)
|
||||
insert_edge!(graph, d_uBp_v1, v1, track = false)
|
||||
insert_edge!(graph, d_uAp_v1, v1, track = false)
|
||||
|
||||
insert_edge!(graph, v0, d_v0_s0, false)
|
||||
insert_edge!(graph, v1, d_v1_s0, false)
|
||||
insert_edge!(graph, v0, d_v0_s0, track = false)
|
||||
insert_edge!(graph, v1, d_v1_s0, track = false)
|
||||
|
||||
insert_edge!(graph, d_v0_s0, s0, false)
|
||||
insert_edge!(graph, d_v1_s0, s0, false)
|
||||
insert_edge!(graph, d_v0_s0, s0, track = false)
|
||||
insert_edge!(graph, d_v1_s0, s0, track = false)
|
||||
|
||||
insert_edge!(graph, s0, d_exit, false)
|
||||
insert_edge!(graph, s0, d_exit, track = false)
|
||||
|
||||
@test length(graph.nodes) == 26
|
||||
@test length(graph.appliedOperations) == 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user