This commit is contained in:
2023-10-04 11:05:49 +02:00
parent f9e60a7b5e
commit cbfed20b82
13 changed files with 801 additions and 494 deletions

View File

@@ -12,9 +12,6 @@ Generate the code for a given graph. The return value is a named tuple of:
See also: [`execute`](@ref)
"""
function gen_code(graph::DAG, machine::Machine)
code = Vector{Expr}()
sizehint!(code, length(graph.nodes))
nodeQueue = PriorityQueue{Node, Int}()
inputSyms = Dict{String, Vector{Symbol}}()
@@ -28,12 +25,16 @@ function gen_code(graph::DAG, machine::Machine)
push!(inputSyms[node.name], Symbol("$(to_var_name(node.id))_in"))
end
schedule = Vector{Node}()
sizehint!(schedule, length(graph.nodes))
# "scheduling"
node = nothing
while !isempty(nodeQueue)
@assert peek(nodeQueue)[2] == 0
node = dequeue!(nodeQueue)
push!(code, get_expression(node, machine.devices[1]))
push!(schedule, node)
for parent in node.parents
# reduce the priority of all parents by one
if (!haskey(nodeQueue, parent))
@@ -44,16 +45,27 @@ function gen_code(graph::DAG, machine::Machine)
end
end
codeAcc = Vector{Expr}()
sizehint!(codeAcc, length(graph.nodes))
for node in schedule
push!(codeAcc, get_expression(node, machine.devices[1]))
end
# node is now the last node we looked at -> the output node
outSym = Symbol(to_var_name(node.id))
return (code = Expr(:block, code...), inputSymbols = inputSyms, outputSymbol = outSym)
return (code = Expr(:block, codeAcc...), inputSymbols = inputSyms, outputSymbol = outSym)
end
function gen_cache_init_code(machine::Machine)
initializeCaches = Vector{Expr}()
return initializeCaches
for device in machine.devices
push!(initializeCaches, gen_cache_init_code(device))
end
return Expr(:block, initializeCaches...)
end
function gen_input_assignment_code(
@@ -85,7 +97,7 @@ function gen_input_assignment_code(
# TODO generate correct access expression
# TODO how to define cahce strategies?
device = machine.devices[1]
evalExpr = eval(gen_access_expr(device, cache_strategy(device), symbol))
evalExpr = eval(gen_access_expr(device, symbol))
push!(assignInputs, Meta.parse("$(evalExpr) = ParticleValue($p, 1.0)"))
end
end
@@ -101,6 +113,7 @@ Return a function of signature `compute_<id>(input::AbstractProcessInput)`, whic
function get_compute_function(graph::DAG, process::AbstractProcessDescription, machine::Machine)
(code, inputSymbols, outputSymbol) = gen_code(graph, machine)
initCaches = gen_cache_init_code(machine)
assignInputs = gen_input_assignment_code(inputSymbols, process, machine, :input)
# TODO generate correct access expression
@@ -108,9 +121,9 @@ function get_compute_function(graph::DAG, process::AbstractProcessDescription, m
device = machine.devices[1]
functionId = to_var_name(UUIDs.uuid1(rng[1]))
resSym = eval(gen_access_expr(device, cache_strategy(device), outputSymbol))
resSym = eval(gen_access_expr(device, outputSymbol))
expr = Meta.parse(
"function compute_$(functionId)(input::AbstractProcessInput) $assignInputs; $code; return $resSym; end",
"function compute_$(functionId)(input::AbstractProcessInput) $initCaches; $assignInputs; $code; return $resSym; end",
)
func = eval(expr)
@@ -131,7 +144,21 @@ This is essentially shorthand for
See also: [`parse_dag`](@ref), [`parse_process`](@ref), [`gen_process_input`](@ref)
"""
function execute(graph::DAG, process::AbstractProcessDescription, machine::Machine, input::AbstractProcessInput)
func = get_compute_function(graph, process, machine)
(code, inputSymbols, outputSymbol) = gen_code(graph, machine)
initCaches = gen_cache_init_code(machine)
assignInputs = gen_input_assignment_code(inputSymbols, process, machine, :input)
# TODO generate correct access expression
# TODO how to define cache strategies?
device = machine.devices[1]
functionId = to_var_name(UUIDs.uuid1(rng[1]))
resSym = eval(gen_access_expr(device, outputSymbol))
expr = Meta.parse(
"function compute_$(functionId)(input::AbstractProcessInput) $initCaches; $assignInputs; $code; return $resSym; end",
)
func = eval(expr)
result = 0
try
@@ -139,7 +166,7 @@ function execute(graph::DAG, process::AbstractProcessDescription, machine::Machi
catch e
println("Error while evaluating: $e")
println("Function: $func")
println("Function:\n$expr")
@assert false
end

View File

@@ -92,7 +92,7 @@ Interface function that must be implemented for every subtype of [`AbstractDevic
function measure_device! end
"""
gen_cache_init_code(device::AbstractDevice, strategy::CacheStrategy)
gen_cache_init_code(device::AbstractDevice)
Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref) and at least one [`CacheStrategy`](@ref). Returns an `Expr` initializing this device's variable cache.
@@ -101,7 +101,7 @@ The strategy is a symbol
function gen_cache_init_code end
"""
gen_access_expr(device::AbstractDevice, strategy::CacheStrategy, symbol::Symbol)
gen_access_expr(device::AbstractDevice, symbol::Symbol)
Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref) and at least one [`CacheStrategy`](@ref).
Return an `Expr` or `QuoteNode` accessing the variable identified by [`symbol`].

View File

@@ -10,6 +10,7 @@ mutable struct NumaNode <: AbstractCPU
threads::UInt16
cacheStrategy::CacheStrategy
FLOPS::Float64
id::UUID
end
push!(DEVICE_TYPES, NumaNode)
@@ -40,29 +41,44 @@ function get_devices(deviceType::Type{T}; verbose::Bool = false) where {T <: Num
println("Found $(noNumaNodes + 1) NUMA nodes")
end
for i in 0:noNumaNodes
push!(devices, NumaNode(i, 1, default_strategy(NumaNode), -1))
push!(devices, NumaNode(i, 1, default_strategy(NumaNode), -1, UUIDs.uuid1(rng[1])))
end
return devices
end
"""
gen_cache_init_code(device::NumaNode, strategy::LocalVariables)
gen_cache_init_code(device::NumaNode)
Generate code for initializing the [`LocalVariables`](@ref) strategy on a [`NumaNode`](@ref).
"""
function gen_cache_init_code(::NumaNode, ::LocalVariables)
# don't need to initialize anything
return Expr()
function gen_cache_init_code(device::NumaNode)
if typeof(device.cacheStrategy) <: LocalVariables
# don't need to initialize anything
return Expr(:block)
elseif typeof(device.cacheStrategy) <: Dictionary
return Meta.parse("cache_$(to_var_name(device.id)) = Dict{Symbol, Any}()")
# TODO: sizehint?
end
return error("Unimplemented cache strategy \"$(device.cacheStrategy)\" for device \"$(device)\"")
end
"""
gen_access_expr(device::NumaNode, strategy::LocalVariables, symbol::Symbol)
gen_access_expr(device::NumaNode, symbol::Symbol)
Generate code to access the variable designated by `symbol` using the [`LocalVariables`](@ref) [`CacheStrategy`](@ref) on a [`NumaNode`](@ref).
"""
function gen_access_expr(::NumaNode, ::LocalVariables, symbol::Symbol)
s = Symbol("data_$symbol")
quoteNode = Meta.parse(":($s)")
return quoteNode
function gen_access_expr(device::NumaNode, symbol::Symbol)
if typeof(device.cacheStrategy) <: LocalVariables
s = Symbol("data_$symbol")
quoteNode = Meta.parse(":($s)")
return quoteNode
elseif typeof(device.cacheStrategy) <: Dictionary
accessStr = ":(cache_$(to_var_name(device.id))[:$symbol])"
quoteNode = Meta.parse(accessStr)
return quoteNode
end
return error("Unimplemented cache strategy \"$(device.cacheStrategy)\" for device \"$(device)\"")
end

View File

@@ -6,6 +6,6 @@ Pretty-print a [`Diff`](@ref). Called via print, println and co.
function show(io::IO, diff::Diff)
print(io, "Nodes: ")
print(io, length(diff.addedNodes) + length(diff.removedNodes))
print(io, " Edges: ")
print(io, ", Edges: ")
return print(io, length(diff.addedEdges) + length(diff.removedEdges))
end

View File

@@ -160,14 +160,17 @@ function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invali
end
function replace_children!(task::FusedComputeTask, before, after)
replace!(task.t1_inputs, before => after)
replace!(task.t2_inputs, before => after)
# TODO: this assert fails sometimes and really shouldn't
@assert length(findall(x -> x == before, task.t1_inputs)) >= 1 ||
length(findall(x -> x == before, task.t2_inputs)) >= 1 "Replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
# recursively descend down the tree
replace_children!(task.first_task, before, after)
replace_children!(task.second_task, before, after)
replace!(task.t1_inputs, before => after)
replace!(task.t2_inputs, before => after)
return nothing
end
@@ -181,6 +184,16 @@ function update_child!(graph::DAG, n::Node, child_before::Symbol, child_after::S
return nothing
end
if !((child_before in n.task.t1_inputs) || (child_before in n.task.t2_inputs))
println("------------------ Nothing to replace!! ------------------")
child_ids = Vector{String}()
for child in n.children
push!(child_ids, "$(child.id)")
end
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
@assert false
end
replace_children!(n.task, child_before, child_after)
if !((child_after in n.task.t1_inputs) || (child_after in n.task.t2_inputs))

View File

@@ -22,5 +22,6 @@ end
Return the uuid as a string usable as a variable name in code generation.
"""
function to_var_name(id::UUID)
return replace(string(id), "-" => "_")
str = "_" * replace(string(id), "-" => "_")
return str
end

View File

@@ -22,12 +22,12 @@ function is_valid_node(graph::DAG, node::Node)
@assert node in child.parents "Node is not a parent of its child!"
end
if !ismissing(node.nodeReduction)
#=if !ismissing(node.nodeReduction)
@assert is_valid(graph, node.nodeReduction)
end
if !ismissing(node.nodeSplit)
@assert is_valid(graph, node.nodeSplit)
end
end=#
if !(typeof(node.task) <: FusedComputeTask)
# the remaining checks are only necessary for fused compute tasks
@@ -53,9 +53,9 @@ This also calls [`is_valid_node(graph::DAG, node::Node)`](@ref).
function is_valid(graph::DAG, node::ComputeTaskNode)
@assert is_valid_node(graph, node)
for nf in node.nodeFusions
#=for nf in node.nodeFusions
@assert is_valid(graph, nf)
end
end=#
return true
end
@@ -69,8 +69,8 @@ This also calls [`is_valid_node(graph::DAG, node::Node)`](@ref).
function is_valid(graph::DAG, node::DataTaskNode)
@assert is_valid_node(graph, node)
if !ismissing(node.nodeFusion)
#=if !ismissing(node.nodeFusion)
@assert is_valid(graph, node.nodeFusion)
end
end=#
return true
end

View File

@@ -9,6 +9,7 @@ function apply_all!(graph::DAG)
op = popfirst!(graph.operationsToApply)
# apply it
println(graph.appliedOperations)
appliedOp = apply_operation!(graph, op)
# push to the end of the appliedOperations deque
@@ -158,8 +159,17 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com
get_snapshot_diff(graph)
# save children and parents
n1_children = children(n1)
n3_parents = parents(n3)
n1Children = children(n1)
n3Parents = parents(n3)
n1Task = copy(n1.task)
n3Task = copy(n3.task)
# assemble the input node vectors of n1 and n3 to save into the FusedComputeTask
n1Inputs = Vector{Symbol}()
for child in n1Children
push!(n1Inputs, Symbol(to_var_name(child.id)))
end
# remove the edges and nodes that will be replaced by the fused node
remove_edge!(graph, n1, n2)
@@ -168,43 +178,38 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com
remove_node!(graph, n2)
# get n3's children now so it automatically excludes n2
n3_children = children(n3)
n3Children = children(n3)
n3Inputs = Vector{Symbol}()
for child in n3Children
push!(n3Inputs, Symbol(to_var_name(child.id)))
end
remove_node!(graph, n3)
# assemble the input node vectors of n1 and n3 to save into the FusedComputeTask
n1_inputs = Vector{Symbol}()
for child in n1_children
push!(n1_inputs, Symbol(to_var_name(child.id)))
end
n3_inputs = Vector{Symbol}()
for child in n3_children
push!(n3_inputs, Symbol(to_var_name(child.id)))
end
# create new node with the fused compute task
new_node = ComputeTaskNode(FusedComputeTask(n1.task, n3.task, n1_inputs, Symbol(to_var_name(n2.id)), n3_inputs))
insert_node!(graph, new_node)
newNode = ComputeTaskNode(FusedComputeTask(n1Task, n3Task, n1Inputs, Symbol(to_var_name(n2.id)), n3Inputs))
insert_node!(graph, newNode)
for child in n1_children
for child in n1Children
remove_edge!(graph, child, n1)
insert_edge!(graph, child, new_node)
insert_edge!(graph, child, newNode)
end
for child in n3_children
for child in n3Children
remove_edge!(graph, child, n3)
if !(child in n1_children)
insert_edge!(graph, child, new_node)
if !(child in n1Children)
insert_edge!(graph, child, newNode)
end
end
for parent in n3_parents
for parent in n3Parents
remove_edge!(graph, n3, parent)
insert_edge!(graph, new_node, parent)
insert_edge!(graph, newNode, parent)
# important! update the parent node's child names in case they are fused compute tasks
# needed for compute generation so the fused compute task can correctly match inputs to its component tasks
update_child!(graph, parent, Symbol(to_var_name(n3.id)), Symbol(to_var_name(new_node.id)))
update_child!(graph, parent, Symbol(to_var_name(n3.id)), Symbol(to_var_name(newNode.id)))
end
return get_snapshot_diff(graph)
@@ -224,14 +229,14 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
get_snapshot_diff(graph)
n1 = nodes[1]
n1_children = children(n1)
n1Children = children(n1)
n1_parents = Set(n1.parents)
n1Parents = Set(n1.parents)
# set of the new parents of n1
new_parents = Set{Node}()
newParents = Set{Node}()
# names of the previous children that n1 now replaces per parent
new_parents_child_names = Dict{Node, Symbol}()
newParentsChildNames = Dict{Node, Symbol}()
str = Vector{String}()
for n in nodes
@@ -242,7 +247,7 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
# remove all of the nodes' parents and children and the nodes themselves (except for first node)
for i in 2:length(nodes)
n = nodes[i]
for child in n1_children
for child in n1Children
remove_edge!(graph, child, n)
end
@@ -250,21 +255,21 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
remove_edge!(graph, n, parent)
# collect all parents
push!(new_parents, parent)
new_parents_child_names[parent] = Symbol(to_var_name(n.id))
push!(newParents, parent)
newParentsChildNames[parent] = Symbol(to_var_name(n.id))
end
remove_node!(graph, n)
end
setdiff!(new_parents, n1_parents)
setdiff!(newParents, n1Parents)
for parent in new_parents
for parent in newParents
# now add parents of all input nodes to n1 without duplicates
insert_edge!(graph, n1, parent)
prev_child = new_parents_child_names[parent]
update_child!(graph, parent, prev_child, Symbol(to_var_name(n1.id)))
prevChild = newParentsChildNames[parent]
update_child!(graph, parent, prevChild, Symbol(to_var_name(n1.id)))
end
return get_snapshot_diff(graph)
@@ -283,28 +288,28 @@ function node_split!(graph::DAG, n1::Node)
# clear snapshot
get_snapshot_diff(graph)
n1_parents = parents(n1)
n1_children = children(n1)
n1Parents = parents(n1)
n1Children = children(n1)
for parent in n1_parents
for parent in n1Parents
remove_edge!(graph, n1, parent)
end
for child in n1_children
for child in n1Children
remove_edge!(graph, child, n1)
end
remove_node!(graph, n1)
for parent in n1_parents
n_copy = copy(n1)
for parent in n1Parents
nCopy = copy(n1)
insert_node!(graph, n_copy)
insert_edge!(graph, n_copy, parent)
insert_node!(graph, nCopy)
insert_edge!(graph, nCopy, parent)
for child in n1_children
insert_edge!(graph, child, n_copy)
for child in n1Children
insert_edge!(graph, child, nCopy)
end
update_child!(graph, parent, Symbol(to_var_name(n1.id)), Symbol(to_var_name(n_copy.id)))
update_child!(graph, parent, Symbol(to_var_name(n1.id)), Symbol(to_var_name(nCopy.id)))
end
return get_snapshot_diff(graph)

View File

@@ -32,6 +32,10 @@ function is_valid_node_fusion_input(graph::DAG, n1::ComputeTaskNode, n2::DataTas
throw(AssertionError("[Node Fusion] The given n1 has more than one parent"))
end
@assert is_valid(graph, n1)
@assert is_valid(graph, n2)
@assert is_valid(graph, n3)
return true
end
@@ -47,6 +51,7 @@ function is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node})
if n graph
throw(AssertionError("[Node Reduction] The given nodes are not part of the given graph"))
end
@assert is_valid(graph, n)
end
t = typeof(nodes[1].task)
@@ -96,6 +101,8 @@ function is_valid_node_split_input(graph::DAG, n1::Node)
)
end
@assert is_valid(graph, n1)
return true
end

View File

@@ -24,14 +24,14 @@ function get_expression(t::FusedComputeTask, device::AbstractDevice, inExprs::Ve
inExprs1 = Vector()
for sym in t.t1_inputs
push!(inExprs1, gen_access_expr(device, cacheStrategy, sym))
push!(inExprs1, gen_access_expr(device, sym))
end
outExpr1 = gen_access_expr(device, cacheStrategy, t.t1_output)
outExpr1 = gen_access_expr(device, t.t1_output)
inExprs2 = Vector()
for sym in t.t2_inputs
push!(inExprs2, gen_access_expr(device, cacheStrategy, sym))
push!(inExprs2, gen_access_expr(device, sym))
end
expr1 = get_expression(t.first_task, device, inExprs1, outExpr1)
@@ -56,9 +56,9 @@ function get_expression(node::ComputeTaskNode, device::AbstractDevice)
inExprs = Vector()
for id in getfield.(node.children, :id)
push!(inExprs, gen_access_expr(device, cacheStrategy, Symbol(to_var_name(id))))
push!(inExprs, gen_access_expr(device, Symbol(to_var_name(id))))
end
outExpr = gen_access_expr(device, cacheStrategy, Symbol(to_var_name(node.id)))
outExpr = gen_access_expr(device, Symbol(to_var_name(node.id)))
return get_expression(node.task, device, inExprs, outExpr)
end
@@ -77,11 +77,11 @@ function get_expression(node::DataTaskNode, device::AbstractDevice)
cacheStrategy = cache_strategy(device)
inExpr = nothing
if (length(node.children) == 1)
inExpr = eval(gen_access_expr(device, cacheStrategy, Symbol(to_var_name(node.children[1].id))))
inExpr = eval(gen_access_expr(device, Symbol(to_var_name(node.children[1].id))))
else
inExpr = eval(gen_access_expr(device, cacheStrategy, Symbol("$(to_var_name(node.id))_in")))
inExpr = eval(gen_access_expr(device, Symbol("$(to_var_name(node.id))_in")))
end
outExpr = eval(gen_access_expr(device, cacheStrategy, Symbol(to_var_name(node.id))))
outExpr = eval(gen_access_expr(device, Symbol(to_var_name(node.id))))
dataTransportExp = Meta.parse("$outExpr = $inExpr")