pavex_session/
session_.rs

1use errors::{
2    FinalizeError, ServerGetError, ServerInsertError, ServerRemoveError, SyncError,
3    ValueDeserializationError, ValueLocation, ValueSerializationError,
4};
5use pavex::cookie::{RemovalCookie, ResponseCookie};
6use pavex::methods;
7use pavex::time::SignedDuration;
8use serde::Serialize;
9use serde::de::DeserializeOwned;
10use serde_json::Value;
11use std::borrow::Cow;
12use std::cell::OnceCell;
13use std::collections::HashMap;
14use std::marker::PhantomData;
15use std::sync::MutexGuard;
16
17use crate::SessionConfig;
18use crate::SessionId;
19use crate::SessionStore;
20use crate::State;
21use crate::config::{
22    MissingServerState, ServerStateCreation, SessionCookieKind, TtlExtensionTrigger,
23};
24use crate::incoming::IncomingSession;
25use crate::store::SessionRecordRef;
26use crate::store::errors::{ChangeIdError, DeleteError, LoadError};
27use crate::wire::WireClientState;
28
29/// The current HTTP session.
30///
31/// # Implementation notes
32///
33/// ## Not `Clone`
34///
35/// The session is a stateful object that holds the client-side and server-side state
36/// of the session, tracking all changes to both states. As a result, `Session` does
37/// not implement the `Clone` trait.
38///
39/// ## Not `Send` nor `Sync`
40///
41/// The session object is designed to be used within the lifetime of the request
42/// it refers to.
43/// When Pavex receives a new request, it assigns it to a specific worker thread,
44/// where all the processing for that request takes place.
45///
46/// Given the above, we optimized `Session`'s internals for single-thread usage
47/// and decided not to implement `Send` and `Sync` for it.
48pub struct Session<'store> {
49    id: CurrentSessionId,
50    /// The server state is loaded lazily, hence the `OnceCell` wrapper.
51    server_state: OnceCell<ServerState>,
52    client_state: ClientState,
53    /// # Internal invariant
54    ///
55    /// If the session has been invalidated, `server_state` MUST
56    /// be set to `Some(ServerState::MarkedForDeletion)`.
57    invalidated: InvalidationFlag,
58    store: &'store SessionStore,
59    config: &'store SessionConfig,
60    /// This field is used to prevent `Send` being implemented for `Session`.
61    _unsend: PhantomUnsend,
62}
63
64impl std::fmt::Debug for Session<'_> {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct("Session")
67            .field("id", &"**redacted**")
68            .field("server_state", &self.server_state)
69            .field("client_state", &self.client_state)
70            .field("invalidated", &self.invalidated)
71            .field("store", &self.store)
72            .field("config", &self.config)
73            .finish()
74    }
75}
76
77/// A thin wrapper around `OnceCell<()>` to represent an invalidation flag.
78#[derive(Clone)]
79struct InvalidationFlag(OnceCell<()>);
80
81impl std::fmt::Debug for InvalidationFlag {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.debug_struct("InvalidationFlag")
84            .field("is_invalidated", &self.is_invalidated())
85            .finish()
86    }
87}
88
89impl InvalidationFlag {
90    /// Create a new invalidation flag, initially set to `false`.
91    fn new() -> Self {
92        Self(OnceCell::new())
93    }
94
95    /// Set the invalidation flag to `true`.
96    fn invalidate(&self) {
97        // We don't care if it has already been invalidated.
98        let _ = self.0.set(());
99    }
100
101    fn is_invalidated(&self) -> bool {
102        self.0.get().is_some()
103    }
104}
105
106/// See <https://stackoverflow.com/questions/62713667/how-to-implement-send-or-sync-for-a-type>
107type PhantomUnsend = PhantomData<MutexGuard<'static, ()>>;
108
109#[derive(Clone, PartialEq, Eq)]
110enum CurrentSessionId {
111    Existing(SessionId),
112    /// # Internal invariant
113    ///
114    /// `old` is always different from `new`.
115    ToBeRenamed {
116        old: SessionId,
117        new: SessionId,
118    },
119    NewlyGenerated(SessionId),
120}
121
122impl CurrentSessionId {
123    fn new_id(&self) -> SessionId {
124        match self {
125            Self::Existing(id) => *id,
126            Self::ToBeRenamed { new, .. } => *new,
127            Self::NewlyGenerated(id) => *id,
128        }
129    }
130
131    fn old_id(&self) -> Option<SessionId> {
132        match self {
133            Self::Existing(id) => Some(*id),
134            Self::ToBeRenamed { old, .. } => Some(*old),
135            Self::NewlyGenerated(..) => None,
136        }
137    }
138}
139
140#[derive(Debug, Clone)]
141enum ClientState {
142    Unchanged { state: State },
143    Updated { state: State },
144}
145
146#[derive(Debug, Clone, PartialEq, Eq)]
147enum ServerState {
148    Unchanged {
149        state: State,
150        ttl: std::time::Duration,
151    },
152    DoesNotExist,
153    MarkedForDeletion,
154    Changed {
155        state: State,
156    },
157}
158
159#[methods]
160impl<'store> Session<'store> {
161    /// Create a new HTTP session.
162    ///
163    /// It is a continuation of the existing session if there was a valid session cookie
164    /// attached to the request.
165    /// It is a brand-new session otherwise.
166    #[request_scoped]
167    pub fn new(
168        store: &'store SessionStore,
169        config: &'store SessionConfig,
170        incoming_session: Option<IncomingSession>,
171    ) -> Self {
172        let (client_state, previous_session_id) = match incoming_session {
173            Some(s) => (s.client_state, Some(s.id)),
174            None => (Default::default(), None),
175        };
176        let (id, server_state) = match previous_session_id {
177            Some(id) => (CurrentSessionId::Existing(id), None),
178            None => (
179                CurrentSessionId::NewlyGenerated(SessionId::random()),
180                Some(ServerState::DoesNotExist),
181            ),
182        };
183        Self {
184            id,
185            server_state: new_cell_with(server_state),
186            client_state: ClientState::Unchanged {
187                state: client_state,
188            },
189            invalidated: InvalidationFlag::new(),
190            store,
191            config,
192            _unsend: Default::default(),
193        }
194    }
195}
196
197/// All the operations you can perform on the server-side state of your session.
198impl Session<'_> {
199    /// Get the value associated with `key` from the server-side state.
200    ///
201    /// If the value is not found, `None` is returned.\
202    /// If the value is found, but it cannot be deserialized into the expected type,
203    /// an error is returned.
204    ///
205    /// If you don't need to deserialize the value, or you'd like to handle the deserialization
206    /// yourself, use [`get_raw`][Self::get_raw] instead.
207    pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, ServerGetError> {
208        self.get_raw(key)
209            .await?
210            .map(|value| serde_json::from_value(value.clone()))
211            .transpose()
212            .map_err(|e| {
213                ValueDeserializationError {
214                    key: key.to_string().into(),
215                    location: ValueLocation::Server,
216                    source: e,
217                }
218                .into()
219            })
220    }
221
222    /// Insert a value for the given key in the server-side state.
223    ///
224    /// If the state didn't have an entry for this key, the value is inserted and `None` is returned.\
225    /// If the state did have an entry for this key, its value is updated and the old
226    /// value is returned in its raw JSON form.
227    ///
228    /// The provided value is serialized to JSON prior to being stored. If
229    /// the serialization fails, an error is returned. If you'd prefer to
230    /// take care of the serialization yourself, use [`insert_raw`][Self::insert_raw] instead.
231    pub async fn insert<T, Key>(
232        &mut self,
233        key: Key,
234        value: T,
235    ) -> Result<Option<Value>, ServerInsertError>
236    where
237        T: Serialize,
238        Key: Into<Cow<'static, str>>,
239    {
240        let key = key.into();
241        let value = match serde_json::to_value(value) {
242            Ok(t) => t,
243            Err(source) => {
244                return Err(ValueSerializationError {
245                    key,
246                    location: ValueLocation::Server,
247                    source,
248                }
249                .into());
250            }
251        };
252        self.insert_raw(key, value).await.map_err(Into::into)
253    }
254
255    /// Remove the value associated with `key` from the server-side state.
256    ///
257    /// If the key doesn't exist, `None` is returned.
258    ///
259    /// If the key exists, the removed value is returned, deserialized into the type you specify as `T`.
260    /// If the removed value cannot be deserialized, an error is returned.
261    ///
262    /// If you're not interested in the removed value, or you don't want to deserialize it,
263    /// use [`remove_raw`][Self::remove_raw] instead.
264    pub async fn remove<T: DeserializeOwned>(
265        &mut self,
266        key: &str,
267    ) -> Result<Option<T>, ServerRemoveError> {
268        self.remove_raw(key)
269            .await?
270            .map(serde_json::from_value)
271            .transpose()
272            .map_err(|source| ValueDeserializationError {
273                key: key.to_string().into(),
274                location: ValueLocation::Server,
275                source,
276            })
277            .map_err(ServerRemoveError::DeserializationError)
278    }
279
280    /// Returns `true` if there are no values in the server-side state.
281    pub async fn is_empty(&self) -> Result<bool, LoadError> {
282        use ServerState::*;
283
284        match force_load_ref(self).await? {
285            Unchanged { state, .. } | Changed { state } => Ok(state.is_empty()),
286            DoesNotExist | MarkedForDeletion => Ok(true),
287        }
288    }
289
290    /// Get the value associated with `key` from the server-side state.
291    pub async fn get_raw<'a>(&'a self, key: &str) -> Result<Option<&'a Value>, LoadError> {
292        use ServerState::*;
293
294        match force_load_ref(self).await? {
295            Unchanged { state, .. } | Changed { state } => Ok(state.get(key)),
296            DoesNotExist => Ok(None),
297            MarkedForDeletion => {
298                tracing::debug!(session.key = %key, "Tried to access a server-side value on a session marked for deletion.");
299                Ok(None)
300            }
301        }
302    }
303
304    /// Insert a value for the given key in the server-side state.
305    ///
306    /// If the state didn't have an entry for this key, the value is inserted and `None` is returned.\
307    /// If the state did have an entry for this key, its value is updated and the old
308    /// value is returned in its raw JSON form.
309    ///
310    /// The provided value must be a JSON value, which will be stored as-is, without any
311    /// further manipulation. If you'd prefer to let `pavex_session` handle the serialization,
312    /// use [`insert`][Self::insert] instead.
313    pub async fn insert_raw<Key>(
314        &mut self,
315        key: Key,
316        value: Value,
317    ) -> Result<Option<Value>, LoadError>
318    where
319        Key: Into<Cow<'static, str>>,
320    {
321        let mut existing_state;
322        let key = key.into();
323
324        use ServerState::*;
325        match force_load_mut(self).await? {
326            MarkedForDeletion => {
327                tracing::debug!(session.key = %key, "Tried to insert a server-side value on a session marked for deletion.");
328                return Ok(None);
329            }
330            Unchanged { state, .. } | Changed { state } => {
331                existing_state = std::mem::take(state);
332            }
333            DoesNotExist => {
334                existing_state = HashMap::new();
335            }
336        };
337        let old_value = existing_state.insert(key, value);
338        self.server_state = new_cell_with(Some(ServerState::Changed {
339            state: existing_state,
340        }));
341        Ok(old_value)
342    }
343
344    /// Remove the value associated with `key` from the server-side state.
345    ///
346    /// If the key exists, the removed value is returned.\
347    /// The value is returned as it was stored in the server-side state, without any deserialization.
348    /// If you want to deserialize the value as a specific type, use [`remove`][Self::remove] instead.
349    pub async fn remove_raw(&mut self, key: &str) -> Result<Option<Value>, LoadError> {
350        use ServerState::*;
351        match force_load_mut(self).await? {
352            MarkedForDeletion => {
353                tracing::debug!(session.key = %key, "Tried to delete a server-side value on a session marked for deletion.");
354                Ok(None)
355            }
356            DoesNotExist => Ok(None),
357            Unchanged { state, .. } | Changed { state } => Ok(state.remove(key)),
358        }
359    }
360
361    /// Delete the session record from the store.
362    ///
363    /// This doesn't destroy the whole session—you must invoke [`Session::invalidate`]
364    /// if that's your goal.
365    pub fn delete(&mut self) {
366        self.server_state = new_cell_with(Some(ServerState::MarkedForDeletion));
367    }
368
369    /// Remove all key-value pairs from the server-side state.
370    ///
371    /// This doesn't delete the session record from the store—you must invoke
372    /// [`Session::delete`][Self::delete] if you want to delete the record altogether.
373    ///
374    /// This doesn't invalidate the session—you must invoke [`Session::invalidate`]
375    /// if you want to delete the session altogether.
376    pub async fn clear(&mut self) -> Result<(), LoadError> {
377        use ServerState::*;
378        match force_load_mut(self).await? {
379            MarkedForDeletion | DoesNotExist => {}
380            Unchanged { state, .. } => {
381                if !state.is_empty() {
382                    self.server_state = new_cell_with(Some(ServerState::Changed {
383                        state: HashMap::new(),
384                    }));
385                }
386            }
387            Changed { state } => {
388                state.clear();
389            }
390        }
391        Ok(())
392    }
393
394    /// Generate a new session identifier and attach it to this session.
395    /// The session state is preserved on both the client-side and the server-side.
396    ///
397    /// This method is useful for security reasons, as it can help prevent
398    /// [session fixation attacks](https://owasp.org/www-community/attacks/Session_fixation).
399    pub fn cycle_id(&mut self) {
400        let old = match &self.id {
401            CurrentSessionId::Existing(id) => Some(*id),
402            CurrentSessionId::ToBeRenamed { old, .. } => Some(*old),
403            CurrentSessionId::NewlyGenerated(_) => None,
404        };
405
406        static MAX_N_ATTEMPTS: usize = 16;
407
408        let mut i = 0;
409        let new = loop {
410            if i >= MAX_N_ATTEMPTS {
411                panic!(
412                    "Failed to generate a new session ID that doesn't collide with the pre-existing one, \
413                    even though {MAX_N_ATTEMPTS} attempts were carried out. Something seems to be seriously wrong \
414                    with the underlying source of randomness."
415                )
416            }
417
418            let new = SessionId::random();
419            if Some(new) != old {
420                break new;
421            } else {
422                i += 1;
423            }
424        };
425
426        self.id = match old {
427            Some(old) => CurrentSessionId::ToBeRenamed { old, new },
428            None => CurrentSessionId::NewlyGenerated(new),
429        };
430    }
431
432    /// Invalidate the session.
433    ///
434    /// The server-side session state will be marked for deletion.
435    /// The client-side cookie will be removed from the client using a removal cookie.
436    ///
437    /// After calling this method, the session is considered invalid and should not be used anymore.
438    /// All further operations on the session will be no-ops.
439    pub fn invalidate(&mut self) {
440        self.invalidated.invalidate();
441        self.server_state = new_cell_with(Some(ServerState::MarkedForDeletion));
442    }
443
444    /// Check if the session has been invalidated.
445    ///
446    /// See [`Session::invalidate`] for more information.
447    pub fn is_invalidated(&self) -> bool {
448        self.invalidated.is_invalidated()
449    }
450}
451
452/// Control when the server-side state is synchronized with the store.
453impl Session<'_> {
454    /// Sync the in-memory representation of the server-side state
455    /// with the store.
456    ///
457    /// In most cases, you don't need to invoke this method manually: it is
458    /// done for you by [`finalize_session`][`super::finalize_session`],
459    /// the post-processing middleware that attaches the session cookie to
460    /// the response returned to the client.
461    pub async fn sync(&mut self) -> Result<(), SyncError> {
462        let state_config = &self.config.state;
463        let fresh_ttl = state_config.ttl;
464        let create_if_empty = {
465            let has_client_side = self.id.old_id().is_some()
466                || matches!(self.client_state, ClientState::Updated { .. });
467            has_client_side && state_config.server_state_creation == ServerStateCreation::NeverSkip
468        };
469        use ServerState::*;
470        match self.server_state.get() {
471            Some(DoesNotExist) => match self.id {
472                CurrentSessionId::NewlyGenerated(id) | CurrentSessionId::Existing(id) => {
473                    if create_if_empty {
474                        self.store
475                            .create(&id, SessionRecordRef::empty(fresh_ttl))
476                            .await?;
477                    }
478                }
479                CurrentSessionId::ToBeRenamed { .. } => {
480                    // Nothing to do.
481                }
482            },
483            None => {
484                match self.id {
485                    CurrentSessionId::Existing(_) => {
486                        // Nothing to do.
487                    }
488                    CurrentSessionId::ToBeRenamed { old, new } => {
489                        self.store.change_id(&old, &new).await?;
490                    }
491                    CurrentSessionId::NewlyGenerated(..) => {
492                        unreachable!(
493                            "A newly generated session cannot have a 'NotLoaded' server state. It must be set to 'DoesNotExist'."
494                        )
495                    }
496                };
497            }
498            Some(Unchanged {
499                state,
500                ttl: remaining_ttl,
501            }) => {
502                match self.id {
503                    CurrentSessionId::Existing(old) => {
504                        if state_config.extend_ttl == TtlExtensionTrigger::OnStateLoadsAndChanges {
505                            let extend = state_config
506                                .ttl_extension_threshold
507                                .map(|ratio| *remaining_ttl < fresh_ttl.mul_f32(ratio.inner()))
508                                .unwrap_or(true);
509                            if extend {
510                                self.store.update_ttl(&old, fresh_ttl).await?;
511                            }
512                        }
513                    }
514                    CurrentSessionId::ToBeRenamed { old, new } => {
515                        match self.store.change_id(&old, &new).await {
516                            Ok(_) => {}
517                            Err(ChangeIdError::UnknownId(_)) => {
518                                // The old state is no longer in the store—e.g. it may have
519                                // expired while we were processing. Rare, but possible.
520                                // We know what the new state needs to be though, so we
521                                // can handle this edge case gracefully.
522                                let record = SessionRecordRef {
523                                    state: Cow::Borrowed(state),
524                                    ttl: fresh_ttl,
525                                };
526                                self.store.create(&new, record).await?;
527                            }
528                            Err(e) => {
529                                return Err(e.into());
530                            }
531                        }
532                    }
533                    CurrentSessionId::NewlyGenerated(new) => {
534                        if create_if_empty {
535                            self.store
536                                .create(&new, SessionRecordRef::empty(fresh_ttl))
537                                .await?;
538                        }
539
540                        // Integrity check.
541                        assert!(
542                            state.is_empty(),
543                            "Server state is not empty on a new session, \
544                                    but the state is marked as 'unchanged'. This is a bug in `pavex_session`"
545                        );
546                    }
547                };
548            }
549            Some(MarkedForDeletion) => match self.id.old_id() {
550                Some(id) => {
551                    if let Err(e) = self.store.delete(&id).await {
552                        match e {
553                            // We're good as long as we made sure that no server-side
554                            // state is stored against this id, we're good.
555                            DeleteError::UnknownId(_) => {}
556                            _ => return Err(e.into()),
557                        }
558                    }
559                }
560                None => {
561                    tracing::trace!(
562                        "The server session state was marked for deletion, but there was no session to delete. This is a no-op."
563                    )
564                }
565            },
566            Some(Changed { state }) => {
567                let record = SessionRecordRef {
568                    state: Cow::Borrowed(state),
569                    ttl: fresh_ttl,
570                };
571                match self.id {
572                    CurrentSessionId::Existing(id) => {
573                        self.store.update(&id, record).await?;
574                    }
575                    CurrentSessionId::ToBeRenamed { old, new } => {
576                        if let Err(e) = self.store.delete(&old).await {
577                            match e {
578                                DeleteError::UnknownId(_) => {
579                                    // The record may have expired between this
580                                    // delete operation and the first (successful)
581                                    // load we performed at the beginning of this
582                                    // request processing task.
583                                    // Since we already have the value in memory,
584                                    // this is not an issue.
585                                }
586                                _ => {
587                                    return Err(e.into());
588                                }
589                            }
590                        }
591                        self.store.create(&new, record).await?;
592                    }
593                    CurrentSessionId::NewlyGenerated(id) => {
594                        self.store.create(&id, record).await?;
595                    }
596                }
597            }
598        };
599
600        self.server_state = {
601            let old_state = self.server_state.take();
602            let new_state = old_state.map(|state| match state {
603                Changed { state } => Unchanged {
604                    state,
605                    ttl: fresh_ttl,
606                },
607                Unchanged { state, ttl } => Unchanged { state, ttl },
608                MarkedForDeletion => {
609                    if self.is_invalidated() {
610                        MarkedForDeletion
611                    } else {
612                        DoesNotExist
613                    }
614                }
615                DoesNotExist => {
616                    if create_if_empty {
617                        Unchanged {
618                            state: HashMap::new(),
619                            ttl: fresh_ttl,
620                        }
621                    } else {
622                        DoesNotExist
623                    }
624                }
625            });
626            new_cell_with(new_state)
627        };
628        Ok(())
629    }
630
631    /// Load the server-side state from the store.
632    /// This method does nothing if the server-side state has already been loaded.
633    ///
634    /// After calling this method, the server-side state will be loaded
635    /// and cached in memory, so that subsequent calls to [`get_raw`](#method.get_raw),
636    /// [`insert_raw`](#method.insert_raw), and [`remove_raw`](#method.remove_raw)
637    /// will operate on the in-memory state.
638    pub async fn force_load(&self) -> Result<(), LoadError> {
639        force_load(self).await
640    }
641
642    /// Sync the current server-side state with the chosen storage backend.
643    /// If necessary, it returns a cookie to be attached to the outgoing response
644    /// in order to sync the client-side state.
645    #[must_use = "The cookie returned by `finalize` must be attached to the outgoing HTTP response. \
646        Failing to do so will push the session into an invalid state."]
647    pub async fn finalize(&mut self) -> Result<Option<ResponseCookie<'static>>, FinalizeError> {
648        self.sync().await?;
649
650        let cookie_config = &self.config.cookie;
651        let cookie_name = &cookie_config.name;
652
653        if self.invalidated.is_invalidated() {
654            if self.id.old_id().is_none() {
655                // This is a new session, so there's nothing on the client-side
656                // to be removed.
657                return Ok(None);
658            }
659            let mut cookie = RemovalCookie::new(cookie_name.clone());
660            if let Some(domain) = cookie_config.domain.as_deref() {
661                cookie = cookie.set_domain(domain.to_owned());
662            }
663            if let Some(path) = cookie_config.path.as_deref() {
664                cookie = cookie.set_path(path.to_owned());
665            }
666            Ok(Some(cookie.into()))
667        } else {
668            match &self.client_state {
669                ClientState::Updated {
670                    state: client_state,
671                }
672                | ClientState::Unchanged {
673                    state: client_state,
674                } => {
675                    let server_record_exists = match &self.server_state.get() {
676                        None => None,
677                        Some(ServerState::Unchanged { .. }) => Some(true),
678                        Some(ServerState::DoesNotExist) => Some(false),
679                        Some(ServerState::MarkedForDeletion)
680                        | Some(ServerState::Changed { .. }) => {
681                            unreachable!("The server state has just been synchronized.")
682                        }
683                    };
684                    // The session is new, we don't have a server-side record, and the client state is empty.
685                    // We don't need to create a session cookie in this case.
686                    if client_state.is_empty()
687                        && self.id.old_id().is_none()
688                        && !server_record_exists.unwrap_or(true)
689                    {
690                        return Ok(None);
691                    }
692                    let value = WireClientState {
693                        session_id: self.id.new_id(),
694                        user_values: Cow::Borrowed(client_state),
695                    };
696                    let value = serde_json::to_string(&value)?;
697                    let mut cookie = ResponseCookie::new(cookie_name.clone(), value);
698                    if let Some(domain) = cookie_config.domain.as_deref() {
699                        cookie = cookie.set_domain(domain.to_owned());
700                    }
701                    if let Some(path) = cookie_config.path.as_deref() {
702                        cookie = cookie.set_path(path.to_owned());
703                    }
704                    if let Some(same_site) = cookie_config.same_site {
705                        cookie = cookie.set_same_site(same_site);
706                    }
707                    if cookie_config.secure {
708                        cookie = cookie.set_secure(true);
709                    }
710                    if cookie_config.http_only {
711                        cookie = cookie.set_http_only(true);
712                    }
713                    if cookie_config.kind == SessionCookieKind::Persistent {
714                        let max_age: SignedDuration = self
715                            .config
716                            .state
717                            .ttl
718                            .try_into()
719                            .unwrap_or(SignedDuration::MAX);
720                        cookie = cookie.set_max_age(max_age);
721                    }
722                    Ok(Some(cookie))
723                }
724            }
725        }
726    }
727}
728
729/// APIs for manipulating the client-side session state.
730impl Session<'_> {
731    /// Read values from the client-side state attached to this session.
732    pub fn client(&self) -> ClientSessionState<'_> {
733        ClientSessionState(&self.client_state, &self.invalidated)
734    }
735
736    /// Read or mutate the client-side state attached to this session.
737    pub fn client_mut(&mut self) -> ClientSessionStateMut<'_> {
738        ClientSessionStateMut(&mut self.client_state, &self.invalidated)
739    }
740}
741
742/// A read-only reference to the client-side state of a session.
743pub struct ClientSessionState<'session>(&'session ClientState, &'session InvalidationFlag);
744
745impl<'session> ClientSessionState<'session> {
746    /// Get the value associated with `key` from the client-side state.
747    ///
748    /// If the value is not found, `None` is returned.
749    /// If the value is found, but it cannot be deserialized into the expected type, an error is returned.
750    pub fn get<T: DeserializeOwned>(
751        &self,
752        key: &str,
753    ) -> Result<Option<T>, ValueDeserializationError> {
754        client_get(self.0, self.1, key)
755    }
756
757    /// Get the raw JSON value associated with `key` from the client-side state.
758    pub fn get_raw(&self, key: &str) -> Option<&'session Value> {
759        client_get_raw(self.0, self.1, key)
760    }
761
762    /// Returns true if there are no values in the client-side state.
763    pub fn is_empty(&self) -> bool {
764        client_is_empty(self.0, self.1)
765    }
766}
767
768/// A mutable reference to the client-side state of a session.
769pub struct ClientSessionStateMut<'session>(&'session mut ClientState, &'session InvalidationFlag);
770
771impl ClientSessionStateMut<'_> {
772    /// Get the value associated with `key` from the client-side state.
773    ///
774    /// If the value is not found, `None` is returned.
775    /// If the value is found, but it cannot be deserialized into the expected type, an error is returned.
776    pub fn get<T: DeserializeOwned>(
777        &self,
778        key: &str,
779    ) -> Result<Option<T>, ValueDeserializationError> {
780        client_get(self.0, self.1, key)
781    }
782
783    /// Get the raw JSON value associated with `key` from the client-side state.
784    pub fn get_raw<'a>(&'a self, key: &str) -> Option<&'a Value> {
785        client_get_raw(&*self.0, self.1, key)
786    }
787
788    /// Returns true if there are no values in the client-side state.
789    pub fn is_empty(&self) -> bool {
790        client_is_empty(self.0, self.1)
791    }
792
793    /// Insert a value in the client-side state for the given key.
794    ///
795    /// If the key already exists, the value is updated and the old raw value is returned.
796    /// If the value cannot be serialized, an error is returned.
797    pub fn insert<T, Key>(
798        &mut self,
799        key: Key,
800        value: T,
801    ) -> Result<Option<Value>, ValueSerializationError>
802    where
803        T: Serialize,
804        Key: Into<Cow<'static, str>>,
805    {
806        let key = key.into();
807        let value = match serde_json::to_value(value) {
808            Ok(t) => t,
809            Err(e) => {
810                return Err(ValueSerializationError {
811                    key,
812                    location: ValueLocation::Client,
813                    source: e,
814                });
815            }
816        };
817        Ok(self.insert_raw(key, value))
818    }
819
820    /// Insert a value in the client-side state for the given key.
821    ///
822    /// If the key already exists, the value is updated and the old value is returned.
823    pub fn insert_raw<Key>(&mut self, key: Key, value: Value) -> Option<Value>
824    where
825        Key: Into<Cow<'static, str>>,
826    {
827        if self.1.is_invalidated() {
828            tracing::trace!(
829                "Attempted to insert a client-side value on a session that's been invalidated."
830            );
831            return None;
832        }
833        let key = key.into();
834        match &mut self.0 {
835            ClientState::Updated { state } => state.insert(key, value),
836            ClientState::Unchanged { state } => {
837                let value = state.insert(key, value);
838                *self.0 = ClientState::Updated {
839                    state: std::mem::take(state),
840                };
841                value
842            }
843        }
844    }
845
846    /// Remove the value associated with `key` from the client-side state.
847    ///
848    /// If the key exists, the removed value is returned.
849    /// If the removed value cannot be serialized, an error is returned.
850    pub fn remove<T: DeserializeOwned>(
851        &mut self,
852        key: &str,
853    ) -> Result<Option<T>, ValueDeserializationError> {
854        self.remove_raw(key)
855            .map(|value| serde_json::from_value(value))
856            .transpose()
857            .map_err(|source| ValueDeserializationError {
858                key: key.to_string().into(),
859                location: ValueLocation::Client,
860                source,
861            })
862    }
863
864    /// Remove the value associated with `key` from the client-side state.
865    ///
866    /// If the key exists, the removed value is returned.
867    pub fn remove_raw(&mut self, key: &str) -> Option<Value> {
868        if self.1.is_invalidated() {
869            return None;
870        }
871        match &mut self.0 {
872            ClientState::Updated { state } => state.remove(key),
873            ClientState::Unchanged { state } => {
874                let value = state.remove(key)?;
875                *self.0 = ClientState::Updated {
876                    state: std::mem::take(state),
877                };
878                Some(value)
879            }
880        }
881    }
882
883    /// Remove all key-value pairs from the client-side state.
884    ///
885    /// This doesn't invalidate the session—you must invoke [`Session::invalidate`]
886    /// if you want to delete the session altogether.
887    pub fn clear(&mut self) {
888        if self.1.is_invalidated() {
889            return;
890        }
891        match &mut self.0 {
892            ClientState::Updated { state } => state.clear(),
893            ClientState::Unchanged { state } => {
894                if !state.is_empty() {
895                    *self.0 = ClientState::Updated {
896                        state: HashMap::new(),
897                    };
898                }
899            }
900        }
901    }
902}
903
904/// Get the value associated with `key` from the client-side state.
905///
906/// If the value is not found, `None` is returned.
907/// If the value is found, but it cannot be deserialized into the expected type, an error is returned.
908fn client_get<T: DeserializeOwned>(
909    state: &ClientState,
910    flag: &InvalidationFlag,
911    key: &str,
912) -> Result<Option<T>, ValueDeserializationError> {
913    client_get_raw(state, flag, key)
914        .map(|value| serde_json::from_value(value.clone()))
915        .transpose()
916        .map_err(|source| ValueDeserializationError {
917            location: ValueLocation::Client,
918            key: key.to_string().into(),
919            source,
920        })
921}
922
923/// Get the raw JSON value associated with `key` from the client-side state.
924fn client_get_raw<'session>(
925    state: &'session ClientState,
926    flag: &'session InvalidationFlag,
927    key: &str,
928) -> Option<&'session Value> {
929    if flag.is_invalidated() {
930        tracing::trace!(
931            "Attempted to get a client-side value on a session that's been invalidated."
932        );
933        return None;
934    }
935    match state {
936        ClientState::Unchanged { state } | ClientState::Updated { state } => state.get(key),
937    }
938}
939
940fn client_is_empty(state: &ClientState, flag: &InvalidationFlag) -> bool {
941    if flag.is_invalidated() {
942        return true;
943    }
944    match state {
945        ClientState::Updated { state } | ClientState::Unchanged { state } => state.is_empty(),
946    }
947}
948
949/// Little helper to create a new `OnceCell` with a value, if provided.
950fn new_cell_with<T>(value: Option<T>) -> OnceCell<T> {
951    match value {
952        Some(t) => OnceCell::from(t),
953        None => OnceCell::new(),
954    }
955}
956
957/// Load the server-side state from the store, then return a mutable reference to it.
958async fn force_load_mut<'a>(
959    session: &'a mut Session<'_>,
960) -> Result<&'a mut ServerState, LoadError> {
961    force_load(session).await?;
962    let Some(state) = session.server_state.get_mut() else {
963        unreachable!("Server-side state should have been loaded by now!")
964    };
965    Ok(state)
966}
967
968/// Load the server-side state from the store, then return an immutable reference to it.
969async fn force_load_ref<'a>(session: &'a Session<'_>) -> Result<&'a ServerState, LoadError> {
970    force_load(session).await?;
971    let Some(state) = session.server_state.get() else {
972        unreachable!("Server-side state should have been loaded by now!")
973    };
974    Ok(state)
975}
976
977/// Load the server-side state from the store.
978/// This method does nothing if the server-side state has already been loaded.
979///
980/// After calling this method, the server-side state will be loaded
981/// and cached in memory, so that subsequent calls to [`get_raw`](#method.get_raw),
982/// [`insert_raw`](#method.insert_raw), and [`remove_raw`](#method.remove_raw)
983/// will operate on the in-memory state.
984async fn force_load(session: &Session<'_>) -> Result<(), LoadError> {
985    // All other cases either imply that we've already loaded the
986    // server state or that we don't need to (e.g. delete).
987    let Some(session_id) = session.id.old_id() else {
988        return Ok(());
989    };
990    if session.server_state.get().is_some() {
991        return Ok(());
992    }
993    let record = session.store.load(&session_id).await?;
994    let mut must_invalidate = false;
995    let server_state = match record {
996        Some(r) => ServerState::Unchanged {
997            state: r.state,
998            ttl: r.ttl,
999        },
1000        None => {
1001            match session.config.state.missing_server_state {
1002                MissingServerState::Allow => ServerState::DoesNotExist,
1003                MissingServerState::Reject => {
1004                    // This can happen in some edge cases—e.g. the state expired between
1005                    // the time the server received the request and the time it tried to load
1006                    // the state.
1007                    must_invalidate = true;
1008                    ServerState::MarkedForDeletion
1009                }
1010            }
1011        }
1012    };
1013    if session.server_state.set(server_state).is_err() {
1014        tracing::warn!(
1015            "There were multiple concurrent attempts to load the server-side state for the same session.
1016            The state loaded by this one will be discarded."
1017        );
1018    } else {
1019        // We invalidate the session here, rather than doing above, because we want to make
1020        // sure we succeeded in setting the state.
1021        // If someone else beat us to it, we want to let them make a decision
1022        // based on the state they loaded.
1023        // Race conditions all the way down.
1024        if must_invalidate {
1025            tracing::warn!(
1026                "There is no server-side state for the current session, \
1027                even though one was expected. Invalidating the current session."
1028            );
1029            session.invalidated.invalidate();
1030        }
1031    }
1032    Ok(())
1033}
1034
1035/// Errors that can occur when interacting with the session state.
1036pub mod errors {
1037    use std::borrow::Cow;
1038
1039    use pavex::{Response, methods};
1040
1041    use crate::store::errors::{
1042        ChangeIdError, CreateError, DeleteError, LoadError, UpdateError, UpdateTtlError,
1043    };
1044
1045    #[derive(Debug, thiserror::Error)]
1046    #[non_exhaustive]
1047    /// The error returned by [`Session::sync`][super::Session::sync].
1048    pub enum SyncError {
1049        #[error("Failed to create a new session record")]
1050        CreateError(#[from] CreateError),
1051        #[error("Failed to update a session record")]
1052        UpdateError(#[from] UpdateError),
1053        #[error("Failed to delete a session record")]
1054        DeleteError(#[from] DeleteError),
1055        #[error("Failed to update the TTL for a session record")]
1056        UpdateTtlError(#[from] UpdateTtlError),
1057        #[error("Failed to change the session id for a session record")]
1058        ChangeIdError(#[from] ChangeIdError),
1059    }
1060
1061    #[derive(Debug, thiserror::Error)]
1062    #[non_exhaustive]
1063    /// The error returned by [`Session::get`][super::Session::get].
1064    pub enum ServerGetError {
1065        #[error("Failed to load the session record")]
1066        LoadError(#[from] LoadError),
1067        #[error(transparent)]
1068        DeserializationError(#[from] ValueDeserializationError),
1069    }
1070
1071    #[derive(Debug, thiserror::Error)]
1072    #[non_exhaustive]
1073    /// The error returned by [`Session::remove`][super::Session::remove].
1074    pub enum ServerRemoveError {
1075        #[error("Failed to load the session record")]
1076        LoadError(#[from] LoadError),
1077        #[error(transparent)]
1078        DeserializationError(#[from] ValueDeserializationError),
1079    }
1080
1081    #[derive(Debug, thiserror::Error)]
1082    #[non_exhaustive]
1083    /// The error returned by [`Session::insert`][super::Session::insert].
1084    pub enum ServerInsertError {
1085        #[error("Failed to load the session record")]
1086        LoadError(#[from] LoadError),
1087        #[error(transparent)]
1088        SerializationError(#[from] ValueSerializationError),
1089    }
1090
1091    #[derive(Debug, thiserror::Error)]
1092    #[non_exhaustive]
1093    #[error(
1094        "Failed to deserialize the value associated with `{key}` in the {location}-side session state"
1095    )]
1096    /// Returned when we fail to deserialize a value stored in either the server or the client
1097    /// session state.
1098    pub struct ValueDeserializationError {
1099        /// The key of the value that we failed to deserialize.
1100        pub key: Cow<'static, str>,
1101        pub(crate) location: ValueLocation,
1102        #[source]
1103        /// The underlying deserialization error.
1104        pub(crate) source: serde_json::Error,
1105    }
1106
1107    #[derive(Debug, thiserror::Error)]
1108    #[non_exhaustive]
1109    #[error(
1110        "Failed to serialize the value that would have been associated with `{key}` in the {location}-side session state"
1111    )]
1112    /// Returned when we fail to serialize a value to be stored in either the server or the client
1113    /// session state.
1114    pub struct ValueSerializationError {
1115        /// The key of the value that we failed to serialize.
1116        pub key: Cow<'static, str>,
1117        pub(crate) location: ValueLocation,
1118        #[source]
1119        /// The underlying serialization error.
1120        pub(crate) source: serde_json::Error,
1121    }
1122
1123    /// Where the value was stored.
1124    #[derive(Debug)]
1125    pub(crate) enum ValueLocation {
1126        Server,
1127        Client,
1128    }
1129
1130    impl std::fmt::Display for ValueLocation {
1131        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1132            let s = match self {
1133                ValueLocation::Server => "server",
1134                ValueLocation::Client => "client",
1135            };
1136            write!(f, "{s}")
1137        }
1138    }
1139
1140    /// The error returned by [`finalize_session`][crate::finalize_session].
1141    #[derive(Debug, thiserror::Error)]
1142    #[non_exhaustive]
1143    pub enum FinalizeError {
1144        #[error("Failed to serialize the client-side session state")]
1145        SerializationError(#[from] serde_json::Error),
1146        #[error("Failed to sync the server-side session state")]
1147        SyncErr(#[from] SyncError),
1148        #[error(
1149            "The client-side session state is not empty, but the session cookie (`{cookie_name}`) is not configured to be encrypted. \
1150            This may be a security risk, as the client-side session state may be intercepted and read by an attacker. \
1151            Configure the cookie processor to encrypt the session cookie; check out \
1152            https://docs.rs/biscotti/latest/biscotti/struct.ProcessorConfig.html#structfield.crypto_rules \
1153            for more information."
1154        )]
1155        EncryptionRequired { cookie_name: String },
1156        #[error(
1157            "The session cookie (`{cookie_name}`) is not configured to be signed nor encrypted. \
1158            This is a security risk, as the client-side session state may be intercepted and manipulated by an attacker. \
1159            Configure the cookie processor to sign or encrypt the session cookie; check out \
1160            https://docs.rs/biscotti/latest/biscotti/struct.ProcessorConfig.html#structfield.crypto_rules \
1161            for more information."
1162        )]
1163        CryptoRequired { cookie_name: String },
1164    }
1165
1166    #[methods]
1167    impl FinalizeError {
1168        /// Convert the error into a response.
1169        #[error_handler]
1170        pub fn into_response(&self) -> Response {
1171            Response::internal_server_error()
1172        }
1173    }
1174}
1175
1176#[cfg(test)]
1177mod tests {
1178    use super::Session;
1179
1180    // Check that `Session` is not `Send` nor `Sync`.
1181    static_assertions::assert_not_impl_any!(Session: Send, Sync);
1182}