1use pavex::methods;
4use pavex::time::Timestamp;
5use pavex_session::SessionStore;
6use pavex_session::{
7 SessionId,
8 store::{
9 SessionRecord, SessionRecordRef, SessionStorageBackend,
10 errors::{
11 ChangeIdError, CreateError, DeleteError, DeleteExpiredError, DuplicateIdError,
12 LoadError, UnknownIdError, UpdateError, UpdateTtlError,
13 },
14 },
15};
16use sqlx::{
17 SqlitePool,
18 error::DatabaseError,
19 sqlite::{SqliteError, SqliteQueryResult},
20};
21use std::num::NonZeroUsize;
22
23#[derive(Debug, Clone)]
24pub struct SqliteSessionStore(sqlx::SqlitePool);
43
44#[methods]
45impl From<SqliteSessionStore> for SessionStore {
46 #[singleton]
47 fn from(value: SqliteSessionStore) -> Self {
48 SessionStore::new(value)
49 }
50}
51
52#[methods]
53impl SqliteSessionStore {
54 #[singleton]
59 pub fn new(pool: SqlitePool) -> Self {
60 Self(pool)
61 }
62
63 pub fn migration_query() -> &'static str {
77 "-- Create the sessions table if it doesn't exist
78CREATE TABLE IF NOT EXISTS sessions (
79 id TEXT PRIMARY KEY,
80 deadline INTEGER NOT NULL,
81 state JSONB NOT NULL
82);
83
84-- Create the index on the deadline column if it doesn't exist
85CREATE INDEX IF NOT EXISTS idx_sessions_deadline ON sessions(deadline);"
86 }
87
88 pub async fn migrate(&self) -> Result<(), sqlx::Error> {
96 use sqlx::Executor as _;
97
98 self.0.execute(Self::migration_query()).await?;
99 Ok(())
100 }
101}
102
103#[async_trait::async_trait]
104impl SessionStorageBackend for SqliteSessionStore {
105 #[tracing::instrument(name = "Create server-side session record", level = tracing::Level::INFO, skip_all)]
107 async fn create(
108 &self,
109 id: &SessionId,
110 record: SessionRecordRef<'_>,
111 ) -> Result<(), CreateError> {
112 let deadline = Timestamp::now() + record.ttl;
113 let deadline_unix = deadline.as_second();
114 let state = serde_json::to_value(record.state)?;
115 let query = sqlx::query(
116 "INSERT INTO sessions (id, deadline, state) \
117 VALUES (?, ?, ?) \
118 ON CONFLICT(id) DO UPDATE \
119 SET deadline = excluded.deadline, state = excluded.state \
120 WHERE sessions.deadline < unixepoch()",
121 )
122 .bind(id.inner().to_string())
123 .bind(deadline_unix)
124 .bind(state);
125
126 match query.execute(&self.0).await {
127 Ok(_) => Ok(()),
129 Err(e) => {
130 if let Err(e) = as_duplicated_id_error(&e, id) {
132 Err(e.into())
133 } else {
134 Err(CreateError::Other(e.into()))
135 }
136 }
137 }
138 }
139
140 #[tracing::instrument(name = "Update server-side session record", level = tracing::Level::INFO, skip_all)]
144 async fn update(
145 &self,
146 id: &SessionId,
147 record: SessionRecordRef<'_>,
148 ) -> Result<(), UpdateError> {
149 let new_deadline = Timestamp::now() + record.ttl;
150 let new_deadline_unix = new_deadline.as_second();
151 let new_state = serde_json::to_value(record.state)?;
152 let query = sqlx::query(
153 "UPDATE sessions \
154 SET deadline = ?, state = ? \
155 WHERE id = ? AND deadline > unixepoch()",
156 )
157 .bind(new_deadline_unix)
158 .bind(new_state)
159 .bind(id.inner().to_string());
160
161 match query.execute(&self.0).await {
162 Ok(r) => as_unknown_id_error(&r, id).map_err(Into::into),
163 Err(e) => Err(UpdateError::Other(e.into())),
164 }
165 }
166
167 #[tracing::instrument(name = "Update TTL for server-side session record", level = tracing::Level::INFO, skip_all)]
171 async fn update_ttl(
172 &self,
173 id: &SessionId,
174 ttl: std::time::Duration,
175 ) -> Result<(), UpdateTtlError> {
176 let new_deadline = Timestamp::now() + ttl;
177 let new_deadline_unix = new_deadline.as_second();
178 let query = sqlx::query(
179 "UPDATE sessions \
180 SET deadline = ? \
181 WHERE id = ? AND deadline > unixepoch()",
182 )
183 .bind(new_deadline_unix)
184 .bind(id.inner().to_string());
185 match query.execute(&self.0).await {
186 Ok(r) => as_unknown_id_error(&r, id).map_err(Into::into),
187 Err(e) => Err(UpdateTtlError::Other(e.into())),
188 }
189 }
190
191 #[tracing::instrument(name = "Load server-side session record", level = tracing::Level::INFO, skip_all)]
197 async fn load(&self, session_id: &SessionId) -> Result<Option<SessionRecord>, LoadError> {
198 let row = sqlx::query(
199 "SELECT deadline, state \
200 FROM sessions \
201 WHERE id = ? AND deadline > unixepoch()",
202 )
203 .bind(session_id.inner().to_string())
204 .fetch_optional(&self.0)
205 .await
206 .map_err(|e| LoadError::Other(e.into()))?;
207 row.map(|r| {
208 use anyhow::Context as _;
209 use sqlx::Row as _;
210
211 let deadline_unix: i64 = r
212 .try_get(0)
213 .context("Failed to deserialize the retrieved session deadline")
214 .map_err(LoadError::DeserializationError)?;
215 let deadline = Timestamp::from_second(deadline_unix)
216 .context("Failed to parse the retrieved session deadline")
217 .map_err(LoadError::DeserializationError)?;
218 let state: serde_json::Value = r
219 .try_get(1)
220 .context("Failed to deserialize the retrieved session state")
221 .map_err(LoadError::DeserializationError)?;
222 let ttl = deadline - Timestamp::now();
223 Ok(SessionRecord {
224 ttl: ttl.try_into().unwrap_or(std::time::Duration::ZERO),
226 state: serde_json::from_value(state)
227 .context("Failed to deserialize the retrieved session state")
228 .map_err(LoadError::DeserializationError)?,
229 })
230 })
231 .transpose()
232 }
233
234 #[tracing::instrument(name = "Delete server-side session record", level = tracing::Level::INFO, skip_all)]
238 async fn delete(&self, id: &SessionId) -> Result<(), DeleteError> {
239 let query = sqlx::query(
240 "DELETE FROM sessions \
241 WHERE id = ? AND deadline > unixepoch()",
242 )
243 .bind(id.inner().to_string());
244 match query.execute(&self.0).await {
245 Ok(r) => as_unknown_id_error(&r, id).map_err(Into::into),
246 Err(e) => Err(DeleteError::Other(e.into())),
247 }
248 }
249
250 #[tracing::instrument(name = "Change id for server-side session record", level = tracing::Level::INFO, skip_all)]
254 async fn change_id(&self, old_id: &SessionId, new_id: &SessionId) -> Result<(), ChangeIdError> {
255 let query = sqlx::query(
256 "UPDATE sessions \
257 SET id = ? \
258 WHERE id = ? AND deadline > unixepoch()",
259 )
260 .bind(new_id.inner().to_string())
261 .bind(old_id.inner().to_string());
262 match query.execute(&self.0).await {
263 Ok(r) => as_unknown_id_error(&r, old_id).map_err(Into::into),
264 Err(e) => {
265 if let Err(e) = as_duplicated_id_error(&e, new_id) {
266 Err(e.into())
267 } else {
268 Err(ChangeIdError::Other(e.into()))
269 }
270 }
271 }
272 }
273
274 async fn delete_expired(
319 &self,
320 batch_size: Option<NonZeroUsize>,
321 ) -> Result<usize, DeleteExpiredError> {
322 let query = if let Some(batch_size) = batch_size {
323 let batch_size: i64 = batch_size.get().try_into().unwrap_or(i64::MAX);
324 sqlx::query("DELETE FROM sessions WHERE id IN (SELECT id FROM sessions WHERE deadline < unixepoch() LIMIT ?)")
325 .bind(batch_size)
326 } else {
327 sqlx::query("DELETE FROM sessions WHERE deadline < unixepoch()")
328 };
329 let r = query.execute(&self.0).await.map_err(|e| {
330 let e: anyhow::Error = e.into();
331 e
332 })?;
333 Ok(r.rows_affected().try_into().unwrap_or(usize::MAX))
334 }
335}
336
337fn as_duplicated_id_error(e: &sqlx::Error, id: &SessionId) -> Result<(), DuplicateIdError> {
338 if let Some(e) = e.as_database_error() {
339 if let Some(e) = e.try_downcast_ref::<SqliteError>() {
340 if e.code() == Some("1555".into()) {
343 return Err(DuplicateIdError { id: id.to_owned() });
344 }
345 }
346 }
347 Ok(())
348}
349
350fn as_unknown_id_error(r: &SqliteQueryResult, id: &SessionId) -> Result<(), UnknownIdError> {
351 if r.rows_affected() == 0 {
353 return Err(UnknownIdError { id: id.to_owned() });
354 }
355 assert_eq!(
357 r.rows_affected(),
358 1,
359 "More than one session record was affected, even though the session ID is used as primary key. Something is deeply wrong here!"
360 );
361 Ok(())
362}