Optimized shutdown_gracefully() (#2041)

This commit is contained in:
小明 2023-10-05 16:00:11 +08:00 committed by GitHub
parent 3268444793
commit 7c78e740ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 40 additions and 79 deletions

11
Cargo.lock generated
View File

@ -463,15 +463,6 @@ version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2696e8a945f658fd14dc3b87242e6b80cd0f36ff04ea560fa39082368847946"
[[package]]
name = "drain"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f1a0abf3fcefad9b4dd0e414207a7408e12b68414a01e6bb19b897d5bd7632d"
dependencies = [
"tokio",
]
[[package]]
name = "encode_unicode"
version = "0.3.6"
@ -2031,7 +2022,6 @@ dependencies = [
"basic-toml",
"bytes",
"cfg-if",
"drain",
"enum-as-inner",
"futures-executor",
"futures-util",
@ -2046,6 +2036,7 @@ dependencies = [
"tokio",
"tokio-openssl",
"tokio-rustls",
"tokio-util",
"tracing",
"tracing-subscriber",
"trust-dns-proto",

View File

@ -54,6 +54,7 @@ tokio = "1.21"
tokio-native-tls = "0.3.0"
tokio-openssl = "0.6.0"
tokio-rustls = "0.24.0"
tokio-util = "0.7.9"
parking_lot = "0.12"
@ -81,7 +82,6 @@ cfg-if = "1"
clap = { version = "4.0", default-features = false }
console = "0.15.0"
data-encoding = "2.2.0"
drain = "0.1.1"
enum-as-inner = "0.6"
idna = "0.4.0"
ipconfig = "0.3.0"

View File

@ -76,7 +76,6 @@ async-trait.workspace = true
basic-toml = { workspace = true, optional = true }
bytes.workspace = true
cfg-if.workspace = true
drain.workspace = true
enum-as-inner.workspace = true
futures-util = { workspace = true, default-features = false, features = ["std"] }
h2 = { workspace = true, features = ["stream"], optional = true }
@ -91,6 +90,7 @@ tracing.workspace = true
tokio = { workspace = true, features = ["macros", "net", "sync"] }
tokio-openssl = { workspace = true, optional = true }
tokio-rustls = { workspace = true, optional = true }
tokio-util.workspace = true
trust-dns-proto = { workspace = true, features = ["text-parsing", "tokio-runtime"] }
trust-dns-recursor = { workspace = true, features = ["serde-config"], optional = true }
trust-dns-resolver = { workspace = true, features = ["serde-config", "system-config", "tokio-runtime"], optional = true }

View File

@ -8,10 +8,10 @@
use std::{io, net::SocketAddr, sync::Arc};
use bytes::{Bytes, BytesMut};
use drain::Watch;
use futures_util::lock::Mutex;
use h2::server;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken;
use tracing::{debug, warn};
use trust_dns_proto::rr::Record;
@ -29,7 +29,7 @@ pub(crate) async fn h2_handler<T, I>(
io: I,
src_addr: SocketAddr,
dns_hostname: Option<Arc<str>>,
shutdown: Watch,
shutdown: CancellationToken,
) where
T: RequestHandler,
I: AsyncRead + AsyncWrite + Unpin,
@ -59,7 +59,7 @@ pub(crate) async fn h2_handler<T, I>(
return;
}
},
_ = shutdown.clone().signaled() => {
_ = shutdown.cancelled() => {
// A graceful shutdown was initiated.
return
},

View File

@ -8,8 +8,8 @@
use std::{io, net::SocketAddr, sync::Arc};
use bytes::{Bytes, BytesMut};
use drain::Watch;
use futures_util::lock::Mutex;
use tokio_util::sync::CancellationToken;
use tracing::{debug, warn};
use trust_dns_proto::{
error::ProtoError,
@ -31,7 +31,7 @@ pub(crate) async fn quic_handler<T>(
mut quic_streams: QuicStreams,
src_addr: SocketAddr,
_dns_hostname: Option<Arc<str>>,
shutdown: Watch,
shutdown: CancellationToken,
) -> Result<(), ProtoError>
where
T: RequestHandler,
@ -52,7 +52,7 @@ where
break;
}
},
_ = shutdown.clone().signaled() => {
_ = shutdown.cancelled() => {
// A graceful shutdown was initiated.
break;
},

View File

@ -4,7 +4,6 @@
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.
use std::future::Future;
use std::{
io,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
@ -12,11 +11,11 @@ use std::{
time::Duration,
};
use drain::{Signal, Watch};
use futures_util::{FutureExt, StreamExt};
#[cfg(feature = "dns-over-rustls")]
use rustls::{Certificate, PrivateKey, ServerConfig};
use tokio::{net, task::JoinSet};
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
use trust_dns_proto::{op::MessageType, rr::Record};
@ -42,19 +41,16 @@ use crate::{
pub struct ServerFuture<T: RequestHandler> {
handler: Arc<T>,
join_set: JoinSet<Result<(), ProtoError>>,
shutdown_signal: ShutdownSignal,
shutdown_watch: Watch,
shutdown_token: CancellationToken,
}
impl<T: RequestHandler> ServerFuture<T> {
/// Creates a new ServerFuture with the specified Handler.
pub fn new(handler: T) -> Self {
let (signal, watch) = drain::channel();
Self {
handler: Arc::new(handler),
join_set: JoinSet::new(),
shutdown_signal: ShutdownSignal::new(signal),
shutdown_watch: watch,
shutdown_token: CancellationToken::new(),
}
}
@ -64,17 +60,24 @@ impl<T: RequestHandler> ServerFuture<T> {
// create the new UdpStream, the IP address isn't relevant, and ideally goes essentially no where.
// the address used is acquired from the inbound queries
let (stream, stream_handle) =
let (mut stream, stream_handle) =
UdpStream::with_bound(socket, ([127, 255, 255, 254], 0).into());
let shutdown = self.shutdown_watch.clone();
let mut stream = stream.take_until(Box::pin(shutdown.signaled()));
let shutdown = self.shutdown_token.clone();
let handler = self.handler.clone();
// this spawns a ForEach future which handles all the requests into a Handler.
self.join_set.spawn({
async move {
let mut inner_join_set = JoinSet::new();
while let Some(message) = stream.next().await {
loop {
let message = tokio::select! {
message = stream.next() => match message {
None => break,
Some(message) => message,
},
_ = shutdown.cancelled() => break,
};
let message = match message {
Err(e) => {
warn!("error receiving message on udp_socket: {}", e);
@ -106,7 +109,7 @@ impl<T: RequestHandler> ServerFuture<T> {
reap_tasks(&mut inner_join_set);
}
if stream.is_stopped() {
if shutdown.is_cancelled() {
Ok(())
} else {
// TODO: let's consider capturing all the initial configuration details so that the socket could be recreated...
@ -140,7 +143,7 @@ impl<T: RequestHandler> ServerFuture<T> {
let handler = self.handler.clone();
// for each incoming request...
let shutdown = self.shutdown_watch.clone();
let shutdown = self.shutdown_token.clone();
self.join_set.spawn(async move {
let mut inner_join_set = JoinSet::new();
loop {
@ -152,7 +155,7 @@ impl<T: RequestHandler> ServerFuture<T> {
continue;
},
},
_ = shutdown.clone().signaled() => {
_ = shutdown.cancelled() => {
// A graceful shutdown was initiated. Break out of the loop.
break;
},
@ -418,7 +421,7 @@ impl<T: RequestHandler> ServerFuture<T> {
let tls_acceptor = TlsAcceptor::from(tls_config);
// for each incoming request...
let shutdown = self.shutdown_watch.clone();
let shutdown = self.shutdown_token.clone();
self.join_set.spawn(async move {
let mut inner_join_set = JoinSet::new();
loop {
@ -430,7 +433,7 @@ impl<T: RequestHandler> ServerFuture<T> {
continue;
},
},
_ = shutdown.clone().signaled() => {
_ = shutdown.cancelled() => {
// A graceful shutdown was initiated. Break out of the loop.
break;
},
@ -610,7 +613,7 @@ impl<T: RequestHandler> ServerFuture<T> {
// for each incoming request...
let dns_hostname = dns_hostname;
let shutdown = self.shutdown_watch.clone();
let shutdown = self.shutdown_token.clone();
self.join_set.spawn(async move {
let mut inner_join_set = JoinSet::new();
let dns_hostname = dns_hostname;
@ -624,7 +627,7 @@ impl<T: RequestHandler> ServerFuture<T> {
continue;
},
},
_ = shutdown.clone().signaled() => {
_ = shutdown.cancelled() => {
// A graceful shutdown was initiated. Break out of the loop.
break;
},
@ -711,7 +714,7 @@ impl<T: RequestHandler> ServerFuture<T> {
// for each incoming request...
let dns_hostname = dns_hostname;
let shutdown = self.shutdown_watch.clone();
let shutdown = self.shutdown_token.clone();
self.join_set.spawn(async move {
let mut inner_join_set = JoinSet::new();
let dns_hostname = dns_hostname;
@ -726,7 +729,7 @@ impl<T: RequestHandler> ServerFuture<T> {
continue;
}
},
_ = shutdown.clone().signaled() => {
_ = shutdown.cancelled() => {
// A graceful shutdown was initiated. Break out of the loop.
break;
},
@ -768,58 +771,25 @@ impl<T: RequestHandler> ServerFuture<T> {
Ok(())
}
/// Returns a signal used for initiating a graceful shutdown of the server and the future
/// used for awaiting completion of the shutdown.
///
/// This allows the application to have separate code paths that are responsible for
/// triggering shutdown and awaiting application completion.
pub fn graceful(self) -> (ShutdownSignal, impl Future<Output = Result<(), ProtoError>>) {
let signal = self.shutdown_signal;
let join_set = self.join_set;
(signal, block_until_done(join_set))
}
/// Triggers a graceful shutdown the server. All background tasks will stop accepting
/// new connections and the returned future will complete once all tasks have terminated.
///
/// This is equivalent to calling [Self::graceful], then triggering the graceful
/// shutdown (via [ShutdownSignal::shutdown]) and awaiting completion of the server.
pub async fn shutdown_gracefully(self) -> Result<(), ProtoError> {
let (signal, fut) = self.graceful();
// Trigger shutdown.
signal.shutdown().await;
pub async fn shutdown_gracefully(&mut self) -> Result<(), ProtoError> {
self.shutdown_token.cancel();
// Wait for the server to complete.
fut.await
block_until_done(&mut self.join_set).await
}
/// This will run until all background tasks complete. If one or more tasks return an error,
/// one will be chosen as the returned error for this future.
pub async fn block_until_done(self) -> Result<(), ProtoError> {
block_until_done(self.join_set).await
pub async fn block_until_done(&mut self) -> Result<(), ProtoError> {
block_until_done(&mut self.join_set).await
}
}
/// Signals the start of a graceful shutdown.
#[derive(Debug)]
pub struct ShutdownSignal {
signal: Signal,
}
impl ShutdownSignal {
fn new(signal: Signal) -> Self {
Self { signal }
}
/// Asynchronously sends the shutdown command to all server threads and
/// waits for them to complete.
pub async fn shutdown(self) {
self.signal.drain().await
}
}
async fn block_until_done(mut join_set: JoinSet<Result<(), ProtoError>>) -> Result<(), ProtoError> {
async fn block_until_done(
join_set: &mut JoinSet<Result<(), ProtoError>>,
) -> Result<(), ProtoError> {
if join_set.is_empty() {
warn!("block_until_done called with no pending tasks");
return Ok(());