use std::future::{poll_fn, Future};
use std::net::SocketAddr;
use std::task::Poll;
use std::thread;
use anyhow::Context;
use hyper_util::rt::TokioIo;
use hyper_util::server::graceful::GracefulShutdown;
use tokio::net::TcpStream;
use tokio::sync::mpsc::error::TrySendError;
use tracing_log_error::log_error;
use crate::connection::ConnectionInfo;
use crate::server::ShutdownMode;
pub(super) struct ConnectionMessage {
pub(super) connection: TcpStream,
pub(super) peer_addr: SocketAddr,
}
pub(super) struct WorkerHandle {
connection_outbox: tokio::sync::mpsc::Sender<ConnectionMessage>,
shutdown_outbox: tokio::sync::mpsc::UnboundedSender<ShutdownWorkerCommand>,
id: usize,
}
impl WorkerHandle {
pub(super) fn dispatch(
&self,
connection: ConnectionMessage,
) -> Result<(), TrySendError<ConnectionMessage>> {
self.connection_outbox.try_send(connection)
}
pub(super) fn id(&self) -> usize {
self.id
}
pub(super) fn shutdown(self, mode: ShutdownMode) -> impl Future<Output = ()> {
let (completion_notifier, completion) = tokio::sync::oneshot::channel();
let sent = self
.shutdown_outbox
.send(ShutdownWorkerCommand {
completion_notifier,
mode,
})
.is_ok();
async move {
if sent {
let _ = completion.await;
}
}
}
}
pub(super) struct ShutdownWorkerCommand {
completion_notifier: tokio::sync::oneshot::Sender<()>,
mode: ShutdownMode,
}
#[must_use]
pub(super) struct Worker<HandlerFuture, ApplicationState> {
connection_inbox: tokio::sync::mpsc::Receiver<ConnectionMessage>,
shutdown_inbox: tokio::sync::mpsc::UnboundedReceiver<ShutdownWorkerCommand>,
handler: fn(
http::Request<hyper::body::Incoming>,
Option<ConnectionInfo>,
ApplicationState,
) -> HandlerFuture,
application_state: ApplicationState,
id: usize,
shutdown_coordinator: GracefulShutdown,
}
impl<HandlerFuture, ApplicationState> Worker<HandlerFuture, ApplicationState>
where
HandlerFuture: Future<Output = crate::response::Response> + 'static,
ApplicationState: Clone + Send + Sync + 'static,
{
pub(super) fn new(
id: usize,
max_queue_length: usize,
handler: fn(
http::Request<hyper::body::Incoming>,
Option<ConnectionInfo>,
ApplicationState,
) -> HandlerFuture,
application_state: ApplicationState,
) -> (Self, WorkerHandle) {
let (connection_outbox, connection_inbox) = tokio::sync::mpsc::channel(max_queue_length);
let (shutdown_outbox, shutdown_inbox) = tokio::sync::mpsc::unbounded_channel();
let self_ = Self {
connection_inbox,
shutdown_inbox,
handler,
application_state,
id,
shutdown_coordinator: GracefulShutdown::new(),
};
let handle = WorkerHandle {
connection_outbox,
shutdown_outbox,
id,
};
(self_, handle)
}
pub(super) fn spawn(self) -> Result<thread::JoinHandle<()>, anyhow::Error> {
let id = self.id;
let name = || format!("pavex-worker-{}", id);
thread::Builder::new()
.name(name())
.spawn(move || {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to build single-threaded Tokio runtime for worker thread");
let local = tokio::task::LocalSet::new();
local.block_on(&runtime, self.run());
})
.with_context(|| format!("Failed to spawn worker thread `{}`", name()))
}
async fn run(self) {
let Self {
mut connection_inbox,
mut shutdown_inbox,
handler,
application_state,
id,
shutdown_coordinator,
} = self;
'event_loop: loop {
let message =
poll_fn(|cx| Self::poll_inboxes(cx, &mut shutdown_inbox, &mut connection_inbox))
.await;
match message {
WorkerInboxMessage::Connection(connection) => {
Self::handle_connection(
connection,
handler,
application_state.clone(),
&shutdown_coordinator,
);
}
WorkerInboxMessage::Shutdown(shutdown) => {
let ShutdownWorkerCommand {
completion_notifier,
mode,
} = shutdown;
match mode {
ShutdownMode::Graceful { timeout } => {
connection_inbox.close();
while let Some(connection) = connection_inbox.recv().await {
Self::handle_connection(
connection,
handler,
application_state.clone(),
&shutdown_coordinator,
);
}
let _ = tokio::time::timeout(timeout, shutdown_coordinator.shutdown())
.await;
}
ShutdownMode::Forced => {}
}
let _ = completion_notifier.send(());
break 'event_loop;
}
}
}
tracing::info!(worker_id = id, "Worker shut down");
}
fn handle_connection(
connection_message: ConnectionMessage,
handler: fn(
http::Request<hyper::body::Incoming>,
Option<ConnectionInfo>,
ApplicationState,
) -> HandlerFuture,
application_state: ApplicationState,
shutdown_coordinator: &GracefulShutdown,
) {
let ConnectionMessage {
connection,
peer_addr,
} = connection_message;
let handler = hyper::service::service_fn(move |request| {
let state = application_state.clone();
async move {
let handler = (handler)(request, Some(ConnectionInfo { peer_addr }), state);
let response = handler.await;
let response = hyper::Response::from(response);
Ok::<_, hyper::Error>(response)
}
});
let builder = hyper_util::server::conn::auto::Builder::new(LocalExec);
let connection = TokioIo::new(connection);
let connection_future =
shutdown_coordinator.watch(builder.serve_connection(connection, handler).into_owned());
tokio::task::spawn_local(async move {
if let Err(e) = connection_future.await {
log_error!(*e, level: tracing::Level::WARN, "Failed to serve an incoming connection");
}
});
}
fn poll_inboxes(
cx: &mut std::task::Context<'_>,
shutdown_inbox: &mut tokio::sync::mpsc::UnboundedReceiver<ShutdownWorkerCommand>,
connection_inbox: &mut tokio::sync::mpsc::Receiver<ConnectionMessage>,
) -> Poll<WorkerInboxMessage> {
if let Poll::Ready(Some(message)) = shutdown_inbox.poll_recv(cx) {
return Poll::Ready(message.into());
}
if let Poll::Ready(Some(message)) = connection_inbox.poll_recv(cx) {
return Poll::Ready(message.into());
}
Poll::Pending
}
}
enum WorkerInboxMessage {
Connection(ConnectionMessage),
Shutdown(ShutdownWorkerCommand),
}
impl From<ConnectionMessage> for WorkerInboxMessage {
fn from(connection: ConnectionMessage) -> Self {
Self::Connection(connection)
}
}
impl From<ShutdownWorkerCommand> for WorkerInboxMessage {
fn from(command: ShutdownWorkerCommand) -> Self {
Self::Shutdown(command)
}
}
#[derive(Clone, Copy, Debug)]
struct LocalExec;
impl<F> hyper::rt::Executor<F> for LocalExec
where
F: Future + 'static, {
fn execute(&self, fut: F) {
tokio::task::spawn_local(fut);
}
}