Fix execution with fusion

This commit is contained in:
Anton Reinhard 2023-09-26 16:52:50 +02:00
parent cc05cae1cd
commit 95f92f080c
11 changed files with 129 additions and 73 deletions

View File

@ -6,7 +6,7 @@ using DataStructures
Generate the code for a given graph. The return value is a tuple of:
- `code::Expr`: The julia expression containing the code for the whole graph.
- `inputSymbols::Dict{String, Symbol}`: A dictionary of symbols mapping the names of the input nodes of the graph to the symbols their inputs should be provided on.
- `inputSymbols::Dict{String, Vector{Symbol}}`: A dictionary of symbols mapping the names of the input nodes of the graph to the symbols their inputs should be provided on.
- `outputSymbol::Symbol`: The symbol of the final calculated value
See also: [`execute`](@ref)
@ -16,12 +16,16 @@ function gen_code(graph::DAG)
sizehint!(code, length(graph.nodes))
nodeQueue = PriorityQueue{Node, Int}()
inputSyms = Dict{String, Symbol}()
inputSyms = Dict{String, Vector{Symbol}}()
# use a priority equal to the number of unseen children -> 0 are nodes that can be added
for node in get_entry_nodes(graph)
enqueue!(nodeQueue, node => 0)
push!(inputSyms, node.name => Symbol("data_$(to_var_name(node.id))_in"))
if !haskey(inputSyms, node.name)
inputSyms[node.name] = Vector{Symbol}()
end
push!(inputSyms[node.name], Symbol("data_$(to_var_name(node.id))_in"))
end
node = nothing
@ -51,12 +55,12 @@ function gen_code(graph::DAG)
end
function gen_input_assignment_code(
inputSymbols::Dict{String, Symbol},
inputSymbols::Dict{String, Vector{Symbol}},
particles::Tuple{Vector{Particle}, Vector{Particle}},
)
@assert !isempty(particles[1]) "Can't have 0 input particles!"
@assert !isempty(particles[2]) "Can't have 0 output particles!"
@assert length(inputSymbols) == length(particles[1]) + length(particles[2])
@assert length(inputSymbols) >= length(particles[1]) + length(particles[2])
# TODO none of this is very pretty
in_out_count = Dict{ParticleType, Tuple{Int, Int}}()
@ -73,7 +77,7 @@ function gen_input_assignment_code(
end
assignInputs = Vector{Expr}()
for (name, symbol) in inputSymbols
for (name, symbols) in inputSymbols
type = nothing
if startswith(name, "A")
type = A
@ -97,12 +101,14 @@ function gen_input_assignment_code(
p = particles[1][findall(condition, particles[1])[index]]
end
push!(
assignInputs,
Meta.parse(
"$(symbol) = ParticleValue(Particle($(p.momentum), $(p.type)), 1.0)",
),
)
for symbol in symbols
push!(
assignInputs,
Meta.parse(
"$(symbol) = ParticleValue(Particle($(p.momentum), $(p.type)), 1.0)",
),
)
end
end
return Expr(:block, assignInputs...)
@ -140,11 +146,17 @@ function execute(graph::DAG, input::Tuple{Vector{Particle}, Vector{Particle}})
assignInputs = gen_input_assignment_code(inputSymbols, input)
println(code)
eval(assignInputs)
eval(code)
try
eval(assignInputs)
eval(code)
eval(Meta.parse("result = $outputSymbol"))
eval(Meta.parse("result = $outputSymbol"))
catch e
println("Error while evaluating: $e")
println("Assign Input Code:\n$assignInputs\n")
println("Code:\n$code")
@assert false
end
return result
end

View File

@ -57,7 +57,7 @@ function insert_edge!(
track = true,
invalidate_cache = true,
)
# @assert (node2 ∉ node1.parents) && (node1 ∉ node2.children) "Edge to insert already exists"
@assert (node2 node1.parents) && (node1 node2.children) "Edge to insert already exists"
# 1: mute
# edge points from child to parent
@ -101,7 +101,7 @@ function remove_node!(
track = true,
invalidate_cache = true,
)
# @assert node in graph.nodes "Trying to remove a node that's not in the graph"
@assert node in graph.nodes "Trying to remove a node that's not in the graph"
# 1: mute
delete!(graph.nodes, node)
@ -149,15 +149,15 @@ function remove_edge!(
filter!(x -> x != node2, node1.parents)
filter!(x -> x != node1, node2.children)
#=@assert begin
removed = pre_length1 - length(node1.parents)
removed <= 1
end "removed more than one node from node1's parents"=#
@assert begin
removed = pre_length1 - length(node1.parents)
removed <= 1
end "removed more than one node from node1's parents"
#=@assert begin
removed = pre_length2 - length(node2.children)
removed <= 1
end "removed more than one node from node2's children"=#
@assert begin
removed = pre_length2 - length(node2.children)
removed <= 1
end "removed more than one node from node2's children"
# 2: keep track
if (track)
@ -181,6 +181,22 @@ function remove_edge!(
return nothing
end
function replace_children!(task::FusedComputeTask, before, after)
replace!(task.t1_inputs, before => after)
replace!(task.t2_inputs, before => after)
# recursively descend down the tree
replace_children!(task.first_task, before, after)
replace_children!(task.second_task, before, after)
return nothing
end
function replace_children!(task::AbstractTask, before, after)
return nothing
end
function update_child!(
graph::DAG,
n::Node,
@ -193,8 +209,21 @@ function update_child!(
return nothing
end
replace!(n.task.t1_inputs, child_before => child_after)
replace!(n.task.t2_inputs, child_before => child_after)
replace_children!(n.task, child_before, child_after)
if !((child_after in n.task.t1_inputs) || (child_after in n.task.t2_inputs))
println(
"------------------ Did not replace anything!! ------------------",
)
child_ids = Vector{String}()
for child in n.children
push!(child_ids, "$(child.id)")
end
println(
"From $(child_before) to $(child_after) in $n with children $(child_ids)",
)
@assert false
end
# keep track
if (track)

View File

@ -184,12 +184,13 @@ function get_expression(
expr1 = nothing
expr2 = nothing
expr0 = Meta.parse("# fused compute task $(t.first_task), $(t.second_task)")
expr1 = get_expression(t.first_task, t.t1_inputs, t.t1_output)
expr2 =
get_expression(t.second_task, [t.t2_inputs..., t.t1_output], outExpr)
return Expr(:block, expr1, expr2)
full_expr = Expr(:block, expr1, expr2)
return full_expr
end
"""

View File

@ -57,42 +57,42 @@ end
Print the S1 task to io.
"""
show(io::IO, t::ComputeTaskS1) = print("ComputeS1")
show(io::IO, t::ComputeTaskS1) = print(io, "ComputeS1")
"""
show(io::IO, t::ComputeTaskS2)
Print the S2 task to io.
"""
show(io::IO, t::ComputeTaskS2) = print("ComputeS2")
show(io::IO, t::ComputeTaskS2) = print(io, "ComputeS2")
"""
show(io::IO, t::ComputeTaskP)
Print the P task to io.
"""
show(io::IO, t::ComputeTaskP) = print("ComputeP")
show(io::IO, t::ComputeTaskP) = print(io, "ComputeP")
"""
show(io::IO, t::ComputeTaskU)
Print the U task to io.
"""
show(io::IO, t::ComputeTaskU) = print("ComputeU")
show(io::IO, t::ComputeTaskU) = print(io, "ComputeU")
"""
show(io::IO, t::ComputeTaskV)
Print the V task to io.
"""
show(io::IO, t::ComputeTaskV) = print("ComputeV")
show(io::IO, t::ComputeTaskV) = print(io, "ComputeV")
"""
show(io::IO, t::ComputeTaskSum)
Print the sum task to io.
"""
show(io::IO, t::ComputeTaskSum) = print("ComputeSum")
show(io::IO, t::ComputeTaskSum) = print(io, "ComputeSum")
"""
copy(t::DataTask)

View File

@ -20,25 +20,8 @@ ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
)
copy(m::Missing) = missing
copy(n::ComputeTaskNode) = ComputeTaskNode(
copy(n.task),
copy(n.parents),
copy(n.children),
UUIDs.uuid1(rng[threadid()]),
copy(n.nodeReduction),
copy(n.nodeSplit),
copy(n.nodeFusions),
)
copy(n::DataTaskNode) = DataTaskNode(
copy(n.task),
copy(n.parents),
copy(n.children),
UUIDs.uuid1(rng[threadid()]),
copy(n.nodeReduction),
copy(n.nodeSplit),
copy(n.nodeFusion),
n.name,
)
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task))
copy(n::DataTaskNode) = DataTaskNode(copy(n.task), n.name)
"""
make_node(t::AbstractTask)

View File

@ -28,6 +28,18 @@ function is_valid_node(graph::DAG, node::Node)
if !ismissing(node.nodeSplit)
@assert is_valid(graph, node.nodeSplit)
end
if !(typeof(node.task) <: FusedComputeTask)
# the remaining checks are only necessary for fused compute tasks
return true
end
# every child must be in some input of the task
for child in node.children
str = "data_$(to_var_name(child.id))"
@assert (str in node.task.t1_inputs) || (str in node.task.t2_inputs) "$str was not in any of the tasks' inputs\nt1_inputs: $(node.task.t1_inputs)\nt2_inputs: $(node.task.t2_inputs)"
end
return true
end

View File

@ -141,8 +141,7 @@ function revert_diff!(graph::DAG, diff::Diff)
# node must be fused compute task at this point
@assert typeof(node.task) <: FusedComputeTask
replace!(node.task.t1_inputs, after => before)
replace!(node.task.t2_inputs, after => before)
update_child!(graph, node, after, before, track = false)
end
graph.properties -= GraphProperties(diff)
@ -171,7 +170,6 @@ function node_fusion!(
# save children and parents
n1_children = children(n1)
n3_parents = parents(n3)
n3_children = children(n3)
# remove the edges and nodes that will be replaced by the fused node
remove_edge!(graph, n1, n2)
@ -253,8 +251,16 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
n1_parents = Set(n1.parents)
# set of the new parents of n1, together with the names of the previous children that n1 now replaces
new_parents = Set{Tuple{Node, String}}()
# set of the new parents of n1
new_parents = Set{Node}()
# names of the previous children that n1 now replaces per parent
new_parents_child_names = Dict{Node, String}()
str = Vector{String}()
for n in nodes
push!(str, "$(n.id)")
end
#println("Reducing $(nodes[1].task) Nodes $(str)")
# remove all of the nodes' parents and children and the nodes themselves (except for first node)
for i in 2:length(nodes)
@ -267,7 +273,8 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
remove_edge!(graph, n, parent)
# collect all parents
push!(new_parents, (parent, "data_$(to_var_name(n.id))"))
push!(new_parents, parent)
new_parents_child_names[parent] = "data_$(to_var_name(n.id))"
end
remove_node!(graph, n)
@ -275,10 +282,11 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
setdiff!(new_parents, n1_parents)
for (parent, prev_child) in new_parents
for parent in new_parents
# now add parents of all input nodes to n1 without duplicates
insert_edge!(graph, n1, parent)
prev_child = new_parents_child_names[parent]
update_child!(graph, parent, prev_child, "data_$(to_var_name(n1.id))")
end
@ -311,19 +319,20 @@ function node_split!(graph::DAG, n1::Node)
for parent in n1_parents
n_copy = copy(n1)
insert_node!(graph, n_copy)
insert_edge!(graph, n_copy, parent)
for child in n1_children
insert_edge!(graph, child, n_copy)
end
update_child!(
graph,
parent,
"data_$(to_var_name(n1.id))",
"data_$(to_var_name(n_copy.id))",
)
for child in n1_children
insert_edge!(graph, child, n_copy)
end
end
return get_snapshot_diff(graph)

View File

@ -86,6 +86,16 @@ function is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node})
),
)
end
if (typeof(n) <: DataTaskNode)
if (n.name != nodes[1].name)
throw(
AssertionError(
"[Node Reduction] The given nodes do not have the same name",
),
)
end
end
end
n1_children = nodes[1].children

View File

@ -20,11 +20,11 @@ Return a copy of th egiven [`FusedComputeTask`](@ref).
"""
function copy(t::FusedComputeTask{T1, T2}) where {T1, T2}
return FusedComputeTask{T1, T2}(
t.first_task,
t.second_task,
t.t1_inputs,
copy(t.first_task),
copy(t.second_task),
copy(t.t1_inputs),
t.t1_output,
t.t2_inputs,
copy(t.t2_inputs),
)
end

View File

@ -4,5 +4,5 @@
Print a string representation of the fused compute task to io.
"""
function show(io::IO, t::FusedComputeTask)
return print(io, "ComputeFuse(", t.first_task, ", ", t.second_task, ")")
return print(io, "ComputeFuse($(t.first_task), $(t.second_task))")
end

View File

@ -42,9 +42,9 @@ include("../examples/profiling_utilities.jl")
end
@testset "AB->AB after random walk" begin
for _ in 1:50
for i in 1:100
graph = parse_abc(joinpath(@__DIR__, "..", "input", "AB->AB.txt"))
random_walk!(graph, 40)
random_walk!(graph, 50)
@test isapprox(
execute(graph, particles),