Optimizer interface and sample implementation (#19)

Reviewed-on: Rubydragon/MetagraphOptimization.jl#19
Co-authored-by: Anton Reinhard <anton.reinhard@proton.me>
Co-committed-by: Anton Reinhard <anton.reinhard@proton.me>
This commit is contained in:
2023-11-22 13:51:54 +01:00
committed by Anton Reinhard
parent 16274919e4
commit b7560685d4
53 changed files with 639 additions and 331 deletions

View File

@@ -17,21 +17,5 @@ function in(edge::Edge, graph::DAG)
return false
end
return n1 in n2.children
end
"""
==(n1::Node, n2::Node, g::DAG)
Check equality of two nodes in a graph.
"""
function ==(n1::Node, n2::Node, g::DAG)
if typeof(n1) != typeof(n2)
return false
end
if !(n1 in g) || !(n2 in g)
return false
end
return n1.task == n2.task && children(n1) == children(n2)
return n1 in children(n2)
end

View File

@@ -46,7 +46,7 @@ Insert the edge between node1 (child) and node2 (parent) into the graph.
See also: [`insert_node!`](@ref), [`remove_node!`](@ref), [`remove_edge!`](@ref)
"""
function insert_edge!(graph::DAG, node1::Node, node2::Node; track = true, invalidate_cache = true)
@assert (node2 node1.parents) && (node1 node2.children) "Edge to insert already exists"
#@assert (node2 ∉ parents(node1)) && (node1 ∉ children(node2)) "Edge to insert already exists"
# 1: mute
# edge points from child to parent
@@ -85,7 +85,7 @@ Remove the node from the graph.
See also: [`insert_node!`](@ref), [`insert_edge!`](@ref), [`remove_edge!`](@ref)
"""
function remove_node!(graph::DAG, node::Node; track = true, invalidate_cache = true)
@assert node in graph.nodes "Trying to remove a node that's not in the graph"
#@assert node in graph.nodes "Trying to remove a node that's not in the graph"
# 1: mute
delete!(graph.nodes, node)
@@ -124,18 +124,29 @@ function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invali
pre_length2 = length(node2.children)
#TODO: filter is very slow
filter!(x -> x != node2, node1.parents)
filter!(x -> x != node1, node2.children)
for i in eachindex(node1.parents)
if (node1.parents[i] == node2)
splice!(node1.parents, i)
break
end
end
@assert begin
for i in eachindex(node2.children)
if (node2.children[i] == node1)
splice!(node2.children, i)
break
end
end
#=@assert begin
removed = pre_length1 - length(node1.parents)
removed <= 1
end "removed more than one node from node1's parents"
end "removed more than one node from node1's parents"=#
@assert begin
removed = pre_length2 - length(node2.children)
#=@assert begin
removed = pre_length2 - length(children(node2))
removed <= 1
end "removed more than one node from node2's children"
end "removed more than one node from node2's children"=#
# 2: keep track
if (track)
@@ -163,7 +174,7 @@ function replace_children!(task::FusedComputeTask, before, after)
replacedIn1 = length(findall(x -> x == before, task.t1_inputs))
replacedIn2 = length(findall(x -> x == before, task.t2_inputs))
@assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
#@assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
replace!(task.t1_inputs, before => after)
replace!(task.t2_inputs, before => after)
@@ -185,33 +196,33 @@ end
function update_child!(graph::DAG, n::Node, child_before::Symbol, child_after::Symbol; track = true)
# only need to update fused compute tasks
if !(typeof(n.task) <: FusedComputeTask)
if !(typeof(task(n)) <: FusedComputeTask)
return nothing
end
taskBefore = copy(n.task)
taskBefore = copy(task(n))
if !((child_before in n.task.t1_inputs) || (child_before in n.task.t2_inputs))
#=if !((child_before in task(n).t1_inputs) || (child_before in task(n).t2_inputs))
println("------------------ Nothing to replace!! ------------------")
child_ids = Vector{String}()
for child in n.children
for child in children(n)
push!(child_ids, "$(child.id)")
end
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
@assert false
end
end=#
replace_children!(n.task, child_before, child_after)
replace_children!(task(n), child_before, child_after)
if !((child_after in n.task.t1_inputs) || (child_after in n.task.t2_inputs))
#=if !((child_after in task(n).t1_inputs) || (child_after in task(n).t2_inputs))
println("------------------ Did not replace anything!! ------------------")
child_ids = Vector{String}()
for child in n.children
for child in children(n)
push!(child_ids, "$(child.id)")
end
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
@assert false
end
end=#
# keep track
if (track)
@@ -242,8 +253,14 @@ function invalidate_caches!(graph::DAG, operation::NodeFusion)
# delete the operation from all caches of nodes involved in the operation
# TODO: filter is very slow
filter!(!=(operation), operation.input[1].nodeFusions)
filter!(!=(operation), operation.input[3].nodeFusions)
for n in [1, 3]
for i in eachindex(operation.input[n].nodeFusions)
if operation == operation.input[n].nodeFusions[i]
splice!(operation.input[n].nodeFusions, i)
break
end
end
end
operation.input[2].nodeFusion = missing

View File

@@ -30,10 +30,10 @@ function show(io::IO, graph::DAG)
nodeDict = Dict{Type, Int64}()
noEdges = 0
for node in graph.nodes
if haskey(nodeDict, typeof(node.task))
nodeDict[typeof(node.task)] = nodeDict[typeof(node.task)] + 1
if haskey(nodeDict, typeof(task(node)))
nodeDict[typeof(task(node))] = nodeDict[typeof(task(node))] + 1
else
nodeDict[typeof(node.task)] = 1
nodeDict[typeof(task(node))] = 1
end
noEdges += length(parents(node))
end

View File

@@ -43,3 +43,12 @@ function get_entry_nodes(graph::DAG)
end
return result
end
"""
operation_stack_length(graph::DAG)
Return the number of operations applied to the graph.
"""
function operation_stack_length(graph::DAG)
return length(graph.appliedOperations) + length(graph.operationsToApply)
end

View File

@@ -24,7 +24,7 @@ To get the set of possible operations, use [`get_operations`](@ref).
The members of the object should not be manually accessed, instead always use the provided interface functions.
"""
mutable struct DAG
nodes::Set{Node}
nodes::Set{Union{DataTaskNode, ComputeTaskNode}}
# The operations currently applied to the set of nodes
appliedOperations::Stack{AppliedOperation}
@@ -36,7 +36,7 @@ mutable struct DAG
possibleOperations::PossibleOperations
# The set of nodes whose possible operations need to be reevaluated
dirtyNodes::Set{Node}
dirtyNodes::Set{Union{DataTaskNode, ComputeTaskNode}}
# "snapshot" system: keep track of added/removed nodes/edges since last snapshot
# these are muted in insert_node! etc.