Optimized shutdown_gracefully() (#2041)
This commit is contained in:
parent
3268444793
commit
7c78e740ef
11
Cargo.lock
generated
11
Cargo.lock
generated
@ -463,15 +463,6 @@ version = "0.3.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f2696e8a945f658fd14dc3b87242e6b80cd0f36ff04ea560fa39082368847946"
|
||||
|
||||
[[package]]
|
||||
name = "drain"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2f1a0abf3fcefad9b4dd0e414207a7408e12b68414a01e6bb19b897d5bd7632d"
|
||||
dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "encode_unicode"
|
||||
version = "0.3.6"
|
||||
@ -2031,7 +2022,6 @@ dependencies = [
|
||||
"basic-toml",
|
||||
"bytes",
|
||||
"cfg-if",
|
||||
"drain",
|
||||
"enum-as-inner",
|
||||
"futures-executor",
|
||||
"futures-util",
|
||||
@ -2046,6 +2036,7 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-openssl",
|
||||
"tokio-rustls",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"trust-dns-proto",
|
||||
|
@ -54,6 +54,7 @@ tokio = "1.21"
|
||||
tokio-native-tls = "0.3.0"
|
||||
tokio-openssl = "0.6.0"
|
||||
tokio-rustls = "0.24.0"
|
||||
tokio-util = "0.7.9"
|
||||
parking_lot = "0.12"
|
||||
|
||||
|
||||
@ -81,7 +82,6 @@ cfg-if = "1"
|
||||
clap = { version = "4.0", default-features = false }
|
||||
console = "0.15.0"
|
||||
data-encoding = "2.2.0"
|
||||
drain = "0.1.1"
|
||||
enum-as-inner = "0.6"
|
||||
idna = "0.4.0"
|
||||
ipconfig = "0.3.0"
|
||||
|
@ -76,7 +76,6 @@ async-trait.workspace = true
|
||||
basic-toml = { workspace = true, optional = true }
|
||||
bytes.workspace = true
|
||||
cfg-if.workspace = true
|
||||
drain.workspace = true
|
||||
enum-as-inner.workspace = true
|
||||
futures-util = { workspace = true, default-features = false, features = ["std"] }
|
||||
h2 = { workspace = true, features = ["stream"], optional = true }
|
||||
@ -91,6 +90,7 @@ tracing.workspace = true
|
||||
tokio = { workspace = true, features = ["macros", "net", "sync"] }
|
||||
tokio-openssl = { workspace = true, optional = true }
|
||||
tokio-rustls = { workspace = true, optional = true }
|
||||
tokio-util.workspace = true
|
||||
trust-dns-proto = { workspace = true, features = ["text-parsing", "tokio-runtime"] }
|
||||
trust-dns-recursor = { workspace = true, features = ["serde-config"], optional = true }
|
||||
trust-dns-resolver = { workspace = true, features = ["serde-config", "system-config", "tokio-runtime"], optional = true }
|
||||
|
@ -8,10 +8,10 @@
|
||||
use std::{io, net::SocketAddr, sync::Arc};
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use drain::Watch;
|
||||
use futures_util::lock::Mutex;
|
||||
use h2::server;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, warn};
|
||||
use trust_dns_proto::rr::Record;
|
||||
|
||||
@ -29,7 +29,7 @@ pub(crate) async fn h2_handler<T, I>(
|
||||
io: I,
|
||||
src_addr: SocketAddr,
|
||||
dns_hostname: Option<Arc<str>>,
|
||||
shutdown: Watch,
|
||||
shutdown: CancellationToken,
|
||||
) where
|
||||
T: RequestHandler,
|
||||
I: AsyncRead + AsyncWrite + Unpin,
|
||||
@ -59,7 +59,7 @@ pub(crate) async fn h2_handler<T, I>(
|
||||
return;
|
||||
}
|
||||
},
|
||||
_ = shutdown.clone().signaled() => {
|
||||
_ = shutdown.cancelled() => {
|
||||
// A graceful shutdown was initiated.
|
||||
return
|
||||
},
|
||||
|
@ -8,8 +8,8 @@
|
||||
use std::{io, net::SocketAddr, sync::Arc};
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use drain::Watch;
|
||||
use futures_util::lock::Mutex;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, warn};
|
||||
use trust_dns_proto::{
|
||||
error::ProtoError,
|
||||
@ -31,7 +31,7 @@ pub(crate) async fn quic_handler<T>(
|
||||
mut quic_streams: QuicStreams,
|
||||
src_addr: SocketAddr,
|
||||
_dns_hostname: Option<Arc<str>>,
|
||||
shutdown: Watch,
|
||||
shutdown: CancellationToken,
|
||||
) -> Result<(), ProtoError>
|
||||
where
|
||||
T: RequestHandler,
|
||||
@ -52,7 +52,7 @@ where
|
||||
break;
|
||||
}
|
||||
},
|
||||
_ = shutdown.clone().signaled() => {
|
||||
_ = shutdown.cancelled() => {
|
||||
// A graceful shutdown was initiated.
|
||||
break;
|
||||
},
|
||||
|
@ -4,7 +4,6 @@
|
||||
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
|
||||
// http://opensource.org/licenses/MIT>, at your option. This file may not be
|
||||
// copied, modified, or distributed except according to those terms.
|
||||
use std::future::Future;
|
||||
use std::{
|
||||
io,
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
|
||||
@ -12,11 +11,11 @@ use std::{
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use drain::{Signal, Watch};
|
||||
use futures_util::{FutureExt, StreamExt};
|
||||
#[cfg(feature = "dns-over-rustls")]
|
||||
use rustls::{Certificate, PrivateKey, ServerConfig};
|
||||
use tokio::{net, task::JoinSet};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, info, warn};
|
||||
use trust_dns_proto::{op::MessageType, rr::Record};
|
||||
|
||||
@ -42,19 +41,16 @@ use crate::{
|
||||
pub struct ServerFuture<T: RequestHandler> {
|
||||
handler: Arc<T>,
|
||||
join_set: JoinSet<Result<(), ProtoError>>,
|
||||
shutdown_signal: ShutdownSignal,
|
||||
shutdown_watch: Watch,
|
||||
shutdown_token: CancellationToken,
|
||||
}
|
||||
|
||||
impl<T: RequestHandler> ServerFuture<T> {
|
||||
/// Creates a new ServerFuture with the specified Handler.
|
||||
pub fn new(handler: T) -> Self {
|
||||
let (signal, watch) = drain::channel();
|
||||
Self {
|
||||
handler: Arc::new(handler),
|
||||
join_set: JoinSet::new(),
|
||||
shutdown_signal: ShutdownSignal::new(signal),
|
||||
shutdown_watch: watch,
|
||||
shutdown_token: CancellationToken::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -64,17 +60,24 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
|
||||
// create the new UdpStream, the IP address isn't relevant, and ideally goes essentially no where.
|
||||
// the address used is acquired from the inbound queries
|
||||
let (stream, stream_handle) =
|
||||
let (mut stream, stream_handle) =
|
||||
UdpStream::with_bound(socket, ([127, 255, 255, 254], 0).into());
|
||||
let shutdown = self.shutdown_watch.clone();
|
||||
let mut stream = stream.take_until(Box::pin(shutdown.signaled()));
|
||||
let shutdown = self.shutdown_token.clone();
|
||||
let handler = self.handler.clone();
|
||||
|
||||
// this spawns a ForEach future which handles all the requests into a Handler.
|
||||
self.join_set.spawn({
|
||||
async move {
|
||||
let mut inner_join_set = JoinSet::new();
|
||||
while let Some(message) = stream.next().await {
|
||||
loop {
|
||||
let message = tokio::select! {
|
||||
message = stream.next() => match message {
|
||||
None => break,
|
||||
Some(message) => message,
|
||||
},
|
||||
_ = shutdown.cancelled() => break,
|
||||
};
|
||||
|
||||
let message = match message {
|
||||
Err(e) => {
|
||||
warn!("error receiving message on udp_socket: {}", e);
|
||||
@ -106,7 +109,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
reap_tasks(&mut inner_join_set);
|
||||
}
|
||||
|
||||
if stream.is_stopped() {
|
||||
if shutdown.is_cancelled() {
|
||||
Ok(())
|
||||
} else {
|
||||
// TODO: let's consider capturing all the initial configuration details so that the socket could be recreated...
|
||||
@ -140,7 +143,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
let handler = self.handler.clone();
|
||||
|
||||
// for each incoming request...
|
||||
let shutdown = self.shutdown_watch.clone();
|
||||
let shutdown = self.shutdown_token.clone();
|
||||
self.join_set.spawn(async move {
|
||||
let mut inner_join_set = JoinSet::new();
|
||||
loop {
|
||||
@ -152,7 +155,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
continue;
|
||||
},
|
||||
},
|
||||
_ = shutdown.clone().signaled() => {
|
||||
_ = shutdown.cancelled() => {
|
||||
// A graceful shutdown was initiated. Break out of the loop.
|
||||
break;
|
||||
},
|
||||
@ -418,7 +421,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
let tls_acceptor = TlsAcceptor::from(tls_config);
|
||||
|
||||
// for each incoming request...
|
||||
let shutdown = self.shutdown_watch.clone();
|
||||
let shutdown = self.shutdown_token.clone();
|
||||
self.join_set.spawn(async move {
|
||||
let mut inner_join_set = JoinSet::new();
|
||||
loop {
|
||||
@ -430,7 +433,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
continue;
|
||||
},
|
||||
},
|
||||
_ = shutdown.clone().signaled() => {
|
||||
_ = shutdown.cancelled() => {
|
||||
// A graceful shutdown was initiated. Break out of the loop.
|
||||
break;
|
||||
},
|
||||
@ -610,7 +613,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
|
||||
// for each incoming request...
|
||||
let dns_hostname = dns_hostname;
|
||||
let shutdown = self.shutdown_watch.clone();
|
||||
let shutdown = self.shutdown_token.clone();
|
||||
self.join_set.spawn(async move {
|
||||
let mut inner_join_set = JoinSet::new();
|
||||
let dns_hostname = dns_hostname;
|
||||
@ -624,7 +627,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
continue;
|
||||
},
|
||||
},
|
||||
_ = shutdown.clone().signaled() => {
|
||||
_ = shutdown.cancelled() => {
|
||||
// A graceful shutdown was initiated. Break out of the loop.
|
||||
break;
|
||||
},
|
||||
@ -711,7 +714,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
|
||||
// for each incoming request...
|
||||
let dns_hostname = dns_hostname;
|
||||
let shutdown = self.shutdown_watch.clone();
|
||||
let shutdown = self.shutdown_token.clone();
|
||||
self.join_set.spawn(async move {
|
||||
let mut inner_join_set = JoinSet::new();
|
||||
let dns_hostname = dns_hostname;
|
||||
@ -726,7 +729,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
continue;
|
||||
}
|
||||
},
|
||||
_ = shutdown.clone().signaled() => {
|
||||
_ = shutdown.cancelled() => {
|
||||
// A graceful shutdown was initiated. Break out of the loop.
|
||||
break;
|
||||
},
|
||||
@ -768,58 +771,25 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns a signal used for initiating a graceful shutdown of the server and the future
|
||||
/// used for awaiting completion of the shutdown.
|
||||
///
|
||||
/// This allows the application to have separate code paths that are responsible for
|
||||
/// triggering shutdown and awaiting application completion.
|
||||
pub fn graceful(self) -> (ShutdownSignal, impl Future<Output = Result<(), ProtoError>>) {
|
||||
let signal = self.shutdown_signal;
|
||||
let join_set = self.join_set;
|
||||
(signal, block_until_done(join_set))
|
||||
}
|
||||
|
||||
/// Triggers a graceful shutdown the server. All background tasks will stop accepting
|
||||
/// new connections and the returned future will complete once all tasks have terminated.
|
||||
///
|
||||
/// This is equivalent to calling [Self::graceful], then triggering the graceful
|
||||
/// shutdown (via [ShutdownSignal::shutdown]) and awaiting completion of the server.
|
||||
pub async fn shutdown_gracefully(self) -> Result<(), ProtoError> {
|
||||
let (signal, fut) = self.graceful();
|
||||
|
||||
// Trigger shutdown.
|
||||
signal.shutdown().await;
|
||||
pub async fn shutdown_gracefully(&mut self) -> Result<(), ProtoError> {
|
||||
self.shutdown_token.cancel();
|
||||
|
||||
// Wait for the server to complete.
|
||||
fut.await
|
||||
block_until_done(&mut self.join_set).await
|
||||
}
|
||||
|
||||
/// This will run until all background tasks complete. If one or more tasks return an error,
|
||||
/// one will be chosen as the returned error for this future.
|
||||
pub async fn block_until_done(self) -> Result<(), ProtoError> {
|
||||
block_until_done(self.join_set).await
|
||||
pub async fn block_until_done(&mut self) -> Result<(), ProtoError> {
|
||||
block_until_done(&mut self.join_set).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Signals the start of a graceful shutdown.
|
||||
#[derive(Debug)]
|
||||
pub struct ShutdownSignal {
|
||||
signal: Signal,
|
||||
}
|
||||
|
||||
impl ShutdownSignal {
|
||||
fn new(signal: Signal) -> Self {
|
||||
Self { signal }
|
||||
}
|
||||
|
||||
/// Asynchronously sends the shutdown command to all server threads and
|
||||
/// waits for them to complete.
|
||||
pub async fn shutdown(self) {
|
||||
self.signal.drain().await
|
||||
}
|
||||
}
|
||||
|
||||
async fn block_until_done(mut join_set: JoinSet<Result<(), ProtoError>>) -> Result<(), ProtoError> {
|
||||
async fn block_until_done(
|
||||
join_set: &mut JoinSet<Result<(), ProtoError>>,
|
||||
) -> Result<(), ProtoError> {
|
||||
if join_set.is_empty() {
|
||||
warn!("block_until_done called with no pending tasks");
|
||||
return Ok(());
|
||||
|
Loading…
Reference in New Issue
Block a user