Fix execution with fusion
This commit is contained in:
parent
cc05cae1cd
commit
95f92f080c
@ -6,7 +6,7 @@ using DataStructures
|
||||
Generate the code for a given graph. The return value is a tuple of:
|
||||
|
||||
- `code::Expr`: The julia expression containing the code for the whole graph.
|
||||
- `inputSymbols::Dict{String, Symbol}`: A dictionary of symbols mapping the names of the input nodes of the graph to the symbols their inputs should be provided on.
|
||||
- `inputSymbols::Dict{String, Vector{Symbol}}`: A dictionary of symbols mapping the names of the input nodes of the graph to the symbols their inputs should be provided on.
|
||||
- `outputSymbol::Symbol`: The symbol of the final calculated value
|
||||
|
||||
See also: [`execute`](@ref)
|
||||
@ -16,12 +16,16 @@ function gen_code(graph::DAG)
|
||||
sizehint!(code, length(graph.nodes))
|
||||
|
||||
nodeQueue = PriorityQueue{Node, Int}()
|
||||
inputSyms = Dict{String, Symbol}()
|
||||
inputSyms = Dict{String, Vector{Symbol}}()
|
||||
|
||||
# use a priority equal to the number of unseen children -> 0 are nodes that can be added
|
||||
for node in get_entry_nodes(graph)
|
||||
enqueue!(nodeQueue, node => 0)
|
||||
push!(inputSyms, node.name => Symbol("data_$(to_var_name(node.id))_in"))
|
||||
if !haskey(inputSyms, node.name)
|
||||
inputSyms[node.name] = Vector{Symbol}()
|
||||
end
|
||||
|
||||
push!(inputSyms[node.name], Symbol("data_$(to_var_name(node.id))_in"))
|
||||
end
|
||||
|
||||
node = nothing
|
||||
@ -51,12 +55,12 @@ function gen_code(graph::DAG)
|
||||
end
|
||||
|
||||
function gen_input_assignment_code(
|
||||
inputSymbols::Dict{String, Symbol},
|
||||
inputSymbols::Dict{String, Vector{Symbol}},
|
||||
particles::Tuple{Vector{Particle}, Vector{Particle}},
|
||||
)
|
||||
@assert !isempty(particles[1]) "Can't have 0 input particles!"
|
||||
@assert !isempty(particles[2]) "Can't have 0 output particles!"
|
||||
@assert length(inputSymbols) == length(particles[1]) + length(particles[2])
|
||||
@assert length(inputSymbols) >= length(particles[1]) + length(particles[2])
|
||||
|
||||
# TODO none of this is very pretty
|
||||
in_out_count = Dict{ParticleType, Tuple{Int, Int}}()
|
||||
@ -73,7 +77,7 @@ function gen_input_assignment_code(
|
||||
end
|
||||
|
||||
assignInputs = Vector{Expr}()
|
||||
for (name, symbol) in inputSymbols
|
||||
for (name, symbols) in inputSymbols
|
||||
type = nothing
|
||||
if startswith(name, "A")
|
||||
type = A
|
||||
@ -97,12 +101,14 @@ function gen_input_assignment_code(
|
||||
p = particles[1][findall(condition, particles[1])[index]]
|
||||
end
|
||||
|
||||
push!(
|
||||
assignInputs,
|
||||
Meta.parse(
|
||||
"$(symbol) = ParticleValue(Particle($(p.momentum), $(p.type)), 1.0)",
|
||||
),
|
||||
)
|
||||
for symbol in symbols
|
||||
push!(
|
||||
assignInputs,
|
||||
Meta.parse(
|
||||
"$(symbol) = ParticleValue(Particle($(p.momentum), $(p.type)), 1.0)",
|
||||
),
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
return Expr(:block, assignInputs...)
|
||||
@ -140,11 +146,17 @@ function execute(graph::DAG, input::Tuple{Vector{Particle}, Vector{Particle}})
|
||||
|
||||
assignInputs = gen_input_assignment_code(inputSymbols, input)
|
||||
|
||||
println(code)
|
||||
eval(assignInputs)
|
||||
eval(code)
|
||||
try
|
||||
eval(assignInputs)
|
||||
eval(code)
|
||||
|
||||
eval(Meta.parse("result = $outputSymbol"))
|
||||
eval(Meta.parse("result = $outputSymbol"))
|
||||
catch e
|
||||
println("Error while evaluating: $e")
|
||||
println("Assign Input Code:\n$assignInputs\n")
|
||||
println("Code:\n$code")
|
||||
@assert false
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
@ -57,7 +57,7 @@ function insert_edge!(
|
||||
track = true,
|
||||
invalidate_cache = true,
|
||||
)
|
||||
# @assert (node2 ∉ node1.parents) && (node1 ∉ node2.children) "Edge to insert already exists"
|
||||
@assert (node2 ∉ node1.parents) && (node1 ∉ node2.children) "Edge to insert already exists"
|
||||
|
||||
# 1: mute
|
||||
# edge points from child to parent
|
||||
@ -101,7 +101,7 @@ function remove_node!(
|
||||
track = true,
|
||||
invalidate_cache = true,
|
||||
)
|
||||
# @assert node in graph.nodes "Trying to remove a node that's not in the graph"
|
||||
@assert node in graph.nodes "Trying to remove a node that's not in the graph"
|
||||
|
||||
# 1: mute
|
||||
delete!(graph.nodes, node)
|
||||
@ -149,15 +149,15 @@ function remove_edge!(
|
||||
filter!(x -> x != node2, node1.parents)
|
||||
filter!(x -> x != node1, node2.children)
|
||||
|
||||
#=@assert begin
|
||||
removed = pre_length1 - length(node1.parents)
|
||||
removed <= 1
|
||||
end "removed more than one node from node1's parents"=#
|
||||
@assert begin
|
||||
removed = pre_length1 - length(node1.parents)
|
||||
removed <= 1
|
||||
end "removed more than one node from node1's parents"
|
||||
|
||||
#=@assert begin
|
||||
removed = pre_length2 - length(node2.children)
|
||||
removed <= 1
|
||||
end "removed more than one node from node2's children"=#
|
||||
@assert begin
|
||||
removed = pre_length2 - length(node2.children)
|
||||
removed <= 1
|
||||
end "removed more than one node from node2's children"
|
||||
|
||||
# 2: keep track
|
||||
if (track)
|
||||
@ -181,6 +181,22 @@ function remove_edge!(
|
||||
return nothing
|
||||
end
|
||||
|
||||
function replace_children!(task::FusedComputeTask, before, after)
|
||||
|
||||
replace!(task.t1_inputs, before => after)
|
||||
replace!(task.t2_inputs, before => after)
|
||||
|
||||
# recursively descend down the tree
|
||||
replace_children!(task.first_task, before, after)
|
||||
replace_children!(task.second_task, before, after)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function replace_children!(task::AbstractTask, before, after)
|
||||
return nothing
|
||||
end
|
||||
|
||||
function update_child!(
|
||||
graph::DAG,
|
||||
n::Node,
|
||||
@ -193,8 +209,21 @@ function update_child!(
|
||||
return nothing
|
||||
end
|
||||
|
||||
replace!(n.task.t1_inputs, child_before => child_after)
|
||||
replace!(n.task.t2_inputs, child_before => child_after)
|
||||
replace_children!(n.task, child_before, child_after)
|
||||
|
||||
if !((child_after in n.task.t1_inputs) || (child_after in n.task.t2_inputs))
|
||||
println(
|
||||
"------------------ Did not replace anything!! ------------------",
|
||||
)
|
||||
child_ids = Vector{String}()
|
||||
for child in n.children
|
||||
push!(child_ids, "$(child.id)")
|
||||
end
|
||||
println(
|
||||
"From $(child_before) to $(child_after) in $n with children $(child_ids)",
|
||||
)
|
||||
@assert false
|
||||
end
|
||||
|
||||
# keep track
|
||||
if (track)
|
||||
|
@ -184,12 +184,13 @@ function get_expression(
|
||||
expr1 = nothing
|
||||
expr2 = nothing
|
||||
|
||||
expr0 = Meta.parse("# fused compute task $(t.first_task), $(t.second_task)")
|
||||
expr1 = get_expression(t.first_task, t.t1_inputs, t.t1_output)
|
||||
expr2 =
|
||||
get_expression(t.second_task, [t.t2_inputs..., t.t1_output], outExpr)
|
||||
|
||||
return Expr(:block, expr1, expr2)
|
||||
full_expr = Expr(:block, expr1, expr2)
|
||||
|
||||
return full_expr
|
||||
end
|
||||
|
||||
"""
|
||||
|
@ -57,42 +57,42 @@ end
|
||||
|
||||
Print the S1 task to io.
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskS1) = print("ComputeS1")
|
||||
show(io::IO, t::ComputeTaskS1) = print(io, "ComputeS1")
|
||||
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskS2)
|
||||
|
||||
Print the S2 task to io.
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskS2) = print("ComputeS2")
|
||||
show(io::IO, t::ComputeTaskS2) = print(io, "ComputeS2")
|
||||
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskP)
|
||||
|
||||
Print the P task to io.
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskP) = print("ComputeP")
|
||||
show(io::IO, t::ComputeTaskP) = print(io, "ComputeP")
|
||||
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskU)
|
||||
|
||||
Print the U task to io.
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskU) = print("ComputeU")
|
||||
show(io::IO, t::ComputeTaskU) = print(io, "ComputeU")
|
||||
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskV)
|
||||
|
||||
Print the V task to io.
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskV) = print("ComputeV")
|
||||
show(io::IO, t::ComputeTaskV) = print(io, "ComputeV")
|
||||
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskSum)
|
||||
|
||||
Print the sum task to io.
|
||||
"""
|
||||
show(io::IO, t::ComputeTaskSum) = print("ComputeSum")
|
||||
show(io::IO, t::ComputeTaskSum) = print(io, "ComputeSum")
|
||||
|
||||
"""
|
||||
copy(t::DataTask)
|
||||
|
@ -20,25 +20,8 @@ ComputeTaskNode(t::AbstractComputeTask) = ComputeTaskNode(
|
||||
)
|
||||
|
||||
copy(m::Missing) = missing
|
||||
copy(n::ComputeTaskNode) = ComputeTaskNode(
|
||||
copy(n.task),
|
||||
copy(n.parents),
|
||||
copy(n.children),
|
||||
UUIDs.uuid1(rng[threadid()]),
|
||||
copy(n.nodeReduction),
|
||||
copy(n.nodeSplit),
|
||||
copy(n.nodeFusions),
|
||||
)
|
||||
copy(n::DataTaskNode) = DataTaskNode(
|
||||
copy(n.task),
|
||||
copy(n.parents),
|
||||
copy(n.children),
|
||||
UUIDs.uuid1(rng[threadid()]),
|
||||
copy(n.nodeReduction),
|
||||
copy(n.nodeSplit),
|
||||
copy(n.nodeFusion),
|
||||
n.name,
|
||||
)
|
||||
copy(n::ComputeTaskNode) = ComputeTaskNode(copy(n.task))
|
||||
copy(n::DataTaskNode) = DataTaskNode(copy(n.task), n.name)
|
||||
|
||||
"""
|
||||
make_node(t::AbstractTask)
|
||||
|
@ -28,6 +28,18 @@ function is_valid_node(graph::DAG, node::Node)
|
||||
if !ismissing(node.nodeSplit)
|
||||
@assert is_valid(graph, node.nodeSplit)
|
||||
end
|
||||
|
||||
if !(typeof(node.task) <: FusedComputeTask)
|
||||
# the remaining checks are only necessary for fused compute tasks
|
||||
return true
|
||||
end
|
||||
|
||||
# every child must be in some input of the task
|
||||
for child in node.children
|
||||
str = "data_$(to_var_name(child.id))"
|
||||
@assert (str in node.task.t1_inputs) || (str in node.task.t2_inputs) "$str was not in any of the tasks' inputs\nt1_inputs: $(node.task.t1_inputs)\nt2_inputs: $(node.task.t2_inputs)"
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
|
@ -141,8 +141,7 @@ function revert_diff!(graph::DAG, diff::Diff)
|
||||
# node must be fused compute task at this point
|
||||
@assert typeof(node.task) <: FusedComputeTask
|
||||
|
||||
replace!(node.task.t1_inputs, after => before)
|
||||
replace!(node.task.t2_inputs, after => before)
|
||||
update_child!(graph, node, after, before, track = false)
|
||||
end
|
||||
|
||||
graph.properties -= GraphProperties(diff)
|
||||
@ -171,7 +170,6 @@ function node_fusion!(
|
||||
# save children and parents
|
||||
n1_children = children(n1)
|
||||
n3_parents = parents(n3)
|
||||
n3_children = children(n3)
|
||||
|
||||
# remove the edges and nodes that will be replaced by the fused node
|
||||
remove_edge!(graph, n1, n2)
|
||||
@ -253,8 +251,16 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
|
||||
n1_parents = Set(n1.parents)
|
||||
|
||||
# set of the new parents of n1, together with the names of the previous children that n1 now replaces
|
||||
new_parents = Set{Tuple{Node, String}}()
|
||||
# set of the new parents of n1
|
||||
new_parents = Set{Node}()
|
||||
# names of the previous children that n1 now replaces per parent
|
||||
new_parents_child_names = Dict{Node, String}()
|
||||
|
||||
str = Vector{String}()
|
||||
for n in nodes
|
||||
push!(str, "$(n.id)")
|
||||
end
|
||||
#println("Reducing $(nodes[1].task) Nodes $(str)")
|
||||
|
||||
# remove all of the nodes' parents and children and the nodes themselves (except for first node)
|
||||
for i in 2:length(nodes)
|
||||
@ -267,7 +273,8 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
remove_edge!(graph, n, parent)
|
||||
|
||||
# collect all parents
|
||||
push!(new_parents, (parent, "data_$(to_var_name(n.id))"))
|
||||
push!(new_parents, parent)
|
||||
new_parents_child_names[parent] = "data_$(to_var_name(n.id))"
|
||||
end
|
||||
|
||||
remove_node!(graph, n)
|
||||
@ -275,10 +282,11 @@ function node_reduction!(graph::DAG, nodes::Vector{Node})
|
||||
|
||||
setdiff!(new_parents, n1_parents)
|
||||
|
||||
for (parent, prev_child) in new_parents
|
||||
for parent in new_parents
|
||||
# now add parents of all input nodes to n1 without duplicates
|
||||
insert_edge!(graph, n1, parent)
|
||||
|
||||
prev_child = new_parents_child_names[parent]
|
||||
update_child!(graph, parent, prev_child, "data_$(to_var_name(n1.id))")
|
||||
end
|
||||
|
||||
@ -311,19 +319,20 @@ function node_split!(graph::DAG, n1::Node)
|
||||
|
||||
for parent in n1_parents
|
||||
n_copy = copy(n1)
|
||||
|
||||
insert_node!(graph, n_copy)
|
||||
insert_edge!(graph, n_copy, parent)
|
||||
|
||||
for child in n1_children
|
||||
insert_edge!(graph, child, n_copy)
|
||||
end
|
||||
|
||||
update_child!(
|
||||
graph,
|
||||
parent,
|
||||
"data_$(to_var_name(n1.id))",
|
||||
"data_$(to_var_name(n_copy.id))",
|
||||
)
|
||||
|
||||
for child in n1_children
|
||||
insert_edge!(graph, child, n_copy)
|
||||
end
|
||||
end
|
||||
|
||||
return get_snapshot_diff(graph)
|
||||
|
@ -86,6 +86,16 @@ function is_valid_node_reduction_input(graph::DAG, nodes::Vector{Node})
|
||||
),
|
||||
)
|
||||
end
|
||||
|
||||
if (typeof(n) <: DataTaskNode)
|
||||
if (n.name != nodes[1].name)
|
||||
throw(
|
||||
AssertionError(
|
||||
"[Node Reduction] The given nodes do not have the same name",
|
||||
),
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
n1_children = nodes[1].children
|
||||
|
@ -20,11 +20,11 @@ Return a copy of th egiven [`FusedComputeTask`](@ref).
|
||||
"""
|
||||
function copy(t::FusedComputeTask{T1, T2}) where {T1, T2}
|
||||
return FusedComputeTask{T1, T2}(
|
||||
t.first_task,
|
||||
t.second_task,
|
||||
t.t1_inputs,
|
||||
copy(t.first_task),
|
||||
copy(t.second_task),
|
||||
copy(t.t1_inputs),
|
||||
t.t1_output,
|
||||
t.t2_inputs,
|
||||
copy(t.t2_inputs),
|
||||
)
|
||||
end
|
||||
|
||||
|
@ -4,5 +4,5 @@
|
||||
Print a string representation of the fused compute task to io.
|
||||
"""
|
||||
function show(io::IO, t::FusedComputeTask)
|
||||
return print(io, "ComputeFuse(", t.first_task, ", ", t.second_task, ")")
|
||||
return print(io, "ComputeFuse($(t.first_task), $(t.second_task))")
|
||||
end
|
||||
|
@ -42,9 +42,9 @@ include("../examples/profiling_utilities.jl")
|
||||
end
|
||||
|
||||
@testset "AB->AB after random walk" begin
|
||||
for _ in 1:50
|
||||
for i in 1:100
|
||||
graph = parse_abc(joinpath(@__DIR__, "..", "input", "AB->AB.txt"))
|
||||
random_walk!(graph, 40)
|
||||
random_walk!(graph, 50)
|
||||
|
||||
@test isapprox(
|
||||
execute(graph, particles),
|
||||
|
Loading…
x
Reference in New Issue
Block a user