Still remove NodeFusion
This commit is contained in:
parent
97ccb3f3fb
commit
e9bd1f2939
@ -19,7 +19,6 @@ export AbstractTask
|
||||
export AbstractComputeTask
|
||||
export AbstractDataTask
|
||||
export DataTask
|
||||
export FusedComputeTask
|
||||
export PossibleOperations
|
||||
export GraphProperties
|
||||
|
||||
@ -44,7 +43,6 @@ export is_valid, is_scheduled
|
||||
# graph operation related
|
||||
export Operation
|
||||
export AppliedOperation
|
||||
export NodeFusion
|
||||
export NodeReduction
|
||||
export NodeSplit
|
||||
export push_operation!
|
||||
@ -88,7 +86,7 @@ export GlobalMetricEstimator, CDCost
|
||||
|
||||
# optimization
|
||||
export AbstractOptimizer, GreedyOptimizer, RandomWalkOptimizer
|
||||
export ReductionOptimizer, SplitOptimizer, FusionOptimizer
|
||||
export ReductionOptimizer, SplitOptimizer
|
||||
export optimize_step!, optimize!
|
||||
export fixpoint_reached, optimize_to_fixpoint!
|
||||
|
||||
@ -166,7 +164,6 @@ include("optimization/interface.jl")
|
||||
include("optimization/greedy.jl")
|
||||
include("optimization/random_walk.jl")
|
||||
include("optimization/reduce.jl")
|
||||
include("optimization/fuse.jl")
|
||||
include("optimization/split.jl")
|
||||
|
||||
include("models/interface.jl")
|
||||
|
@ -10,8 +10,7 @@ Representation of a [`DAG`](@ref)'s cost as estimated by the [`GlobalMetricEstim
|
||||
|
||||
|
||||
!!! note
|
||||
Note that the `computeIntensity` doesn't necessarily make sense in the context of only operation costs.
|
||||
For example, for node fusions this will always be 0, since the computeEffort is zero.
|
||||
Note that the `computeIntensity` doesn't necessarily make sense in the context of only operation costs.
|
||||
It will still work as intended when adding/subtracting to/from a `graph_cost` estimate.
|
||||
"""
|
||||
const CDCost = NamedTuple{(:data, :computeEffort, :computeIntensity), Tuple{Float64, Float64, Float64}}
|
||||
@ -55,10 +54,6 @@ function graph_cost(estimator::GlobalMetricEstimator, graph::DAG)
|
||||
)::CDCost
|
||||
end
|
||||
|
||||
function operation_effect(estimator::GlobalMetricEstimator, graph::DAG, operation::NodeFusion)
|
||||
return (data = -data(operation.input[2].task), computeEffort = 0.0, computeIntensity = 0.0)::CDCost
|
||||
end
|
||||
|
||||
function operation_effect(estimator::GlobalMetricEstimator, graph::DAG, operation::NodeReduction)
|
||||
s = length(operation.input) - 1
|
||||
return (
|
||||
|
@ -169,66 +169,6 @@ function remove_edge!(graph::DAG, node1::Node, node2::Node; track = true, invali
|
||||
return nothing
|
||||
end
|
||||
|
||||
function replace_children!(task::FusedComputeTask, before, after)
|
||||
replacedIn1 = length(findall(x -> x == before, task.t1_inputs))
|
||||
replacedIn2 = length(findall(x -> x == before, task.t2_inputs))
|
||||
|
||||
#@assert replacedIn1 >= 1 || replacedIn2 >= 1 "Nothing to replace while replacing $before with $after in $(task.t1_inputs...) and $(task.t2_inputs...)"
|
||||
|
||||
replace!(task.t1_inputs, before => after)
|
||||
replace!(task.t2_inputs, before => after)
|
||||
|
||||
# recursively descend down the tree, but only in the tasks where we're replacing things
|
||||
if replacedIn1 > 0
|
||||
replace_children!(task.first_task, before, after)
|
||||
end
|
||||
if replacedIn2 > 0
|
||||
replace_children!(task.second_task, before, after)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function replace_children!(task::AbstractTask, before, after)
|
||||
return nothing
|
||||
end
|
||||
|
||||
function update_child!(graph::DAG, n::Node, child_before::Symbol, child_after::Symbol; track = true)
|
||||
# only need to update fused compute tasks
|
||||
if !(typeof(task(n)) <: FusedComputeTask)
|
||||
return nothing
|
||||
end
|
||||
|
||||
taskBefore = copy(task(n))
|
||||
|
||||
#=if !((child_before in task(n).t1_inputs) || (child_before in task(n).t2_inputs))
|
||||
println("------------------ Nothing to replace!! ------------------")
|
||||
child_ids = Vector{String}()
|
||||
for child in children(n)
|
||||
push!(child_ids, "$(child.id)")
|
||||
end
|
||||
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
|
||||
@assert false
|
||||
end=#
|
||||
|
||||
replace_children!(task(n), child_before, child_after)
|
||||
|
||||
#=if !((child_after in task(n).t1_inputs) || (child_after in task(n).t2_inputs))
|
||||
println("------------------ Did not replace anything!! ------------------")
|
||||
child_ids = Vector{String}()
|
||||
for child in children(n)
|
||||
push!(child_ids, "$(child.id)")
|
||||
end
|
||||
println("From $(child_before) to $(child_after) in $n with children $(child_ids)")
|
||||
@assert false
|
||||
end=#
|
||||
|
||||
# keep track
|
||||
if (track)
|
||||
push!(graph.diff.updatedChildren, (n, taskBefore))
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
get_snapshot_diff(graph::DAG)
|
||||
|
||||
|
@ -40,19 +40,11 @@ function is_valid(graph::DAG)
|
||||
for ns in graph.possibleOperations.nodeSplits
|
||||
@assert is_valid(graph, ns)
|
||||
end
|
||||
for nf in graph.possibleOperations.nodeFusions
|
||||
@assert is_valid(graph, nf)
|
||||
end
|
||||
|
||||
for node in graph.dirtyNodes
|
||||
@assert node in graph "Dirty Node is not part of the graph!"
|
||||
@assert ismissing(node.nodeReduction) "Dirty Node has a NodeReduction!"
|
||||
@assert ismissing(node.nodeSplit) "Dirty Node has a NodeSplit!"
|
||||
if (typeof(node) <: DataTaskNode)
|
||||
@assert ismissing(node.nodeFusion) "Dirty DataTaskNode has a Node Fusion!"
|
||||
elseif (typeof(node) <: ComputeTaskNode)
|
||||
@assert isempty(node.nodeFusions) "Dirty ComputeTaskNode has Node Fusions!"
|
||||
end
|
||||
end
|
||||
|
||||
@assert is_connected(graph) "Graph is not connected!"
|
||||
|
3
src/models/physics_models/README.md
Normal file
3
src/models/physics_models/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
## Deprecation Warning
|
||||
|
||||
These models are deprecated and should not be used anymore. They will be dropped entirely soon.
|
@ -1,5 +1,3 @@
|
||||
|
||||
# TODO generalize this somehow, probably using AllSpin/AllPol as defaults
|
||||
"""
|
||||
parse_process(string::AbstractString, model::QEDModel)
|
||||
|
||||
@ -8,10 +6,10 @@ Parse a string representation of a process, such as "ke->ke" into the correspond
|
||||
function parse_process(
|
||||
str::AbstractString,
|
||||
model::QEDModel,
|
||||
inphpol::AbstractDefinitePolarization,
|
||||
inelspin::AbstractDefiniteSpin,
|
||||
outphpol::AbstractDefinitePolarization,
|
||||
outelspin::AbstractDefiniteSpin,
|
||||
inphpol::AbstractDefinitePolarization = PolX(),
|
||||
inelspin::AbstractDefiniteSpin = SpinUp(),
|
||||
outphpol::AbstractDefinitePolarization = PolX(),
|
||||
outelspin::AbstractDefiniteSpin = SpinUp(),
|
||||
)
|
||||
inParticles = Dict(
|
||||
ParticleStateful{Incoming, Photon, SFourMomentum} => 0,
|
||||
|
@ -20,45 +20,6 @@ See also: [`revert_operation!`](@ref).
|
||||
"""
|
||||
abstract type AppliedOperation end
|
||||
|
||||
"""
|
||||
NodeFusion <: Operation
|
||||
|
||||
The NodeFusion operation. Represents the fusing of a chain of compute node -> data node -> compute node.
|
||||
|
||||
After the node fusion is applied, the graph has 2 fewer nodes and edges, and a new [`FusedComputeTask`](@ref) with the two input compute nodes as parts.
|
||||
|
||||
# Requirements for successful application
|
||||
|
||||
A chain of (n1, n2, n3) can be fused if:
|
||||
- All nodes are in the graph.
|
||||
- (n1, n2) is an edge in the graph.
|
||||
- (n2, n3) is an edge in the graph.
|
||||
- n2 has exactly one parent (n3) and exactly one child (n1).
|
||||
- n1 has exactly one parent (n2).
|
||||
|
||||
[`is_valid_node_fusion_input`](@ref) can be used to `@assert` these requirements.
|
||||
|
||||
See also: [`can_fuse`](@ref)
|
||||
"""
|
||||
struct NodeFusion{TaskType1 <: AbstractComputeTask, TaskType2 <: AbstractDataTask, TaskType3 <: AbstractComputeTask} <:
|
||||
Operation
|
||||
input::Tuple{ComputeTaskNode{TaskType1}, DataTaskNode{TaskType2}, ComputeTaskNode{TaskType3}}
|
||||
end
|
||||
|
||||
"""
|
||||
AppliedNodeFusion <: AppliedOperation
|
||||
|
||||
The applied version of the [`NodeFusion`](@ref).
|
||||
"""
|
||||
struct AppliedNodeFusion{
|
||||
TaskType1 <: AbstractComputeTask,
|
||||
TaskType2 <: AbstractDataTask,
|
||||
TaskType3 <: AbstractComputeTask,
|
||||
} <: AppliedOperation
|
||||
operation::NodeFusion{TaskType1, TaskType2, TaskType3}
|
||||
diff::Diff
|
||||
end
|
||||
|
||||
"""
|
||||
NodeReduction <: Operation
|
||||
|
||||
|
@ -74,9 +74,6 @@ function mem(graph::DAG)
|
||||
size += sizeof(graph.operationsToApply)
|
||||
|
||||
size += sizeof(graph.possibleOperations)
|
||||
for op in graph.possibleOperations.nodeFusions
|
||||
size += mem(op)
|
||||
end
|
||||
for op in graph.possibleOperations.nodeReductions
|
||||
size += mem(op)
|
||||
end
|
||||
|
@ -1,6 +1,5 @@
|
||||
using SafeTestsets
|
||||
|
||||
#=
|
||||
@safetestset "Utility Unit Tests " begin
|
||||
include("unit_tests_utility.jl")
|
||||
end
|
||||
@ -19,11 +18,9 @@ end
|
||||
@safetestset "ABC-Model Unit Tests " begin
|
||||
include("unit_tests_abcmodel.jl")
|
||||
end
|
||||
=#
|
||||
@safetestset "QED-Model Unit Tests " begin
|
||||
include("unit_tests_qedmodel.jl")
|
||||
end
|
||||
#=
|
||||
@safetestset "QED Feynman Diagram Generation Tests" begin
|
||||
include("unit_tests_qed_diagrams.jl")
|
||||
end
|
||||
@ -42,4 +39,3 @@ end
|
||||
@safetestset "Known Graph Tests " begin
|
||||
include("known_graphs.jl")
|
||||
end
|
||||
=#
|
@ -1,5 +1,5 @@
|
||||
using MetagraphOptimization
|
||||
using QEDbase
|
||||
using QEDcore
|
||||
|
||||
import MetagraphOptimization.interaction_result
|
||||
|
||||
|
@ -144,32 +144,7 @@ end
|
||||
|
||||
for i in 1:100
|
||||
input = gen_process_input(process)
|
||||
@test length(input.inFerms) ==
|
||||
get(process.inParticles, ParticleStateful{Incoming, Electron, SFourMomentum}, 0)
|
||||
@test length(input.inAntiferms) ==
|
||||
get(process.inParticles, ParticleStateful{Incoming, Positron, SFourMomentum}, 0)
|
||||
@test length(input.inPhotons) ==
|
||||
get(process.inParticles, ParticleStateful{Incoming, Photon, SFourMomentum}, 0)
|
||||
@test length(input.outFerms) ==
|
||||
get(process.outParticles, ParticleStateful{Outgoing, Electron, SFourMomentum}, 0)
|
||||
@test length(input.outAntiferms) ==
|
||||
get(process.outParticles, ParticleStateful{Outgoing, Positron, SFourMomentum}, 0)
|
||||
@test length(input.outPhotons) ==
|
||||
get(process.outParticles, ParticleStateful{Outgoing, Photon, SFourMomentum}, 0)
|
||||
|
||||
@test isapprox(
|
||||
sum([
|
||||
getfield.(input.inFerms, :momentum)...,
|
||||
getfield.(input.inAntiferms, :momentum)...,
|
||||
getfield.(input.inPhotons, :momentum)...,
|
||||
]),
|
||||
sum([
|
||||
getfield.(input.outFerms, :momentum)...,
|
||||
getfield.(input.outAntiferms, :momentum)...,
|
||||
getfield.(input.outPhotons, :momentum)...,
|
||||
]);
|
||||
atol = sqrt(eps()),
|
||||
)
|
||||
@test isapprox(sum(momenta(input, Incoming())), sum(momenta(input, Outgoing())); atol = sqrt(eps()))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
Loading…
x
Reference in New Issue
Block a user