Remove double edge insertions

This commit is contained in:
Anton Reinhard 2023-08-18 11:47:12 +02:00
parent ab38f618c3
commit 6ee444b46f
4 changed files with 71 additions and 21 deletions

View File

@ -2,3 +2,4 @@
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
MetagraphOptimization = "3e869610-d48d-4942-ba70-c1b702a33ca4"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProfileView = "c46f51b8-102a-5cf2-8d2c-8597cb0e0da7"

View File

@ -0,0 +1,36 @@
function test_random_walk(g::DAG, n::Int64)
# the purpose here is to do "random" operations and reverse them again and validate that the graph stays the same and doesn't diverge
reset_graph!(g)
properties = graph_properties(g)
for i = 1:n
# choose push or pop
if rand(Bool)
# push
opt = get_operations(g)
# choose one of fuse/split/reduce
option = rand(1:3)
if option == 1 && !isempty(opt.nodeFusions)
push_operation!(g, rand(collect(opt.nodeFusions)))
elseif option == 2 && !isempty(opt.nodeReductions)
push_operation!(g, rand(collect(opt.nodeReductions)))
elseif option == 3 && !isempty(opt.nodeSplits)
push_operation!(g, rand(collect(opt.nodeSplits)))
else
i = i - 1
end
else
# pop
if (can_pop(g))
pop_operation!(g)
else
i = i - 1
end
end
end
reset_graph!(g)
end

View File

@ -135,13 +135,12 @@ function insert_edge!(graph::DAG, edge::Edge, track=true)
node2 = edge.edge[2]
# 1: mute
if (node2 in node1.parents) || (node1 in node2.children)
#=if !(node2 in node1.parents && node1 in node2.children)
#=if (node2 in node1.parents) || (node1 in node2.children)
if !(node2 in node1.parents && node1 in node2.children)
error("One-sided edge")
end=#
return edge
end
end
error("Edge to insert already exists")
end=#
# edge points from child to parent
push!(node1.parents, node2)
@ -276,7 +275,7 @@ function can_reduce(n1::Node, n2::Node)
if (n1.task != n2.task)
return false
end
return Set(children(n1)) == Set(children(n2))
return Set(n1.children) == Set(n2.children)
end
function can_split(n::Node)
@ -293,8 +292,7 @@ function is_valid(graph::DAG)
current = pop!(nodeQueue)
push!(seenNodes, current)
childrenNodes = children(current)
for child in childrenNodes
for child in current.chlidren
push!(nodeQueue, child)
end
end

View File

@ -155,15 +155,22 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com
new_node = ComputeTaskNode(FusedComputeTask{typeof(n1.task), typeof(n3.task)}())
insert_node!(graph, new_node)
# "repoint" children of n1 to the new node
# use a set for combined children of n1 and n3 to not get duplicates
n1and3_children = Set{Node}()
# remove edges from n1 children to n1
for child in n1_children
remove_edge!(graph, make_edge(child, n1))
insert_edge!(graph, make_edge(child, new_node))
push!(n1and3_children, child)
end
# "repoint" children of n3 to the new node
# remove edges from n3 children to n3
for child in n3_children
remove_edge!(graph, make_edge(child, n3))
push!(n1and3_children, child)
end
for child in n1and3_children
insert_edge!(graph, make_edge(child, new_node))
end
@ -190,9 +197,9 @@ function node_reduction!(graph::DAG, n1::Node, n2::Node)
# save n2 parents and children
n2_children = children(n2)
n2_parents = parents(n2)
n2_parents = Set(n2.parents)
#=if Set(n2_children) != Set(children(n1))
#=if Set(n2_children) != Set(n1.children)
error("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction")
end=#
@ -200,12 +207,22 @@ function node_reduction!(graph::DAG, n1::Node, n2::Node)
for child in n2_children
remove_edge!(graph, make_edge(child, n2))
end
for parent in n2_parents
remove_edge!(graph, make_edge(n2, parent))
end
# add parents of n2 to n1
for parent in n1.parents
# delete parents in n1 that already exist in n2
delete!(n2_parents, parent)
end
for parent in n2_parents
# now add parents of n2 to n1 without duplicates
insert_edge!(graph, make_edge(n1, parent))
end
remove_node!(graph, n2)
return get_snapshot_diff(graph)
@ -355,12 +372,10 @@ function find_reductions!(graph::DAG, node::Node)
end
function find_splits!(graph::DAG, node::Node)
for node in graph.nodes
if (can_split(node))
ns = NodeSplit(node)
push!(graph.possibleOperations.nodeSplits, ns)
push!(node.operations, ns)
end
if (can_split(node))
ns = NodeSplit(node)
push!(graph.possibleOperations.nodeSplits, ns)
push!(node.operations, ns)
end
return nothing