pavex_session_sqlx/
mysql.rs

1//! Types related to [`MySqlSessionStore`].
2use pavex::methods;
3use pavex::time::Timestamp;
4use pavex_session::SessionStore;
5use pavex_session::{
6    SessionId,
7    store::{
8        SessionRecord, SessionRecordRef, SessionStorageBackend,
9        errors::{
10            ChangeIdError, CreateError, DeleteError, DeleteExpiredError, DuplicateIdError,
11            LoadError, UnknownIdError, UpdateError, UpdateTtlError,
12        },
13    },
14};
15use sqlx::{
16    MySqlPool,
17    mysql::{MySqlDatabaseError, MySqlQueryResult},
18};
19use std::num::NonZeroUsize;
20
21#[derive(Debug, Clone)]
22/// A server-side session store using MySQL as its backend.
23///
24/// # Implementation details
25///
26/// This store uses `sqlx` to interact with MySQL.
27/// All session records are stored in a single table. You can use
28/// [`migrate`](Self::migrate) to create the table and index
29/// required by the store in the database.
30/// Alternatively, you can use [`migration_query`](Self::migration_query)
31/// to get the SQL query that creates the table and index in order to run it yourself
32/// (e.g. as part of your database migration scripts).
33///
34/// # MySQL version requirements
35///
36/// This implementation requires MySQL 5.7.8+ or MariaDB 10.2+ for JSON support.
37/// For optimal performance with JSON operations, MySQL 8.0+ is recommended.
38pub struct MySqlSessionStore(sqlx::MySqlPool);
39
40#[methods]
41impl From<MySqlSessionStore> for SessionStore {
42    #[singleton]
43    fn from(value: MySqlSessionStore) -> Self {
44        SessionStore::new(value)
45    }
46}
47
48#[methods]
49impl MySqlSessionStore {
50    /// Creates a new MySQL session store instance.
51    ///
52    /// It requires a pool of MySQL connections to interact with the database
53    /// where the session records are stored.
54    #[singleton]
55    pub fn new(pool: MySqlPool) -> Self {
56        Self(pool)
57    }
58
59    /// Return the query used to create the sessions table and index.
60    ///
61    /// # Implementation details
62    ///
63    /// The query is designed to be idempotent, meaning it can be run multiple times
64    /// without causing any issues. If the table and index already exist, the query
65    /// does nothing.
66    ///
67    /// # MySQL version requirements
68    ///
69    /// This query requires MySQL 5.7.8+ or MariaDB 10.2+ for JSON column support.
70    ///
71    /// # Alternatives
72    ///
73    /// You can use this method to add the query to your database migration scripts.
74    /// Alternatively, you can use [`migrate`](Self::migrate)
75    /// to run the query directly on the database.
76    pub fn migration_query() -> &'static str {
77        "-- Create the sessions table if it doesn't exist
78CREATE TABLE IF NOT EXISTS sessions (
79    id CHAR(36) PRIMARY KEY,
80    deadline BIGINT NOT NULL,
81    state JSON NOT NULL,
82    INDEX idx_sessions_deadline (deadline)
83);"
84    }
85
86    /// Create the sessions table and index in the database.
87    ///
88    /// This method is idempotent, meaning it can be called multiple times without
89    /// causing any issues. If the table and index already exist, this method does nothing.
90    ///
91    /// If you prefer to run the query yourself, rely on [`migration_query`](Self::migration_query)
92    /// to get the SQL that's being executed.
93    pub async fn migrate(&self) -> Result<(), sqlx::Error> {
94        use sqlx::Executor as _;
95
96        self.0.execute(Self::migration_query()).await?;
97        Ok(())
98    }
99}
100
101#[async_trait::async_trait]
102impl SessionStorageBackend for MySqlSessionStore {
103    /// Creates a new session record in the store using the provided ID.
104    /// When a conflicting session id is present, we perform a simple upsert.
105    /// This is a deliberate decision given we can't return an error and keep atomicity
106    /// at the same time.
107    ///
108    /// Even when using a guard clause, which we'd expect to amount to a noop:
109    ///
110    /// ON DUPLICATE KEY UPDATE
111    ///    deadline = IF(sessions.deadline < UNIX_TIMESTAMP(), VALUES(deadline), sessions.deadline),
112    ///    state = IF(sessions.deadline < UNIX_TIMESTAMP(), VALUES(state), sessions.state)
113    ///
114    /// affected_rows() is still non-zero. This seems to be a kink in MySQL.
115    #[tracing::instrument(name = "Create server-side session record", level = tracing::Level::INFO, skip_all)]
116    async fn create(
117        &self,
118        id: &SessionId,
119        record: SessionRecordRef<'_>,
120    ) -> Result<(), CreateError> {
121        let deadline = Timestamp::now() + record.ttl;
122        let deadline_unix = deadline.as_second();
123        let state = serde_json::to_value(record.state)?;
124        let query = sqlx::query(
125            "INSERT INTO sessions (id, deadline, state) \
126            VALUES (?, ?, ?) \
127            ON DUPLICATE KEY UPDATE \
128            deadline = VALUES(deadline), state = VALUES(state)",
129        )
130        .bind(id.inner().to_string())
131        .bind(deadline_unix)
132        .bind(state);
133
134        match query.execute(&self.0).await {
135            // All good, we created the session record.
136            Ok(_) => Ok(()),
137            Err(e) => Err(CreateError::Other(e.into())),
138        }
139    }
140
141    /// Update the state of an existing session in the store.
142    ///
143    /// It overwrites the existing record with the provided one.
144    #[tracing::instrument(name = "Update server-side session record", level = tracing::Level::INFO, skip_all)]
145    async fn update(
146        &self,
147        id: &SessionId,
148        record: SessionRecordRef<'_>,
149    ) -> Result<(), UpdateError> {
150        let new_deadline = Timestamp::now() + record.ttl;
151        let new_deadline_unix = new_deadline.as_second();
152        let new_state = serde_json::to_value(record.state)?;
153        let query = sqlx::query(
154            "UPDATE sessions \
155            SET deadline = ?, state = ? \
156            WHERE id = ? AND deadline > UNIX_TIMESTAMP()",
157        )
158        .bind(new_deadline_unix)
159        .bind(new_state)
160        .bind(id.inner().to_string());
161
162        match query.execute(&self.0).await {
163            Ok(r) => as_unknown_id_error(&r, id).map_err(Into::into),
164            Err(e) => Err(UpdateError::Other(e.into())),
165        }
166    }
167
168    /// Update the TTL of an existing session record in the store.
169    ///
170    /// It leaves the session state unchanged.
171    #[tracing::instrument(name = "Update TTL for server-side session record", level = tracing::Level::INFO, skip_all)]
172    async fn update_ttl(
173        &self,
174        id: &SessionId,
175        ttl: std::time::Duration,
176    ) -> Result<(), UpdateTtlError> {
177        let new_deadline = Timestamp::now() + ttl;
178        let new_deadline_unix = new_deadline.as_second();
179        let query = sqlx::query(
180            "UPDATE sessions \
181            SET deadline = ? \
182            WHERE id = ? AND deadline > UNIX_TIMESTAMP()",
183        )
184        .bind(new_deadline_unix)
185        .bind(id.inner().to_string());
186        match query.execute(&self.0).await {
187            Ok(r) => as_unknown_id_error(&r, id).map_err(Into::into),
188            Err(e) => Err(UpdateTtlError::Other(e.into())),
189        }
190    }
191
192    /// Loads an existing session record from the store using the provided ID.
193    ///
194    /// If a session with the given ID exists, it is returned. If the session
195    /// does not exist or has been invalidated (e.g., expired), `None` is
196    /// returned.
197    #[tracing::instrument(name = "Load server-side session record", level = tracing::Level::INFO, skip_all)]
198    async fn load(&self, session_id: &SessionId) -> Result<Option<SessionRecord>, LoadError> {
199        let row = sqlx::query(
200            "SELECT deadline, state \
201            FROM sessions \
202            WHERE id = ? AND deadline > UNIX_TIMESTAMP()",
203        )
204        .bind(session_id.inner().to_string())
205        .fetch_optional(&self.0)
206        .await
207        .map_err(|e| LoadError::Other(e.into()))?;
208        row.map(|r| {
209            use anyhow::Context as _;
210            use sqlx::Row as _;
211
212            let deadline_unix: i64 = r
213                .try_get(0)
214                .context("Failed to deserialize the retrieved session deadline")
215                .map_err(LoadError::DeserializationError)?;
216            let deadline = Timestamp::from_second(deadline_unix)
217                .context("Failed to parse the retrieved session deadline")
218                .map_err(LoadError::DeserializationError)?;
219            let state: serde_json::Value = r
220                .try_get(1)
221                .context("Failed to deserialize the retrieved session state")
222                .map_err(LoadError::DeserializationError)?;
223            let ttl = deadline - Timestamp::now();
224            Ok(SessionRecord {
225                // This conversion only fails if the duration is negative, which should not happen
226                ttl: ttl.try_into().unwrap_or(std::time::Duration::ZERO),
227                state: serde_json::from_value(state)
228                    .context("Failed to deserialize the retrieved session state")
229                    .map_err(LoadError::DeserializationError)?,
230            })
231        })
232        .transpose()
233    }
234
235    /// Deletes a session record from the store using the provided ID.
236    ///
237    /// If the session exists, it is removed from the store.
238    #[tracing::instrument(name = "Delete server-side session record", level = tracing::Level::INFO, skip_all)]
239    async fn delete(&self, id: &SessionId) -> Result<(), DeleteError> {
240        let query = sqlx::query(
241            "DELETE FROM sessions \
242            WHERE id = ? AND deadline > UNIX_TIMESTAMP()",
243        )
244        .bind(id.inner().to_string());
245        match query.execute(&self.0).await {
246            Ok(r) => as_unknown_id_error(&r, id).map_err(Into::into),
247            Err(e) => Err(DeleteError::Other(e.into())),
248        }
249    }
250
251    /// Change the session id associated with an existing session record.
252    ///
253    /// The server-side state is left unchanged.
254    #[tracing::instrument(name = "Change id for server-side session record", level = tracing::Level::INFO, skip_all)]
255    async fn change_id(&self, old_id: &SessionId, new_id: &SessionId) -> Result<(), ChangeIdError> {
256        let query = sqlx::query(
257            "UPDATE sessions \
258            SET id = ? \
259            WHERE id = ? AND deadline > UNIX_TIMESTAMP()",
260        )
261        .bind(new_id.inner().to_string())
262        .bind(old_id.inner().to_string());
263        match query.execute(&self.0).await {
264            Ok(r) => as_unknown_id_error(&r, old_id).map_err(Into::into),
265            Err(e) => {
266                if let Err(e) = as_duplicated_id_error(&e, new_id) {
267                    Err(e.into())
268                } else {
269                    Err(ChangeIdError::Other(e.into()))
270                }
271            }
272        }
273    }
274
275    /// Delete expired sessions from the database.
276    ///
277    /// If `batch_size` is provided, the query will delete at most `batch_size` expired sessions.
278    /// In either case, if successful, the method returns the number of expired sessions that
279    /// have been deleted.
280    ///
281    /// # When should you delete in batches?
282    ///
283    /// If there are a lot of expired sessions in the database, deleting them all at once can
284    /// cause performance issues. By deleting in batches, you can limit the number of sessions
285    /// deleted in a single query, reducing the impact.
286    ///
287    /// # Example
288    ///
289    /// Delete expired sessions in batches of 1000:
290    ///
291    /// ```no_run
292    /// use pavex_session::SessionStore;
293    /// use pavex_session_sqlx::MySqlSessionStore;
294    /// use pavex_tracing::fields::{
295    ///     error_details,
296    ///     error_message,
297    ///     ERROR_DETAILS,
298    ///     ERROR_MESSAGE
299    /// };
300    /// use std::time::Duration;
301    ///
302    /// # async fn delete_expired_sessions(pool: sqlx::MySqlPool) {
303    /// let backend = MySqlSessionStore::new(pool);
304    /// let store = SessionStore::new(backend);
305    /// let batch_size = Some(1000.try_into().unwrap());
306    /// let batch_sleep = Duration::from_secs(60);
307    /// loop {
308    ///     if let Err(e) = store.delete_expired(batch_size).await {
309    ///         tracing::event!(
310    ///             tracing::Level::ERROR,
311    ///             { ERROR_MESSAGE } = error_message(&e),
312    ///             { ERROR_DETAILS } = error_details(&e),
313    ///             "Failed to delete a batch of expired sessions",
314    ///         );
315    ///     }
316    ///     tokio::time::sleep(batch_sleep).await;
317    /// }
318    /// # }
319    /// ```
320    async fn delete_expired(
321        &self,
322        batch_size: Option<NonZeroUsize>,
323    ) -> Result<usize, DeleteExpiredError> {
324        let query = if let Some(batch_size) = batch_size {
325            let batch_size: u64 = batch_size.get().try_into().unwrap_or(u64::MAX);
326            sqlx::query("DELETE FROM sessions WHERE deadline < UNIX_TIMESTAMP() LIMIT ?")
327                .bind(batch_size)
328        } else {
329            sqlx::query("DELETE FROM sessions WHERE deadline < UNIX_TIMESTAMP()")
330        };
331        let r = query.execute(&self.0).await.map_err(|e| {
332            let e: anyhow::Error = e.into();
333            e
334        })?;
335        Ok(r.rows_affected().try_into().unwrap_or(usize::MAX))
336    }
337}
338
339fn as_duplicated_id_error(e: &sqlx::Error, id: &SessionId) -> Result<(), DuplicateIdError> {
340    if let Some(e) = e.as_database_error() {
341        if let Some(e) = e.try_downcast_ref::<MySqlDatabaseError>() {
342            // Check if the error is due to a duplicate ID
343            // MySQL error code 1062 is for duplicate entry
344            if e.number() == 1062 {
345                return Err(DuplicateIdError { id: id.to_owned() });
346            }
347        }
348    }
349    Ok(())
350}
351
352fn as_unknown_id_error(r: &MySqlQueryResult, id: &SessionId) -> Result<(), UnknownIdError> {
353    // Check if the session record was changed
354    if r.rows_affected() == 0 {
355        return Err(UnknownIdError { id: id.to_owned() });
356    }
357    // Sanity check
358    assert_eq!(
359        r.rows_affected(),
360        1,
361        "More than one session record was affected, even though the session ID is used as primary key. Something is deeply wrong here!"
362    );
363    Ok(())
364}