pavex_session_sqlx/
sqlite.rs

1//! Types related to [`SqliteSessionStore`].
2
3use pavex::methods;
4use pavex::time::Timestamp;
5use pavex_session::SessionStore;
6use pavex_session::{
7    SessionId,
8    store::{
9        SessionRecord, SessionRecordRef, SessionStorageBackend,
10        errors::{
11            ChangeIdError, CreateError, DeleteError, DeleteExpiredError, DuplicateIdError,
12            LoadError, UnknownIdError, UpdateError, UpdateTtlError,
13        },
14    },
15};
16use sqlx::{
17    SqlitePool,
18    error::DatabaseError,
19    sqlite::{SqliteError, SqliteQueryResult},
20};
21use std::num::NonZeroUsize;
22
23#[derive(Debug, Clone)]
24/// A server-side session store using SQLite as its backend.
25///
26/// # Implementation details
27///
28/// This store uses `sqlx` to interact with SQLite.
29/// All session records are stored in a single table with JSONB for efficient
30/// binary JSON storage (requires SQLite 3.45.0+). You can use
31/// [`migrate`](Self::migrate) to create the table and index
32/// required by the store in the database.
33/// Alternatively, you can use [`migration_query`](Self::migration_query)
34/// to get the SQL query that creates the table and index in order to run it yourself
35/// (e.g. as part of your database migration scripts).
36///
37/// # JSONB Support
38///
39/// This implementation uses SQLite's JSONB format for storing session state,
40/// which provides better performance (5-10% smaller size, ~50% faster processing)
41/// compared to plain text JSON. JSONB is supported in SQLite 3.45.0 and later.
42pub struct SqliteSessionStore(sqlx::SqlitePool);
43
44#[methods]
45impl From<SqliteSessionStore> for SessionStore {
46    #[singleton]
47    fn from(value: SqliteSessionStore) -> Self {
48        SessionStore::new(value)
49    }
50}
51
52#[methods]
53impl SqliteSessionStore {
54    /// Creates a new SQLite session store instance.
55    ///
56    /// It requires a pool of SQLite connections to interact with the database
57    /// where the session records are stored.
58    #[singleton]
59    pub fn new(pool: SqlitePool) -> Self {
60        Self(pool)
61    }
62
63    /// Return the query used to create the sessions table and index.
64    ///
65    /// # Implementation details
66    ///
67    /// The query is designed to be idempotent, meaning it can be run multiple times
68    /// without causing any issues. If the table and index already exist, the query
69    /// does nothing.
70    ///
71    /// # Alternatives
72    ///
73    /// You can use this method to add the query to your database migration scripts.
74    /// Alternatively, you can use [`migrate`](Self::migrate)
75    /// to run the query directly on the database.
76    pub fn migration_query() -> &'static str {
77        "-- Create the sessions table if it doesn't exist
78CREATE TABLE IF NOT EXISTS sessions (
79    id TEXT PRIMARY KEY,
80    deadline INTEGER NOT NULL,
81    state JSONB NOT NULL
82);
83
84-- Create the index on the deadline column if it doesn't exist
85CREATE INDEX IF NOT EXISTS idx_sessions_deadline ON sessions(deadline);"
86    }
87
88    /// Create the sessions table and index in the database.
89    ///
90    /// This method is idempotent, meaning it can be called multiple times without
91    /// causing any issues. If the table and index already exist, this method does nothing.
92    ///
93    /// If you prefer to run the query yourself, rely on [`migration_query`](Self::migration_query)
94    /// to get the SQL that's being executed.
95    pub async fn migrate(&self) -> Result<(), sqlx::Error> {
96        use sqlx::Executor as _;
97
98        self.0.execute(Self::migration_query()).await?;
99        Ok(())
100    }
101}
102
103#[async_trait::async_trait]
104impl SessionStorageBackend for SqliteSessionStore {
105    /// Creates a new session record in the store using the provided ID.
106    #[tracing::instrument(name = "Create server-side session record", level = tracing::Level::INFO, skip_all)]
107    async fn create(
108        &self,
109        id: &SessionId,
110        record: SessionRecordRef<'_>,
111    ) -> Result<(), CreateError> {
112        let deadline = Timestamp::now() + record.ttl;
113        let deadline_unix = deadline.as_second();
114        let state = serde_json::to_value(record.state)?;
115        let query = sqlx::query(
116            "INSERT INTO sessions (id, deadline, state) \
117            VALUES (?, ?, ?) \
118            ON CONFLICT(id) DO UPDATE \
119            SET deadline = excluded.deadline, state = excluded.state \
120            WHERE sessions.deadline < unixepoch()",
121        )
122        .bind(id.inner().to_string())
123        .bind(deadline_unix)
124        .bind(state);
125
126        match query.execute(&self.0).await {
127            // All good, we created the session record.
128            Ok(_) => Ok(()),
129            Err(e) => {
130                // Return the specialized error variant if the ID is already in use
131                if let Err(e) = as_duplicated_id_error(&e, id) {
132                    Err(e.into())
133                } else {
134                    Err(CreateError::Other(e.into()))
135                }
136            }
137        }
138    }
139
140    /// Update the state of an existing session in the store.
141    ///
142    /// It overwrites the existing record with the provided one.
143    #[tracing::instrument(name = "Update server-side session record", level = tracing::Level::INFO, skip_all)]
144    async fn update(
145        &self,
146        id: &SessionId,
147        record: SessionRecordRef<'_>,
148    ) -> Result<(), UpdateError> {
149        let new_deadline = Timestamp::now() + record.ttl;
150        let new_deadline_unix = new_deadline.as_second();
151        let new_state = serde_json::to_value(record.state)?;
152        let query = sqlx::query(
153            "UPDATE sessions \
154            SET deadline = ?, state = ? \
155            WHERE id = ? AND deadline > unixepoch()",
156        )
157        .bind(new_deadline_unix)
158        .bind(new_state)
159        .bind(id.inner().to_string());
160
161        match query.execute(&self.0).await {
162            Ok(r) => as_unknown_id_error(&r, id).map_err(Into::into),
163            Err(e) => Err(UpdateError::Other(e.into())),
164        }
165    }
166
167    /// Update the TTL of an existing session record in the store.
168    ///
169    /// It leaves the session state unchanged.
170    #[tracing::instrument(name = "Update TTL for server-side session record", level = tracing::Level::INFO, skip_all)]
171    async fn update_ttl(
172        &self,
173        id: &SessionId,
174        ttl: std::time::Duration,
175    ) -> Result<(), UpdateTtlError> {
176        let new_deadline = Timestamp::now() + ttl;
177        let new_deadline_unix = new_deadline.as_second();
178        let query = sqlx::query(
179            "UPDATE sessions \
180            SET deadline = ? \
181            WHERE id = ? AND deadline > unixepoch()",
182        )
183        .bind(new_deadline_unix)
184        .bind(id.inner().to_string());
185        match query.execute(&self.0).await {
186            Ok(r) => as_unknown_id_error(&r, id).map_err(Into::into),
187            Err(e) => Err(UpdateTtlError::Other(e.into())),
188        }
189    }
190
191    /// Loads an existing session record from the store using the provided ID.
192    ///
193    /// If a session with the given ID exists, it is returned. If the session
194    /// does not exist or has been invalidated (e.g., expired), `None` is
195    /// returned.
196    #[tracing::instrument(name = "Load server-side session record", level = tracing::Level::INFO, skip_all)]
197    async fn load(&self, session_id: &SessionId) -> Result<Option<SessionRecord>, LoadError> {
198        let row = sqlx::query(
199            "SELECT deadline, state \
200            FROM sessions \
201            WHERE id = ? AND deadline > unixepoch()",
202        )
203        .bind(session_id.inner().to_string())
204        .fetch_optional(&self.0)
205        .await
206        .map_err(|e| LoadError::Other(e.into()))?;
207        row.map(|r| {
208            use anyhow::Context as _;
209            use sqlx::Row as _;
210
211            let deadline_unix: i64 = r
212                .try_get(0)
213                .context("Failed to deserialize the retrieved session deadline")
214                .map_err(LoadError::DeserializationError)?;
215            let deadline = Timestamp::from_second(deadline_unix)
216                .context("Failed to parse the retrieved session deadline")
217                .map_err(LoadError::DeserializationError)?;
218            let state: serde_json::Value = r
219                .try_get(1)
220                .context("Failed to deserialize the retrieved session state")
221                .map_err(LoadError::DeserializationError)?;
222            let ttl = deadline - Timestamp::now();
223            Ok(SessionRecord {
224                // This conversion only fails if the duration is negative, which should not happen
225                ttl: ttl.try_into().unwrap_or(std::time::Duration::ZERO),
226                state: serde_json::from_value(state)
227                    .context("Failed to deserialize the retrieved session state")
228                    .map_err(LoadError::DeserializationError)?,
229            })
230        })
231        .transpose()
232    }
233
234    /// Deletes a session record from the store using the provided ID.
235    ///
236    /// If the session exists, it is removed from the store.
237    #[tracing::instrument(name = "Delete server-side session record", level = tracing::Level::INFO, skip_all)]
238    async fn delete(&self, id: &SessionId) -> Result<(), DeleteError> {
239        let query = sqlx::query(
240            "DELETE FROM sessions \
241            WHERE id = ? AND deadline > unixepoch()",
242        )
243        .bind(id.inner().to_string());
244        match query.execute(&self.0).await {
245            Ok(r) => as_unknown_id_error(&r, id).map_err(Into::into),
246            Err(e) => Err(DeleteError::Other(e.into())),
247        }
248    }
249
250    /// Change the session id associated with an existing session record.
251    ///
252    /// The server-side state is left unchanged.
253    #[tracing::instrument(name = "Change id for server-side session record", level = tracing::Level::INFO, skip_all)]
254    async fn change_id(&self, old_id: &SessionId, new_id: &SessionId) -> Result<(), ChangeIdError> {
255        let query = sqlx::query(
256            "UPDATE sessions \
257            SET id = ? \
258            WHERE id = ? AND deadline > unixepoch()",
259        )
260        .bind(new_id.inner().to_string())
261        .bind(old_id.inner().to_string());
262        match query.execute(&self.0).await {
263            Ok(r) => as_unknown_id_error(&r, old_id).map_err(Into::into),
264            Err(e) => {
265                if let Err(e) = as_duplicated_id_error(&e, new_id) {
266                    Err(e.into())
267                } else {
268                    Err(ChangeIdError::Other(e.into()))
269                }
270            }
271        }
272    }
273
274    /// Delete expired sessions from the database.
275    ///
276    /// If `batch_size` is provided, the query will delete at most `batch_size` expired sessions.
277    /// In either case, if successful, the method returns the number of expired sessions that
278    /// have been deleted.
279    ///
280    /// # When should you delete in batches?
281    ///
282    /// If there are a lot of expired sessions in the database, deleting them all at once can
283    /// cause performance issues. By deleting in batches, you can limit the number of sessions
284    /// deleted in a single query, reducing the impact.
285    ///
286    /// # Example
287    ///
288    /// Delete expired sessions in batches of 1000:
289    ///
290    /// ```no_run
291    /// use pavex_session::SessionStore;
292    /// use pavex_session_sqlx::SqliteSessionStore;
293    /// use pavex_tracing::fields::{
294    ///     error_details,
295    ///     error_message,
296    ///     ERROR_DETAILS,
297    ///     ERROR_MESSAGE
298    /// };
299    /// use std::time::Duration;
300    ///
301    /// # async fn delete_expired_sessions(pool: sqlx::SqlitePool) {
302    /// let backend = SqliteSessionStore::new(pool);
303    /// let store = SessionStore::new(backend);
304    /// let batch_size = Some(1000.try_into().unwrap());
305    /// let batch_sleep = Duration::from_secs(60);
306    /// loop {
307    ///     if let Err(e) = store.delete_expired(batch_size).await {
308    ///         tracing::event!(
309    ///             tracing::Level::ERROR,
310    ///             { ERROR_MESSAGE } = error_message(&e),
311    ///             { ERROR_DETAILS } = error_details(&e),
312    ///             "Failed to delete a batch of expired sessions",
313    ///         );
314    ///     }
315    ///     tokio::time::sleep(batch_sleep).await;
316    /// }
317    /// # }
318    async fn delete_expired(
319        &self,
320        batch_size: Option<NonZeroUsize>,
321    ) -> Result<usize, DeleteExpiredError> {
322        let query = if let Some(batch_size) = batch_size {
323            let batch_size: i64 = batch_size.get().try_into().unwrap_or(i64::MAX);
324            sqlx::query("DELETE FROM sessions WHERE id IN (SELECT id FROM sessions WHERE deadline < unixepoch() LIMIT ?)")
325                .bind(batch_size)
326        } else {
327            sqlx::query("DELETE FROM sessions WHERE deadline < unixepoch()")
328        };
329        let r = query.execute(&self.0).await.map_err(|e| {
330            let e: anyhow::Error = e.into();
331            e
332        })?;
333        Ok(r.rows_affected().try_into().unwrap_or(usize::MAX))
334    }
335}
336
337fn as_duplicated_id_error(e: &sqlx::Error, id: &SessionId) -> Result<(), DuplicateIdError> {
338    if let Some(e) = e.as_database_error() {
339        if let Some(e) = e.try_downcast_ref::<SqliteError>() {
340            // Check if the error is due to a duplicate ID
341            // SQLite constraint violation error code is "1555" (SQLITE_CONSTRAINT_PRIMARYKEY)
342            if e.code() == Some("1555".into()) {
343                return Err(DuplicateIdError { id: id.to_owned() });
344            }
345        }
346    }
347    Ok(())
348}
349
350fn as_unknown_id_error(r: &SqliteQueryResult, id: &SessionId) -> Result<(), UnknownIdError> {
351    // Check if the session record was changed
352    if r.rows_affected() == 0 {
353        return Err(UnknownIdError { id: id.to_owned() });
354    }
355    // Sanity check
356    assert_eq!(
357        r.rows_affected(),
358        1,
359        "More than one session record was affected, even though the session ID is used as primary key. Something is deeply wrong here!"
360    );
361    Ok(())
362}