veecle_os_runtime/datastore/
generational.rs

1//! [`generational`](self) supports synchronizing the "generation" of data from a producer to multiple consumers.
2//! The producer should own a [`Source`] that is in charge of notifying when the generation is incremented.
3//! The consumers should have their own [`Waiter`]s referencing this `Source` allowing them to wait for an update to the
4//! generation.
5//!
6//! See the `tests` module for an example.
7
8use core::cell::Cell;
9use core::pin::Pin;
10use core::task::{Poll, Waker};
11
12use pin_cell::{PinCell, PinMut};
13use pin_project::pin_project;
14use wakerset::{ExtractedWakers, WakerList, WakerSlot};
15
16/// Tracks the current generation, waking tasks that are `await`ing associated [`Waiter`]s when it increments.
17#[derive(Debug, Default)]
18#[pin_project]
19pub struct Source {
20    generation: Cell<usize>,
21    #[pin]
22    list: PinCell<WakerList>,
23}
24
25impl Source {
26    pub(crate) fn new() -> Self {
27        Self::default()
28    }
29
30    /// Returns a new waiter for this source.
31    ///
32    /// # Panics
33    ///
34    /// If called more times than the `waiter_count` provided on init.
35    pub(crate) fn waiter(self: Pin<&Self>) -> Waiter<'_> {
36        Waiter::new(self)
37    }
38
39    /// Increments the generation of the current [`Source`] and notifies any waiting [`Waiter`]s they can continue.
40    pub(crate) fn increment_generation(self: Pin<&Self>) {
41        self.generation.set(self.generation.get() + 1);
42
43        let round = PinMut::as_mut(&mut self.project_ref().list.borrow_mut()).begin_extraction();
44        let mut wakers = ExtractedWakers::new();
45        let mut more = true;
46        while more {
47            more = PinMut::as_mut(&mut self.project_ref().list.borrow_mut())
48                .extract_some_wakers(round, &mut wakers);
49            wakers.wake_all();
50        }
51    }
52
53    fn link(self: Pin<&Self>, slot: Pin<&mut WakerSlot>, waker: Waker) {
54        PinMut::as_mut(&mut self.project_ref().list.borrow_mut()).link(slot, waker)
55    }
56
57    fn unlink(self: Pin<&Self>, slot: Pin<&mut WakerSlot>) {
58        PinMut::as_mut(&mut self.project_ref().list.borrow_mut()).unlink(slot)
59    }
60}
61
62/// Tracks the last seen generation of a [`Source`], when `await`ed will resolve once the source is at a newer
63/// generation.
64#[derive(Debug)]
65pub(crate) struct Waiter<'a> {
66    generation: usize,
67    source: Pin<&'a Source>,
68}
69
70impl<'a> Waiter<'a> {
71    /// Creates a new [`Waiter`].
72    pub(crate) fn new(source: Pin<&'a Source>) -> Self {
73        Self {
74            generation: source.generation.get(),
75            source,
76        }
77    }
78
79    /// Updates the generation from the source [`Source`] to allow waiting for the next generation.
80    pub(crate) fn update_generation(&mut self) {
81        self.generation = self.source.generation.get();
82    }
83
84    pub(crate) async fn wait(&self) -> Result<(), MissedUpdate> {
85        // Using a guard here makes sure that the slot is unlinked if this future is dropped before completing.
86        struct Guard<'a, 'b> {
87            source: Pin<&'a Source>,
88            slot: Pin<&'b mut WakerSlot>,
89        }
90
91        impl Drop for Guard<'_, '_> {
92            fn drop(&mut self) {
93                if self.slot.is_linked() {
94                    self.source.unlink(self.slot.as_mut());
95                }
96            }
97        }
98        use core::pin::pin;
99
100        let mut guard = Guard {
101            source: self.source,
102            slot: pin!(WakerSlot::new()),
103        };
104
105        core::future::poll_fn(|cx| {
106            let current = self.source.generation.get();
107
108            // If the generation is the same, we need to register the waker to be woken
109            // on next update. Else, it means we already got an update so we can return
110            // from the future.
111            if current == self.generation {
112                self.source.link(guard.slot.as_mut(), cx.waker().clone());
113                return Poll::Pending;
114            }
115
116            let expected = self.generation + 1;
117            if current != expected {
118                return Poll::Ready(Err(MissedUpdate { expected, current }));
119            }
120
121            Poll::Ready(Ok(()))
122        })
123        .await
124    }
125}
126
127/// Indicates that the [`Source`] has had multiple generation updates since the last time [`Waiter::update_generation`]
128/// was called, depending on the usecase this may mean some data values were missed.
129pub(crate) struct MissedUpdate {
130    pub(crate) expected: usize,
131    pub(crate) current: usize,
132}
133
134#[cfg(test)]
135#[cfg_attr(coverage_nightly, coverage(off))]
136mod tests {
137    use std::cell::Cell;
138    use std::future::Future;
139    use std::pin::pin;
140
141    use crate::datastore::generational;
142
143    #[test]
144    fn example() {
145        let source = pin!(generational::Source::new());
146
147        let counter = Cell::new(0);
148        let sum = Cell::new(0);
149        let mut waiter = source.as_ref().waiter();
150
151        let mut future = pin!(async {
152            loop {
153                let _ = waiter.wait().await;
154                waiter.update_generation();
155                sum.set(sum.get() + counter.get());
156            }
157        });
158
159        let mut context = std::task::Context::from_waker(futures::task::noop_waker_ref());
160
161        for i in 1..10 {
162            // Before incrementing the generation, nothing should happen.
163            assert!(future.as_mut().poll(&mut context).is_pending());
164            assert_eq!(sum.get(), (i - 1) * i / 2);
165
166            counter.set(i);
167            source.as_ref().increment_generation();
168
169            // After incrementing the generation it should run.
170            assert!(future.as_mut().poll(&mut context).is_pending());
171            assert_eq!(sum.get(), i * (i + 1) / 2);
172        }
173    }
174}