pavex_session_memory_store/
lib.rs1use pavex::{methods, time::Timestamp};
3use std::{borrow::Cow, collections::HashMap, num::NonZeroUsize, sync::Arc, time::Duration};
4use tokio::sync::{Mutex, MutexGuard};
5
6use pavex_session::{
7 SessionId, SessionStore,
8 store::{
9 SessionRecord, SessionRecordRef, SessionStorageBackend,
10 errors::{
11 ChangeIdError, CreateError, DeleteError, DeleteExpiredError, DuplicateIdError,
12 LoadError, UnknownIdError, UpdateError, UpdateTtlError,
13 },
14 },
15};
16
17#[derive(Clone)]
18pub struct InMemorySessionStore(Arc<Mutex<HashMap<SessionId, StoreRecord>>>);
26
27impl std::fmt::Debug for InMemorySessionStore {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("InMemorySessionStore")
30 .finish_non_exhaustive()
31 }
32}
33
34#[methods]
35impl From<InMemorySessionStore> for SessionStore {
36 #[singleton]
37 fn from(value: InMemorySessionStore) -> Self {
38 SessionStore::new(value)
39 }
40}
41
42#[doc(hidden)]
43pub type SessionStoreMemory = InMemorySessionStore;
45
46#[derive(Debug)]
47struct StoreRecord {
48 state: HashMap<Cow<'static, str>, serde_json::Value>,
49 deadline: Timestamp,
50}
51impl StoreRecord {
52 fn is_stale(&self) -> bool {
53 self.deadline <= Timestamp::now()
54 }
55}
56
57impl Default for InMemorySessionStore {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63#[methods]
64impl InMemorySessionStore {
65 #[singleton]
67 pub fn new() -> Self {
68 Self(Arc::new(Mutex::new(HashMap::new())))
69 }
70
71 fn get_mut_if_fresh<'a, 'b, 'c: 'a>(
72 guard: &'a mut MutexGuard<'c, HashMap<SessionId, StoreRecord>>,
73 id: &'b SessionId,
74 ) -> Result<&'a mut StoreRecord, UnknownIdError> {
75 let Some(old_record) = guard.get_mut(id) else {
76 return Err(UnknownIdError { id: id.to_owned() });
77 };
78 if old_record.is_stale() {
79 return Err(UnknownIdError { id: id.to_owned() });
80 }
81 Ok(old_record)
82 }
83
84 fn _delete(
88 guard: &mut MutexGuard<'_, HashMap<SessionId, StoreRecord>>,
89 id: &SessionId,
90 ) -> Result<StoreRecord, UnknownIdError> {
91 let Some(old_record) = guard.remove(id) else {
92 return Err(UnknownIdError { id: id.to_owned() });
93 };
94 if old_record.is_stale() {
95 return Err(UnknownIdError { id: id.to_owned() });
96 }
97 Ok(old_record)
98 }
99}
100
101#[async_trait::async_trait]
102impl SessionStorageBackend for InMemorySessionStore {
103 #[tracing::instrument(name = "Create server-side session record", level = tracing::Level::TRACE, skip_all)]
105 async fn create(
106 &self,
107 id: &SessionId,
108 record: SessionRecordRef<'_>,
109 ) -> Result<(), CreateError> {
110 let mut guard = self.0.lock().await;
111 if Self::get_mut_if_fresh(&mut guard, id).is_ok() {
112 return Err(CreateError::DuplicateId(DuplicateIdError { id: *id }));
113 }
114
115 guard.insert(
116 *id,
117 StoreRecord {
118 state: record.state.into_owned(),
119 deadline: Timestamp::now() + record.ttl,
120 },
121 );
122 Ok(())
123 }
124
125 #[tracing::instrument(name = "Update server-side session record", level = tracing::Level::TRACE, skip_all)]
129 async fn update(
130 &self,
131 id: &SessionId,
132 record: SessionRecordRef<'_>,
133 ) -> Result<(), UpdateError> {
134 let mut guard = self.0.lock().await;
135 let old_record = Self::get_mut_if_fresh(&mut guard, id)?;
136 *old_record = StoreRecord {
137 state: record.state.into_owned(),
138 deadline: Timestamp::now() + record.ttl,
139 };
140 Ok(())
141 }
142
143 #[tracing::instrument(name = "Update TTL for server-side session record", level = tracing::Level::TRACE, skip_all)]
147 async fn update_ttl(
148 &self,
149 id: &SessionId,
150 ttl: std::time::Duration,
151 ) -> Result<(), UpdateTtlError> {
152 let mut guard = self.0.lock().await;
153 let old_record = Self::get_mut_if_fresh(&mut guard, id)?;
154 old_record.deadline = Timestamp::now() + ttl;
155 Ok(())
156 }
157
158 #[tracing::instrument(name = "Load server-side session record", level = tracing::Level::TRACE, skip_all)]
164 async fn load(&self, session_id: &SessionId) -> Result<Option<SessionRecord>, LoadError> {
165 let mut guard = self.0.lock().await;
166 let outcome = match Self::get_mut_if_fresh(&mut guard, session_id) {
167 Ok(old_record) => Some(SessionRecord {
168 state: old_record.state.clone(),
169 ttl: (old_record.deadline - Timestamp::now())
170 .try_into()
171 .unwrap_or(Duration::from_millis(0)),
172 }),
173 Err(_) => None,
174 };
175 Ok(outcome)
176 }
177
178 #[tracing::instrument(name = "Delete server-side session record", level = tracing::Level::TRACE, skip_all)]
182 async fn delete(&self, id: &SessionId) -> Result<(), DeleteError> {
183 let mut guard = self.0.lock().await;
184 Self::_delete(&mut guard, id)?;
185 Ok(())
186 }
187
188 #[tracing::instrument(name = "Change id for server-side session record", level = tracing::Level::TRACE, skip_all)]
192 async fn change_id(&self, old_id: &SessionId, new_id: &SessionId) -> Result<(), ChangeIdError> {
193 let mut guard = self.0.lock().await;
194 if Self::get_mut_if_fresh(&mut guard, new_id).is_ok() {
195 return Err(DuplicateIdError {
196 id: new_id.to_owned(),
197 }
198 .into());
199 }
200 let record = Self::_delete(&mut guard, old_id)?;
201 guard.insert(*new_id, record);
202 Ok(())
203 }
204
205 #[tracing::instrument(name = "Delete expired records", level = tracing::Level::TRACE, skip_all)]
207 async fn delete_expired(
208 &self,
209 batch_size: Option<NonZeroUsize>,
210 ) -> Result<usize, DeleteExpiredError> {
211 let mut guard = self.0.lock().await;
212 let now = Timestamp::now();
213 let mut stale_ids = Vec::new();
214 for (id, record) in guard.iter() {
215 if record.deadline <= now {
216 stale_ids.push(*id);
217 if let Some(batch_size) = batch_size
218 && stale_ids.len() >= batch_size.get()
219 {
220 break;
221 }
222 }
223 }
224 let num_deleted = stale_ids.len();
225 for id in stale_ids {
226 guard.remove(&id);
227 }
228 Ok(num_deleted)
229 }
230}