1use std::future::{Future, IntoFuture, poll_fn};
2use std::io::Error;
3use std::marker::PhantomData;
4use std::net::SocketAddr;
5use std::pin::Pin;
6use std::task::Poll;
7use std::thread;
8
9use tokio::net::TcpStream;
10use tokio::sync::mpsc::error::TrySendError;
11use tokio::task::{JoinError, JoinSet, LocalSet};
12use tracing_log_error::log_error;
13
14use crate::connection::ConnectionInfo;
15use crate::server::configuration::ServerConfiguration;
16use crate::server::worker::{ConnectionMessage, Worker, WorkerHandle};
17
18use super::{IncomingStream, ShutdownMode};
19
20#[derive(Clone)]
46pub struct ServerHandle {
47 command_outbox: tokio::sync::mpsc::Sender<ServerCommand>,
48}
49
50impl ServerHandle {
51 pub(super) fn new<HandlerFuture, ApplicationState>(
52 config: ServerConfiguration,
53 incoming: Vec<IncomingStream>,
54 handler: fn(
55 http::Request<hyper::body::Incoming>,
56 Option<ConnectionInfo>,
57 ApplicationState,
58 ) -> HandlerFuture,
59 application_state: ApplicationState,
60 ) -> Self
61 where
62 HandlerFuture: Future<Output = crate::Response> + 'static,
63 ApplicationState: Clone + Send + Sync + 'static,
64 {
65 let (command_outbox, command_inbox) = tokio::sync::mpsc::channel(32);
66 let acceptor = Acceptor::new(config, incoming, handler, application_state, command_inbox);
67 let _ = acceptor.spawn();
68 Self { command_outbox }
69 }
70
71 #[doc(alias("stop"))]
73 pub async fn shutdown(self, mode: ShutdownMode) {
74 let (completion_notifier, completion) = tokio::sync::oneshot::channel();
75 if self
76 .command_outbox
77 .send(ServerCommand::Shutdown {
78 completion_notifier,
79 mode,
80 })
81 .await
82 .is_ok()
83 {
84 let _ = completion.await;
88 }
89 }
90}
91
92impl IntoFuture for ServerHandle {
93 type Output = ();
94 type IntoFuture = Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>>;
95
96 fn into_future(self) -> Self::IntoFuture {
97 Box::pin(async move { self.command_outbox.closed().await })
98 }
99}
100
101enum ServerCommand {
102 Shutdown {
103 completion_notifier: tokio::sync::oneshot::Sender<()>,
104 mode: ShutdownMode,
105 },
106}
107
108#[must_use]
109struct Acceptor<HandlerFuture, ApplicationState> {
110 command_inbox: tokio::sync::mpsc::Receiver<ServerCommand>,
111 incoming: Vec<IncomingStream>,
112 worker_handles: Vec<WorkerHandle>,
113 #[allow(dead_code)]
114 config: ServerConfiguration,
115 next_worker: usize,
116 max_queue_length: usize,
117 handler: fn(
118 http::Request<hyper::body::Incoming>,
119 Option<ConnectionInfo>,
120 ApplicationState,
121 ) -> HandlerFuture,
122 application_state: ApplicationState,
123 handler_output_future: PhantomData<fn() -> HandlerFuture>,
127}
128
129enum AcceptorInboxMessage {
130 ServerCommand(ServerCommand),
131 Connection(Option<Result<(IncomingStream, TcpStream, SocketAddr), JoinError>>),
132}
133
134impl<HandlerFuture, ApplicationState> Acceptor<HandlerFuture, ApplicationState>
135where
136 HandlerFuture: Future<Output = crate::Response> + 'static,
137 ApplicationState: Clone + Send + Sync + 'static,
138{
139 fn new(
140 config: ServerConfiguration,
141 incoming: Vec<IncomingStream>,
142 handler: fn(
143 http::Request<hyper::body::Incoming>,
144 Option<ConnectionInfo>,
145 ApplicationState,
146 ) -> HandlerFuture,
147 application_state: ApplicationState,
148 command_inbox: tokio::sync::mpsc::Receiver<ServerCommand>,
149 ) -> Self {
150 let max_queue_length = 15;
152 let n_workers = config.n_workers.get();
153 let mut worker_handles = Vec::with_capacity(n_workers);
154 for i in 0..n_workers {
155 let (worker, handle) =
156 Worker::new(i, max_queue_length, handler, application_state.clone());
157 worker_handles.push(handle);
158 worker.spawn().expect("Failed to spawn worker thread");
160 }
161 Self {
162 command_inbox,
163 incoming,
164 worker_handles,
165 config,
166 max_queue_length,
167 handler,
168 handler_output_future: Default::default(),
169 next_worker: 0,
170 application_state,
171 }
172 }
173
174 async fn run(self) {
178 async fn accept_connection(
181 incoming: IncomingStream,
182 ) -> (IncomingStream, TcpStream, SocketAddr) {
183 #[allow(deprecated)]
184 fn is_rt_shutdown_err(err: &Error) -> bool {
186 const RT_SHUTDOWN_ERR: &str =
187 "A Tokio 1.x context was found, but it is being shutdown.";
188 if err.kind() != std::io::ErrorKind::Other {
189 return false;
190 }
191 let Some(inner) = err.get_ref() else {
192 return false;
193 };
194 inner.source().is_none() && inner.description() == RT_SHUTDOWN_ERR
197 }
198
199 loop {
200 match incoming.accept().await {
201 Ok((connection, remote_peer)) => return (incoming, connection, remote_peer),
202 Err(e) => {
203 if is_rt_shutdown_err(&e) {
204 log_error!(e, level: tracing::Level::DEBUG, "Failed to accept connection");
205 } else {
206 log_error!(e, level: tracing::Level::INFO, "Failed to accept connection");
207 }
208 continue;
209 }
210 }
211 }
212 }
213
214 let Self {
215 mut command_inbox,
216 mut next_worker,
217 mut worker_handles,
218 incoming,
219 config: _,
220 max_queue_length,
221 handler,
222 application_state,
223 handler_output_future: _,
224 } = self;
225
226 let n_workers = worker_handles.len();
227
228 let mut incoming_join_set = JoinSet::new();
229 for incoming in incoming.into_iter() {
230 incoming_join_set.spawn(accept_connection(incoming));
231 }
232
233 let error = 'event_loop: loop {
234 let message =
236 poll_fn(|cx| Self::poll_inboxes(cx, &mut command_inbox, &mut incoming_join_set))
237 .await;
238 match message {
239 AcceptorInboxMessage::ServerCommand(command) => match command {
240 ServerCommand::Shutdown {
241 completion_notifier,
242 mode,
243 } => {
244 Self::shutdown(
245 completion_notifier,
246 mode,
247 incoming_join_set,
248 worker_handles,
249 )
250 .await;
251 return;
252 }
253 },
254 AcceptorInboxMessage::Connection(msg) => {
255 let (incoming, connection, remote_peer) = match msg {
256 Some(Ok((incoming, connection, remote_peer))) => {
257 (incoming, connection, remote_peer)
258 }
259 Some(Err(e)) => {
260 break 'event_loop e;
265 }
266 None => {
267 unreachable!(
272 "The JoinSet for incoming connections cannot ever be empty"
273 )
274 }
275 };
276 incoming_join_set.spawn(accept_connection(incoming));
278
279 let mut has_been_handled = false;
281 let mut connection_message = ConnectionMessage {
285 connection,
286 peer_addr: remote_peer,
287 };
288 for _ in 0..n_workers {
289 let mut has_crashed: Option<usize> = None;
291 let worker_handle = &worker_handles[next_worker];
292 match worker_handle.dispatch(connection_message) {
293 Err(e) => {
294 connection_message = match e {
295 TrySendError::Full(message) => message,
296 TrySendError::Closed(conn) => {
299 has_crashed = Some(worker_handle.id());
300 conn
301 }
302 };
303 next_worker = (next_worker + 1) % n_workers;
304 }
305 _ => {
306 has_been_handled = true;
309 break;
310 }
311 }
312
313 if let Some(worker_id) = has_crashed {
315 tracing::warn!(worker_id = worker_id, "Worker crashed, restarting it");
316 let (worker, worker_handle) = Worker::new(
317 worker_id,
318 max_queue_length,
319 handler,
320 application_state.clone(),
321 );
322 worker.spawn().expect("Failed to spawn worker thread");
324 worker_handles[worker_id] = worker_handle;
325 }
326 }
327
328 if !has_been_handled {
329 tracing::error!(
330 remote_peer = %remote_peer,
331 "All workers are busy, dropping connection",
332 );
333 }
334 }
335 }
336 };
337
338 log_error!(
339 error,
340 "Failed to accept new connections. The acceptor thread will exit now."
341 );
342 }
343
344 fn poll_inboxes(
346 cx: &mut std::task::Context<'_>,
347 server_command_inbox: &mut tokio::sync::mpsc::Receiver<ServerCommand>,
348 incoming_join_set: &mut JoinSet<(IncomingStream, TcpStream, SocketAddr)>,
349 ) -> Poll<AcceptorInboxMessage> {
350 if let Poll::Ready(Some(message)) = server_command_inbox.poll_recv(cx) {
352 return Poll::Ready(AcceptorInboxMessage::ServerCommand(message));
353 }
354 if let Poll::Ready(message) = incoming_join_set.poll_join_next(cx) {
355 return Poll::Ready(AcceptorInboxMessage::Connection(message));
356 }
357 Poll::Pending
358 }
359
360 fn spawn(self) -> thread::JoinHandle<()> {
361 thread::Builder::new()
362 .name("pavex-acceptor".to_string())
363 .spawn(move || {
364 let rt = tokio::runtime::Builder::new_current_thread()
365 .enable_all()
366 .build()
367 .expect("Failed to build single-threaded Tokio runtime for acceptor thread");
368 LocalSet::new().block_on(&rt, self.run());
369 })
370 .expect("Failed to spawn acceptor thread")
371 }
372
373 async fn shutdown(
374 completion_notifier: tokio::sync::oneshot::Sender<()>,
375 mode: ShutdownMode,
376 incoming_join_set: JoinSet<(IncomingStream, TcpStream, SocketAddr)>,
377 worker_handles: Vec<WorkerHandle>,
378 ) {
379 drop(incoming_join_set);
384
385 let mut shutdown_join_set = JoinSet::new();
386 for worker_handle in worker_handles {
387 let mode2 = mode.clone();
388 let future = worker_handle.shutdown(mode2);
391 if mode.is_graceful() {
392 shutdown_join_set.spawn_local(future);
393 }
394 }
395
396 if let ShutdownMode::Graceful { timeout } = mode {
397 let _ = tokio::time::timeout(timeout, async move {
400 while shutdown_join_set.join_next().await.is_some() {}
401 })
402 .await;
403 }
404
405 let _ = completion_notifier.send(());
407 }
408}