1use pavex::methods;
3use pavex::time::Timestamp;
4use pavex_session::SessionStore;
5use pavex_session::{
6 SessionId,
7 store::{
8 SessionRecord, SessionRecordRef, SessionStorageBackend,
9 errors::{
10 ChangeIdError, CreateError, DeleteError, DeleteExpiredError, DuplicateIdError,
11 LoadError, UnknownIdError, UpdateError, UpdateTtlError,
12 },
13 },
14};
15use sqlx::{
16 MySqlPool,
17 mysql::{MySqlDatabaseError, MySqlQueryResult},
18};
19use std::num::NonZeroUsize;
20
21#[derive(Debug, Clone)]
22pub struct MySqlSessionStore(sqlx::MySqlPool);
39
40#[methods]
41impl From<MySqlSessionStore> for SessionStore {
42 #[singleton]
43 fn from(value: MySqlSessionStore) -> Self {
44 SessionStore::new(value)
45 }
46}
47
48#[methods]
49impl MySqlSessionStore {
50 #[singleton]
55 pub fn new(pool: MySqlPool) -> Self {
56 Self(pool)
57 }
58
59 pub fn migration_query() -> &'static str {
77 "-- Create the sessions table if it doesn't exist
78CREATE TABLE IF NOT EXISTS sessions (
79 id CHAR(36) PRIMARY KEY,
80 deadline BIGINT NOT NULL,
81 state JSON NOT NULL,
82 INDEX idx_sessions_deadline (deadline)
83);"
84 }
85
86 pub async fn migrate(&self) -> Result<(), sqlx::Error> {
94 use sqlx::Executor as _;
95
96 self.0.execute(Self::migration_query()).await?;
97 Ok(())
98 }
99}
100
101#[async_trait::async_trait]
102impl SessionStorageBackend for MySqlSessionStore {
103 #[tracing::instrument(name = "Create server-side session record", level = tracing::Level::INFO, skip_all)]
116 async fn create(
117 &self,
118 id: &SessionId,
119 record: SessionRecordRef<'_>,
120 ) -> Result<(), CreateError> {
121 let deadline = Timestamp::now() + record.ttl;
122 let deadline_unix = deadline.as_second();
123 let state = serde_json::to_value(record.state)?;
124 let query = sqlx::query(
125 "INSERT INTO sessions (id, deadline, state) \
126 VALUES (?, ?, ?) \
127 ON DUPLICATE KEY UPDATE \
128 deadline = VALUES(deadline), state = VALUES(state)",
129 )
130 .bind(id.inner().to_string())
131 .bind(deadline_unix)
132 .bind(state);
133
134 match query.execute(&self.0).await {
135 Ok(_) => Ok(()),
137 Err(e) => Err(CreateError::Other(e.into())),
138 }
139 }
140
141 #[tracing::instrument(name = "Update server-side session record", level = tracing::Level::INFO, skip_all)]
145 async fn update(
146 &self,
147 id: &SessionId,
148 record: SessionRecordRef<'_>,
149 ) -> Result<(), UpdateError> {
150 let new_deadline = Timestamp::now() + record.ttl;
151 let new_deadline_unix = new_deadline.as_second();
152 let new_state = serde_json::to_value(record.state)?;
153 let query = sqlx::query(
154 "UPDATE sessions \
155 SET deadline = ?, state = ? \
156 WHERE id = ? AND deadline > UNIX_TIMESTAMP()",
157 )
158 .bind(new_deadline_unix)
159 .bind(new_state)
160 .bind(id.inner().to_string());
161
162 match query.execute(&self.0).await {
163 Ok(r) => as_unknown_id_error(&r, id).map_err(Into::into),
164 Err(e) => Err(UpdateError::Other(e.into())),
165 }
166 }
167
168 #[tracing::instrument(name = "Update TTL for server-side session record", level = tracing::Level::INFO, skip_all)]
172 async fn update_ttl(
173 &self,
174 id: &SessionId,
175 ttl: std::time::Duration,
176 ) -> Result<(), UpdateTtlError> {
177 let new_deadline = Timestamp::now() + ttl;
178 let new_deadline_unix = new_deadline.as_second();
179 let query = sqlx::query(
180 "UPDATE sessions \
181 SET deadline = ? \
182 WHERE id = ? AND deadline > UNIX_TIMESTAMP()",
183 )
184 .bind(new_deadline_unix)
185 .bind(id.inner().to_string());
186 match query.execute(&self.0).await {
187 Ok(r) => as_unknown_id_error(&r, id).map_err(Into::into),
188 Err(e) => Err(UpdateTtlError::Other(e.into())),
189 }
190 }
191
192 #[tracing::instrument(name = "Load server-side session record", level = tracing::Level::INFO, skip_all)]
198 async fn load(&self, session_id: &SessionId) -> Result<Option<SessionRecord>, LoadError> {
199 let row = sqlx::query(
200 "SELECT deadline, state \
201 FROM sessions \
202 WHERE id = ? AND deadline > UNIX_TIMESTAMP()",
203 )
204 .bind(session_id.inner().to_string())
205 .fetch_optional(&self.0)
206 .await
207 .map_err(|e| LoadError::Other(e.into()))?;
208 row.map(|r| {
209 use anyhow::Context as _;
210 use sqlx::Row as _;
211
212 let deadline_unix: i64 = r
213 .try_get(0)
214 .context("Failed to deserialize the retrieved session deadline")
215 .map_err(LoadError::DeserializationError)?;
216 let deadline = Timestamp::from_second(deadline_unix)
217 .context("Failed to parse the retrieved session deadline")
218 .map_err(LoadError::DeserializationError)?;
219 let state: serde_json::Value = r
220 .try_get(1)
221 .context("Failed to deserialize the retrieved session state")
222 .map_err(LoadError::DeserializationError)?;
223 let ttl = deadline - Timestamp::now();
224 Ok(SessionRecord {
225 ttl: ttl.try_into().unwrap_or(std::time::Duration::ZERO),
227 state: serde_json::from_value(state)
228 .context("Failed to deserialize the retrieved session state")
229 .map_err(LoadError::DeserializationError)?,
230 })
231 })
232 .transpose()
233 }
234
235 #[tracing::instrument(name = "Delete server-side session record", level = tracing::Level::INFO, skip_all)]
239 async fn delete(&self, id: &SessionId) -> Result<(), DeleteError> {
240 let query = sqlx::query(
241 "DELETE FROM sessions \
242 WHERE id = ? AND deadline > UNIX_TIMESTAMP()",
243 )
244 .bind(id.inner().to_string());
245 match query.execute(&self.0).await {
246 Ok(r) => as_unknown_id_error(&r, id).map_err(Into::into),
247 Err(e) => Err(DeleteError::Other(e.into())),
248 }
249 }
250
251 #[tracing::instrument(name = "Change id for server-side session record", level = tracing::Level::INFO, skip_all)]
255 async fn change_id(&self, old_id: &SessionId, new_id: &SessionId) -> Result<(), ChangeIdError> {
256 let query = sqlx::query(
257 "UPDATE sessions \
258 SET id = ? \
259 WHERE id = ? AND deadline > UNIX_TIMESTAMP()",
260 )
261 .bind(new_id.inner().to_string())
262 .bind(old_id.inner().to_string());
263 match query.execute(&self.0).await {
264 Ok(r) => as_unknown_id_error(&r, old_id).map_err(Into::into),
265 Err(e) => {
266 if let Err(e) = as_duplicated_id_error(&e, new_id) {
267 Err(e.into())
268 } else {
269 Err(ChangeIdError::Other(e.into()))
270 }
271 }
272 }
273 }
274
275 async fn delete_expired(
321 &self,
322 batch_size: Option<NonZeroUsize>,
323 ) -> Result<usize, DeleteExpiredError> {
324 let query = if let Some(batch_size) = batch_size {
325 let batch_size: u64 = batch_size.get().try_into().unwrap_or(u64::MAX);
326 sqlx::query("DELETE FROM sessions WHERE deadline < UNIX_TIMESTAMP() LIMIT ?")
327 .bind(batch_size)
328 } else {
329 sqlx::query("DELETE FROM sessions WHERE deadline < UNIX_TIMESTAMP()")
330 };
331 let r = query.execute(&self.0).await.map_err(|e| {
332 let e: anyhow::Error = e.into();
333 e
334 })?;
335 Ok(r.rows_affected().try_into().unwrap_or(usize::MAX))
336 }
337}
338
339fn as_duplicated_id_error(e: &sqlx::Error, id: &SessionId) -> Result<(), DuplicateIdError> {
340 if let Some(e) = e.as_database_error() {
341 if let Some(e) = e.try_downcast_ref::<MySqlDatabaseError>() {
342 if e.number() == 1062 {
345 return Err(DuplicateIdError { id: id.to_owned() });
346 }
347 }
348 }
349 Ok(())
350}
351
352fn as_unknown_id_error(r: &MySqlQueryResult, id: &SessionId) -> Result<(), UnknownIdError> {
353 if r.rows_affected() == 0 {
355 return Err(UnknownIdError { id: id.to_owned() });
356 }
357 assert_eq!(
359 r.rows_affected(),
360 1,
361 "More than one session record was affected, even though the session ID is used as primary key. Something is deeply wrong here!"
362 );
363 Ok(())
364}