pavex_session_memory_store/
lib.rs

1//! An in-memory session store for `pavex_session`, geared towards testing and local development.
2use pavex::{methods, time::Timestamp};
3use std::{borrow::Cow, collections::HashMap, num::NonZeroUsize, sync::Arc, time::Duration};
4use tokio::sync::{Mutex, MutexGuard};
5
6use pavex_session::{
7    SessionId, SessionStore,
8    store::{
9        SessionRecord, SessionRecordRef, SessionStorageBackend,
10        errors::{
11            ChangeIdError, CreateError, DeleteError, DeleteExpiredError, DuplicateIdError,
12            LoadError, UnknownIdError, UpdateError, UpdateTtlError,
13        },
14    },
15};
16
17#[derive(Clone)]
18/// An in-memory session store.
19///
20/// # Limitations
21///
22/// This store won't persist data between server restarts.
23/// It also won't synchronize data between multiple server instances.
24/// It is primarily intended for testing and local development.
25pub struct InMemorySessionStore(Arc<Mutex<HashMap<SessionId, StoreRecord>>>);
26
27impl std::fmt::Debug for InMemorySessionStore {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        f.debug_struct("InMemorySessionStore")
30            .finish_non_exhaustive()
31    }
32}
33
34#[methods]
35impl From<InMemorySessionStore> for SessionStore {
36    #[singleton]
37    fn from(value: InMemorySessionStore) -> Self {
38        SessionStore::new(value)
39    }
40}
41
42#[doc(hidden)]
43// Here for backwards compatibility.
44pub type SessionStoreMemory = InMemorySessionStore;
45
46#[derive(Debug)]
47struct StoreRecord {
48    state: HashMap<Cow<'static, str>, serde_json::Value>,
49    deadline: Timestamp,
50}
51impl StoreRecord {
52    fn is_stale(&self) -> bool {
53        self.deadline <= Timestamp::now()
54    }
55}
56
57impl Default for InMemorySessionStore {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63#[methods]
64impl InMemorySessionStore {
65    /// Creates a new (empty) in-memory session store.
66    #[singleton]
67    pub fn new() -> Self {
68        Self(Arc::new(Mutex::new(HashMap::new())))
69    }
70
71    fn get_mut_if_fresh<'a, 'b, 'c: 'a>(
72        guard: &'a mut MutexGuard<'c, HashMap<SessionId, StoreRecord>>,
73        id: &'b SessionId,
74    ) -> Result<&'a mut StoreRecord, UnknownIdError> {
75        let Some(old_record) = guard.get_mut(id) else {
76            return Err(UnknownIdError { id: id.to_owned() });
77        };
78        if old_record.is_stale() {
79            return Err(UnknownIdError { id: id.to_owned() });
80        }
81        Ok(old_record)
82    }
83
84    /// Deletes a session record from the store using the provided ID.
85    ///
86    /// If the session exists, it is removed from the store.
87    fn _delete(
88        guard: &mut MutexGuard<'_, HashMap<SessionId, StoreRecord>>,
89        id: &SessionId,
90    ) -> Result<StoreRecord, UnknownIdError> {
91        let Some(old_record) = guard.remove(id) else {
92            return Err(UnknownIdError { id: id.to_owned() });
93        };
94        if old_record.is_stale() {
95            return Err(UnknownIdError { id: id.to_owned() });
96        }
97        Ok(old_record)
98    }
99}
100
101#[async_trait::async_trait]
102impl SessionStorageBackend for InMemorySessionStore {
103    /// Creates a new session record in the store using the provided ID.
104    #[tracing::instrument(name = "Create server-side session record", level = tracing::Level::TRACE, skip_all)]
105    async fn create(
106        &self,
107        id: &SessionId,
108        record: SessionRecordRef<'_>,
109    ) -> Result<(), CreateError> {
110        let mut guard = self.0.lock().await;
111        if Self::get_mut_if_fresh(&mut guard, id).is_ok() {
112            return Err(CreateError::DuplicateId(DuplicateIdError { id: *id }));
113        }
114
115        guard.insert(
116            *id,
117            StoreRecord {
118                state: record.state.into_owned(),
119                deadline: Timestamp::now() + record.ttl,
120            },
121        );
122        Ok(())
123    }
124
125    /// Update the state of an existing session in the store.
126    ///
127    /// It overwrites the existing record with the provided one.
128    #[tracing::instrument(name = "Update server-side session record", level = tracing::Level::TRACE, skip_all)]
129    async fn update(
130        &self,
131        id: &SessionId,
132        record: SessionRecordRef<'_>,
133    ) -> Result<(), UpdateError> {
134        let mut guard = self.0.lock().await;
135        let old_record = Self::get_mut_if_fresh(&mut guard, id)?;
136        *old_record = StoreRecord {
137            state: record.state.into_owned(),
138            deadline: Timestamp::now() + record.ttl,
139        };
140        Ok(())
141    }
142
143    /// Update the TTL of an existing session record in the store.
144    ///
145    /// It leaves the session state unchanged.
146    #[tracing::instrument(name = "Update TTL for server-side session record", level = tracing::Level::TRACE, skip_all)]
147    async fn update_ttl(
148        &self,
149        id: &SessionId,
150        ttl: std::time::Duration,
151    ) -> Result<(), UpdateTtlError> {
152        let mut guard = self.0.lock().await;
153        let old_record = Self::get_mut_if_fresh(&mut guard, id)?;
154        old_record.deadline = Timestamp::now() + ttl;
155        Ok(())
156    }
157
158    /// Loads an existing session record from the store using the provided ID.
159    ///
160    /// If a session with the given ID exists, it is returned. If the session
161    /// does not exist or has been invalidated (e.g., expired), `None` is
162    /// returned.
163    #[tracing::instrument(name = "Load server-side session record", level = tracing::Level::TRACE, skip_all)]
164    async fn load(&self, session_id: &SessionId) -> Result<Option<SessionRecord>, LoadError> {
165        let mut guard = self.0.lock().await;
166        let outcome = match Self::get_mut_if_fresh(&mut guard, session_id) {
167            Ok(old_record) => Some(SessionRecord {
168                state: old_record.state.clone(),
169                ttl: (old_record.deadline - Timestamp::now())
170                    .try_into()
171                    .unwrap_or(Duration::from_millis(0)),
172            }),
173            Err(_) => None,
174        };
175        Ok(outcome)
176    }
177
178    /// Deletes a session record from the store using the provided ID.
179    ///
180    /// If the session exists, it is removed from the store.
181    #[tracing::instrument(name = "Delete server-side session record", level = tracing::Level::TRACE, skip_all)]
182    async fn delete(&self, id: &SessionId) -> Result<(), DeleteError> {
183        let mut guard = self.0.lock().await;
184        Self::_delete(&mut guard, id)?;
185        Ok(())
186    }
187
188    /// Change the session id associated with an existing session record.
189    ///
190    /// The server-side state is left unchanged.
191    #[tracing::instrument(name = "Change id for server-side session record", level = tracing::Level::TRACE, skip_all)]
192    async fn change_id(&self, old_id: &SessionId, new_id: &SessionId) -> Result<(), ChangeIdError> {
193        let mut guard = self.0.lock().await;
194        if Self::get_mut_if_fresh(&mut guard, new_id).is_ok() {
195            return Err(DuplicateIdError {
196                id: new_id.to_owned(),
197            }
198            .into());
199        }
200        let record = Self::_delete(&mut guard, old_id)?;
201        guard.insert(*new_id, record);
202        Ok(())
203    }
204
205    /// Delete all expired records from the store.
206    #[tracing::instrument(name = "Delete expired records", level = tracing::Level::TRACE, skip_all)]
207    async fn delete_expired(
208        &self,
209        batch_size: Option<NonZeroUsize>,
210    ) -> Result<usize, DeleteExpiredError> {
211        let mut guard = self.0.lock().await;
212        let now = Timestamp::now();
213        let mut stale_ids = Vec::new();
214        for (id, record) in guard.iter() {
215            if record.deadline <= now {
216                stale_ids.push(*id);
217                if let Some(batch_size) = batch_size
218                    && stale_ids.len() >= batch_size.get()
219                {
220                    break;
221                }
222            }
223        }
224        let num_deleted = stale_ids.len();
225        for id in stale_ids {
226            guard.remove(&id);
227        }
228        Ok(num_deleted)
229    }
230}