pavex/server/
server_handle.rs

1use std::future::{Future, IntoFuture, poll_fn};
2use std::io::Error;
3use std::marker::PhantomData;
4use std::net::SocketAddr;
5use std::pin::Pin;
6use std::task::Poll;
7use std::thread;
8
9use tokio::net::TcpStream;
10use tokio::sync::mpsc::error::TrySendError;
11use tokio::task::{JoinError, JoinSet, LocalSet};
12use tracing_log_error::log_error;
13
14use crate::connection::ConnectionInfo;
15use crate::server::configuration::ServerConfiguration;
16use crate::server::worker::{ConnectionMessage, Worker, WorkerHandle};
17
18use super::{IncomingStream, ShutdownMode};
19
20/// A handle to a running [`Server`](super::Server).
21///
22/// # Example: waiting for the server to shut down
23///
24/// You can just `.await` the [`ServerHandle`] to wait for the server to shut down:
25///
26/// ```rust
27/// use std::net::SocketAddr;
28/// use pavex::server::Server;
29///
30/// # #[derive(Clone)] struct ApplicationState;
31/// # async fn router(_req: hyper::Request<hyper::body::Incoming>, _conn_info: Option<pavex::connection::ConnectionInfo>, _state: ApplicationState) -> pavex::Response { todo!() }
32/// # async fn t() -> std::io::Result<()> {
33/// # let application_state = ApplicationState;
34/// let addr = SocketAddr::from(([127, 0, 0, 1], 8080));
35///
36/// let server_handle = Server::new()
37///     .bind(addr)
38///     .await?
39///     .serve(router, application_state);
40/// // Wait until the server shuts down.
41/// server_handle.await;
42/// # Ok(())
43/// # }
44/// ```
45#[derive(Clone)]
46pub struct ServerHandle {
47    command_outbox: tokio::sync::mpsc::Sender<ServerCommand>,
48}
49
50impl ServerHandle {
51    pub(super) fn new<HandlerFuture, ApplicationState>(
52        config: ServerConfiguration,
53        incoming: Vec<IncomingStream>,
54        handler: fn(
55            http::Request<hyper::body::Incoming>,
56            Option<ConnectionInfo>,
57            ApplicationState,
58        ) -> HandlerFuture,
59        application_state: ApplicationState,
60    ) -> Self
61    where
62        HandlerFuture: Future<Output = crate::Response> + 'static,
63        ApplicationState: Clone + Send + Sync + 'static,
64    {
65        let (command_outbox, command_inbox) = tokio::sync::mpsc::channel(32);
66        let acceptor = Acceptor::new(config, incoming, handler, application_state, command_inbox);
67        let _ = acceptor.spawn();
68        Self { command_outbox }
69    }
70
71    /// Instruct the [`Server`](super::Server) to stop accepting new connections.
72    #[doc(alias("stop"))]
73    pub async fn shutdown(self, mode: ShutdownMode) {
74        let (completion_notifier, completion) = tokio::sync::oneshot::channel();
75        if self
76            .command_outbox
77            .send(ServerCommand::Shutdown {
78                completion_notifier,
79                mode,
80            })
81            .await
82            .is_ok()
83        {
84            // What if sending fails?
85            // It only happens if the other end of the channel has already been dropped, which
86            // implies that the acceptor thread has already shut down—nothing to do!
87            let _ = completion.await;
88        }
89    }
90}
91
92impl IntoFuture for ServerHandle {
93    type Output = ();
94    type IntoFuture = Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>>;
95
96    fn into_future(self) -> Self::IntoFuture {
97        Box::pin(async move { self.command_outbox.closed().await })
98    }
99}
100
101enum ServerCommand {
102    Shutdown {
103        completion_notifier: tokio::sync::oneshot::Sender<()>,
104        mode: ShutdownMode,
105    },
106}
107
108#[must_use]
109struct Acceptor<HandlerFuture, ApplicationState> {
110    command_inbox: tokio::sync::mpsc::Receiver<ServerCommand>,
111    incoming: Vec<IncomingStream>,
112    worker_handles: Vec<WorkerHandle>,
113    #[allow(dead_code)]
114    config: ServerConfiguration,
115    next_worker: usize,
116    max_queue_length: usize,
117    handler: fn(
118        http::Request<hyper::body::Incoming>,
119        Option<ConnectionInfo>,
120        ApplicationState,
121    ) -> HandlerFuture,
122    application_state: ApplicationState,
123    // We use a `fn() -> HandlerFuture` instead of a `HandlerFuture` because we need `Acceptor`
124    // to be `Send` and `Sync`. That wouldn't work with `PhantomData<HandlerFuture>`.
125    // In the end, we just need to stash the generic type *somewhere*.
126    handler_output_future: PhantomData<fn() -> HandlerFuture>,
127}
128
129enum AcceptorInboxMessage {
130    ServerCommand(ServerCommand),
131    Connection(Option<Result<(IncomingStream, TcpStream, SocketAddr), JoinError>>),
132}
133
134impl<HandlerFuture, ApplicationState> Acceptor<HandlerFuture, ApplicationState>
135where
136    HandlerFuture: Future<Output = crate::Response> + 'static,
137    ApplicationState: Clone + Send + Sync + 'static,
138{
139    fn new(
140        config: ServerConfiguration,
141        incoming: Vec<IncomingStream>,
142        handler: fn(
143            http::Request<hyper::body::Incoming>,
144            Option<ConnectionInfo>,
145            ApplicationState,
146        ) -> HandlerFuture,
147        application_state: ApplicationState,
148        command_inbox: tokio::sync::mpsc::Receiver<ServerCommand>,
149    ) -> Self {
150        // TODO: make this configurable
151        let max_queue_length = 15;
152        let n_workers = config.n_workers.get();
153        let mut worker_handles = Vec::with_capacity(n_workers);
154        for i in 0..n_workers {
155            let (worker, handle) =
156                Worker::new(i, max_queue_length, handler, application_state.clone());
157            worker_handles.push(handle);
158            // TODO: should we panic here?
159            worker.spawn().expect("Failed to spawn worker thread");
160        }
161        Self {
162            command_inbox,
163            incoming,
164            worker_handles,
165            config,
166            max_queue_length,
167            handler,
168            handler_output_future: Default::default(),
169            next_worker: 0,
170            application_state,
171        }
172    }
173
174    /// Run the acceptor: accept incoming connections and dispatch them to workers.
175    ///
176    /// Constraint: this method **must not panic**.
177    async fn run(self) {
178        /// Accept a connection from the given [`IncomingStream`].
179        /// If accepting a certain connection fails, log the error and keep trying with the next connection.
180        async fn accept_connection(
181            incoming: IncomingStream,
182        ) -> (IncomingStream, TcpStream, SocketAddr) {
183            #[allow(deprecated)]
184            // This has been inlined from `tokio`'s codebase, since it's not public API.
185            fn is_rt_shutdown_err(err: &Error) -> bool {
186                const RT_SHUTDOWN_ERR: &str =
187                    "A Tokio 1.x context was found, but it is being shutdown.";
188                if err.kind() != std::io::ErrorKind::Other {
189                    return false;
190                }
191                let Some(inner) = err.get_ref() else {
192                    return false;
193                };
194                // Using `Error::description()` is more efficient than `format!("{inner}")`,
195                // so we use it here even if it is deprecated.
196                inner.source().is_none() && inner.description() == RT_SHUTDOWN_ERR
197            }
198
199            loop {
200                match incoming.accept().await {
201                    Ok((connection, remote_peer)) => return (incoming, connection, remote_peer),
202                    Err(e) => {
203                        if is_rt_shutdown_err(&e) {
204                            log_error!(e, level: tracing::Level::DEBUG, "Failed to accept connection");
205                        } else {
206                            log_error!(e, level: tracing::Level::INFO, "Failed to accept connection");
207                        }
208                        continue;
209                    }
210                }
211            }
212        }
213
214        let Self {
215            mut command_inbox,
216            mut next_worker,
217            mut worker_handles,
218            incoming,
219            config: _,
220            max_queue_length,
221            handler,
222            application_state,
223            handler_output_future: _,
224        } = self;
225
226        let n_workers = worker_handles.len();
227
228        let mut incoming_join_set = JoinSet::new();
229        for incoming in incoming.into_iter() {
230            incoming_join_set.spawn(accept_connection(incoming));
231        }
232
233        let error = 'event_loop: loop {
234            // Check if there is work to be done.
235            let message =
236                poll_fn(|cx| Self::poll_inboxes(cx, &mut command_inbox, &mut incoming_join_set))
237                    .await;
238            match message {
239                AcceptorInboxMessage::ServerCommand(command) => match command {
240                    ServerCommand::Shutdown {
241                        completion_notifier,
242                        mode,
243                    } => {
244                        Self::shutdown(
245                            completion_notifier,
246                            mode,
247                            incoming_join_set,
248                            worker_handles,
249                        )
250                        .await;
251                        return;
252                    }
253                },
254                AcceptorInboxMessage::Connection(msg) => {
255                    let (incoming, connection, remote_peer) = match msg {
256                        Some(Ok((incoming, connection, remote_peer))) => {
257                            (incoming, connection, remote_peer)
258                        }
259                        Some(Err(e)) => {
260                            // This only ever happens if we panicked in the task that was accepting
261                            // connections or if we somehow cancel it.
262                            // Neither of these should ever happen, but we handle the error just in case
263                            // to make sure we log the error info if we end up introducing a fatal bug.
264                            break 'event_loop e;
265                        }
266                        None => {
267                            // When we succeed in accepting a connection, we always spawn a new task to
268                            // accept the next connection from the same socket.
269                            // If we fail to accept a connection, we exit the acceptor thread.
270                            // Therefore, the JoinSet should never be empty.
271                            unreachable!(
272                                "The JoinSet for incoming connections cannot ever be empty"
273                            )
274                        }
275                    };
276                    // Re-spawn the task to keep accepting connections from the same socket.
277                    incoming_join_set.spawn(accept_connection(incoming));
278
279                    // A flag to track if the connection has been successfully sent to a worker.
280                    let mut has_been_handled = false;
281                    // We try to send the connection to a worker (`ConnectionMessage`).
282                    // If the worker's inbox is full, we try the next worker until we find one that can
283                    // accept the connection or we've tried all workers.
284                    let mut connection_message = ConnectionMessage {
285                        connection,
286                        peer_addr: remote_peer,
287                    };
288                    for _ in 0..n_workers {
289                        // Track if the worker has crashed.
290                        let mut has_crashed: Option<usize> = None;
291                        let worker_handle = &worker_handles[next_worker];
292                        match worker_handle.dispatch(connection_message) {
293                            Err(e) => {
294                                connection_message = match e {
295                                    TrySendError::Full(message) => message,
296                                    // A closed channel implies that the worker thread is no longer running,
297                                    // therefore we need to restart it.
298                                    TrySendError::Closed(conn) => {
299                                        has_crashed = Some(worker_handle.id());
300                                        conn
301                                    }
302                                };
303                                next_worker = (next_worker + 1) % n_workers;
304                            }
305                            _ => {
306                                // We've successfully sent the connection to a worker, so we can stop trying
307                                // to send it to other workers.
308                                has_been_handled = true;
309                                break;
310                            }
311                        }
312
313                        // Restart the crashed worker thread.
314                        if let Some(worker_id) = has_crashed {
315                            tracing::warn!(worker_id = worker_id, "Worker crashed, restarting it");
316                            let (worker, worker_handle) = Worker::new(
317                                worker_id,
318                                max_queue_length,
319                                handler,
320                                application_state.clone(),
321                            );
322                            // TODO: what if we fail to spawn the worker thread? We don't want to panic here!
323                            worker.spawn().expect("Failed to spawn worker thread");
324                            worker_handles[worker_id] = worker_handle;
325                        }
326                    }
327
328                    if !has_been_handled {
329                        tracing::error!(
330                            remote_peer = %remote_peer,
331                            "All workers are busy, dropping connection",
332                        );
333                    }
334                }
335            }
336        };
337
338        log_error!(
339            error,
340            "Failed to accept new connections. The acceptor thread will exit now."
341        );
342    }
343
344    /// Check if there is work to be done.
345    fn poll_inboxes(
346        cx: &mut std::task::Context<'_>,
347        server_command_inbox: &mut tokio::sync::mpsc::Receiver<ServerCommand>,
348        incoming_join_set: &mut JoinSet<(IncomingStream, TcpStream, SocketAddr)>,
349    ) -> Poll<AcceptorInboxMessage> {
350        // Order matters here: we want to prioritize shutdown messages over incoming connections.
351        if let Poll::Ready(Some(message)) = server_command_inbox.poll_recv(cx) {
352            return Poll::Ready(AcceptorInboxMessage::ServerCommand(message));
353        }
354        if let Poll::Ready(message) = incoming_join_set.poll_join_next(cx) {
355            return Poll::Ready(AcceptorInboxMessage::Connection(message));
356        }
357        Poll::Pending
358    }
359
360    fn spawn(self) -> thread::JoinHandle<()> {
361        thread::Builder::new()
362            .name("pavex-acceptor".to_string())
363            .spawn(move || {
364                let rt = tokio::runtime::Builder::new_current_thread()
365                    .enable_all()
366                    .build()
367                    .expect("Failed to build single-threaded Tokio runtime for acceptor thread");
368                LocalSet::new().block_on(&rt, self.run());
369            })
370            .expect("Failed to spawn acceptor thread")
371    }
372
373    async fn shutdown(
374        completion_notifier: tokio::sync::oneshot::Sender<()>,
375        mode: ShutdownMode,
376        incoming_join_set: JoinSet<(IncomingStream, TcpStream, SocketAddr)>,
377        worker_handles: Vec<WorkerHandle>,
378    ) {
379        // This drops the `JoinSet`, which will cause all the tasks that are still running to
380        // be cancelled.
381        // It will in turn cause the `Incoming` to be dropped, which will cause the `TcpListener`
382        // to be dropped, thus closing the socket and stopping acceptance of new connections.
383        drop(incoming_join_set);
384
385        let mut shutdown_join_set = JoinSet::new();
386        for worker_handle in worker_handles {
387            let mode2 = mode.clone();
388            // The shutdown command is enqueued immediately, before the future is polled for the
389            // first time.
390            let future = worker_handle.shutdown(mode2);
391            if mode.is_graceful() {
392                shutdown_join_set.spawn_local(future);
393            }
394        }
395
396        if let ShutdownMode::Graceful { timeout } = mode {
397            // Wait for all workers to shut down, or for the timeout to expire,
398            // whichever happens first.
399            let _ = tokio::time::timeout(timeout, async move {
400                while shutdown_join_set.join_next().await.is_some() {}
401            })
402            .await;
403        }
404
405        // Notify the caller that the server has shut down.
406        let _ = completion_notifier.send(());
407    }
408}