Still remove NodeFusion
Some checks failed
MetagraphOptimization_CI / test (push) Failing after 7m23s
MetagraphOptimization_CI / docs (push) Successful in 7m57s

This commit is contained in:
Anton Reinhard 2024-08-19 14:02:46 +02:00
parent 97ccb3f3fb
commit e9bd1f2939
No known key found for this signature in database
GPG Key ID: D65083A1729C9270
11 changed files with 11 additions and 157 deletions

View File

@ -19,7 +19,6 @@ export AbstractTask
export AbstractComputeTask
export AbstractDataTask
export DataTask
export FusedComputeTask
export PossibleOperations
export GraphProperties
@ -44,7 +43,6 @@ export is_valid, is_scheduled
# graph operation related
export Operation
export AppliedOperation
export NodeFusion
export NodeReduction
export NodeSplit
export push_operation!
@ -88,7 +86,7 @@ export GlobalMetricEstimator, CDCost
# optimization
export AbstractOptimizer, GreedyOptimizer, RandomWalkOptimizer
export ReductionOptimizer, SplitOptimizer, FusionOptimizer
export ReductionOptimizer, SplitOptimizer
export optimize_step!, optimize!
export fixpoint_reached, optimize_to_fixpoint!
@ -166,7 +164,6 @@ include("optimization/interface.jl")
include("optimization/greedy.jl")
include("optimization/random_walk.jl")
include("optimization/reduce.jl")
include("optimization/fuse.jl")
include("optimization/split.jl")
include("models/interface.jl")

View File

@ -10,8 +10,7 @@ Representation of a [`DAG`](@ref)'s cost as estimated by the [`GlobalMetricEstim
!!! note
Note that the `computeIntensity` doesn't necessarily make sense in the context of only operation costs.
For example, for node fusions this will always be 0, since the computeEffort is zero.
Note that the `computeIntensity` doesn't necessarily make sense in the context of only operation costs.
It will still work as intended when adding/subtracting to/from a `graph_cost` estimate.
"""
const CDCost = NamedTuple{(:data, :computeEffort, :computeIntensity), Tuple{Float64, Float64, Float64}}
@ -55,10 +54,6 @@ function graph_cost(estimator::GlobalMetricEstimator, graph::DAG)
)::CDCost
end
function operation_effect(estimator::GlobalMetricEstimator, graph::DAG, operation::NodeFusion)
return (data = -data(operation.input[2].task), computeEffort = 0.0, computeIntensity = 0.0)::CDCost
end
function operation_effect(estimator::GlobalMetricEstimator, graph::DAG, operation::NodeReduction)
s = length(operation.input) - 1
return (

View File

@ -169,66 +169,6 @@ function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invali
return nothing
end
function replace_children!(task::FusedComputeTask, before, after)
replacedIn1 = length(findall(x -> x == before, task.t1_inputs))
replacedIn2 = length(findall(x -> x == before, task.t2_inputs))
#@assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
replace!(task.t1_inputs, before => after)
replace!(task.t2_inputs, before => after)
# recursively descend down the tree, but only in the tasks where we're replacing things
if replacedIn1 > 0
replace_children!(task.first_task, before, after)
end
if replacedIn2 > 0
replace_children!(task.second_task, before, after)
end
return nothing
end
function replace_children!(task::AbstractTask, before, after)
return nothing
end
function update_child!(graph::DAG, n::Node, child_before::Symbol, child_after::Symbol; track = true)
# only need to update fused compute tasks
if !(typeof(task(n)) <: FusedComputeTask)
return nothing
end
taskBefore = copy(task(n))
#=if !((child_before in task(n).t1_inputs) || (child_before in task(n).t2_inputs))
println("------------------ Nothing to replace!! ------------------")
child_ids = Vector{String}()
for child in children(n)
push!(child_ids, "$(child.id)")
end
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
@assert false
end=#
replace_children!(task(n), child_before, child_after)
#=if !((child_after in task(n).t1_inputs) || (child_after in task(n).t2_inputs))
println("------------------ Did not replace anything!! ------------------")
child_ids = Vector{String}()
for child in children(n)
push!(child_ids, "$(child.id)")
end
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
@assert false
end=#
# keep track
if (track)
push!(graph.diff.updatedChildren, (n, taskBefore))
end
end
"""
get_snapshot_diff(graph::DAG)

View File

@ -40,19 +40,11 @@ function is_valid(graph::DAG)
for ns in graph.possibleOperations.nodeSplits
@assert is_valid(graph, ns)
end
for nf in graph.possibleOperations.nodeFusions
@assert is_valid(graph, nf)
end
for node in graph.dirtyNodes
@assert node in graph "Dirty Node is not part of the graph!"
@assert ismissing(node.nodeReduction) "Dirty Node has a NodeReduction!"
@assert ismissing(node.nodeSplit) "Dirty Node has a NodeSplit!"
if (typeof(node) <: DataTaskNode)
@assert ismissing(node.nodeFusion) "Dirty DataTaskNode has a Node Fusion!"
elseif (typeof(node) <: ComputeTaskNode)
@assert isempty(node.nodeFusions) "Dirty ComputeTaskNode has Node Fusions!"
end
end
@assert is_connected(graph) "Graph is not connected!"

View File

@ -0,0 +1,3 @@
## Deprecation Warning
These models are deprecated and should not be used anymore. They will be dropped entirely soon.

View File

@ -1,5 +1,3 @@
# TODO generalize this somehow, probably using AllSpin/AllPol as defaults
"""
parse_process(string::AbstractString, model::QEDModel)
@ -8,10 +6,10 @@ Parse a string representation of a process, such as "ke->ke" into the correspond
function parse_process(
str::AbstractString,
model::QEDModel,
inphpol::AbstractDefinitePolarization,
inelspin::AbstractDefiniteSpin,
outphpol::AbstractDefinitePolarization,
outelspin::AbstractDefiniteSpin,
inphpol::AbstractDefinitePolarization = PolX(),
inelspin::AbstractDefiniteSpin = SpinUp(),
outphpol::AbstractDefinitePolarization = PolX(),
outelspin::AbstractDefiniteSpin = SpinUp(),
)
inParticles = Dict(
ParticleStateful{Incoming, Photon, SFourMomentum} => 0,

View File

@ -20,45 +20,6 @@ See also: [`revert_operation!`](@ref).
"""
abstract type AppliedOperation end
"""
NodeFusion <: Operation
The NodeFusion operation. Represents the fusing of a chain of compute node -> data node -> compute node.
After the node fusion is applied, the graph has 2 fewer nodes and edges, and a new [`FusedComputeTask`](@ref) with the two input compute nodes as parts.
# Requirements for successful application
A chain of (n1, n2, n3) can be fused if:
- All nodes are in the graph.
- (n1, n2) is an edge in the graph.
- (n2, n3) is an edge in the graph.
- n2 has exactly one parent (n3) and exactly one child (n1).
- n1 has exactly one parent (n2).
[`is_valid_node_fusion_input`](@ref) can be used to `@assert` these requirements.
See also: [`can_fuse`](@ref)
"""
struct NodeFusion{TaskType1 <: AbstractComputeTask, TaskType2 <: AbstractDataTask, TaskType3 <: AbstractComputeTask} <:
Operation
input::Tuple{ComputeTaskNode{TaskType1}, DataTaskNode{TaskType2}, ComputeTaskNode{TaskType3}}
end
"""
AppliedNodeFusion <: AppliedOperation
The applied version of the [`NodeFusion`](@ref).
"""
struct AppliedNodeFusion{
TaskType1 <: AbstractComputeTask,
TaskType2 <: AbstractDataTask,
TaskType3 <: AbstractComputeTask,
} <: AppliedOperation
operation::NodeFusion{TaskType1, TaskType2, TaskType3}
diff::Diff
end
"""
NodeReduction <: Operation

View File

@ -74,9 +74,6 @@ function mem(graph::DAG)
size += sizeof(graph.operationsToApply)
size += sizeof(graph.possibleOperations)
for op in graph.possibleOperations.nodeFusions
size += mem(op)
end
for op in graph.possibleOperations.nodeReductions
size += mem(op)
end

View File

@ -1,6 +1,5 @@
using SafeTestsets
#=
@safetestset "Utility Unit Tests " begin
include("unit_tests_utility.jl")
end
@ -19,11 +18,9 @@ end
@safetestset "ABC-Model Unit Tests " begin
include("unit_tests_abcmodel.jl")
end
=#
@safetestset "QED-Model Unit Tests " begin
include("unit_tests_qedmodel.jl")
end
#=
@safetestset "QED Feynman Diagram Generation Tests" begin
include("unit_tests_qed_diagrams.jl")
end
@ -42,4 +39,3 @@ end
@safetestset "Known Graph Tests " begin
include("known_graphs.jl")
end
=#

View File

@ -1,5 +1,5 @@
using MetagraphOptimization
using QEDbase
using QEDcore
import MetagraphOptimization.interaction_result

View File

@ -144,32 +144,7 @@ end
for i in 1:100
input = gen_process_input(process)
@test length(input.inFerms) ==
get(process.inParticles, ParticleStateful{Incoming, Electron, SFourMomentum}, 0)
@test length(input.inAntiferms) ==
get(process.inParticles, ParticleStateful{Incoming, Positron, SFourMomentum}, 0)
@test length(input.inPhotons) ==
get(process.inParticles, ParticleStateful{Incoming, Photon, SFourMomentum}, 0)
@test length(input.outFerms) ==
get(process.outParticles, ParticleStateful{Outgoing, Electron, SFourMomentum}, 0)
@test length(input.outAntiferms) ==
get(process.outParticles, ParticleStateful{Outgoing, Positron, SFourMomentum}, 0)
@test length(input.outPhotons) ==
get(process.outParticles, ParticleStateful{Outgoing, Photon, SFourMomentum}, 0)
@test isapprox(
sum([
getfield.(input.inFerms, :momentum)...,
getfield.(input.inAntiferms, :momentum)...,
getfield.(input.inPhotons, :momentum)...,
]),
sum([
getfield.(input.outFerms, :momentum)...,
getfield.(input.outAntiferms, :momentum)...,
getfield.(input.outPhotons, :momentum)...,
]);
atol = sqrt(eps()),
)
@test isapprox(sum(momenta(input, Incoming())), sum(momenta(input, Outgoing())); atol = sqrt(eps()))
end
end
end