1use core::convert::Infallible;
4use core::fmt::Debug;
5use core::future::Future;
6use core::ops::{Add, Div};
7use core::pin::Pin;
8use core::sync::atomic::{AtomicUsize, Ordering};
9use core::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
10
11use futures::task::AtomicWaker;
12use generic_array::{ArrayLength, GenericArray};
13use typenum::operator_aliases::{Quot, Sum};
14use typenum::{Const, ToUInt, U};
15
16use crate::datastore::generational;
17
18type UsizeBits = U<{ usize::BITS as usize }>;
19type UsizeBitsMinusOne = typenum::operator_aliases::Sub1<UsizeBits>;
20
21trait AddUsizeBitsMinusOne {
23 type Output;
25}
26
27trait DivCeilUsizeBits {
29 type Output;
31}
32
33trait Internal {
37 type LengthInWords: ArrayLength;
41}
42
43impl<const LEN: usize> AddUsizeBitsMinusOne for Const<LEN>
44where
45 Const<LEN>: ToUInt<Output: Add<UsizeBitsMinusOne>>,
46{
47 type Output = Sum<U<LEN>, UsizeBitsMinusOne>;
49}
50
51impl<const LEN: usize> DivCeilUsizeBits for Const<LEN>
52where
53 Const<LEN>: AddUsizeBitsMinusOne<Output: Div<UsizeBits>>,
54{
55 type Output = Quot<<Const<LEN> as AddUsizeBitsMinusOne>::Output, UsizeBits>;
57}
58
59impl<const LEN: usize> Internal for Const<LEN>
60where
61 Const<LEN>: DivCeilUsizeBits<Output: ArrayLength>,
62{
63 type LengthInWords = <Const<LEN> as DivCeilUsizeBits>::Output;
70}
71
72type LengthInWords<const LEN: usize> = <Const<LEN> as Internal>::LengthInWords;
74
75#[derive(Debug)]
77struct WakerShared<const LEN: usize>
78where
79 Const<LEN>: Internal,
80{
81 waker: AtomicWaker,
83
84 active: GenericArray<AtomicUsize, LengthInWords<LEN>>,
86}
87
88fn get_active_index_and_mask(index: usize) -> (usize, usize) {
91 let word_index = index / usize::BITS as usize;
92 let bit_index = index % usize::BITS as usize;
93 (word_index, 1 << bit_index)
94}
95
96impl<const LEN: usize> WakerShared<LEN>
97where
98 Const<LEN>: Internal,
99{
100 const fn new() -> Self {
102 let active = {
103 let mut active = GenericArray::uninit();
105
106 let mut index = 0;
108 let slice = active.as_mut_slice();
109 while index < slice.len() {
110 slice[index].write(AtomicUsize::new(usize::MAX));
111 index += 1;
112 }
113
114 unsafe { GenericArray::assume_init(active) }
116 };
117
118 Self {
119 waker: AtomicWaker::new(),
120 active,
121 }
122 }
123
124 fn reset(&self, index: usize) -> bool {
126 let (active_word, mask) = self.get_active_ref_and_mask(index);
127 let previous_value = active_word.fetch_and(!mask, Ordering::Relaxed);
128 (previous_value & mask) != 0
130 }
131
132 fn reset_all(&self) -> impl Iterator<Item = usize> + use<'_, LEN> {
134 (0..LEN).filter(|&index| self.reset(index))
135 }
136
137 fn set(&self, index: usize) -> bool {
140 let (active_word, mask) = self.get_active_ref_and_mask(index);
141 let previous_value = active_word.fetch_or(mask, Ordering::Relaxed);
142
143 self.waker.wake();
144
145 (previous_value & mask) != 0
147 }
148
149 async fn register_current(&self) {
151 core::future::poll_fn(|ctx| {
152 self.waker.register(ctx.waker());
153 Poll::Ready(())
154 })
155 .await;
156 }
157
158 fn get_active_ref_and_mask(&self, index: usize) -> (&AtomicUsize, usize) {
160 let (index, mask) = get_active_index_and_mask(index);
161 (&self.active[index], mask)
162 }
163}
164
165#[derive(Debug)]
167struct BitWaker<const LEN: usize>
168where
169 Const<LEN>: Internal,
170{
171 index: usize,
173
174 shared: Option<&'static WakerShared<LEN>>,
176}
177
178impl<const LEN: usize> BitWaker<LEN>
179where
180 Const<LEN>: Internal,
181{
182 const VTABLE: &RawWakerVTable = &RawWakerVTable::new(
185 |ptr| RawWaker::new(ptr, Self::VTABLE),
187 |ptr| unsafe { &*ptr.cast::<Self>() }.wake_by_ref(),
189 |ptr| unsafe { &*ptr.cast::<Self>() }.wake_by_ref(),
191 |_| {},
193 );
194
195 const fn invalid() -> Self {
197 Self {
198 index: usize::MAX,
199 shared: None,
200 }
201 }
202
203 const fn new(index: usize, shared: &'static WakerShared<LEN>) -> Self {
205 assert!(index < LEN, "Future index out of bounds.");
206 Self {
207 index,
208 shared: Some(shared),
209 }
210 }
211
212 fn wake_by_ref(&self) {
214 self.shared.unwrap().set(self.index);
215 }
216
217 fn as_waker(&'static self) -> Waker {
219 let pointer = (&raw const *self).cast();
220 unsafe { Waker::new(pointer, Self::VTABLE) }
222 }
223}
224
225#[derive(Debug)]
227#[expect(private_bounds)]
228pub struct ExecutorShared<const LEN: usize>
229where
230 Const<LEN>: Internal,
231{
232 shared: WakerShared<LEN>,
233 bit_wakers: [BitWaker<LEN>; LEN],
234}
235
236#[expect(private_bounds)]
237impl<const LEN: usize> ExecutorShared<LEN>
238where
239 Const<LEN>: Internal,
240{
241 pub const fn new(&'static self) -> Self {
250 let mut bit_wakers = [const { BitWaker::invalid() }; LEN];
251 let mut index = 0;
252 while index < LEN {
253 bit_wakers[index] = BitWaker::new(index, &self.shared);
254 index += 1;
255 }
256 Self {
257 shared: WakerShared::new(),
258 bit_wakers,
259 }
260 }
261}
262
263#[expect(private_bounds)]
296pub struct Executor<'a, const LEN: usize>
297where
298 Const<LEN>: Internal,
299{
300 source: Pin<&'a generational::Source>,
302 shared: &'static ExecutorShared<LEN>,
303 futures: [Pin<&'a mut (dyn Future<Output = Infallible> + 'a)>; LEN],
304}
305
306impl<const LEN: usize> core::fmt::Debug for Executor<'_, LEN>
307where
308 Const<LEN>: Internal,
309{
310 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
311 f.debug_struct("Executor")
312 .field("source", &self.source)
313 .field("shared", &self.shared)
314 .field("futures", &"<opaque>")
315 .finish()
316 }
317}
318
319#[expect(private_bounds)]
320impl<'a, const LEN: usize> Executor<'a, LEN>
321where
322 Const<LEN>: Internal,
323{
324 pub fn new(
326 shared: &'static ExecutorShared<LEN>,
327 source: Pin<&'a generational::Source>,
328 futures: [Pin<&'a mut (dyn Future<Output = Infallible> + 'a)>; LEN],
329 ) -> Self {
330 Self {
331 source,
332 shared,
333 futures,
334 }
335 }
336
337 pub(crate) fn run_once(&mut self) -> bool {
339 let mut polled = false;
340
341 for index in self.shared.shared.reset_all() {
342 let future = &mut self.futures[index];
343 let waker = self.shared.bit_wakers[index].as_waker();
344 let mut context = Context::from_waker(&waker);
345 match future.as_mut().poll(&mut context) {
346 Poll::Pending => {}
347 }
348 polled = true;
349 }
350
351 self.source.increment_generation();
352
353 polled
354 }
355
356 pub async fn run(mut self) -> ! {
358 loop {
359 self.shared.shared.register_current().await;
360
361 self.run_once();
364
365 let mut yielded = false;
368 core::future::poll_fn(|_| {
369 if yielded {
370 Poll::Ready(())
371 } else {
372 yielded = true;
373 Poll::Pending
374 }
375 })
376 .await;
377 }
378 }
379}
380
381#[cfg(test)]
382#[cfg_attr(coverage_nightly, coverage(off))]
383mod tests {
384 use core::pin::pin;
385 use core::task::Poll;
386 use std::vec::Vec;
387
388 use super::{BitWaker, Executor, ExecutorShared, WakerShared, get_active_index_and_mask};
389 use crate::datastore::generational;
390
391 const TWO_WORDS: usize = usize::BITS as usize * 2;
392
393 #[test]
394 fn calculate_indices() {
395 assert_eq!(get_active_index_and_mask(0), (0, 1 << 0));
397
398 assert_eq!(get_active_index_and_mask(1), (0, 1 << 1));
400
401 assert_eq!(
403 get_active_index_and_mask(usize::BITS as usize - 1),
404 (0, 1 << (usize::BITS as usize - 1))
405 );
406
407 assert_eq!(get_active_index_and_mask(usize::BITS as usize), (1, 1 << 0));
409
410 assert_eq!(
412 get_active_index_and_mask(usize::BITS as usize + 1),
413 (1, 1 << 1)
414 );
415 }
416
417 #[test]
418 fn waker_shared_initializes_as_all_awake() {
419 assert_eq!(
420 Vec::from_iter(WakerShared::<0>::new().reset_all()),
421 Vec::<usize>::new()
424 );
425 assert_eq!(
426 Vec::from_iter(WakerShared::<1>::new().reset_all()),
427 Vec::from_iter(0..1)
428 );
429 assert_eq!(
430 Vec::from_iter(WakerShared::<{ usize::BITS as usize - 1 }>::new().reset_all()),
431 Vec::from_iter(0..usize::BITS as usize - 1)
432 );
433 assert_eq!(
434 Vec::from_iter(WakerShared::<{ usize::BITS as usize }>::new().reset_all()),
435 Vec::from_iter(0..usize::BITS as usize)
436 );
437 assert_eq!(
438 Vec::from_iter(WakerShared::<{ usize::BITS as usize + 1 }>::new().reset_all()),
439 Vec::from_iter(0..usize::BITS as usize + 1)
440 );
441 }
442
443 #[test]
444 fn bitwaker_valid_indexes() {
445 static SHARED: WakerShared<TWO_WORDS> = WakerShared::new();
446 let mut i = 0;
447 while i < TWO_WORDS {
448 BitWaker::new(i, &SHARED).wake_by_ref();
449 i += 1;
450 }
451 assert!(std::panic::catch_unwind(|| BitWaker::new(i, &SHARED)).is_err());
452 }
453
454 #[test]
455 fn extra_code_coverage() {
456 static SHARED: ExecutorShared<1> = ExecutorShared::new(&SHARED);
457
458 let _ = ExecutorShared::new(&SHARED);
461
462 let source = pin!(generational::Source::new());
463 let futures = [pin!(async move { core::future::pending().await }) as _];
464 let executor = Executor::new(&SHARED, source.as_ref(), futures);
465
466 let _ = std::format!("{executor:?}");
467
468 let _ = BitWaker::<1>::invalid();
469 }
470
471 #[cfg(not(miri))] #[test]
473 fn executor() {
474 let (tx, rx) = std::sync::mpsc::channel();
475
476 std::thread::spawn({
477 move || {
478 let source = pin!(generational::Source::new());
479
480 static SHARED: ExecutorShared<1> = ExecutorShared::new(&SHARED);
481 let futures = [pin!(async move {
482 let mut yielded = false;
483 core::future::poll_fn(|cx| {
484 if yielded {
485 Poll::Ready(())
486 } else {
487 yielded = true;
488 cx.waker().wake_by_ref();
489
490 #[expect(clippy::waker_clone_wake)]
492 cx.waker().clone().wake();
493
494 Poll::Pending
495 }
496 })
497 .await;
498 let _ = tx.send(());
500 std::future::pending().await
501 }) as _];
502
503 let executor = Executor::new(&SHARED, source.as_ref(), futures);
504
505 futures::executor::block_on(executor.run());
506 }
507 });
508
509 assert!(rx.recv_timeout(std::time::Duration::from_secs(1)).is_ok());
510 }
511}