Actually fix tests now

This commit is contained in:
Anton Reinhard 2023-08-17 21:53:55 +02:00
parent 78f7fb2f05
commit ef6184b8ea
3 changed files with 61 additions and 26 deletions

View File

@ -135,6 +135,14 @@ function insert_edge!(graph::DAG, edge::Edge, track=true)
node2 = edge.edge[2]
# 1: mute
if (node2 in node1.parents) || (node1 in node2.children)
#=if !(node2 in node1.parents && node1 in node2.children)
error("One-sided edge")
end=#
return edge
end
# edge points from child to parent
push!(node1.parents, node2)
push!(node2.children, node1)
@ -157,6 +165,9 @@ end
function remove_node!(graph::DAG, node::Node, track=true)
# 1: mute
#=if !(node in graph.nodes)
error("Trying to remove a node that's not in the graph")
end=#
delete!(graph.nodes, node)
# 2: keep track
@ -176,9 +187,21 @@ function remove_edge!(graph::DAG, edge::Edge, track=true)
node2 = edge.edge[2]
# 1: mute
pre_length1 = length(node1.parents)
pre_length2 = length(node2.children)
filter!(x -> x != node2, node1.parents)
filter!(x -> x != node1, node2.children)
#=removed = pre_length1 - length(node1.parents)
if (removed > 1)
error("removed $removed from node1's parents")
end
removed = pre_length2 - length(node2.children)
if (removed > 1)
error("removed $removed from node2's children")
end=#
# 2: keep track
if (track) push!(graph.diff.removedEdges, edge) end
@ -242,7 +265,7 @@ function can_fuse(n1::ComputeTaskNode, n2::DataTaskNode, n3::ComputeTaskNode)
return false
end
if length(parents(n2)) != 1 || length(children(n2)) != 1
if length(n2.parents) != 1 || length(n2.children) != 1 || length(n1.parents) != 1
return false
end

View File

@ -124,13 +124,22 @@ function node_fusion!(graph::DAG, n1::ComputeTaskNode, n2::DataTaskNode, n3::Com
end
# save children and parents
n1_parents = parents(n1)
n1_children = children(n1)
n2_parents = parents(n2)
n2_children = children(n2)
n3_parents = parents(n3)
n3_children = children(n3)
if length(n2_parents) > 1
error("[Node Fusion] The given data node has more than one parent")
end
if length(n2_children) > 1
error("[Node Fusion] The given data node has more than one child")
end
if length(n1_parents) > 1
error("[Node Fusion] The given n1 has more than one parent")
end
required_edge1 = make_edge(n1, n2)
required_edge2 = make_edge(n2, n3)
@ -174,21 +183,21 @@ function node_reduction!(graph::DAG, n1::Node, n2::Node)
# clear snapshot
get_snapshot_diff(graph)
if !(n1 in graph) || !(n2 in graph)
#=if !(n1 in graph) || !(n2 in graph)
error("[Node Reduction] The given nodes are not part of the given graph")
end
end=#
if typeof(n1) != typeof(n2)
#=if typeof(n1) != typeof(n2)
error("[Node Reduction] The given nodes are not of the same type")
end
end=#
# save n2 parents and children
n2_children = children(n2)
n2_parents = parents(n2)
if Set(n2_children) != Set(children(n1))
#=if Set(n2_children) != Set(children(n1))
error("[Node Reduction] The given nodes do not have equal prerequisite nodes which is required for node reduction")
end
end=#
# remove n2 and all its parents and children
for child in n2_children
@ -209,16 +218,16 @@ function node_split!(graph::DAG, n1::Node)
# clear snapshot
get_snapshot_diff(graph)
if !(n1 in graph)
#=if !(n1 in graph)
error("[Node Split] The given node is not part of the given graph")
end
end=#
n1_parents = parents(n1)
n1_children = children(n1)
if length(n1_parents) <= 1
#=if length(n1_parents) <= 1
error("[Node Split] The given node does not have multiple parents which is required for node split")
end
end=#
for parent in n1_parents
remove_edge!(graph, make_edge(n1, parent))
@ -244,9 +253,6 @@ end
# function to find node fusions involving the given node if it's a data node
# pushes the found fusion everywhere it needs to be and returns nothing
function find_fusions!(graph::DAG, node::DataTaskNode)
if !(node in graph)
error("wot")
end
if length(parents(node)) != 1 || length(children(node)) != 1
return nothing
end
@ -254,8 +260,12 @@ function find_fusions!(graph::DAG, node::DataTaskNode)
child_node = first(children(node))
parent_node = first(parents(node))
if !(child_node in graph) || !(parent_node in graph)
#=if !(child_node in graph) || !(parent_node in graph)
error("Parents/Children that are not in the graph!!!")
end=#
if length(child_node.parents) != 1
return nothing
end
nf = NodeFusion((child_node, node, parent_node))
@ -270,13 +280,10 @@ end
# function to find node fusions involving the given node if it's a compute node
# pushes the found fusion(s) everywhere it needs to be and returns nothing
function find_fusions!(graph::DAG, node::ComputeTaskNode)
if !(node in graph)
error("wot")
end
# for loop that always runs once for a scoped block we can break out of
for _ in 1:1
# assume this node as child of the chain
if length(parents(node)) < 1
if length(parents(node)) != 1
break
end
node2 = first(parents(node))
@ -285,10 +292,9 @@ function find_fusions!(graph::DAG, node::ComputeTaskNode)
end
node3 = first(parents(node2))
if !(node2 in graph) || !(node3 in graph)
#=if !(node2 in graph) || !(node3 in graph)
error("Parents/Children that are not in the graph!!!")
end
end=#
nf = NodeFusion((node, node2, node3))
push!(graph.possibleOperations.nodeFusions, nf)
@ -307,11 +313,14 @@ function find_fusions!(graph::DAG, node::ComputeTaskNode)
break
end
node1 = first(children(node2))
if !(node2 in graph) || !(node1 in graph)
error("Parents/Children that are not in the graph!!!")
if (length(node1.parents) > 1)
break
end
#=if !(node2 in graph) || !(node1 in graph)
error("Parents/Children that are not in the graph!!!")
end=#
nf = NodeFusion((node1, node2, node))
push!(graph.possibleOperations.nodeFusions, nf)
push!(node1.operations, nf)
@ -392,6 +401,9 @@ function generate_options(graph::DAG)
continue
end
child_node = pop!(node_children)
if (length(child_node.parents) != 1)
continue
end
nf = NodeFusion((child_node, node, parent_node))
push!(options.nodeFusions, nf)

View File

@ -61,7 +61,7 @@ function test_random_walk(g::DAG, n::Int64)
push_operation!(g, rand(collect(opt.nodeFusions)))
elseif option == 2 && !isempty(opt.nodeReductions)
push_operation!(g, rand(collect(opt.nodeReductions)))
elseif option == 3 && !isempty(opt.nodeSplits) && false
elseif option == 3 && !isempty(opt.nodeSplits)
push_operation!(g, rand(collect(opt.nodeSplits)))
else
i = i - 1