Reimplement same code generation through new cache strategy interface
This commit is contained in:
parent
37d645cb4e
commit
dd01a5e691
@ -83,6 +83,7 @@ include("diff/type.jl")
|
||||
include("properties/type.jl")
|
||||
include("operation/type.jl")
|
||||
include("graph/type.jl")
|
||||
include("devices/interface.jl")
|
||||
|
||||
include("trie.jl")
|
||||
include("utility.jl")
|
||||
@ -116,6 +117,7 @@ include("properties/utility.jl")
|
||||
|
||||
include("task/create.jl")
|
||||
include("task/compare.jl")
|
||||
include("task/compute.jl")
|
||||
include("task/print.jl")
|
||||
include("task/properties.jl")
|
||||
|
||||
@ -130,7 +132,6 @@ include("models/abc/properties.jl")
|
||||
include("models/abc/parse.jl")
|
||||
include("models/abc/print.jl")
|
||||
|
||||
include("devices/interface.jl")
|
||||
include("devices/measure.jl")
|
||||
include("devices/detect.jl")
|
||||
include("devices/impl.jl")
|
||||
|
@ -11,7 +11,7 @@ Generate the code for a given graph. The return value is a named tuple of:
|
||||
|
||||
See also: [`execute`](@ref)
|
||||
"""
|
||||
function gen_code(graph::DAG)
|
||||
function gen_code(graph::DAG, machine::Machine)
|
||||
code = Vector{Expr}()
|
||||
sizehint!(code, length(graph.nodes))
|
||||
|
||||
@ -33,7 +33,7 @@ function gen_code(graph::DAG)
|
||||
@assert peek(nodeQueue)[2] == 0
|
||||
node = dequeue!(nodeQueue)
|
||||
|
||||
push!(code, get_expression(node))
|
||||
push!(code, get_expression(node, machine.devices[1]))
|
||||
for parent in node.parents
|
||||
# reduce the priority of all parents by one
|
||||
if (!haskey(nodeQueue, parent))
|
||||
@ -45,7 +45,7 @@ function gen_code(graph::DAG)
|
||||
end
|
||||
|
||||
# node is now the last node we looked at -> the output node
|
||||
outSym = Symbol("$(to_var_name(node.id))")
|
||||
outSym = Symbol(to_var_name(node.id))
|
||||
|
||||
return (code = Expr(:block, code...), inputSymbols = inputSyms, outputSymbol = outSym)
|
||||
end
|
||||
@ -93,20 +93,20 @@ end
|
||||
Return a function of signature `compute_<id>(input::AbstractProcessInput)`, which will return the result of the DAG computation on the given input.
|
||||
"""
|
||||
function get_compute_function(graph::DAG, process::AbstractProcessDescription, machine::Machine)
|
||||
(code, inputSymbols, outputSymbol) = gen_code(graph)
|
||||
(code, inputSymbols, outputSymbol) = gen_code(graph, machine)
|
||||
|
||||
assignInputs = gen_input_assignment_code(inputSymbols, process, machine, :input)
|
||||
|
||||
# TODO generate correct access expression
|
||||
# TODO how to define cahce strategies?
|
||||
# TODO how to define cache strategies?
|
||||
device = machine.devices[1]
|
||||
|
||||
functionId = to_var_name(UUIDs.uuid1(rng[1]))
|
||||
func = eval(
|
||||
Meta.parse(
|
||||
"function compute_$(functionId)(input::AbstractProcessInput) $assignInputs; $code; return $(eval(gen_access_expr(device, default_strategy(device), outputSymbol))); end",
|
||||
),
|
||||
resSym = eval(gen_access_expr(device, default_strategy(device), outputSymbol))
|
||||
expr = Meta.parse(
|
||||
"function compute_$(functionId)(input::AbstractProcessInput) $assignInputs; $code; return $resSym; end",
|
||||
)
|
||||
func = eval(expr)
|
||||
|
||||
return func
|
||||
end
|
||||
|
@ -82,6 +82,6 @@ function gen_cache_init_code end
|
||||
gen_access_expr(device::AbstractDevice, strategy::CacheStrategy, symbol::Symbol)
|
||||
|
||||
Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref) and at least one [`CacheStrategy`](@ref).
|
||||
Return an `Expr` accessing the variable identified by [`symbol`].
|
||||
Return an `Expr` or `QuoteNode` accessing the variable identified by [`symbol`].
|
||||
"""
|
||||
function gen_access_expr end
|
||||
|
@ -57,5 +57,6 @@ Generate code to access the variable designated by `symbol` using the [`LocalVar
|
||||
"""
|
||||
function gen_access_expr(::NumaNode, ::LocalVariables, symbol::Symbol)
|
||||
s = Symbol("data_$symbol")
|
||||
return Meta.parse(":($s)")
|
||||
quoteNode = Meta.parse(":($s)")
|
||||
return quoteNode
|
||||
end
|
||||
|
@ -5,7 +5,7 @@ A named tuple representing a difference of added and removed nodes and edges on
|
||||
"""
|
||||
const Diff = NamedTuple{
|
||||
(:addedNodes, :removedNodes, :addedEdges, :removedEdges, :updatedChildren),
|
||||
Tuple{Vector{Node}, Vector{Node}, Vector{Edge}, Vector{Edge}, Vector{Tuple{Node, String, String}}},
|
||||
Tuple{Vector{Node}, Vector{Node}, Vector{Edge}, Vector{Edge}, Vector{Tuple{Node, Symbol, Symbol}}},
|
||||
}
|
||||
|
||||
function Diff()
|
||||
@ -16,6 +16,6 @@ function Diff()
|
||||
removedEdges = Vector{Edge}(),
|
||||
|
||||
# children were updated from updatedChildren[2] to updatedChildren[3] in node updatedChildren[1]
|
||||
updatedChildren = Vector{Tuple{Node, String, String}}(),
|
||||
updatedChildren = Vector{Tuple{Node, Symbol, Symbol}}(),
|
||||
)::Diff
|
||||
end
|
||||
|
@ -175,7 +175,7 @@ function replace_children!(task::AbstractTask, before, after)
|
||||
return nothing
|
||||
end
|
||||
|
||||
function update_child!(graph::DAG, n::Node, child_before::String, child_after::String; track = true)
|
||||
function update_child!(graph::DAG, n::Node, child_before::Symbol, child_after::Symbol; track = true)
|
||||
# only need to update fused compute tasks
|
||||
if !(typeof(n.task) <: FusedComputeTask)
|
||||
return nothing
|
||||
|
@ -77,144 +77,78 @@ function compute(::ComputeTaskSum, data::Vector{Float64})
|
||||
end
|
||||
|
||||
"""
|
||||
compute(t::FusedComputeTask, data)
|
||||
get_expression(::ComputeTaskP, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
|
||||
|
||||
Compute a [`FusedComputeTask`](@ref). This simply asserts false and should not be called. Fused Compute Tasks generate their expressions directly through the other tasks instead.
|
||||
Generate and return code evaluating [`ComputeTaskP`](@ref) on `inSyms`, providing the output on `outSym`.
|
||||
"""
|
||||
function compute(t::FusedComputeTask, data)
|
||||
@assert false "This is not implemented and should never be called"
|
||||
function get_expression(::ComputeTaskP, device::AbstractDevice, inExprs::Vector, outExpr)
|
||||
in = [eval(inExprs[1])]
|
||||
out = eval(outExpr)
|
||||
|
||||
return Meta.parse("$out = compute(ComputeTaskP(), $(in[1]))")
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(::ComputeTaskP, inExprs::Vector{String}, outExpr::String)
|
||||
get_expression(::ComputeTaskU, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
|
||||
|
||||
Generate and return code evaluating [`ComputeTaskP`](@ref) on `inExpr`, providing the output on `outExpr`.
|
||||
Generate code evaluating [`ComputeTaskU`](@ref) on `inSyms`, providing the output on `outSym`.
|
||||
`inSyms` should be of type [`ParticleValue`](@ref), `outSym` will be of type [`ParticleValue`](@ref).
|
||||
"""
|
||||
function get_expression(::ComputeTaskP, inExprs::Vector{String}, outExpr::String)
|
||||
return Meta.parse("$outExpr = compute(ComputeTaskP(), $(inExprs[1]))")
|
||||
function get_expression(::ComputeTaskU, device::AbstractDevice, inExprs::Vector, outExpr)
|
||||
in = [eval(inExprs[1])]
|
||||
out = eval(outExpr)
|
||||
|
||||
return Meta.parse("$out = compute(ComputeTaskU(), $(in[1]))")
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(::ComputeTaskU, inExprs::Vector{String}, outExpr::String)
|
||||
get_expression(::ComputeTaskV, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
|
||||
|
||||
Generate code evaluating [`ComputeTaskU`](@ref) on `inExpr`, providing the output on `outExpr`.
|
||||
`inExpr` should be of type [`ParticleValue`](@ref), `outExpr` will be of type [`ParticleValue`](@ref).
|
||||
Generate code evaluating [`ComputeTaskV`](@ref) on `inSyms`, providing the output on `outSym`.
|
||||
`inSym[1]` and `inSym[2]` should be of type [`ParticleValue`](@ref), `outSym` will be of type [`ParticleValue`](@ref).
|
||||
"""
|
||||
function get_expression(::ComputeTaskU, inExprs::Vector{String}, outExpr::String)
|
||||
return Meta.parse("$outExpr = compute(ComputeTaskU(), $(inExprs[1]))")
|
||||
function get_expression(::ComputeTaskV, device::AbstractDevice, inExprs::Vector, outExpr)
|
||||
in = [eval(inExprs[1]), eval(inExprs[2])]
|
||||
out = eval(outExpr)
|
||||
|
||||
return Meta.parse("$out = compute(ComputeTaskV(), $(in[1]), $(in[2]))")
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(::ComputeTaskV, inExprs::Vector{String}, outExpr::String)
|
||||
get_expression(::ComputeTaskS2, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
|
||||
|
||||
Generate code evaluating [`ComputeTaskV`](@ref) on `inExpr1` and `inExpr2`, providing the output on `outExpr`.
|
||||
`inExpr1` and `inExpr2` should be of type [`ParticleValue`](@ref), `outExpr` will be of type [`ParticleValue`](@ref).
|
||||
Generate code evaluating [`ComputeTaskS2`](@ref) on `inSyms`, providing the output on `outSym`.
|
||||
`inSyms[1]` and `inSyms[2]` should be of type [`ParticleValue`](@ref), `outSym` will be of type `Float64`.
|
||||
"""
|
||||
function get_expression(::ComputeTaskV, inExprs::Vector{String}, outExpr::String)
|
||||
return Meta.parse("$outExpr = compute(ComputeTaskV(), $(inExprs[1]), $(inExprs[2]))")
|
||||
function get_expression(::ComputeTaskS2, device::AbstractDevice, inExprs::Vector, outExpr)
|
||||
in = [eval(inExprs[1]), eval(inExprs[2])]
|
||||
out = eval(outExpr)
|
||||
|
||||
return Meta.parse("$out = compute(ComputeTaskS2(), $(in[1]), $(in[2]))")
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(::ComputeTaskS2, inExprs::Vector{String}, outExpr::String)
|
||||
get_expression(::ComputeTaskS1, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
|
||||
|
||||
Generate code evaluating [`ComputeTaskS2`](@ref) on `inExpr1` and `inExpr2`, providing the output on `outExpr`.
|
||||
`inExpr1` and `inExpr2` should be of type [`ParticleValue`](@ref), `outExpr` will be of type `Float64`.
|
||||
Generate code evaluating [`ComputeTaskS1`](@ref) on `inSyms`, providing the output on `outSym`.
|
||||
`inSyms` should be of type [`ParticleValue`](@ref), `outSym` will be of type [`ParticleValue`](@ref).
|
||||
"""
|
||||
function get_expression(::ComputeTaskS2, inExprs::Vector{String}, outExpr::String)
|
||||
return Meta.parse("$outExpr = compute(ComputeTaskS2(), $(inExprs[1]), $(inExprs[2]))")
|
||||
function get_expression(::ComputeTaskS1, device::AbstractDevice, inExprs::Vector, outExpr)
|
||||
in = [eval(inExprs[1])]
|
||||
out = eval(outExpr)
|
||||
|
||||
return Meta.parse("$out = compute(ComputeTaskS1(), $(in[1]))")
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(::ComputeTaskS1, inExprs::Vector{String}, outExpr::String)
|
||||
get_expression(::ComputeTaskSum, device::AbstractDevice, inExprs::Vector{Expr}, outExpr::Expr)
|
||||
|
||||
Generate code evaluating [`ComputeTaskS1`](@ref) on `inExpr`, providing the output on `outExpr`.
|
||||
`inExpr` should be of type [`ParticleValue`](@ref), `outExpr` will be of type [`ParticleValue`](@ref).
|
||||
Generate code evaluating [`ComputeTaskSum`](@ref) on `inSyms`, providing the output on `outSym`.
|
||||
`inSyms` should be of type [`Float64`], `outSym` will be of type [`Float64`].
|
||||
"""
|
||||
function get_expression(::ComputeTaskS1, inExprs::Vector{String}, outExpr::String)
|
||||
return Meta.parse("$outExpr = compute(ComputeTaskS1(), $(inExprs[1]))")
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(::ComputeTaskSum, inExprs::Vector{String}, outExpr::String)
|
||||
|
||||
Generate code evaluating [`ComputeTaskSum`](@ref) on `inExprs`, providing the output on `outExpr`.
|
||||
`inExprs` should be of type [`Float64`], `outExpr` will be of type [`Float64`].
|
||||
"""
|
||||
function get_expression(::ComputeTaskSum, inExprs::Vector{String}, outExpr::String)
|
||||
return Meta.parse("$outExpr = compute(ComputeTaskSum(), [$(unroll_string_vector(inExprs))])")
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(t::FusedComputeTask, inExprs::Vector{String}, outExpr::String)
|
||||
|
||||
Generate code evaluating a [`FusedComputeTask`](@ref) on `inExprs`, providing the output on `outExpr`.
|
||||
`inExprs` should be of the correct types and may be heterogeneous. `outExpr` will be of the type of the output of `T2` of t.
|
||||
"""
|
||||
function get_expression(t::FusedComputeTask, inExprs::Vector{String}, outExpr::String)
|
||||
c1 = length(t.t1_inputs)
|
||||
c2 = length(t.t2_inputs) + 1
|
||||
expr1 = nothing
|
||||
expr2 = nothing
|
||||
|
||||
expr1 = get_expression(t.first_task, t.t1_inputs, t.t1_output)
|
||||
expr2 = get_expression(t.second_task, [t.t2_inputs..., t.t1_output], outExpr)
|
||||
|
||||
full_expr = Expr(:block, expr1, expr2)
|
||||
|
||||
return full_expr
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(node::ComputeTaskNode)
|
||||
|
||||
Generate and return code for a given [`ComputeTaskNode`](@ref).
|
||||
"""
|
||||
function get_expression(node::ComputeTaskNode)
|
||||
t = typeof(node.task)
|
||||
@assert length(node.children) >= children(node.task) "Node $(node) has too few children for its task: node has $(length(node.children)) versus task has $(children(node.task))\nNode's children: $(getfield.(node.children, :children))"
|
||||
|
||||
if (t <: ComputeTaskU || t <: ComputeTaskP || t <: ComputeTaskS1) # single input
|
||||
symbolIn = "data_$(to_var_name(node.children[1].id))"
|
||||
symbolOut = "data_$(to_var_name(node.id))"
|
||||
return get_expression(node.task, [symbolIn], symbolOut)
|
||||
elseif (t <: ComputeTaskS2 || t <: ComputeTaskV) # double input
|
||||
symbolIn1 = "data_$(to_var_name(node.children[1].id))"
|
||||
symbolIn2 = "data_$(to_var_name(node.children[2].id))"
|
||||
symbolOut = "data_$(to_var_name(node.id))"
|
||||
return get_expression(node.task, [symbolIn1, symbolIn2], symbolOut)
|
||||
elseif (t <: ComputeTaskSum) # vector input
|
||||
inExprs = Vector{String}()
|
||||
for child in node.children
|
||||
push!(inExprs, "data_$(to_var_name(child.id))")
|
||||
end
|
||||
outExpr = "data_$(to_var_name(node.id))"
|
||||
return get_expression(node.task, inExprs, outExpr)
|
||||
elseif t <: FusedComputeTask # fused compute task knows its inputs
|
||||
outExpr = "data_$(to_var_name(node.id))"
|
||||
return get_expression(node.task, Vector{String}(), outExpr)
|
||||
else
|
||||
error("Unknown compute task")
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(node::DataTaskNode)
|
||||
|
||||
Generate and return code for a given [`DataTaskNode`](@ref).
|
||||
"""
|
||||
function get_expression(node::DataTaskNode)
|
||||
# TODO: do things to transport data from/to gpu, between numa nodes, etc.
|
||||
@assert length(node.children) <= 1
|
||||
|
||||
inExpr = nothing
|
||||
if (length(node.children) == 1)
|
||||
inExpr = "data_$(to_var_name(node.children[1].id))"
|
||||
else
|
||||
inExpr = "data_$(to_var_name(node.id))_in"
|
||||
end
|
||||
outExpr = "data_$(to_var_name(node.id))"
|
||||
|
||||
dataTransportExp = Meta.parse("$outExpr = $inExpr")
|
||||
|
||||
return dataTransportExp
|
||||
function get_expression(::ComputeTaskSum, device::AbstractDevice, inExprs::Vector, outExpr)
|
||||
in = eval.(inExprs)
|
||||
out = eval(outExpr)
|
||||
|
||||
return Meta.parse("$out = compute(ComputeTaskSum(), [$(unroll_symbol_vector(in))])")
|
||||
end
|
||||
|
@ -36,7 +36,7 @@ function is_valid_node(graph::DAG, node::Node)
|
||||
|
||||
# every child must be in some input of the task
|
||||
for child in node.children
|
||||
str = "data_$(to_var_name(child.id))"
|
||||
str = Symbol(to_var_name(child.id))
|
||||
@assert (str in node.task.t1_inputs) || (str in node.task.t2_inputs) "$str was not in any of the tasks' inputs\nt1_inputs: $(node.task.t1_inputs)\nt2_inputs: $(node.task.t2_inputs)"
|
||||
end
|
||||
|
||||
|
@ -172,18 +172,18 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com
|
||||
remove_node!(graph, n3)
|
||||
|
||||
# assemble the input node vectors of n1 and n3 to save into the FusedComputeTask
|
||||
n1_inputs = Vector{String}()
|
||||
n1_inputs = Vector{Symbol}()
|
||||
for child in n1_children
|
||||
push!(n1_inputs, "data_$(to_var_name(child.id))")
|
||||
push!(n1_inputs, Symbol(to_var_name(child.id)))
|
||||
end
|
||||
|
||||
n3_inputs = Vector{String}()
|
||||
n3_inputs = Vector{Symbol}()
|
||||
for child in n3_children
|
||||
push!(n3_inputs, "data_$(to_var_name(child.id))")
|
||||
push!(n3_inputs, Symbol(to_var_name(child.id)))
|
||||
end
|
||||
|
||||
# create new node with the fused compute task
|
||||
new_node = ComputeTaskNode(FusedComputeTask(n1.task, n3.task, n1_inputs, "data_$(to_var_name(n2.id))", n3_inputs))
|
||||
new_node = ComputeTaskNode(FusedComputeTask(n1.task, n3.task, n1_inputs, Symbol(to_var_name(n2.id)), n3_inputs))
|
||||
insert_node!(graph, new_node)
|
||||
|
||||
for child in n1_children
|
||||
@ -204,7 +204,7 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com
|
||||
|
||||
# important! update the parent node's child names in case they are fused compute tasks
|
||||
# needed for compute generation so the fused compute task can correctly match inputs to its component tasks
|
||||
update_child!(graph, parent, "data_$(to_var_name(n3.id))", "data_$(to_var_name(new_node.id))")
|
||||
update_child!(graph, parent, Symbol(to_var_name(n3.id)), Symbol(to_var_name(new_node.id)))
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
@ -231,7 +231,7 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
# set of the new parents of n1
|
||||
new_parents = Set{Node}()
|
||||
# names of the previous children that n1 now replaces per parent
|
||||
new_parents_child_names = Dict{Node, String}()
|
||||
new_parents_child_names = Dict{Node, Symbol}()
|
||||
|
||||
str = Vector{String}()
|
||||
for n in nodes
|
||||
@ -251,7 +251,7 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
|
||||
# collect all parents
|
||||
push!(new_parents, parent)
|
||||
new_parents_child_names[parent] = "data_$(to_var_name(n.id))"
|
||||
new_parents_child_names[parent] = Symbol(to_var_name(n.id))
|
||||
end
|
||||
|
||||
remove_node!(graph, n)
|
||||
@ -264,7 +264,7 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
insert_edge!(graph, n1, parent)
|
||||
|
||||
prev_child = new_parents_child_names[parent]
|
||||
update_child!(graph, parent, prev_child, "data_$(to_var_name(n1.id))")
|
||||
update_child!(graph, parent, prev_child, Symbol(to_var_name(n1.id)))
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
@ -304,7 +304,7 @@ function node_split!(graph::DAG, n1::Node)
|
||||
insert_edge!(graph, child, n_copy)
|
||||
end
|
||||
|
||||
update_child!(graph, parent, "data_$(to_var_name(n1.id))", "data_$(to_var_name(n_copy.id))")
|
||||
update_child!(graph, parent, Symbol(to_var_name(n1.id)), Symbol(to_var_name(n_copy.id)))
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
|
89
src/task/compute.jl
Normal file
89
src/task/compute.jl
Normal file
@ -0,0 +1,89 @@
|
||||
|
||||
"""
|
||||
compute(t::FusedComputeTask, data)
|
||||
|
||||
Compute a [`FusedComputeTask`](@ref). This simply asserts false and should not be called. Fused Compute Tasks generate their expressions directly through the other tasks instead.
|
||||
"""
|
||||
function compute(t::FusedComputeTask, data)
|
||||
@assert false "This is not implemented and should never be called"
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(t::FusedComputeTask, device::AbstractDevice, inExprs::Vector{String}, outExpr::String)
|
||||
|
||||
Generate code evaluating a [`FusedComputeTask`](@ref) on `inExprs`, providing the output on `outExpr`.
|
||||
`inExprs` should be of the correct types and may be heterogeneous. `outExpr` will be of the type of the output of `T2` of t.
|
||||
"""
|
||||
function get_expression(t::FusedComputeTask, device::AbstractDevice, inExprs::Vector, outExpr)
|
||||
c1 = length(t.t1_inputs)
|
||||
c2 = length(t.t2_inputs) + 1
|
||||
expr1 = nothing
|
||||
expr2 = nothing
|
||||
|
||||
cacheStrategy = default_strategy(device)
|
||||
|
||||
inExprs1 = Vector()
|
||||
for sym in t.t1_inputs
|
||||
push!(inExprs1, gen_access_expr(device, cacheStrategy, sym))
|
||||
end
|
||||
|
||||
outExpr1 = gen_access_expr(device, cacheStrategy, t.t1_output)
|
||||
|
||||
inExprs2 = Vector()
|
||||
for sym in t.t2_inputs
|
||||
push!(inExprs2, gen_access_expr(device, cacheStrategy, sym))
|
||||
end
|
||||
|
||||
expr1 = get_expression(t.first_task, device, inExprs1, outExpr1)
|
||||
expr2 = get_expression(t.second_task, device, [inExprs2..., outExpr1], outExpr)
|
||||
|
||||
full_expr = Expr(:block, expr1, expr2)
|
||||
|
||||
return full_expr
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(node::ComputeTaskNode, device::AbstractDevice)
|
||||
|
||||
Generate and return code for a given [`ComputeTaskNode`](@ref).
|
||||
"""
|
||||
function get_expression(node::ComputeTaskNode, device::AbstractDevice)
|
||||
t = typeof(node.task)
|
||||
@assert length(node.children) >= children(node.task) "Node $(node) has too few children for its task: node has $(length(node.children)) versus task has $(children(node.task))\nNode's children: $(getfield.(node.children, :children))"
|
||||
|
||||
# TODO get device from the node
|
||||
cacheStrategy = default_strategy(device)
|
||||
|
||||
inExprs = Vector()
|
||||
for id in getfield.(node.children, :id)
|
||||
push!(inExprs, gen_access_expr(device, cacheStrategy, Symbol(to_var_name(id))))
|
||||
end
|
||||
outExpr = gen_access_expr(device, cacheStrategy, Symbol(to_var_name(node.id)))
|
||||
|
||||
return get_expression(node.task, device, inExprs, outExpr)
|
||||
end
|
||||
|
||||
"""
|
||||
get_expression(node::DataTaskNode, device::AbstractDevice)
|
||||
|
||||
Generate and return code for a given [`DataTaskNode`](@ref).
|
||||
"""
|
||||
function get_expression(node::DataTaskNode, device::AbstractDevice)
|
||||
@assert length(node.children) <= 1
|
||||
|
||||
# TODO: do things to transport data from/to gpu, between numa nodes, etc.
|
||||
# TODO get device from the node
|
||||
|
||||
cacheStrategy = default_strategy(device)
|
||||
inExpr = nothing
|
||||
if (length(node.children) == 1)
|
||||
inExpr = eval(gen_access_expr(device, cacheStrategy, Symbol(to_var_name(node.children[1].id))))
|
||||
else
|
||||
inExpr = eval(gen_access_expr(device, cacheStrategy, Symbol("$(to_var_name(node.id))_in")))
|
||||
end
|
||||
outExpr = eval(gen_access_expr(device, cacheStrategy, Symbol(to_var_name(node.id))))
|
||||
|
||||
dataTransportExp = Meta.parse("$outExpr = $inExpr")
|
||||
|
||||
return dataTransportExp
|
||||
end
|
@ -30,9 +30,9 @@ struct FusedComputeTask{T1 <: AbstractComputeTask, T2 <: AbstractComputeTask} <:
|
||||
first_task::T1
|
||||
second_task::T2
|
||||
# the names of the inputs for T1
|
||||
t1_inputs::Vector{String}
|
||||
t1_inputs::Vector{Symbol}
|
||||
# output name of T1
|
||||
t1_output::String
|
||||
t1_output::Symbol
|
||||
# t2_inputs doesn't include the output of t1, that's implicit
|
||||
t2_inputs::Vector{String}
|
||||
t2_inputs::Vector{Symbol}
|
||||
end
|
||||
|
@ -89,17 +89,17 @@ function mem(node::Node)
|
||||
end
|
||||
|
||||
"""
|
||||
unroll_string_vector(vec::Vector{String})
|
||||
unroll_symbol_vector(vec::Vector{Symbol})
|
||||
|
||||
Return the given vector as single String without quotation marks or brackets.
|
||||
"""
|
||||
function unroll_string_vector(vec::Vector{String})
|
||||
function unroll_symbol_vector(vec::Vector)
|
||||
result = ""
|
||||
for s in vec
|
||||
if (result != "")
|
||||
result *= ", "
|
||||
end
|
||||
result *= s
|
||||
result *= "$s"
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
Loading…
x
Reference in New Issue
Block a user