pavex_session_sqlx/
postgres.rs

1//! Types related to [`PostgresSessionStore`].
2use jiff_sqlx::ToSqlx;
3use pavex::methods;
4use pavex::time::Timestamp;
5use pavex_session::SessionStore;
6use pavex_session::{
7    SessionId,
8    store::{
9        SessionRecord, SessionRecordRef, SessionStorageBackend,
10        errors::{
11            ChangeIdError, CreateError, DeleteError, DeleteExpiredError, DuplicateIdError,
12            LoadError, UnknownIdError, UpdateError, UpdateTtlError,
13        },
14    },
15};
16use sqlx::{
17    PgPool,
18    postgres::{PgDatabaseError, PgQueryResult},
19};
20use std::num::NonZeroUsize;
21
22#[derive(Debug, Clone)]
23/// A server-side session store using Postgres as its backend.
24///
25/// # Implementation details
26///
27/// This store uses `sqlx` to interact with Postgres.
28/// All session records are stored in a single table. You can use
29/// [`migrate`](Self::migrate) to create the table and index
30/// required by the store in the database.
31/// Alternatively, you can use [`migration_query`](Self::migration_query)
32/// to get the SQL query that creates the table and index in order to run it yourself
33/// (e.g. as part of your database migration scripts).
34pub struct PostgresSessionStore(sqlx::PgPool);
35
36#[methods]
37impl From<PostgresSessionStore> for SessionStore {
38    #[singleton]
39    fn from(value: PostgresSessionStore) -> Self {
40        SessionStore::new(value)
41    }
42}
43
44#[methods]
45impl PostgresSessionStore {
46    /// Creates a new Postgres session store instance.
47    ///
48    /// It requires a pool of Postgres connections to interact with the database
49    /// where the session records are stored.
50    #[singleton]
51    pub fn new(pool: PgPool) -> Self {
52        Self(pool)
53    }
54
55    /// Return the query used to create the sessions table and index.
56    ///
57    /// # Implementation details
58    ///
59    /// The query is designed to be idempotent, meaning it can be run multiple times
60    /// without causing any issues. If the table and index already exist, the query
61    /// does nothing.
62    ///
63    /// # Alternatives
64    ///
65    /// You can use this method to add the query to your database migration scripts.
66    /// Alternatively, you can use [`migrate`](Self::migrate)
67    /// to run the query directly on the database.
68    pub fn migration_query() -> &'static str {
69        "-- Create the sessions table if it doesn’t exist
70CREATE TABLE IF NOT EXISTS sessions (
71    id UUID PRIMARY KEY,
72    deadline TIMESTAMPTZ NOT NULL,
73    state JSONB NOT NULL
74);
75
76-- Create the index on the deadline column if it doesn’t exist
77DO $$
78BEGIN
79    IF NOT EXISTS (
80        SELECT 1 FROM pg_indexes
81        WHERE schemaname = current_schema()
82            AND tablename = 'sessions'
83            AND indexname = 'idx_sessions_deadline'
84    ) THEN
85        CREATE INDEX idx_sessions_deadline ON sessions(deadline);
86    END IF;
87END $$;"
88    }
89
90    /// Create the sessions table and index in the database.
91    ///
92    /// This method is idempotent, meaning it can be called multiple times without
93    /// causing any issues. If the table and index already exist, this method does nothing.
94    ///
95    /// If you prefer to run the query yourself, rely on [`migration_query`](Self::migration_query)
96    /// to get the SQL that's being executed.
97    pub async fn migrate(&self) -> Result<(), sqlx::Error> {
98        use sqlx::Executor as _;
99
100        self.0.execute(Self::migration_query()).await?;
101        Ok(())
102    }
103}
104
105#[async_trait::async_trait]
106impl SessionStorageBackend for PostgresSessionStore {
107    /// Creates a new session record in the store using the provided ID.
108    #[tracing::instrument(name = "Create server-side session record", level = tracing::Level::INFO, skip_all)]
109    async fn create(
110        &self,
111        id: &SessionId,
112        record: SessionRecordRef<'_>,
113    ) -> Result<(), CreateError> {
114        let deadline = Timestamp::now() + record.ttl;
115        let state = serde_json::to_value(record.state)?;
116        let query = sqlx::query(
117            "INSERT INTO sessions (id, deadline, state) \
118            VALUES ($1, $2, $3) \
119            ON CONFLICT (id) DO UPDATE \
120            SET deadline = EXCLUDED.deadline, state = EXCLUDED.state \
121            WHERE sessions.deadline < (now() AT TIME ZONE 'UTC')",
122        )
123        .bind(id.inner())
124        .bind(deadline.to_sqlx())
125        .bind(state);
126
127        match query.execute(&self.0).await {
128            // All good, we created the session record.
129            Ok(_) => Ok(()),
130            Err(e) => {
131                // Return the specialized error variant if the ID is already in use
132                if let Err(e) = as_duplicated_id_error(&e, id) {
133                    Err(e.into())
134                } else {
135                    Err(CreateError::Other(e.into()))
136                }
137            }
138        }
139    }
140
141    /// Update the state of an existing session in the store.
142    ///
143    /// It overwrites the existing record with the provided one.
144    #[tracing::instrument(name = "Update server-side session record", level = tracing::Level::INFO, skip_all)]
145    async fn update(
146        &self,
147        id: &SessionId,
148        record: SessionRecordRef<'_>,
149    ) -> Result<(), UpdateError> {
150        let new_deadline = Timestamp::now() + record.ttl;
151        let new_state = serde_json::to_value(record.state)?;
152        let query = sqlx::query(
153            "UPDATE sessions \
154            SET deadline = $1, state = $2 \
155            WHERE id = $3 AND deadline > (now() AT TIME ZONE 'UTC')",
156        )
157        .bind(new_deadline.to_sqlx())
158        .bind(new_state)
159        .bind(id.inner());
160
161        match query.execute(&self.0).await {
162            Ok(r) => as_unknown_id_error(&r, id).map_err(Into::into),
163            Err(e) => Err(UpdateError::Other(e.into())),
164        }
165    }
166
167    /// Update the TTL of an existing session record in the store.
168    ///
169    /// It leaves the session state unchanged.
170    #[tracing::instrument(name = "Update TTL for server-side session record", level = tracing::Level::INFO, skip_all)]
171    async fn update_ttl(
172        &self,
173        id: &SessionId,
174        ttl: std::time::Duration,
175    ) -> Result<(), UpdateTtlError> {
176        let new_deadline = Timestamp::now() + ttl;
177        let query = sqlx::query(
178            "UPDATE sessions \
179            SET deadline = $1 \
180            WHERE id = $2 AND deadline > (now() AT TIME ZONE 'UTC')",
181        )
182        .bind(new_deadline.to_sqlx())
183        .bind(id.inner());
184        match query.execute(&self.0).await {
185            Ok(r) => as_unknown_id_error(&r, id).map_err(Into::into),
186            Err(e) => Err(UpdateTtlError::Other(e.into())),
187        }
188    }
189
190    /// Loads an existing session record from the store using the provided ID.
191    ///
192    /// If a session with the given ID exists, it is returned. If the session
193    /// does not exist or has been invalidated (e.g., expired), `None` is
194    /// returned.
195    #[tracing::instrument(name = "Load server-side session record", level = tracing::Level::INFO, skip_all)]
196    async fn load(&self, session_id: &SessionId) -> Result<Option<SessionRecord>, LoadError> {
197        let row = sqlx::query(
198            "SELECT deadline, state \
199            FROM sessions \
200            WHERE id = $1 AND deadline > (now() AT TIME ZONE 'UTC')",
201        )
202        .bind(session_id.inner())
203        .fetch_optional(&self.0)
204        .await
205        .map_err(|e| LoadError::Other(e.into()))?;
206        row.map(|r| {
207            use anyhow::Context as _;
208            use sqlx::Row as _;
209
210            let deadline = r
211                .try_get::<jiff_sqlx::Timestamp, _>(0)
212                .context("Failed to deserialize the retrieved session deadline")
213                .map_err(LoadError::DeserializationError)?
214                .to_jiff();
215            let state: serde_json::Value = r
216                .try_get(1)
217                .context("Failed to deserialize the retrieved session state")
218                .map_err(LoadError::DeserializationError)?;
219            let ttl = deadline - Timestamp::now();
220            Ok(SessionRecord {
221                // This conversion only fails if the duration is negative, which should not happen
222                ttl: ttl.try_into().unwrap_or(std::time::Duration::ZERO),
223                state: serde_json::from_value(state)
224                    .context("Failed to deserialize the retrieved session state")
225                    .map_err(LoadError::DeserializationError)?,
226            })
227        })
228        .transpose()
229    }
230
231    /// Deletes a session record from the store using the provided ID.
232    ///
233    /// If the session exists, it is removed from the store.
234    #[tracing::instrument(name = "Delete server-side session record", level = tracing::Level::INFO, skip_all)]
235    async fn delete(&self, id: &SessionId) -> Result<(), DeleteError> {
236        let query = sqlx::query(
237            "DELETE FROM sessions \
238            WHERE id = $1 AND deadline > (now() AT TIME ZONE 'UTC')",
239        )
240        .bind(id.inner());
241        match query.execute(&self.0).await {
242            Ok(r) => as_unknown_id_error(&r, id).map_err(Into::into),
243            Err(e) => Err(DeleteError::Other(e.into())),
244        }
245    }
246
247    /// Change the session id associated with an existing session record.
248    ///
249    /// The server-side state is left unchanged.
250    #[tracing::instrument(name = "Change id for server-side session record", level = tracing::Level::INFO, skip_all)]
251    async fn change_id(&self, old_id: &SessionId, new_id: &SessionId) -> Result<(), ChangeIdError> {
252        let query = sqlx::query(
253            "UPDATE sessions \
254            SET id = $1 \
255            WHERE id = $2 AND deadline > (now() AT TIME ZONE 'UTC')",
256        )
257        .bind(new_id.inner())
258        .bind(old_id.inner());
259        match query.execute(&self.0).await {
260            Ok(r) => as_unknown_id_error(&r, old_id).map_err(Into::into),
261            Err(e) => {
262                if let Err(e) = as_duplicated_id_error(&e, new_id) {
263                    Err(e.into())
264                } else {
265                    Err(ChangeIdError::Other(e.into()))
266                }
267            }
268        }
269    }
270
271    /// Delete expired sessions from the database.
272    ///
273    /// If `batch_size` is provided, the query will delete at most `batch_size` expired sessions.
274    /// In either case, if successful, the method returns the number of expired sessions that
275    /// have been deleted.
276    ///
277    /// # When should you delete in batches?
278    ///
279    /// If there are a lot of expired sessions in the database, deleting them all at once can
280    /// cause performance issues. By deleting in batches, you can limit the number of sessions
281    /// deleted in a single query, reducing the impact.
282    ///
283    /// # Example
284    ///
285    /// Delete expired sessions in batches of 1000:
286    ///
287    /// ```no_run
288    /// use pavex_session::SessionStore;
289    /// use pavex_session_sqlx::PostgresSessionStore;
290    /// use pavex_tracing::fields::{
291    ///     error_details,
292    ///     error_message,
293    ///     ERROR_DETAILS,
294    ///     ERROR_MESSAGE
295    /// };
296    /// use std::time::Duration;
297    ///
298    /// # async fn delete_expired_sessions(pool: sqlx::PgPool) {
299    /// let backend = PostgresSessionStore::new(pool);
300    /// let store = SessionStore::new(backend);
301    /// let batch_size = Some(1000.try_into().unwrap());
302    /// let batch_sleep = Duration::from_secs(60);
303    /// loop {
304    ///     if let Err(e) = store.delete_expired(batch_size).await {
305    ///         tracing::event!(
306    ///             tracing::Level::ERROR,
307    ///             { ERROR_MESSAGE } = error_message(&e),
308    ///             { ERROR_DETAILS } = error_details(&e),
309    ///             "Failed to delete a batch of expired sessions",
310    ///         );
311    ///     }
312    ///     tokio::time::sleep(batch_sleep).await;
313    /// }
314    /// # }
315    async fn delete_expired(
316        &self,
317        batch_size: Option<NonZeroUsize>,
318    ) -> Result<usize, DeleteExpiredError> {
319        let query = if let Some(batch_size) = batch_size {
320            let batch_size: i64 = batch_size.get().try_into().unwrap_or(i64::MAX);
321            sqlx::query("DELETE FROM sessions WHERE deadline < (now() AT TIME ZONE 'UTC') LIMIT $1")
322                .bind(batch_size)
323        } else {
324            sqlx::query("DELETE FROM sessions WHERE deadline < (now() AT TIME ZONE 'UTC')")
325        };
326        let r = query.execute(&self.0).await.map_err(|e| {
327            let e: anyhow::Error = e.into();
328            e
329        })?;
330        Ok(r.rows_affected().try_into().unwrap_or(usize::MAX))
331    }
332}
333
334fn as_duplicated_id_error(e: &sqlx::Error, id: &SessionId) -> Result<(), DuplicateIdError> {
335    if let Some(e) = e.as_database_error() {
336        if let Some(e) = e.try_downcast_ref::<PgDatabaseError>() {
337            // Check if the error is due to a duplicate ID
338            // See https://www.postgresql.org/docs/current/errcodes-appendix.html
339            // for the list of error codes for Postgres
340            if e.code() == "23505" && e.column() == Some("id") {
341                return Err(DuplicateIdError { id: id.to_owned() });
342            }
343        }
344    }
345    Ok(())
346}
347
348fn as_unknown_id_error(r: &PgQueryResult, id: &SessionId) -> Result<(), UnknownIdError> {
349    // Check if the session record was changed
350    if r.rows_affected() == 0 {
351        return Err(UnknownIdError { id: id.to_owned() });
352    }
353    // Sanity check
354    assert_eq!(
355        r.rows_affected(),
356        1,
357        "More than one session record was affected, even though the session ID is used as primary key. Something is deeply wrong here!"
358    );
359    Ok(())
360}