1use jiff_sqlx::ToSqlx;
3use pavex::methods;
4use pavex::time::Timestamp;
5use pavex_session::SessionStore;
6use pavex_session::{
7 SessionId,
8 store::{
9 SessionRecord, SessionRecordRef, SessionStorageBackend,
10 errors::{
11 ChangeIdError, CreateError, DeleteError, DeleteExpiredError, DuplicateIdError,
12 LoadError, UnknownIdError, UpdateError, UpdateTtlError,
13 },
14 },
15};
16use sqlx::{
17 PgPool,
18 postgres::{PgDatabaseError, PgQueryResult},
19};
20use std::num::NonZeroUsize;
21
22#[derive(Debug, Clone)]
23pub struct PostgresSessionStore(sqlx::PgPool);
35
36#[methods]
37impl From<PostgresSessionStore> for SessionStore {
38 #[singleton]
39 fn from(value: PostgresSessionStore) -> Self {
40 SessionStore::new(value)
41 }
42}
43
44#[methods]
45impl PostgresSessionStore {
46 #[singleton]
51 pub fn new(pool: PgPool) -> Self {
52 Self(pool)
53 }
54
55 pub fn migration_query() -> &'static str {
69 "-- Create the sessions table if it doesn’t exist
70CREATE TABLE IF NOT EXISTS sessions (
71 id UUID PRIMARY KEY,
72 deadline TIMESTAMPTZ NOT NULL,
73 state JSONB NOT NULL
74);
75
76-- Create the index on the deadline column if it doesn’t exist
77DO $$
78BEGIN
79 IF NOT EXISTS (
80 SELECT 1 FROM pg_indexes
81 WHERE schemaname = current_schema()
82 AND tablename = 'sessions'
83 AND indexname = 'idx_sessions_deadline'
84 ) THEN
85 CREATE INDEX idx_sessions_deadline ON sessions(deadline);
86 END IF;
87END $$;"
88 }
89
90 pub async fn migrate(&self) -> Result<(), sqlx::Error> {
98 use sqlx::Executor as _;
99
100 self.0.execute(Self::migration_query()).await?;
101 Ok(())
102 }
103}
104
105#[async_trait::async_trait]
106impl SessionStorageBackend for PostgresSessionStore {
107 #[tracing::instrument(name = "Create server-side session record", level = tracing::Level::INFO, skip_all)]
109 async fn create(
110 &self,
111 id: &SessionId,
112 record: SessionRecordRef<'_>,
113 ) -> Result<(), CreateError> {
114 let deadline = Timestamp::now() + record.ttl;
115 let state = serde_json::to_value(record.state)?;
116 let query = sqlx::query(
117 "INSERT INTO sessions (id, deadline, state) \
118 VALUES ($1, $2, $3) \
119 ON CONFLICT (id) DO UPDATE \
120 SET deadline = EXCLUDED.deadline, state = EXCLUDED.state \
121 WHERE sessions.deadline < (now() AT TIME ZONE 'UTC')",
122 )
123 .bind(id.inner())
124 .bind(deadline.to_sqlx())
125 .bind(state);
126
127 match query.execute(&self.0).await {
128 Ok(_) => Ok(()),
130 Err(e) => {
131 if let Err(e) = as_duplicated_id_error(&e, id) {
133 Err(e.into())
134 } else {
135 Err(CreateError::Other(e.into()))
136 }
137 }
138 }
139 }
140
141 #[tracing::instrument(name = "Update server-side session record", level = tracing::Level::INFO, skip_all)]
145 async fn update(
146 &self,
147 id: &SessionId,
148 record: SessionRecordRef<'_>,
149 ) -> Result<(), UpdateError> {
150 let new_deadline = Timestamp::now() + record.ttl;
151 let new_state = serde_json::to_value(record.state)?;
152 let query = sqlx::query(
153 "UPDATE sessions \
154 SET deadline = $1, state = $2 \
155 WHERE id = $3 AND deadline > (now() AT TIME ZONE 'UTC')",
156 )
157 .bind(new_deadline.to_sqlx())
158 .bind(new_state)
159 .bind(id.inner());
160
161 match query.execute(&self.0).await {
162 Ok(r) => as_unknown_id_error(&r, id).map_err(Into::into),
163 Err(e) => Err(UpdateError::Other(e.into())),
164 }
165 }
166
167 #[tracing::instrument(name = "Update TTL for server-side session record", level = tracing::Level::INFO, skip_all)]
171 async fn update_ttl(
172 &self,
173 id: &SessionId,
174 ttl: std::time::Duration,
175 ) -> Result<(), UpdateTtlError> {
176 let new_deadline = Timestamp::now() + ttl;
177 let query = sqlx::query(
178 "UPDATE sessions \
179 SET deadline = $1 \
180 WHERE id = $2 AND deadline > (now() AT TIME ZONE 'UTC')",
181 )
182 .bind(new_deadline.to_sqlx())
183 .bind(id.inner());
184 match query.execute(&self.0).await {
185 Ok(r) => as_unknown_id_error(&r, id).map_err(Into::into),
186 Err(e) => Err(UpdateTtlError::Other(e.into())),
187 }
188 }
189
190 #[tracing::instrument(name = "Load server-side session record", level = tracing::Level::INFO, skip_all)]
196 async fn load(&self, session_id: &SessionId) -> Result<Option<SessionRecord>, LoadError> {
197 let row = sqlx::query(
198 "SELECT deadline, state \
199 FROM sessions \
200 WHERE id = $1 AND deadline > (now() AT TIME ZONE 'UTC')",
201 )
202 .bind(session_id.inner())
203 .fetch_optional(&self.0)
204 .await
205 .map_err(|e| LoadError::Other(e.into()))?;
206 row.map(|r| {
207 use anyhow::Context as _;
208 use sqlx::Row as _;
209
210 let deadline = r
211 .try_get::<jiff_sqlx::Timestamp, _>(0)
212 .context("Failed to deserialize the retrieved session deadline")
213 .map_err(LoadError::DeserializationError)?
214 .to_jiff();
215 let state: serde_json::Value = r
216 .try_get(1)
217 .context("Failed to deserialize the retrieved session state")
218 .map_err(LoadError::DeserializationError)?;
219 let ttl = deadline - Timestamp::now();
220 Ok(SessionRecord {
221 ttl: ttl.try_into().unwrap_or(std::time::Duration::ZERO),
223 state: serde_json::from_value(state)
224 .context("Failed to deserialize the retrieved session state")
225 .map_err(LoadError::DeserializationError)?,
226 })
227 })
228 .transpose()
229 }
230
231 #[tracing::instrument(name = "Delete server-side session record", level = tracing::Level::INFO, skip_all)]
235 async fn delete(&self, id: &SessionId) -> Result<(), DeleteError> {
236 let query = sqlx::query(
237 "DELETE FROM sessions \
238 WHERE id = $1 AND deadline > (now() AT TIME ZONE 'UTC')",
239 )
240 .bind(id.inner());
241 match query.execute(&self.0).await {
242 Ok(r) => as_unknown_id_error(&r, id).map_err(Into::into),
243 Err(e) => Err(DeleteError::Other(e.into())),
244 }
245 }
246
247 #[tracing::instrument(name = "Change id for server-side session record", level = tracing::Level::INFO, skip_all)]
251 async fn change_id(&self, old_id: &SessionId, new_id: &SessionId) -> Result<(), ChangeIdError> {
252 let query = sqlx::query(
253 "UPDATE sessions \
254 SET id = $1 \
255 WHERE id = $2 AND deadline > (now() AT TIME ZONE 'UTC')",
256 )
257 .bind(new_id.inner())
258 .bind(old_id.inner());
259 match query.execute(&self.0).await {
260 Ok(r) => as_unknown_id_error(&r, old_id).map_err(Into::into),
261 Err(e) => {
262 if let Err(e) = as_duplicated_id_error(&e, new_id) {
263 Err(e.into())
264 } else {
265 Err(ChangeIdError::Other(e.into()))
266 }
267 }
268 }
269 }
270
271 async fn delete_expired(
316 &self,
317 batch_size: Option<NonZeroUsize>,
318 ) -> Result<usize, DeleteExpiredError> {
319 let query = if let Some(batch_size) = batch_size {
320 let batch_size: i64 = batch_size.get().try_into().unwrap_or(i64::MAX);
321 sqlx::query("DELETE FROM sessions WHERE deadline < (now() AT TIME ZONE 'UTC') LIMIT $1")
322 .bind(batch_size)
323 } else {
324 sqlx::query("DELETE FROM sessions WHERE deadline < (now() AT TIME ZONE 'UTC')")
325 };
326 let r = query.execute(&self.0).await.map_err(|e| {
327 let e: anyhow::Error = e.into();
328 e
329 })?;
330 Ok(r.rows_affected().try_into().unwrap_or(usize::MAX))
331 }
332}
333
334fn as_duplicated_id_error(e: &sqlx::Error, id: &SessionId) -> Result<(), DuplicateIdError> {
335 if let Some(e) = e.as_database_error() {
336 if let Some(e) = e.try_downcast_ref::<PgDatabaseError>() {
337 if e.code() == "23505" && e.column() == Some("id") {
341 return Err(DuplicateIdError { id: id.to_owned() });
342 }
343 }
344 }
345 Ok(())
346}
347
348fn as_unknown_id_error(r: &PgQueryResult, id: &SessionId) -> Result<(), UnknownIdError> {
349 if r.rows_affected() == 0 {
351 return Err(UnknownIdError { id: id.to_owned() });
352 }
353 assert_eq!(
355 r.rows_affected(),
356 1,
357 "More than one session record was affected, even though the session ID is used as primary key. Something is deeply wrong here!"
358 );
359 Ok(())
360}