Add device info to nodes during scheduling

This commit is contained in:
Anton Reinhard 2023-10-12 00:29:48 +02:00
parent 3267daadfd
commit 9b28601f18
4 changed files with 33 additions and 3 deletions

View File

@ -29,7 +29,7 @@ export children
export compute
export get_properties
export get_exit_node
export is_valid
export is_valid, is_scheduled
export Operation
export AppliedOperation

View File

@ -59,3 +59,19 @@ function is_valid(graph::DAG)
return true
end
"""
is_scheduled(graph::DAG)
Validate that the entire graph has been scheduled, i.e., every [`ComputeTaskNode`](@ref) has its `.device` set.
"""
function is_scheduled(graph::DAG)
for node in graph.nodes
if (node isa DataTaskNode)
continue
end
@assert !ismissing(node.device)
end
return true
end

View File

@ -18,9 +18,9 @@ function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine)
sizehint!(schedule, length(graph.nodes))
# keep an accumulated cost of things scheduled to this device so far
deviceAccCost = Dict{AbstractDevice, Int}()
deviceAccCost = PriorityQueue{AbstractDevice, Int}()
for device in machine.devices
deviceAccCost[device] = 0
enqueue!(deviceAccCost, device => 0)
end
node = nothing
@ -28,6 +28,13 @@ function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine)
@assert peek(nodeQueue)[2] == 0
node = dequeue!(nodeQueue)
# assign the device with lowest accumulated cost to the node (if it's a compute node)
if (isa(node, ComputeTaskNode))
lowestDevice = peek(deviceAccCost)[1]
node.device = lowestDevice
deviceAccCost[lowestDevice] = compute_effort(node.task)
end
push!(schedule, node)
for parent in node.parents
# reduce the priority of all parents by one

View File

@ -30,6 +30,9 @@ include("../examples/profiling_utilities.jl")
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001)
# graph should be fully scheduled after being executed
@test is_scheduled(graph)
func = get_compute_function(graph, process_2_2, machine)
@test isapprox(func(particles_2_2), expected_result; rtol = 0.001)
end
@ -39,9 +42,13 @@ include("../examples/profiling_utilities.jl")
for i in 1:1000
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
random_walk!(graph, 50)
@test is_valid(graph)
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001)
# graph should be fully scheduled after being executed
@test is_scheduled(graph)
end
end