Add device info to nodes during scheduling
This commit is contained in:
parent
3267daadfd
commit
9b28601f18
@ -29,7 +29,7 @@ export children
|
||||
export compute
|
||||
export get_properties
|
||||
export get_exit_node
|
||||
export is_valid
|
||||
export is_valid, is_scheduled
|
||||
|
||||
export Operation
|
||||
export AppliedOperation
|
||||
|
@ -59,3 +59,19 @@ function is_valid(graph::DAG)
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
"""
|
||||
is_scheduled(graph::DAG)
|
||||
|
||||
Validate that the entire graph has been scheduled, i.e., every [`ComputeTaskNode`](@ref) has its `.device` set.
|
||||
"""
|
||||
function is_scheduled(graph::DAG)
|
||||
for node in graph.nodes
|
||||
if (node isa DataTaskNode)
|
||||
continue
|
||||
end
|
||||
@assert !ismissing(node.device)
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
@ -18,9 +18,9 @@ function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine)
|
||||
sizehint!(schedule, length(graph.nodes))
|
||||
|
||||
# keep an accumulated cost of things scheduled to this device so far
|
||||
deviceAccCost = Dict{AbstractDevice, Int}()
|
||||
deviceAccCost = PriorityQueue{AbstractDevice, Int}()
|
||||
for device in machine.devices
|
||||
deviceAccCost[device] = 0
|
||||
enqueue!(deviceAccCost, device => 0)
|
||||
end
|
||||
|
||||
node = nothing
|
||||
@ -28,6 +28,13 @@ function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine)
|
||||
@assert peek(nodeQueue)[2] == 0
|
||||
node = dequeue!(nodeQueue)
|
||||
|
||||
# assign the device with lowest accumulated cost to the node (if it's a compute node)
|
||||
if (isa(node, ComputeTaskNode))
|
||||
lowestDevice = peek(deviceAccCost)[1]
|
||||
node.device = lowestDevice
|
||||
deviceAccCost[lowestDevice] = compute_effort(node.task)
|
||||
end
|
||||
|
||||
push!(schedule, node)
|
||||
for parent in node.parents
|
||||
# reduce the priority of all parents by one
|
||||
|
@ -30,6 +30,9 @@ include("../examples/profiling_utilities.jl")
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001)
|
||||
|
||||
# graph should be fully scheduled after being executed
|
||||
@test is_scheduled(graph)
|
||||
|
||||
func = get_compute_function(graph, process_2_2, machine)
|
||||
@test isapprox(func(particles_2_2), expected_result; rtol = 0.001)
|
||||
end
|
||||
@ -39,9 +42,13 @@ include("../examples/profiling_utilities.jl")
|
||||
for i in 1:1000
|
||||
graph = parse_dag(joinpath(@__DIR__, "..", "input", "AB->AB.txt"), ABCModel())
|
||||
random_walk!(graph, 50)
|
||||
|
||||
@test is_valid(graph)
|
||||
|
||||
@test isapprox(execute(graph, process_2_2, machine, particles_2_2), expected_result; rtol = 0.001)
|
||||
|
||||
# graph should be fully scheduled after being executed
|
||||
@test is_scheduled(graph)
|
||||
end
|
||||
end
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user