diff --git a/crates/tako/src/internal/scheduler/gap.rs b/crates/tako/src/internal/scheduler/gap.rs new file mode 100644 index 000000000..063e233f3 --- /dev/null +++ b/crates/tako/src/internal/scheduler/gap.rs @@ -0,0 +1,247 @@ +use crate::internal::common::resources::ResourceId; +use crate::internal::server::workerload::WorkerResources; +use crate::internal::solver::{ConstraintType, LpSolver}; +use crate::resources::{ResourceAmount, ResourceRequestVariants, ResourceRqId, ResourceRqMap}; +use crate::{Map, ResourceVariantId}; +use hashbrown::Equivalent; +use std::cell::RefCell; + +#[derive(Default)] +pub(crate) struct GapCache { + inner: RefCell, +} + +#[derive(Hash, PartialEq, Eq)] +struct GapKey { + rq: ResourceRqId, + resources: WorkerResources, +} + +#[derive(Hash, PartialEq, Eq)] +struct GapKeyRef<'a> { + rq: ResourceRqId, + resources: &'a WorkerResources, +} + +impl<'a> Equivalent for GapKeyRef<'a> { + fn equivalent(&self, key: &GapKey) -> bool { + self.rq == key.rq && self.resources == &key.resources + } +} + +#[derive(Default)] +struct GapCacheInner { + resource_gaps: Map, +} + +impl GapCache { + pub fn get_gap( + &self, + high_priority_rq: ResourceRqId, + low_priority_rq: ResourceRqId, + resources: &WorkerResources, + assigned_tasks: impl Iterator, + resource_rq_map: &ResourceRqMap, + ) -> u32 { + let h_rqv = resource_rq_map.get(high_priority_rq); + if h_rqv.is_multi_node() { + return 0; + } + let l_rqv = resource_rq_map.get(low_priority_rq); + if l_rqv.is_multi_node() { + return 0; + } + let mut free: WorkerResources = if let Some(h_rq) = h_rqv.trivial_request() { + if h_rq.entries().iter().any(|r| r.request.amount_is_all()) { + return 0; + } + let count = resources.task_max_count_for_request(h_rq); + let mut resources = resources.clone(); + resources.remove_multiple(h_rq, count); + resources + } else { + let key = GapKeyRef { + rq: high_priority_rq, + resources, + }; + if let Some(free) = self.inner.borrow().resource_gaps.get(&key) { + free.clone() + } else { + let free = compute_gap_resources(h_rqv, resources); + self.inner.borrow_mut().resource_gaps.insert( + GapKey { + rq: high_priority_rq, + resources: resources.clone(), + }, + free.clone(), + ); + free + } + }; + for (rq_id, rv_id) in assigned_tasks { + if rq_id != high_priority_rq { + let rq = resource_rq_map.get(rq_id).get(rv_id); + free.remove(rq); + } + } + l_rqv + .requests() + .iter() + .map(|rq| free.task_max_count_for_request(rq)) + .min() + .unwrap_or(0) + } +} + +fn compute_gap_resources( + rqv: &ResourceRequestVariants, + resources: &WorkerResources, +) -> WorkerResources { + let Some(n_unresources) = rqv + .requests() + .iter() + .flat_map(|rq| rq.entries().iter().map(|r| r.resource_id.as_usize())) + .max() + else { + return WorkerResources::new(Vec::new().into()); + }; + let n_resources = n_unresources + 1; + let gap_res: Vec = resources + .iter_pairs() + .map(|(r_id, r_amount)| { + let mut solver = LpSolver::new(false); + let mut cst = vec![Vec::new(); n_resources]; + let vars: Vec<_> = rqv + .requests() + .iter() + .map(|rq| { + let a = rq.get_amount(r_id).unwrap_or(resources.get(r_id)).as_f64(); + solver.add_nat_variable(a) + }) + .collect(); + for (i, rq) in rqv.requests().iter().enumerate() { + for entry in rq.entries() { + let a = entry + .request + .amount_or_none_if_all() + .unwrap_or(resources.get(r_id)) + .as_f64(); + cst[entry.resource_id.as_usize()].push((vars[i], a)); + } + } + for (idx, c) in cst.into_iter().enumerate() { + let r_id = ResourceId::new(idx as u32); + solver.add_constraint( + ConstraintType::Max, + resources.get(r_id).as_f64(), + c.into_iter(), + ); + } + let Some((_, v)) = solver.solve() else { + return ResourceAmount::ZERO; + }; + r_amount - ResourceAmount::from_float(v.round() as f32) + }) + .collect(); + WorkerResources::new(gap_res.into()) +} + +#[cfg(test)] +mod tests { + use crate::internal::server::core::CoreSplitMut; + use std::iter; + + use crate::tests::utils::env::TestEnv; + use crate::tests::utils::task::TaskBuilder; + use crate::tests::utils::worker::WorkerBuilder; + use crate::{TaskId, WorkerId}; + + fn compute_gap(rt: &mut TestEnv, high_task: TaskId, low_task: TaskId, w: WorkerId) -> u32 { + let CoreSplitMut { + task_map, + worker_map, + scheduler_state, + request_map, + .. + } = rt.core().split_mut(); + let h_rq = task_map.get_task(high_task).resource_rq_id; + let l_rq = task_map.get_task(low_task).resource_rq_id; + let res = &worker_map.get_worker(w).resources; + scheduler_state + .gap_cache + .get_gap(h_rq, l_rq, &res, iter::empty(), request_map) + } + + #[test] + fn test_compute_gap() { + let mut rt = TestEnv::new(); + rt.new_named_resource("foo"); + rt.new_named_resource("bar"); + let w = rt.new_worker(&WorkerBuilder::new(4)); + let t1 = rt.new_task_cpus(2); + let t2 = rt.new_task_cpus(1); + assert_eq!(compute_gap(&mut rt, t1, t2, w), 0); + let t1 = rt.new_task_cpus(3); + assert_eq!(compute_gap(&mut rt, t1, t2, w), 1); + let t2 = rt.new_task_cpus(2); + assert_eq!(compute_gap(&mut rt, t1, t2, w), 0); + + let w = rt.new_worker(&WorkerBuilder::new(12).res_sum("foo", 2).res_sum("bar", 1)); + let t1 = rt.new_task_cpus(4); + let t2 = rt.new_task_cpus(2); + assert_eq!(compute_gap(&mut rt, t1, t2, w), 0); + let t1 = rt.new_task_cpus(5); + let t2 = rt.new_task_cpus(1); + assert_eq!(compute_gap(&mut rt, t1, t2, w), 2); + let t1 = rt.new_task(&TaskBuilder::new().cpus(5).add_resource(1, 2)); + assert_eq!(compute_gap(&mut rt, t1, t2, w), 7); + let t2 = rt.new_task(&TaskBuilder::new().cpus(1).add_resource(1, 1)); + assert_eq!(compute_gap(&mut rt, t1, t2, w), 0); + let t1 = rt.new_task(&TaskBuilder::new().cpus(5).add_resource(1, 2)); + let t2 = rt.new_task(&TaskBuilder::new().cpus(1).add_resource(2, 1)); + assert_eq!(compute_gap(&mut rt, t1, t2, w), 1); + let t1 = rt.new_task( + &TaskBuilder::new() + .cpus(8) + .next_variant() + .cpus(2) + .add_resource(1, 2), + ); + let t2 = rt.new_task(&TaskBuilder::new().cpus(1)); + assert_eq!(compute_gap(&mut rt, t1, t2, w), 2); + let t1 = rt.new_task( + &TaskBuilder::new() + .cpus(8) + .next_variant() + .cpus(2) + .add_resource(1, 1), + ); + assert_eq!(compute_gap(&mut rt, t1, t2, w), 0); + let t1 = rt.new_task( + &TaskBuilder::new() + .cpus(8) + .next_variant() + .cpus(2) + .add_resource(1, 2), + ); + let t2 = rt.new_task(&TaskBuilder::new().cpus(1).add_resource(2, 1)); + assert_eq!(compute_gap(&mut rt, t1, t2, w), 1); + + let w = rt.new_worker(&WorkerBuilder::new(6).res_sum("foo", 2).res_sum("bar", 2)); + let t1 = rt.new_task( + &TaskBuilder::new() + .cpus(2) + .add_resource(1, 1) + .next_variant() + .cpus(2) + .add_resource(2, 1), + ); + let t2 = rt.new_task_cpus(1); + assert_eq!(compute_gap(&mut rt, t1, t2, w), 0); + + let w = rt.new_worker(&WorkerBuilder::new(58)); + let t1 = rt.new_task(&TaskBuilder::new().cpus(13).next_variant().cpus(7)); + let t2 = rt.new_task_cpus(1); + assert_eq!(compute_gap(&mut rt, t1, t2, w), 2); + } +} diff --git a/crates/tako/src/internal/scheduler/gap_cache.rs b/crates/tako/src/internal/scheduler/gap_cache.rs deleted file mode 100644 index 5b6042a7b..000000000 --- a/crates/tako/src/internal/scheduler/gap_cache.rs +++ /dev/null @@ -1,228 +0,0 @@ -use crate::internal::common::resources::ResourceId; -use crate::internal::server::workerload::WorkerResources; -use crate::internal::solver::{ConstraintType, LpSolution, LpSolver}; -use crate::resources::{ResourceRequest, ResourceRequestVariants, ResourceRqId, ResourceRqMap}; -use crate::{Map, ResourceVariantId}; -use hashbrown::Equivalent; -use std::cell::RefCell; - -#[derive(Default)] -pub(crate) struct GapCache { - inner: RefCell, -} - -#[derive(Hash, PartialEq, Eq)] -struct GapKey { - high_priority_rq: ResourceRqId, - low_priority_rq: ResourceRqId, - low_priority_variant: ResourceVariantId, - resources: WorkerResources, -} - -#[derive(Hash, PartialEq, Eq)] -struct GapKeyRef<'a> { - high_priority_rq: ResourceRqId, - low_priority_rq: ResourceRqId, - low_priority_variant: ResourceVariantId, - resources: &'a WorkerResources, -} - -impl<'a> Equivalent for GapKeyRef<'a> { - fn equivalent(&self, key: &GapKey) -> bool { - self.high_priority_rq == key.high_priority_rq - && self.low_priority_rq == key.low_priority_rq - && self.resources == &key.resources - } -} - -#[derive(Default)] -struct GapCacheInner { - resource_gaps: Map, -} - -impl GapCache { - pub fn get_gap( - &self, - high_priority_rq: ResourceRqId, - low_priority_rq: ResourceRqId, - low_priority_variant: ResourceVariantId, - resources: &WorkerResources, - resource_rq_map: &ResourceRqMap, - ) -> u32 { - let key = GapKeyRef { - high_priority_rq, - low_priority_rq, - low_priority_variant, - resources, - }; - let mut inner = self.inner.borrow_mut(); - - if let Some(gap) = inner.resource_gaps.get(&key).copied() { - gap - } else { - let gap = compute_gap( - resource_rq_map.get(high_priority_rq), - resource_rq_map - .get(low_priority_rq) - .get(low_priority_variant), - resources, - ); - inner.resource_gaps.insert( - GapKey { - high_priority_rq, - low_priority_rq, - low_priority_variant, - resources: resources.clone(), - }, - gap, - ); - gap - } - } -} - -fn compute_gap( - high_priority_rqv: &ResourceRequestVariants, - low_priority_rq: &ResourceRequest, - resources: &WorkerResources, -) -> u32 { - if high_priority_rqv.is_multi_node() || low_priority_rq.is_multi_node() { - return 0; - } - if high_priority_rqv.is_trivial() { - let high_priority_rq = high_priority_rqv.get(0.into()); - let count = resources.task_max_count_for_request(high_priority_rq); - let mut resources = resources.clone(); - resources.remove_multiple(high_priority_rq, count); - resources.task_max_count_for_request(low_priority_rq) - } else { - if high_priority_rqv - .requests() - .iter() - .any(|rq| rq.entries().iter().any(|r| r.request.amount_is_all())) - { - return 0; - } - low_priority_rq - .entries() - .iter() - .map(|entry| { - let mut solver = LpSolver::new(false); - let vars: Vec<_> = high_priority_rqv - .requests() - .iter() - .map(|rq| { - solver.add_nat_variable(rq.get_amount(entry.resource_id).unwrap().as_f64()) - }) - .collect(); - let max_resource: usize = high_priority_rqv - .requests() - .iter() - .flat_map(|rq| rq.entries().iter().map(|r| r.resource_id.as_usize())) - .max() - .unwrap_or(0); - let mut cst = vec![Vec::new(); max_resource + 1]; - for (i, rq) in high_priority_rqv.requests().iter().enumerate() { - for entry in rq.entries() { - cst[entry.resource_id.as_usize()].push(( - vars[i], - entry.request.amount_or_none_if_all().unwrap().as_f64(), - )); - } - } - for (idx, c) in cst.into_iter().enumerate() { - let r_id = ResourceId::new(idx as u32); - solver.add_constraint( - ConstraintType::Max, - resources.get(r_id).as_f64(), - c.into_iter(), - ); - } - let Some((solution, _)) = solver.solve() else { - return 0; - }; - let mut resources = resources.clone(); - for (var, rq) in vars.iter().zip(high_priority_rqv.requests().iter()) { - resources.remove_multiple_masked( - rq, - solution.get_value(*var).round() as u32, - entry.resource_id, - ); - } - resources.task_max_count_for_request(low_priority_rq) - }) - .min() - .unwrap_or(0) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use crate::tests::utils::env::TestEnv; - use crate::tests::utils::resources::ResBuilder; - - use crate::tests::utils::worker::WorkerBuilder; - - #[test] - fn test_compute_gap() { - let mut rt = TestEnv::new(); - rt.new_named_resource("foo"); - rt.new_named_resource("bar"); - let w = rt.new_worker(&WorkerBuilder::new(4)); - let rqv1 = ResBuilder::default().cpus(2).finish_v(); - let rq2 = ResBuilder::default().cpus(1).finish(); - assert_eq!(compute_gap(&rqv1, &rq2, &rt.worker(w).resources), 0); - let rqv1 = ResBuilder::default().cpus(3).finish_v(); - let rq2 = ResBuilder::default().cpus(1).finish(); - assert_eq!(compute_gap(&rqv1, &rq2, &rt.worker(w).resources), 1); - let rqv1 = ResBuilder::default().cpus(3).finish_v(); - let rq2 = ResBuilder::default().cpus(2).finish(); - assert_eq!(compute_gap(&rqv1, &rq2, &rt.worker(w).resources), 0); - - let w = rt.new_worker(&WorkerBuilder::new(12).res_sum("foo", 2).res_sum("bar", 1)); - let rqv1 = ResBuilder::default().cpus(4).finish_v(); - let rq2 = ResBuilder::default().cpus(2).finish(); - assert_eq!(compute_gap(&rqv1, &rq2, &rt.worker(w).resources), 0); - let rqv1 = ResBuilder::default().cpus(5).finish_v(); - let rq2 = ResBuilder::default().cpus(1).finish(); - assert_eq!(compute_gap(&rqv1, &rq2, &rt.worker(w).resources), 2); - let rqv1 = ResBuilder::default().cpus(5).add_compact(1, 2).finish_v(); - let rq2 = ResBuilder::default().cpus(1).finish(); - assert_eq!(compute_gap(&rqv1, &rq2, &rt.worker(w).resources), 7); - let rqv1 = ResBuilder::default().cpus(5).add_compact(1, 2).finish_v(); - let rq2 = ResBuilder::default().cpus(1).add_compact(1, 1).finish(); - assert_eq!(compute_gap(&rqv1, &rq2, &rt.worker(w).resources), 0); - let rqv1 = ResBuilder::default().cpus(5).add_compact(1, 2).finish_v(); - let rq2 = ResBuilder::default().cpus(1).add_compact(2, 1).finish(); - assert_eq!(compute_gap(&rqv1, &rq2, &rt.worker(w).resources), 1); - - let mut rqv1 = ResBuilder::default().cpus(8).finish_v(); - rqv1.add_varint(ResBuilder::default().cpus(2).add_compact(1, 2)); - let rq2 = ResBuilder::default().cpus(1).finish(); - assert_eq!(compute_gap(&rqv1, &rq2, &rt.worker(w).resources), 2); - - let mut rqv1 = ResBuilder::default().cpus(8).finish_v(); - rqv1.add_varint(ResBuilder::default().cpus(2).add_compact(1, 1)); - let rq2 = ResBuilder::default().cpus(1).finish(); - assert_eq!(compute_gap(&rqv1, &rq2, &rt.worker(w).resources), 0); - - let mut rqv1 = ResBuilder::default().cpus(8).finish_v(); - rqv1.add_varint(ResBuilder::default().cpus(2).add_compact(1, 2)); - let rq2 = ResBuilder::default().cpus(1).add_compact(2, 1).finish(); - assert_eq!(compute_gap(&rqv1, &rq2, &rt.worker(w).resources), 1); - - let w = rt.new_worker(&WorkerBuilder::new(6).res_sum("foo", 2).res_sum("bar", 2)); - let mut rqv1 = ResBuilder::default().cpus(2).add_compact(1, 1).finish_v(); - rqv1.add_varint(ResBuilder::default().cpus(2).add_compact(2, 1)); - let rq2 = ResBuilder::default().cpus(1).finish(); - assert_eq!(compute_gap(&rqv1, &rq2, &rt.worker(w).resources), 0); - - let w = rt.new_worker(&WorkerBuilder::new(58)); - let mut rqv1 = ResBuilder::default().cpus(13).finish_v(); - rqv1.add_varint(ResBuilder::default().cpus(7)); - let rq2 = ResBuilder::default().cpus(1).finish(); - assert_eq!(compute_gap(&rqv1, &rq2, &rt.worker(w).resources), 2); - } -} diff --git a/crates/tako/src/internal/scheduler/mod.rs b/crates/tako/src/internal/scheduler/mod.rs index bc7377c0d..b33aa6d1c 100644 --- a/crates/tako/src/internal/scheduler/mod.rs +++ b/crates/tako/src/internal/scheduler/mod.rs @@ -1,5 +1,5 @@ mod batches; -mod gap_cache; +mod gap; mod main; mod mapping; pub(crate) mod query; diff --git a/crates/tako/src/internal/scheduler/solver.rs b/crates/tako/src/internal/scheduler/solver.rs index c9a29deaf..c2cc35916 100644 --- a/crates/tako/src/internal/scheduler/solver.rs +++ b/crates/tako/src/internal/scheduler/solver.rs @@ -22,7 +22,7 @@ pub(crate) fn run_scheduling_solver( let n_resources = core.resource_map().n_resources(); let CoreSplit { - task_map: _, + task_map, worker_map, task_queues: _, request_map, @@ -190,7 +190,14 @@ pub(crate) fn run_scheduling_solver( solver.set_name(|| format!("mn_{}_{}", batch.resource_rq_id, group_name)); let v = solver.add_nat_variable(0.0); solver.set_name(|| format!("MN size for rq{}", batch.resource_rq_id)); - constraint_extra_var(&mut solver, ConstraintType::Eq, 0.0, &temp, v, -n_nodes); + constraint_extra_var( + &mut solver, + ConstraintType::Eq, + 0.0, + temp.iter().copied(), + v, + -n_nodes, + ); tasks_count_vars .entry(batch.resource_rq_id) .or_default() @@ -215,7 +222,14 @@ pub(crate) fn run_scheduling_solver( let vars = tasks_count_vars.get(&blocker_rq_id).unwrap(); solver.set_name(|| format!("blocker rq{blocker_rq_id} at size {size}")); let bound = size as f64; - constraint_extra_var(solver, ConstraintType::Min, bound, vars, new_v, bound); + constraint_extra_var( + solver, + ConstraintType::Min, + bound, + vars.iter().copied(), + new_v, + bound, + ); new_v }) }; @@ -254,10 +268,66 @@ pub(crate) fn run_scheduling_solver( } } else { for w in &workers { + let Some(sn_assignment) = w.sn_assignment() else { + continue; + }; if !w.is_capable_to_run_rqv(blocker_rqv, now) { continue; } - for v_id in batch_rqv.variant_ids() { + let gap = scheduler_cache.gap_cache.get_gap( + *blocker_rq_id, + batch.resource_rq_id, + &w.resources, + sn_assignment.assigned_tasks.iter().map(|task_id| { + let t = task_map.get_task(*task_id); + (t.resource_rq_id, t.rv_id().unwrap()) + }), + request_map, + ); + if gap > 0 { + let vars = batch_rqv.variant_ids().filter_map(|v_id| { + placements.get(&(w.id, batch.resource_rq_id, v_id)).copied() + }); + let cut_size = cut.size as f64; + if let Some(s) = blocking_size { + let blocking_v = get_bvar(&mut solver, *blocker_rq_id, *s); + solver.set_name(|| { + format!( + "w{}: if #rq{blocker_rq_id} < {s} then limit #rq{} to {} + {} (gap) where both rqs may run", + w.id, batch.resource_rq_id, cut.size, gap + ) + }); + constraint_extra_var( + &mut solver, + ConstraintType::Max, + cut_size + batch_size + gap as f64, + vars, + blocking_v, + batch_size, + ); + } else { + solver.set_name(|| { + format!( + "w{}: limit #rq{} to {} + {} (gap) where it can run with rq{blocker_rq_id}", + w.id, batch.resource_rq_id, gap, cut.size, + ) + }); + solver.add_constraint( + ConstraintType::Max, + cut_size + gap as f64, + vars.into_iter().map(|v| (v, 1.0)), + ); + } + } else { + for v_id in batch_rqv.variant_ids() { + if let Some(var) = + placements.get(&(w.id, batch.resource_rq_id, v_id)) + { + zero_cond.push(*var); + } + } + } + /*for v_id in batch_rqv.variant_ids() { if let Some(var) = placements.get(&(w.id, batch.resource_rq_id, v_id)) { let gap = scheduler_cache.gap_cache.get_gap( *blocker_rq_id, @@ -298,7 +368,7 @@ pub(crate) fn run_scheduling_solver( zero_cond.push(*var); } } - } + }*/ } } if zero_cond.is_empty() { @@ -317,7 +387,7 @@ pub(crate) fn run_scheduling_solver( &mut solver, ConstraintType::Max, batch_size + cut_size, - &zero_cond, + zero_cond.iter().copied(), blocking_v, batch_size, ); @@ -508,15 +578,13 @@ fn constraint_extra_var( solver: &mut LpSolver, constraint_type: ConstraintType, limit_value: f64, - vars: &[Variable], + vars: impl Iterator, var: Variable, coef: f64, ) { solver.add_constraint( constraint_type, limit_value, - vars.iter() - .map(|v| (*v, 1.0)) - .chain(std::iter::once((var, coef))), + vars.map(|v| (v, 1.0)).chain(std::iter::once((var, coef))), ); } diff --git a/crates/tako/src/internal/scheduler/state.rs b/crates/tako/src/internal/scheduler/state.rs index ec1d552a6..68974c1aa 100644 --- a/crates/tako/src/internal/scheduler/state.rs +++ b/crates/tako/src/internal/scheduler/state.rs @@ -1,4 +1,4 @@ -use crate::internal::scheduler::gap_cache::GapCache; +use crate::internal::scheduler::gap::GapCache; use crate::{Map, ResourceVariantId, TaskId, WorkerId}; pub struct SchedulerConfig { diff --git a/crates/tako/src/internal/server/core.rs b/crates/tako/src/internal/server/core.rs index e8df3a991..c122d15e1 100644 --- a/crates/tako/src/internal/server/core.rs +++ b/crates/tako/src/internal/server/core.rs @@ -290,9 +290,9 @@ impl Core { match worker.assignment() { WorkerAssignment::Sn(s) => { if assigned == Some(*worker_id) { - assert!(s.assign_tasks.contains(&task_id)); + assert!(s.assigned_tasks.contains(&task_id)); } else { - assert!(!s.assign_tasks.contains(&task_id)); + assert!(!s.assigned_tasks.contains(&task_id)); } if prefilled == Some(*worker_id) { assert!(s.prefilled_tasks.contains(&task_id)); @@ -325,7 +325,7 @@ impl Core { worker .sn_assignment() .unwrap() - .assign_tasks + .assigned_tasks .contains(task_id) ); assert!( @@ -414,7 +414,7 @@ impl Core { } else { match worker.assignment() { WorkerAssignment::Sn(sn) => { - assert!(!sn.assign_tasks.contains(&task_id)); + assert!(!sn.assigned_tasks.contains(&task_id)); } WorkerAssignment::Mn(mn) => { assert_ne!(mn.task_id, task_id); diff --git a/crates/tako/src/internal/server/reactor.rs b/crates/tako/src/internal/server/reactor.rs index 3f56febfb..14f01a051 100644 --- a/crates/tako/src/internal/server/reactor.rs +++ b/crates/tako/src/internal/server/reactor.rs @@ -82,7 +82,7 @@ pub(crate) fn on_remove_worker( let mut retracted = Vec::new(); match worker.assignment() { WorkerAssignment::Sn(sn) => { - for task_id in &sn.assign_tasks { + for task_id in &sn.assigned_tasks { let task = task_map.get_task_mut(*task_id); if task.is_sn_running() { running_tasks.push(*task_id); diff --git a/crates/tako/src/internal/server/task.rs b/crates/tako/src/internal/server/task.rs index 79a8ae552..b979e670c 100644 --- a/crates/tako/src/internal/server/task.rs +++ b/crates/tako/src/internal/server/task.rs @@ -247,6 +247,15 @@ impl Task { } } + pub(crate) fn rv_id(&self) -> Option { + match self.state { + TaskRuntimeState::Running { rv_id, .. } | TaskRuntimeState::Assigned { rv_id, .. } => { + Some(rv_id) + } + _ => None, + } + } + pub(crate) fn increment_instance_id(&mut self) { self.instance_id = InstanceId::new(self.instance_id.as_num() + 1); } diff --git a/crates/tako/src/internal/server/worker.rs b/crates/tako/src/internal/server/worker.rs index b1f7fe558..923714530 100644 --- a/crates/tako/src/internal/server/worker.rs +++ b/crates/tako/src/internal/server/worker.rs @@ -39,8 +39,8 @@ pub struct MultiNodeTaskAssignment { #[derive(Debug)] pub struct SingleNodeTaskAssignment { - // This is list of single node assigned tasks - pub assign_tasks: Set, + // The set of single node assigned tasks + pub assigned_tasks: Set, pub free_resources: WorkerResources, pub prefilled_tasks: Set, } @@ -53,7 +53,7 @@ pub enum WorkerAssignment { impl WorkerAssignment { fn empty_sn(wr: &WorkerResources) -> Self { Self::Sn(SingleNodeTaskAssignment { - assign_tasks: Default::default(), + assigned_tasks: Default::default(), free_resources: wr.clone(), prefilled_tasks: Default::default(), }) @@ -113,6 +113,16 @@ impl Worker { } } + #[inline] + pub fn sn_assignment_and_resources( + &self, + ) -> Option<(&SingleNodeTaskAssignment, &WorkerResources)> { + match &self.assignment { + WorkerAssignment::Sn(a) => Some((a, &self.resources)), + WorkerAssignment::Mn(_) => None, + } + } + #[inline] pub fn mn_assignment(&self) -> Option<&MultiNodeTaskAssignment> { match &self.assignment { @@ -141,12 +151,12 @@ impl Worker { match &self.assignment { WorkerAssignment::Sn(a) => { let mut running_tasks = 0; - a.assign_tasks.iter().for_each(|task_id| { + a.assigned_tasks.iter().for_each(|task_id| { if task_map.get_task(*task_id).is_sn_running() { running_tasks += 1; } }); - let assigned_tasks = a.assign_tasks.len() as u32; + let assigned_tasks = a.assigned_tasks.len() as u32; WorkerRuntimeInfo::SingleNodeTasks { assigned_tasks, running_tasks, @@ -170,7 +180,7 @@ impl Worker { pub fn is_free(&self) -> bool { (match &self.assignment { - WorkerAssignment::Sn(a) => a.assign_tasks.is_empty(), + WorkerAssignment::Sn(a) => a.assigned_tasks.is_empty(), WorkerAssignment::Mn(_a) => false, }) && !self.is_stopping() } @@ -179,7 +189,7 @@ impl Worker { match &mut self.assignment { WorkerAssignment::Sn(a) => { a.free_resources.remove(rq); - assert!(a.assign_tasks.insert(task_id)); + assert!(a.assigned_tasks.insert(task_id)); } WorkerAssignment::Mn(_) => unreachable!(), } @@ -203,7 +213,7 @@ impl Worker { match &mut self.assignment { WorkerAssignment::Sn(a) => { assert!(a.prefilled_tasks.remove(&task_id)); - assert!(a.assign_tasks.insert(task_id)); + assert!(a.assigned_tasks.insert(task_id)); a.free_resources.remove(rq); } WorkerAssignment::Mn(_) => unreachable!(), @@ -213,8 +223,8 @@ impl Worker { pub fn remove_sn_task(&mut self, task_id: TaskId, rq: &ResourceRequest) { match &mut self.assignment { WorkerAssignment::Sn(a) => { - assert!(a.assign_tasks.remove(&task_id)); - if a.assign_tasks.is_empty() { + assert!(a.assigned_tasks.remove(&task_id)); + if a.assigned_tasks.is_empty() { self.idle_timestamp = Instant::now(); } a.free_resources.add(rq, &self.resources); @@ -231,7 +241,7 @@ impl Worker { ) { if let Some(a) = self.sn_assignment() { let mut resources = self.resources.clone(); - for task_id in a.assign_tasks.iter() { + for task_id in a.assigned_tasks.iter() { let task = task_map.get_task(*task_id); let (worker_id, rv_id) = match &task.state { TaskRuntimeState::Assigned { worker_id, rv_id } @@ -244,6 +254,7 @@ impl Worker { }; assert_eq!(self.id, worker_id); let rq = request_map.get(task.resource_rq_id).get(rv_id); + assert!(resources.is_capable_to_run_request(rq)); resources.remove(rq); } assert_eq!(a.free_resources, resources); @@ -360,7 +371,7 @@ impl Worker { "id": self.id, "assignment": match &self.assignment { WorkerAssignment::Sn(a) => json! ({ - "assigned_tasks": &a.assign_tasks + "assigned_tasks": &a.assigned_tasks }), WorkerAssignment::Mn(a) => json! ({ "task_id": a.task_id, diff --git a/crates/tako/src/internal/server/workerload.rs b/crates/tako/src/internal/server/workerload.rs index e01b1684c..cdcbb0254 100644 --- a/crates/tako/src/internal/server/workerload.rs +++ b/crates/tako/src/internal/server/workerload.rs @@ -19,6 +19,10 @@ pub struct WorkerResources { } impl WorkerResources { + pub(crate) fn new(n_resources: ResourceVec) -> Self { + Self { n_resources } + } + pub(crate) fn get(&self, resource_id: ResourceId) -> ResourceAmount { self.n_resources .get(resource_id) @@ -152,8 +156,8 @@ impl WorkerResources { pub fn remove(&mut self, rq: &ResourceRequest) { for entry in rq.entries() { if let Some(amount) = entry.request.amount_or_none_if_all() { - assert!(self.n_resources[entry.resource_id] >= amount); - self.n_resources[entry.resource_id] -= amount; + self.n_resources[entry.resource_id] = + self.n_resources[entry.resource_id].saturating_sub(amount); } else { self.n_resources[entry.resource_id] = ResourceAmount::ZERO; } @@ -164,8 +168,8 @@ impl WorkerResources { for entry in rq.entries() { if let Some(amount) = entry.request.amount_or_none_if_all() { let a = amount.times(n); - assert!(self.n_resources[entry.resource_id] >= a); - self.n_resources[entry.resource_id] -= a; + self.n_resources[entry.resource_id] = + self.n_resources[entry.resource_id].saturating_sub(a); } else { self.n_resources[entry.resource_id] = ResourceAmount::ZERO; } @@ -177,8 +181,8 @@ impl WorkerResources { if entry.resource_id == r_id { if let Some(amount) = entry.request.amount_or_none_if_all() { let a = amount.times(n); - assert!(self.n_resources[entry.resource_id] >= a); - self.n_resources[entry.resource_id] -= a; + self.n_resources[entry.resource_id] = + self.n_resources[entry.resource_id].saturating_sub(a); } else { self.n_resources[entry.resource_id] = ResourceAmount::ZERO; } diff --git a/crates/tako/src/internal/tests/test_reactor.rs b/crates/tako/src/internal/tests/test_reactor.rs index 8e0329fb3..31f436c92 100644 --- a/crates/tako/src/internal/tests/test_reactor.rs +++ b/crates/tako/src/internal/tests/test_reactor.rs @@ -607,9 +607,9 @@ fn lost_worker_with_running_and_assign_tasks() { fn check_worker_tasks_exact(core: &Core, worker_id: WorkerId, tasks: &[TaskId]) { let worker = core.get_worker(worker_id.into()); let sn = worker.sn_assignment().unwrap(); - assert_eq!(sn.assign_tasks.len(), tasks.len()); + assert_eq!(sn.assigned_tasks.len(), tasks.len()); for task in tasks { - assert!(sn.assign_tasks.contains(task)); + assert!(sn.assigned_tasks.contains(task)); } } @@ -617,7 +617,7 @@ fn worker_has_task(core: &Core, worker_id: WorkerId, task_id: TaskId) -> bool { core.get_worker(worker_id.into()) .sn_assignment() .unwrap() - .assign_tasks + .assigned_tasks .contains(&task_id) } @@ -814,7 +814,6 @@ fn test_prefill_submit_high_priority() { _ => panic!("Invalid worker msg"), } comm.emptiness_check(); - dbg!(&rt.task(t2).state); match rt.task(t2).state { TaskRuntimeState::Retracting { worker_id } => { assert_eq!(worker_id, w1); @@ -988,7 +987,7 @@ fn test_steal_running() { *rt.worker(w1) .sn_assignment() .unwrap() - .assign_tasks + .assigned_tasks .iter() .next() .unwrap(), @@ -1017,7 +1016,7 @@ fn test_steal_failed() { *rt.worker(w1) .sn_assignment() .unwrap() - .assign_tasks + .assigned_tasks .iter() .next() .unwrap(), diff --git a/crates/tako/src/internal/tests/test_scheduler_mapping.rs b/crates/tako/src/internal/tests/test_scheduler_mapping.rs index 0291dabf3..2fdf52f95 100644 --- a/crates/tako/src/internal/tests/test_scheduler_mapping.rs +++ b/crates/tako/src/internal/tests/test_scheduler_mapping.rs @@ -32,7 +32,7 @@ fn test_schedule_mapping_do_not_change() { rt.worker(w1) .sn_assignment() .unwrap() - .assign_tasks + .assigned_tasks .contains(&t1); let m = rt.schedule_mapping(); diff --git a/crates/tako/src/internal/tests/test_scheduler_sn.rs b/crates/tako/src/internal/tests/test_scheduler_sn.rs index 11771f71e..082b90db1 100644 --- a/crates/tako/src/internal/tests/test_scheduler_sn.rs +++ b/crates/tako/src/internal/tests/test_scheduler_sn.rs @@ -510,7 +510,7 @@ fn test_schedule_gap_filling3() { for w in ws { let mut cpus = 0; let mut t3count = 0; - for t in &rt.worker(w).sn_assignment().unwrap().assign_tasks { + for t in &rt.worker(w).sn_assignment().unwrap().assigned_tasks { if ts2.contains(t) { cpus += 9; } else { @@ -824,7 +824,7 @@ fn test_no_deps_scattering_2() { let mut counts: Vec<_> = rt .core() .get_workers() - .map(|w| w.sn_assignment().unwrap().assign_tasks.len()) + .map(|w| w.sn_assignment().unwrap().assigned_tasks.len()) .collect(); counts.sort(); assert_eq!(counts, expected); @@ -1015,9 +1015,9 @@ fn test_generic_resource_balancing3() { let w = rt.worker(w1); let a = w.sn_assignment().unwrap(); - assert_eq!(a.assign_tasks.len(), 2); + assert_eq!(a.assigned_tasks.len(), 2); assert!( - a.assign_tasks + a.assigned_tasks .iter() .all(|t| rt.task(*t).resource_rq_id == rq1) ); @@ -1032,7 +1032,7 @@ fn test_generic_resource_balancing3() { let w = rt.worker(w2); let a = w.sn_assignment().unwrap(); - assert_eq!(a.assign_tasks.len(), 2); + assert_eq!(a.assigned_tasks.len(), 2); assert_eq!(a.prefilled_tasks.len(), 57); assert_eq!( a.prefilled_tasks @@ -1265,8 +1265,14 @@ fn test_prefill_steal() { assert_eq!(r, vec![(w2, rv), (w2, rv)]); assert_eq!(prefill_count(&mut rt, w1), 3); assert_eq!(prefill_count(&mut rt, w2), 0); - assert_eq!(rt.worker(w1).sn_assignment().unwrap().assign_tasks.len(), 1); - assert_eq!(rt.worker(w2).sn_assignment().unwrap().assign_tasks.len(), 5); + assert_eq!( + rt.worker(w1).sn_assignment().unwrap().assigned_tasks.len(), + 1 + ); + assert_eq!( + rt.worker(w2).sn_assignment().unwrap().assigned_tasks.len(), + 5 + ); assert_eq!(rt.core().split_mut().scheduler_state.redirects.len(), 2); let (t, _) = rt .core() @@ -1300,23 +1306,48 @@ fn test_prefill_steal() { } #[test] -pub fn test_schedule_variant_gap() { +pub fn test_schedule_running() { let mut rt = TestEnv::new(); - rt.new_named_resource("gpus"); - // 8 cpus OR 1 cpus + 2 gpus - rt.new_tasks( - 10, - &TaskBuilder::new() - .user_priority(10) - .cpus(8) - .next_variant() - .cpus(4) - .add_resource(1, 2), - ); + let w = rt.new_worker(&WorkerBuilder::new(14)); + for _ in 0..8 { + rt.new_task_running(&TaskBuilder::new(), w); + } let ts = rt.new_tasks(10, &TaskBuilder::new()); - rt.new_worker(&WorkerBuilder::new(14).res_sum("gpus", 4)); rt.schedule(); - assert_eq!(ts.iter().filter(|t| rt.task(**t).is_assigned()).count(), 2); + assert_eq!( + rt.worker(w).sn_assignment().unwrap().assigned_tasks.len(), + 14 + ); + assert_eq!(ts.iter().filter(|t| rt.task(**t).is_assigned()).count(), 6); +} + +#[test] +pub fn test_schedule_variant_gap1() { + for running in [0, 1, 2] { + let mut rt = TestEnv::new(); + rt.new_named_resource("gpus"); + let w = rt.new_worker(&WorkerBuilder::new(14).res_sum("gpus", 4)); + for _ in 0..running { + rt.new_task_running(&TaskBuilder::new(), w); + } + + // 8 cpus OR 1 cpus + 2 gpus + rt.new_tasks( + 10, + &TaskBuilder::new() + .user_priority(10) + .cpus(8) + .next_variant() + .cpus(4) + .add_resource(1, 2), + ); + let ts = rt.new_tasks(10, &TaskBuilder::new()); + rt.schedule(); + assert_eq!( + ts.iter().filter(|t| rt.task(**t).is_assigned()).count(), + 2 - running + ); + } } #[test] diff --git a/crates/tako/src/internal/tests/utils/env.rs b/crates/tako/src/internal/tests/utils/env.rs index e03fadac9..be84d7e0b 100644 --- a/crates/tako/src/internal/tests/utils/env.rs +++ b/crates/tako/src/internal/tests/utils/env.rs @@ -131,7 +131,11 @@ impl TestEnv { } pub fn worker_tasks(&self, worker_id: WorkerId) -> &Set { - &self.worker(worker_id).sn_assignment().unwrap().assign_tasks + &self + .worker(worker_id) + .sn_assignment() + .unwrap() + .assigned_tasks } pub fn new_worker(&mut self, builder: &WorkerBuilder) -> WorkerId {