3434) ]
3535
3636use std:: fmt;
37- use std:: future:: Future ;
3837use std:: marker:: PhantomData ;
3938use std:: panic:: { RefUnwindSafe , UnwindSafe } ;
4039use std:: rc:: Rc ;
4140use std:: sync:: atomic:: { AtomicBool , Ordering } ;
42- use std:: sync:: { Arc , Mutex , RwLock , TryLockError } ;
41+ use std:: sync:: { Arc , Mutex , TryLockError } ;
4342use std:: task:: { Poll , Waker } ;
4443
4544use async_lock:: OnceCell ;
4645use async_task:: { Builder , Runnable } ;
46+ use atomic_waker:: AtomicWaker ;
4747use concurrent_queue:: ConcurrentQueue ;
4848use futures_lite:: { future, prelude:: * } ;
4949use slab:: Slab ;
50+ use thread_local:: ThreadLocal ;
5051
5152#[ doc( no_inline) ]
5253pub use async_task:: Task ;
@@ -266,8 +267,23 @@ impl<'a> Executor<'a> {
266267 fn schedule ( & self ) -> impl Fn ( Runnable ) + Send + Sync + ' static {
267268 let state = self . state ( ) . clone ( ) ;
268269
269- // TODO: If possible, push into the current local queue and notify the ticker.
270- move |runnable| {
270+ move |mut runnable| {
271+ // If possible, push into the current local queue and notify the ticker.
272+ if let Some ( local) = state. local_queue . get ( ) {
273+ runnable = if let Err ( err) = local. queue . push ( runnable) {
274+ err. into_inner ( )
275+ } else {
276+ // Wake up this thread if it's asleep, otherwise notify another
277+ // thread to try to have the task stolen.
278+ if let Some ( waker) = local. waker . take ( ) {
279+ waker. wake ( ) ;
280+ } else {
281+ state. notify ( ) ;
282+ }
283+ return ;
284+ }
285+ }
286+ // If the local queue is full, fallback to pushing onto the global injector queue.
271287 state. queue . push ( runnable) . unwrap ( ) ;
272288 state. notify ( ) ;
273289 }
@@ -510,7 +526,16 @@ struct State {
510526 queue : ConcurrentQueue < Runnable > ,
511527
512528 /// Local queues created by runners.
513- local_queues : RwLock < Vec < Arc < ConcurrentQueue < Runnable > > > > ,
529+ ///
530+ /// If possible, tasks are scheduled onto the local queue, and will only defer
531+ /// to other global queue when they're full, or the task is being scheduled from
532+ /// a thread without a runner.
533+ ///
534+ /// Note: if a runner terminates and drains its local queue, any subsequent
535+ /// spawn calls from the same thread will be added to the same queue, but won't
536+ /// be executed until `Executor::run` is run on the thread again, or another
537+ /// thread steals the task.
538+ local_queue : ThreadLocal < LocalQueue > ,
514539
515540 /// Set to `true` when a sleeping ticker is notified or no tickers are sleeping.
516541 notified : AtomicBool ,
@@ -527,7 +552,7 @@ impl State {
527552 fn new ( ) -> State {
528553 State {
529554 queue : ConcurrentQueue :: unbounded ( ) ,
530- local_queues : RwLock :: new ( Vec :: new ( ) ) ,
555+ local_queue : ThreadLocal :: new ( ) ,
531556 notified : AtomicBool :: new ( true ) ,
532557 sleepers : Mutex :: new ( Sleepers {
533558 count : 0 ,
@@ -654,6 +679,12 @@ impl Ticker<'_> {
654679 ///
655680 /// Returns `false` if the ticker was already sleeping and unnotified.
656681 fn sleep ( & mut self , waker : & Waker ) -> bool {
682+ self . state
683+ . local_queue
684+ . get_or_default ( )
685+ . waker
686+ . register ( waker) ;
687+
657688 let mut sleepers = self . state . sleepers . lock ( ) . unwrap ( ) ;
658689
659690 match self . sleeping {
@@ -692,7 +723,14 @@ impl Ticker<'_> {
692723
693724 /// Waits for the next runnable task to run.
694725 async fn runnable ( & mut self ) -> Runnable {
695- self . runnable_with ( || self . state . queue . pop ( ) . ok ( ) ) . await
726+ self . runnable_with ( || {
727+ self . state
728+ . local_queue
729+ . get ( )
730+ . and_then ( |local| local. queue . pop ( ) . ok ( ) )
731+ . or_else ( || self . state . queue . pop ( ) . ok ( ) )
732+ } )
733+ . await
696734 }
697735
698736 /// Waits for the next runnable task to run, given a function that searches for a task.
@@ -754,9 +792,6 @@ struct Runner<'a> {
754792 /// Inner ticker.
755793 ticker : Ticker < ' a > ,
756794
757- /// The local queue.
758- local : Arc < ConcurrentQueue < Runnable > > ,
759-
760795 /// Bumped every time a runnable task is found.
761796 ticks : usize ,
762797}
@@ -767,38 +802,34 @@ impl Runner<'_> {
767802 let runner = Runner {
768803 state,
769804 ticker : Ticker :: new ( state) ,
770- local : Arc :: new ( ConcurrentQueue :: bounded ( 512 ) ) ,
771805 ticks : 0 ,
772806 } ;
773- state
774- . local_queues
775- . write ( )
776- . unwrap ( )
777- . push ( runner. local . clone ( ) ) ;
778807 runner
779808 }
780809
781810 /// Waits for the next runnable task to run.
782811 async fn runnable ( & mut self , rng : & mut fastrand:: Rng ) -> Runnable {
812+ let local = self . state . local_queue . get_or_default ( ) ;
813+
783814 let runnable = self
784815 . ticker
785816 . runnable_with ( || {
786817 // Try the local queue.
787- if let Ok ( r) = self . local . pop ( ) {
818+ if let Ok ( r) = local. queue . pop ( ) {
788819 return Some ( r) ;
789820 }
790821
791822 // Try stealing from the global queue.
792823 if let Ok ( r) = self . state . queue . pop ( ) {
793- steal ( & self . state . queue , & self . local ) ;
824+ steal ( & self . state . queue , & local. queue ) ;
794825 return Some ( r) ;
795826 }
796827
797828 // Try stealing from other runners.
798- let local_queues = self . state . local_queues . read ( ) . unwrap ( ) ;
829+ let local_queues = & self . state . local_queue ;
799830
800831 // Pick a random starting point in the iterator list and rotate the list.
801- let n = local_queues. len ( ) ;
832+ let n = local_queues. iter ( ) . count ( ) ;
802833 let start = rng. usize ( ..n) ;
803834 let iter = local_queues
804835 . iter ( )
@@ -807,12 +838,12 @@ impl Runner<'_> {
807838 . take ( n) ;
808839
809840 // Remove this runner's local queue.
810- let iter = iter. filter ( |local | !Arc :: ptr_eq ( local , & self . local ) ) ;
841+ let iter = iter. filter ( |other | !core :: ptr :: eq ( * other , local) ) ;
811842
812843 // Try stealing from each local queue in the list.
813- for local in iter {
814- steal ( local , & self . local ) ;
815- if let Ok ( r) = self . local . pop ( ) {
844+ for other in iter {
845+ steal ( & other . queue , & local. queue ) ;
846+ if let Ok ( r) = local. queue . pop ( ) {
816847 return Some ( r) ;
817848 }
818849 }
@@ -826,7 +857,7 @@ impl Runner<'_> {
826857
827858 if self . ticks % 64 == 0 {
828859 // Steal tasks from the global queue to ensure fair task scheduling.
829- steal ( & self . state . queue , & self . local ) ;
860+ steal ( & self . state . queue , & local. queue ) ;
830861 }
831862
832863 runnable
@@ -836,15 +867,13 @@ impl Runner<'_> {
836867impl Drop for Runner < ' _ > {
837868 fn drop ( & mut self ) {
838869 // Remove the local queue.
839- self . state
840- . local_queues
841- . write ( )
842- . unwrap ( )
843- . retain ( |local| !Arc :: ptr_eq ( local, & self . local ) ) ;
844-
845- // Re-schedule remaining tasks in the local queue.
846- while let Ok ( r) = self . local . pop ( ) {
847- r. schedule ( ) ;
870+ if let Some ( local) = self . state . local_queue . get ( ) {
871+ // Re-schedule remaining tasks in the local queue.
872+ for r in local. queue . try_iter ( ) {
873+ // Explicitly reschedule the runnable back onto the global
874+ // queue to avoid rescheduling onto the local one.
875+ self . state . queue . push ( r) . unwrap ( ) ;
876+ }
848877 }
849878 }
850879}
@@ -904,18 +933,13 @@ fn debug_executor(executor: &Executor<'_>, name: &str, f: &mut fmt::Formatter<'_
904933 }
905934
906935 /// Debug wrapper for the local runners.
907- struct LocalRunners < ' a > ( & ' a RwLock < Vec < Arc < ConcurrentQueue < Runnable > > > > ) ;
936+ struct LocalRunners < ' a > ( & ' a ThreadLocal < LocalQueue > ) ;
908937
909938 impl fmt:: Debug for LocalRunners < ' _ > {
910939 fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
911- match self . 0 . try_read ( ) {
912- Ok ( lock) => f
913- . debug_list ( )
914- . entries ( lock. iter ( ) . map ( |queue| queue. len ( ) ) )
915- . finish ( ) ,
916- Err ( TryLockError :: WouldBlock ) => f. write_str ( "<locked>" ) ,
917- Err ( TryLockError :: Poisoned ( _) ) => f. write_str ( "<poisoned>" ) ,
918- }
940+ f. debug_list ( )
941+ . entries ( self . 0 . iter ( ) . map ( |local| local. queue . len ( ) ) )
942+ . finish ( )
919943 }
920944 }
921945
@@ -935,11 +959,32 @@ fn debug_executor(executor: &Executor<'_>, name: &str, f: &mut fmt::Formatter<'_
935959 f. debug_struct ( name)
936960 . field ( "active" , & ActiveTasks ( & state. active ) )
937961 . field ( "global_tasks" , & state. queue . len ( ) )
938- . field ( "local_runners" , & LocalRunners ( & state. local_queues ) )
962+ . field ( "local_runners" , & LocalRunners ( & state. local_queue ) )
939963 . field ( "sleepers" , & SleepCount ( & state. sleepers ) )
940964 . finish ( )
941965}
942966
967+ /// A queue local to each thread.
968+ ///
969+ /// It's Default implementation is used for initializing each
970+ /// thread's queue via `ThreadLocal::get_or_default`.
971+ ///
972+ /// The local queue *must* be flushed, and all pending runnables
973+ /// rescheduled onto the global queue when a runner is dropped.
974+ struct LocalQueue {
975+ queue : ConcurrentQueue < Runnable > ,
976+ waker : AtomicWaker ,
977+ }
978+
979+ impl Default for LocalQueue {
980+ fn default ( ) -> Self {
981+ Self {
982+ queue : ConcurrentQueue :: bounded ( 512 ) ,
983+ waker : AtomicWaker :: new ( ) ,
984+ }
985+ }
986+ }
987+
943988/// Runs a closure when dropped.
944989struct CallOnDrop < F : FnMut ( ) > ( F ) ;
945990
0 commit comments