Optimizer interface and sample implementation (#19)
Reviewed-on: Rubydragon/MetagraphOptimization.jl#19 Co-authored-by: Anton Reinhard <anton.reinhard@proton.me> Co-committed-by: Anton Reinhard <anton.reinhard@proton.me>
This commit is contained in:
@@ -17,21 +17,5 @@ function in(edge::Edge, graph::DAG)
|
||||
return false
|
||||
end
|
||||
|
||||
return n1 in n2.children
|
||||
end
|
||||
|
||||
"""
|
||||
==(n1::Node, n2::Node, g::DAG)
|
||||
|
||||
Check equality of two nodes in a graph.
|
||||
"""
|
||||
function ==(n1::Node, n2::Node, g::DAG)
|
||||
if typeof(n1) != typeof(n2)
|
||||
return false
|
||||
end
|
||||
if !(n1 in g) || !(n2 in g)
|
||||
return false
|
||||
end
|
||||
|
||||
return n1.task == n2.task && children(n1) == children(n2)
|
||||
return n1 in children(n2)
|
||||
end
|
||||
|
@@ -46,7 +46,7 @@ Insert the edge between node1 (child) and node2 (parent) into the graph.
|
||||
See also: [`insert_node!`](@ref), [`remove_node!`](@ref), [`remove_edge!`](@ref)
|
||||
"""
|
||||
function insert_edge!(graph::DAG, node1::Node, node2::Node; track = true, invalidate_cache = true)
|
||||
@assert (node2 ∉ node1.parents) && (node1 ∉ node2.children) "Edge to insert already exists"
|
||||
#@assert (node2 ∉ parents(node1)) && (node1 ∉ children(node2)) "Edge to insert already exists"
|
||||
|
||||
# 1: mute
|
||||
# edge points from child to parent
|
||||
@@ -85,7 +85,7 @@ Remove the node from the graph.
|
||||
See also: [`insert_node!`](@ref), [`insert_edge!`](@ref), [`remove_edge!`](@ref)
|
||||
"""
|
||||
function remove_node!(graph::DAG, node::Node; track = true, invalidate_cache = true)
|
||||
@assert node in graph.nodes "Trying to remove a node that's not in the graph"
|
||||
#@assert node in graph.nodes "Trying to remove a node that's not in the graph"
|
||||
|
||||
# 1: mute
|
||||
delete!(graph.nodes, node)
|
||||
@@ -124,18 +124,29 @@ function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invali
|
||||
pre_length2 = length(node2.children)
|
||||
|
||||
#TODO: filter is very slow
|
||||
filter!(x -> x != node2, node1.parents)
|
||||
filter!(x -> x != node1, node2.children)
|
||||
for i in eachindex(node1.parents)
|
||||
if (node1.parents[i] == node2)
|
||||
splice!(node1.parents, i)
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
@assert begin
|
||||
for i in eachindex(node2.children)
|
||||
if (node2.children[i] == node1)
|
||||
splice!(node2.children, i)
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
#=@assert begin
|
||||
removed = pre_length1 - length(node1.parents)
|
||||
removed <= 1
|
||||
end "removed more than one node from node1's parents"
|
||||
end "removed more than one node from node1's parents"=#
|
||||
|
||||
@assert begin
|
||||
removed = pre_length2 - length(node2.children)
|
||||
#=@assert begin
|
||||
removed = pre_length2 - length(children(node2))
|
||||
removed <= 1
|
||||
end "removed more than one node from node2's children"
|
||||
end "removed more than one node from node2's children"=#
|
||||
|
||||
# 2: keep track
|
||||
if (track)
|
||||
@@ -163,7 +174,7 @@ function replace_children!(task::FusedComputeTask, before, after)
|
||||
replacedIn1 = length(findall(x -> x == before, task.t1_inputs))
|
||||
replacedIn2 = length(findall(x -> x == before, task.t2_inputs))
|
||||
|
||||
@assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
|
||||
#@assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
|
||||
|
||||
replace!(task.t1_inputs, before => after)
|
||||
replace!(task.t2_inputs, before => after)
|
||||
@@ -185,33 +196,33 @@ end
|
||||
|
||||
function update_child!(graph::DAG, n::Node, child_before::Symbol, child_after::Symbol; track = true)
|
||||
# only need to update fused compute tasks
|
||||
if !(typeof(n.task) <: FusedComputeTask)
|
||||
if !(typeof(task(n)) <: FusedComputeTask)
|
||||
return nothing
|
||||
end
|
||||
|
||||
taskBefore = copy(n.task)
|
||||
taskBefore = copy(task(n))
|
||||
|
||||
if !((child_before in n.task.t1_inputs) || (child_before in n.task.t2_inputs))
|
||||
#=if !((child_before in task(n).t1_inputs) || (child_before in task(n).t2_inputs))
|
||||
println("------------------ Nothing to replace!! ------------------")
|
||||
child_ids = Vector{String}()
|
||||
for child in n.children
|
||||
for child in children(n)
|
||||
push!(child_ids, "$(child.id)")
|
||||
end
|
||||
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
|
||||
@assert false
|
||||
end
|
||||
end=#
|
||||
|
||||
replace_children!(n.task, child_before, child_after)
|
||||
replace_children!(task(n), child_before, child_after)
|
||||
|
||||
if !((child_after in n.task.t1_inputs) || (child_after in n.task.t2_inputs))
|
||||
#=if !((child_after in task(n).t1_inputs) || (child_after in task(n).t2_inputs))
|
||||
println("------------------ Did not replace anything!! ------------------")
|
||||
child_ids = Vector{String}()
|
||||
for child in n.children
|
||||
for child in children(n)
|
||||
push!(child_ids, "$(child.id)")
|
||||
end
|
||||
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
|
||||
@assert false
|
||||
end
|
||||
end=#
|
||||
|
||||
# keep track
|
||||
if (track)
|
||||
@@ -242,8 +253,14 @@ function invalidate_caches!(graph::DAG, operation::NodeFusion)
|
||||
|
||||
# delete the operation from all caches of nodes involved in the operation
|
||||
# TODO: filter is very slow
|
||||
filter!(!=(operation), operation.input[1].nodeFusions)
|
||||
filter!(!=(operation), operation.input[3].nodeFusions)
|
||||
for n in [1, 3]
|
||||
for i in eachindex(operation.input[n].nodeFusions)
|
||||
if operation == operation.input[n].nodeFusions[i]
|
||||
splice!(operation.input[n].nodeFusions, i)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
operation.input[2].nodeFusion = missing
|
||||
|
||||
|
@@ -30,10 +30,10 @@ function show(io::IO, graph::DAG)
|
||||
nodeDict = Dict{Type, Int64}()
|
||||
noEdges = 0
|
||||
for node in graph.nodes
|
||||
if haskey(nodeDict, typeof(node.task))
|
||||
nodeDict[typeof(node.task)] = nodeDict[typeof(node.task)] + 1
|
||||
if haskey(nodeDict, typeof(task(node)))
|
||||
nodeDict[typeof(task(node))] = nodeDict[typeof(task(node))] + 1
|
||||
else
|
||||
nodeDict[typeof(node.task)] = 1
|
||||
nodeDict[typeof(task(node))] = 1
|
||||
end
|
||||
noEdges += length(parents(node))
|
||||
end
|
||||
|
@@ -43,3 +43,12 @@ function get_entry_nodes(graph::DAG)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
"""
|
||||
operation_stack_length(graph::DAG)
|
||||
|
||||
Return the number of operations applied to the graph.
|
||||
"""
|
||||
function operation_stack_length(graph::DAG)
|
||||
return length(graph.appliedOperations) + length(graph.operationsToApply)
|
||||
end
|
||||
|
@@ -24,7 +24,7 @@ To get the set of possible operations, use [`get_operations`](@ref).
|
||||
The members of the object should not be manually accessed, instead always use the provided interface functions.
|
||||
"""
|
||||
mutable struct DAG
|
||||
nodes::Set{Node}
|
||||
nodes::Set{Union{DataTaskNode, ComputeTaskNode}}
|
||||
|
||||
# The operations currently applied to the set of nodes
|
||||
appliedOperations::Stack{AppliedOperation}
|
||||
@@ -36,7 +36,7 @@ mutable struct DAG
|
||||
possibleOperations::PossibleOperations
|
||||
|
||||
# The set of nodes whose possible operations need to be reevaluated
|
||||
dirtyNodes::Set{Node}
|
||||
dirtyNodes::Set{Union{DataTaskNode, ComputeTaskNode}}
|
||||
|
||||
# "snapshot" system: keep track of added/removed nodes/edges since last snapshot
|
||||
# these are muted in insert_node! etc.
|
||||
|
Reference in New Issue
Block a user