Adding graceful shutdown to server.

Fixes #1976.
This commit is contained in:
Nathan Mittler
2023-06-25 07:39:33 -07:00
committed by Benjamin Fry
parent 9e64710324
commit 578ce5a497
8 changed files with 601 additions and 305 deletions

10
Cargo.lock generated
View File

@@ -375,6 +375,15 @@ version = "2.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23d8666cb01533c39dde32bcbab8e227b4ed6679b2c925eba05feabea39508fb"
[[package]]
name = "drain"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f1a0abf3fcefad9b4dd0e414207a7408e12b68414a01e6bb19b897d5bd7632d"
dependencies = [
"tokio",
]
[[package]]
name = "encode_unicode"
version = "0.3.6"
@@ -1963,6 +1972,7 @@ dependencies = [
"async-trait",
"bytes",
"cfg-if",
"drain",
"enum-as-inner",
"futures-executor",
"futures-util",

View File

@@ -80,6 +80,7 @@ cfg-if = "1"
clap = { version = "4.0", default-features = false }
console = "0.15.0"
data-encoding = "2.2.0"
drain = "0.1.1"
enum-as-inner = "0.6"
idna = "0.4.0"
ipconfig = "0.3.0"

View File

@@ -74,6 +74,7 @@ path = "src/lib.rs"
async-trait.workspace = true
bytes.workspace = true
cfg-if.workspace = true
drain.workspace = true
enum-as-inner.workspace = true
futures-executor = { workspace = true, default-features = false, features = ["std"] }
futures-util = { workspace = true, default-features = false, features = ["std"] }
@@ -86,7 +87,7 @@ serde = { workspace = true, features = ["derive"] }
thiserror.workspace = true
time.workspace = true
tracing.workspace = true
tokio = { workspace = true, features = ["net", "sync"] }
tokio = { workspace = true, features = ["macros", "net", "sync"] }
tokio-openssl = { workspace = true, optional = true }
tokio-rustls = { workspace = true, optional = true }
toml.workspace = true

View File

@@ -8,6 +8,7 @@
use std::{io, net::SocketAddr, sync::Arc};
use bytes::{Bytes, BytesMut};
use drain::Watch;
use futures_util::lock::Mutex;
use h2::server;
use tokio::io::{AsyncRead, AsyncWrite};
@@ -28,6 +29,7 @@ pub(crate) async fn h2_handler<T, I>(
io: I,
src_addr: SocketAddr,
dns_hostname: Option<Arc<str>>,
shutdown: Watch,
) where
T: RequestHandler,
I: AsyncRead + AsyncWrite + Unpin,
@@ -45,13 +47,22 @@ pub(crate) async fn h2_handler<T, I>(
// Accept all inbound HTTP/2.0 streams sent over the
// connection.
while let Some(next_request) = h2.accept().await {
let (request, respond) = match next_request {
Ok(next_request) => next_request,
Err(err) => {
warn!("error accepting request {}: {}", src_addr, err);
return;
}
loop {
let (request, respond) = tokio::select! {
result = h2.accept() => match result {
Some(Ok(next_request)) => next_request,
Some(Err(err)) => {
warn!("error accepting request {}: {}", src_addr, err);
return;
}
None => {
return;
}
},
_ = shutdown.clone().signaled() => {
// A graceful shutdown was initiated.
return
},
};
debug!("Received request: {:#?}", request);
@@ -80,7 +91,7 @@ async fn handle_request<T>(
}
#[derive(Clone)]
struct HttpsResponseHandle(Arc<Mutex<::h2::server::SendResponse<Bytes>>>);
struct HttpsResponseHandle(Arc<Mutex<server::SendResponse<Bytes>>>);
#[async_trait::async_trait]
impl ResponseHandler for HttpsResponseHandle {

View File

@@ -8,6 +8,7 @@
use std::{io, net::SocketAddr, sync::Arc};
use bytes::{Bytes, BytesMut};
use drain::Watch;
use futures_util::lock::Mutex;
use tracing::{debug, warn};
use trust_dns_proto::{
@@ -30,6 +31,7 @@ pub(crate) async fn quic_handler<T>(
mut quic_streams: QuicStreams,
src_addr: SocketAddr,
_dns_hostname: Option<Arc<str>>,
shutdown: Watch,
) -> Result<(), ProtoError>
where
T: RequestHandler,
@@ -38,13 +40,22 @@ where
let mut max_requests = 100u32;
// Accept all inbound quic streams sent over the connection.
while let Some(next_request) = quic_streams.next().await {
let mut request_stream = match next_request {
Ok(next_request) => next_request,
Err(err) => {
warn!("error accepting request {}: {}", src_addr, err);
return Err(err);
}
loop {
let mut request_stream = tokio::select! {
result = quic_streams.next() => match result {
Some(Ok(next_request)) => next_request,
Some(Err(err)) => {
warn!("error accepting request {}: {}", src_addr, err);
return Err(err);
}
None => {
break;
}
},
_ = shutdown.clone().signaled() => {
// A graceful shutdown was initiated.
break;
},
};
let request = request_stream.receive_bytes().await?;

View File

@@ -4,6 +4,7 @@
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.
use std::future::Future;
use std::{
io,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
@@ -11,6 +12,7 @@ use std::{
time::Duration,
};
use drain::{Signal, Watch};
use futures_util::{FutureExt, StreamExt};
#[cfg(feature = "dns-over-rustls")]
use rustls::{Certificate, PrivateKey};
@@ -40,14 +42,19 @@ use crate::{
pub struct ServerFuture<T: RequestHandler> {
handler: Arc<T>,
join_set: JoinSet<Result<(), ProtoError>>,
shutdown_signal: ShutdownSignal,
shutdown_watch: Watch,
}
impl<T: RequestHandler> ServerFuture<T> {
/// Creates a new ServerFuture with the specified Handler.
pub fn new(handler: T) -> Self {
let (signal, watch) = drain::channel();
Self {
handler: Arc::new(handler),
join_set: JoinSet::new(),
shutdown_signal: ShutdownSignal::new(signal),
shutdown_watch: watch,
}
}
@@ -57,16 +64,17 @@ impl<T: RequestHandler> ServerFuture<T> {
// create the new UdpStream, the IP address isn't relevant, and ideally goes essentially no where.
// the address used is acquired from the inbound queries
let (mut buf_stream, stream_handle) =
let (stream, stream_handle) =
UdpStream::with_bound(socket, ([127, 255, 255, 254], 0).into());
//let request_stream = RequestStream::new(buf_stream, stream_handle);
let shutdown = self.shutdown_watch.clone();
let mut stream = stream.take_until(Box::pin(shutdown.signaled()));
let handler = self.handler.clone();
// this spawns a ForEach future which handles all the requests into a Handler.
self.join_set.spawn({
async move {
let mut inner_join_set = JoinSet::new();
while let Some(message) = buf_stream.next().await {
while let Some(message) = stream.next().await {
let message = match message {
Err(e) => {
warn!("error receiving message on udp_socket: {}", e);
@@ -92,15 +100,18 @@ impl<T: RequestHandler> ServerFuture<T> {
let stream_handle = stream_handle.with_remote_addr(src_addr);
inner_join_set.spawn(async move {
self::handle_raw_request(message, Protocol::Udp, handler, stream_handle)
.await;
handle_raw_request(message, Protocol::Udp, handler, stream_handle).await;
});
reap_tasks(&mut inner_join_set);
}
// TODO: let's consider capturing all the initial configuration details so that the socket could be recreated...
Err(ProtoError::from("unexpected close of UDP socket"))
if stream.is_stopped() {
Ok(())
} else {
// TODO: let's consider capturing all the initial configuration details so that the socket could be recreated...
Err(ProtoError::from("unexpected close of UDP socket"))
}
}
});
}
@@ -129,67 +140,72 @@ impl<T: RequestHandler> ServerFuture<T> {
let handler = self.handler.clone();
// for each incoming request...
self.join_set.spawn({
async move {
let mut inner_join_set = JoinSet::new();
loop {
let tcp_stream = listener.accept().await;
let (tcp_stream, src_addr) = match tcp_stream {
let shutdown = self.shutdown_watch.clone();
self.join_set.spawn(async move {
let mut inner_join_set = JoinSet::new();
loop {
let (tcp_stream, src_addr) = tokio::select! {
tcp_stream = listener.accept() => match tcp_stream {
Ok((t, s)) => (t, s),
Err(e) => {
debug!("error receiving TCP tcp_stream error: {}", e);
continue;
}
};
},
},
_ = shutdown.clone().signaled() => {
// A graceful shutdown was initiated. Break out of the loop.
break;
},
};
// verify that the src address is safe for responses
if let Err(e) = sanitize_src_address(src_addr) {
warn!(
"address can not be responded to {src_addr}: {e}",
src_addr = src_addr,
e = e
);
continue;
}
let handler = handler.clone();
// and spawn to the io_loop
inner_join_set.spawn(async move {
debug!("accepted request from: {}", src_addr);
// take the created stream...
let (buf_stream, stream_handle) =
TcpStream::from_stream(AsyncIoTokioAsStd(tcp_stream), src_addr);
let mut timeout_stream = TimeoutStream::new(buf_stream, timeout);
//let request_stream = RequestStream::new(timeout_stream, stream_handle);
while let Some(message) = timeout_stream.next().await {
let message = match message {
Ok(message) => message,
Err(e) => {
debug!(
"error in TCP request_stream src: {} error: {}",
src_addr, e
);
// we're going to bail on this connection...
return;
}
};
// we don't spawn here to limit clients from getting too many resources
self::handle_raw_request(
message,
Protocol::Tcp,
handler.clone(),
stream_handle.clone(),
)
.await;
}
});
reap_tasks(&mut inner_join_set);
// verify that the src address is safe for responses
if let Err(e) = sanitize_src_address(src_addr) {
warn!(
"address can not be responded to {src_addr}: {e}",
src_addr = src_addr,
e = e
);
continue;
}
let handler = handler.clone();
// and spawn to the io_loop
inner_join_set.spawn(async move {
debug!("accepted request from: {}", src_addr);
// take the created stream...
let (buf_stream, stream_handle) =
TcpStream::from_stream(AsyncIoTokioAsStd(tcp_stream), src_addr);
let mut timeout_stream = TimeoutStream::new(buf_stream, timeout);
while let Some(message) = timeout_stream.next().await {
let message = match message {
Ok(message) => message,
Err(e) => {
debug!(
"error in TCP request_stream src: {} error: {}",
src_addr, e
);
// we're going to bail on this connection...
return;
}
};
// we don't spawn here to limit clients from getting too many resources
handle_raw_request(
message,
Protocol::Tcp,
handler.clone(),
stream_handle.clone(),
)
.await;
}
});
reap_tasks(&mut inner_join_set);
}
Ok(())
});
}
@@ -251,84 +267,90 @@ impl<T: RequestHandler> ServerFuture<T> {
let tls_acceptor = Box::pin(tls_server::new_acceptor(cert, chain, key)?);
// for each incoming request...
self.join_set.spawn({
async move {
let mut inner_join_set = JoinSet::new();
loop {
let tcp_stream = listener.accept().await;
let (tcp_stream, src_addr) = match tcp_stream {
let shutdown = self.shutdown_watch.clone();
self.join_set.spawn(async move {
let mut inner_join_set = JoinSet::new();
loop {
let (tcp_stream, src_addr) = tokio::select! {
tcp_stream = listener.accept() => match tcp_stream {
Ok((t, s)) => (t, s),
Err(e) => {
debug!("error receiving TLS tcp_stream error: {}", e);
continue;
},
},
_ = shutdown.clone().signaled() => {
// A graceful shutdown was initiated. Break out of the loop.
break;
},
};
// verify that the src address is safe for responses
if let Err(e) = sanitize_src_address(src_addr) {
warn!(
"address can not be responded to {src_addr}: {e}",
src_addr = src_addr,
e = e
);
continue;
}
let handler = handler.clone();
let tls_acceptor = tls_acceptor.clone();
// kick out to a different task immediately, let them do the TLS handshake
inner_join_set.spawn(async move {
debug!("starting TLS request from: {}", src_addr);
// perform the TLS
let mut tls_stream = match Ssl::new(tls_acceptor.context())
.and_then(|ssl| TokioSslStream::new(ssl, tcp_stream))
{
Ok(tls_stream) => tls_stream,
Err(e) => {
debug!("tls handshake src: {} error: {}", src_addr, e);
return ();
}
};
// verify that the src address is safe for responses
if let Err(e) = sanitize_src_address(src_addr) {
warn!(
"address can not be responded to {src_addr}: {e}",
src_addr = src_addr,
e = e
);
continue;
}
let handler = handler.clone();
let tls_acceptor = tls_acceptor.clone();
// kick out to a different task immediately, let them do the TLS handshake
inner_join_set.spawn(async move {
debug!("starting TLS request from: {}", src_addr);
// perform the TLS
let mut tls_stream = match Ssl::new(tls_acceptor.context())
.and_then(|ssl| TokioSslStream::new(ssl, tcp_stream))
{
Ok(tls_stream) => tls_stream,
Err(e) => {
debug!("tls handshake src: {} error: {}", src_addr, e);
return ();
}
};
match Pin::new(&mut tls_stream).accept().await {
Ok(()) => {}
Err(e) => {
debug!("tls handshake src: {} error: {}", src_addr, e);
return ();
}
};
debug!("accepted TLS request from: {}", src_addr);
let (buf_stream, stream_handle) =
TlsStream::from_stream(AsyncIoTokioAsStd(tls_stream), src_addr);
let mut timeout_stream = TimeoutStream::new(buf_stream, timeout);
while let Some(message) = timeout_stream.next().await {
let message = match message {
Ok(message) => message,
Err(e) => {
debug!(
"error in TLS request_stream src: {:?} error: {}",
src_addr, e
);
// kill this connection
return ();
}
};
self::handle_raw_request(
message,
Protocol::Tls,
handler.clone(),
stream_handle.clone(),
)
.await;
match Pin::new(&mut tls_stream).accept().await {
Ok(()) => {}
Err(e) => {
debug!("tls handshake src: {} error: {}", src_addr, e);
return ();
}
});
};
debug!("accepted TLS request from: {}", src_addr);
let (buf_stream, stream_handle) =
TlsStream::from_stream(AsyncIoTokioAsStd(tls_stream), src_addr);
let mut timeout_stream = TimeoutStream::new(buf_stream, timeout);
while let Some(message) = timeout_stream.next().await {
let message = match message {
Ok(message) => message,
Err(e) => {
debug!(
"error in TLS request_stream src: {:?} error: {}",
src_addr, e
);
reap_tasks(&mut inner_join_set);
}
// kill this connection
return ();
}
};
self::handle_raw_request(
message,
Protocol::Tls,
handler.clone(),
stream_handle.clone(),
)
.await;
}
});
reap_tasks(&mut inner_join_set);
}
Ok(())
});
Ok(())
@@ -403,76 +425,82 @@ impl<T: RequestHandler> ServerFuture<T> {
let tls_acceptor = TlsAcceptor::from(Arc::new(tls_acceptor));
// for each incoming request...
self.join_set.spawn({
async move {
let mut inner_join_set = JoinSet::new();
loop {
let tcp_stream = listener.accept().await;
let (tcp_stream, src_addr) = match tcp_stream {
let shutdown = self.shutdown_watch.clone();
self.join_set.spawn(async move {
let mut inner_join_set = JoinSet::new();
loop {
let (tcp_stream, src_addr) = tokio::select! {
tcp_stream = listener.accept() => match tcp_stream {
Ok((t, s)) => (t, s),
Err(e) => {
debug!("error receiving TLS tcp_stream error: {}", e);
continue;
},
},
_ = shutdown.clone().signaled() => {
// A graceful shutdown was initiated. Break out of the loop.
break;
},
};
// verify that the src address is safe for responses
if let Err(e) = sanitize_src_address(src_addr) {
warn!(
"address can not be responded to {src_addr}: {e}",
src_addr = src_addr,
e = e
);
continue;
}
let handler = handler.clone();
let tls_acceptor = tls_acceptor.clone();
// kick out to a different task immediately, let them do the TLS handshake
inner_join_set.spawn(async move {
debug!("starting TLS request from: {}", src_addr);
// perform the TLS
let tls_stream = tls_acceptor.accept(tcp_stream).await;
let tls_stream = match tls_stream {
Ok(tls_stream) => AsyncIoTokioAsStd(tls_stream),
Err(e) => {
debug!("tls handshake src: {} error: {}", src_addr, e);
return;
}
};
// verify that the src address is safe for responses
if let Err(e) = sanitize_src_address(src_addr) {
warn!(
"address can not be responded to {src_addr}: {e}",
src_addr = src_addr,
e = e
);
continue;
}
let handler = handler.clone();
let tls_acceptor = tls_acceptor.clone();
// kick out to a different task immediately, let them do the TLS handshake
inner_join_set.spawn(async move {
debug!("starting TLS request from: {}", src_addr);
// perform the TLS
let tls_stream = tls_acceptor.accept(tcp_stream).await;
let tls_stream = match tls_stream {
Ok(tls_stream) => AsyncIoTokioAsStd(tls_stream),
debug!("accepted TLS request from: {}", src_addr);
let (buf_stream, stream_handle) = tls_from_stream(tls_stream, src_addr);
let mut timeout_stream = TimeoutStream::new(buf_stream, timeout);
while let Some(message) = timeout_stream.next().await {
let message = match message {
Ok(message) => message,
Err(e) => {
debug!("tls handshake src: {} error: {}", src_addr, e);
debug!(
"error in TLS request_stream src: {:?} error: {}",
src_addr, e
);
// kill this connection
return;
}
};
debug!("accepted TLS request from: {}", src_addr);
let (buf_stream, stream_handle) = tls_from_stream(tls_stream, src_addr);
let mut timeout_stream = TimeoutStream::new(buf_stream, timeout);
while let Some(message) = timeout_stream.next().await {
let message = match message {
Ok(message) => message,
Err(e) => {
debug!(
"error in TLS request_stream src: {:?} error: {}",
src_addr, e
);
// kill this connection
return;
}
};
handle_raw_request(
message,
Protocol::Tls,
handler.clone(),
stream_handle.clone(),
)
.await;
}
});
self::handle_raw_request(
message,
Protocol::Tls,
handler.clone(),
stream_handle.clone(),
)
.await;
}
});
reap_tasks(&mut inner_join_set);
}
reap_tasks(&mut inner_join_set);
}
Ok(())
});
Ok(())
@@ -555,52 +583,66 @@ impl<T: RequestHandler> ServerFuture<T> {
// for each incoming request...
let dns_hostname = dns_hostname;
self.join_set.spawn({
async move {
let mut inner_join_set = JoinSet::new();
let dns_hostname = dns_hostname;
loop {
let tcp_stream = listener.accept().await;
let (tcp_stream, src_addr) = match tcp_stream {
let shutdown = self.shutdown_watch.clone();
self.join_set.spawn(async move {
let mut inner_join_set = JoinSet::new();
let dns_hostname = dns_hostname;
loop {
let shutdown = shutdown.clone();
let (tcp_stream, src_addr) = tokio::select! {
tcp_stream = listener.accept() => match tcp_stream {
Ok((t, s)) => (t, s),
Err(e) => {
debug!("error receiving HTTPS tcp_stream error: {e}");
debug!("error receiving HTTPS tcp_stream error: {}", e);
continue;
},
},
_ = shutdown.clone().signaled() => {
// A graceful shutdown was initiated. Break out of the loop.
break;
},
};
// verify that the src address is safe for responses
if let Err(e) = sanitize_src_address(src_addr) {
warn!("address can not be responded to {src_addr}: {e}");
continue;
}
let handler = handler.clone();
let tls_acceptor = tls_acceptor.clone();
let dns_hostname = dns_hostname.clone();
inner_join_set.spawn(async move {
debug!("starting HTTPS request from: {src_addr}");
// TODO: need to consider timeout of total connect...
// take the created stream...
let tls_stream = tls_acceptor.accept(tcp_stream).await;
let tls_stream = match tls_stream {
Ok(tls_stream) => tls_stream,
Err(e) => {
debug!("https handshake src: {src_addr} error: {e}");
return;
}
};
debug!("accepted HTTPS request from: {src_addr}");
// verify that the src address is safe for responses
if let Err(e) = sanitize_src_address(src_addr) {
warn!("address can not be responded to {src_addr}: {e}");
continue;
}
h2_handler(
handler,
tls_stream,
src_addr,
dns_hostname,
shutdown.clone(),
)
.await;
});
let handler = handler.clone();
let tls_acceptor = tls_acceptor.clone();
let dns_hostname = dns_hostname.clone();
inner_join_set.spawn(async move {
debug!("starting HTTPS request from: {src_addr}");
// TODO: need to consider timeout of total connect...
// take the created stream...
let tls_stream = tls_acceptor.accept(tcp_stream).await;
let tls_stream = match tls_stream {
Ok(tls_stream) => tls_stream,
Err(e) => {
debug!("https handshake src: {src_addr} error: {e}");
return;
}
};
debug!("accepted HTTPS request from: {src_addr}");
h2_handler(handler, tls_stream, src_addr, dns_hostname).await;
});
reap_tasks(&mut inner_join_set);
}
reap_tasks(&mut inner_join_set);
}
Ok(())
});
Ok(())
@@ -642,66 +684,137 @@ impl<T: RequestHandler> ServerFuture<T> {
// for each incoming request...
let dns_hostname = dns_hostname;
self.join_set.spawn({
async move {
let mut inner_join_set = JoinSet::new();
let dns_hostname = dns_hostname;
loop {
let (streams, src_addr) = match server.next().await {
let shutdown = self.shutdown_watch.clone();
self.join_set.spawn(async move {
let mut inner_join_set = JoinSet::new();
let dns_hostname = dns_hostname;
loop {
let shutdown = shutdown.clone();
let (streams, src_addr) = tokio::select! {
result = server.next() => match result {
Ok(Some(c)) => c,
Ok(None) => continue,
Err(e) => {
debug!("error receiving quic connection: {e}");
continue;
}
};
},
_ = shutdown.clone().signaled() => {
// A graceful shutdown was initiated. Break out of the loop.
break;
},
};
// verify that the src address is safe for responses
// TODO: we're relying the quinn library to actually validate responses before we get here, but this check is still worth doing
if let Err(e) = sanitize_src_address(src_addr) {
warn!(
"address can not be responded to {src_addr}: {e}",
src_addr = src_addr,
e = e
);
continue;
}
let handler = handler.clone();
let dns_hostname = dns_hostname.clone();
inner_join_set.spawn(async move {
debug!("starting quic stream request from: {src_addr}");
// TODO: need to consider timeout of total connect...
let result = quic_handler(handler, streams, src_addr, dns_hostname).await;
if let Err(e) = result {
warn!("quic stream processing failed from {src_addr}: {e}")
}
});
reap_tasks(&mut inner_join_set);
// verify that the src address is safe for responses
// TODO: we're relying the quinn library to actually validate responses before we get here, but this check is still worth doing
if let Err(e) = sanitize_src_address(src_addr) {
warn!(
"address can not be responded to {src_addr}: {e}",
src_addr = src_addr,
e = e
);
continue;
}
let handler = handler.clone();
let dns_hostname = dns_hostname.clone();
inner_join_set.spawn(async move {
debug!("starting quic stream request from: {src_addr}");
// TODO: need to consider timeout of total connect...
let result =
quic_handler(handler, streams, src_addr, dns_hostname, shutdown.clone())
.await;
if let Err(e) = result {
warn!("quic stream processing failed from {src_addr}: {e}")
}
});
reap_tasks(&mut inner_join_set);
}
Ok(())
});
Ok(())
}
/// This will run until a background task of the trust_dns_server ends.
pub async fn block_until_done(mut self) -> Result<(), ProtoError> {
let result = self.join_set.join_next().await;
/// Returns a signal used for initiating a graceful shutdown of the server and the future
/// used for awaiting completion of the shutdown.
///
/// This allows the application to have separate code paths that are responsible for
/// triggering shutdown and awaiting application completion.
pub fn graceful(self) -> (ShutdownSignal, impl Future<Output = Result<(), ProtoError>>) {
let signal = self.shutdown_signal;
let join_set = self.join_set;
(signal, block_until_done(join_set))
}
match result {
None => {
tracing::warn!("block_until_done called with no pending tasks");
Ok(())
/// Triggers a graceful shutdown the server. All background tasks will stop accepting
/// new connections and the returned future will complete once all tasks have terminated.
///
/// This is equivalent to calling [Self::graceful], then triggering the graceful
/// shutdown (via [ShutdownSignal::shutdown]) and awaiting completion of the server.
pub async fn shutdown_gracefully(self) -> Result<(), ProtoError> {
let (signal, fut) = self.graceful();
// Trigger shutdown.
signal.shutdown().await;
// Wait for the server to complete.
fut.await
}
/// This will run until all background tasks complete. If one or more tasks return an error,
/// one will be chosen as the returned error for this future.
pub async fn block_until_done(self) -> Result<(), ProtoError> {
block_until_done(self.join_set).await
}
}
/// Signals the start of a graceful shutdown.
#[derive(Debug)]
pub struct ShutdownSignal {
signal: Signal,
}
impl ShutdownSignal {
fn new(signal: Signal) -> Self {
Self { signal }
}
/// Asynchronously sends the shutdown command to all server threads and
/// waits for them to complete.
pub async fn shutdown(self) {
self.signal.drain().await
}
}
async fn block_until_done(mut join_set: JoinSet<Result<(), ProtoError>>) -> Result<(), ProtoError> {
if join_set.is_empty() {
warn!("block_until_done called with no pending tasks");
return Ok(());
}
// Now wait for all of the tasks to complete.
let mut out = Ok(());
while let Some(join_result) = join_set.join_next().await {
match join_result {
Ok(result) => {
match result {
Ok(_) => (),
Err(e) => {
// Save the last error.
out = Err(e);
}
}
}
Some(Ok(x)) => x,
Some(Err(e)) => Err(ProtoError::from(format!("Internal error in spawn: {e}"))),
Err(e) => return Err(ProtoError::from(format!("Internal error in spawn: {e}"))),
}
}
out
}
/// Reap finished tasks from a `JoinSet`, without awaiting or blocking.
@@ -718,7 +831,7 @@ pub(crate) async fn handle_raw_request<T: RequestHandler>(
let src_addr = message.addr();
let response_handler = ResponseHandle::new(message.addr(), response_handler, protocol);
self::handle_request(
handle_request(
message.bytes(),
src_addr,
protocol,
@@ -935,35 +1048,44 @@ fn sanitize_src_address(src: SocketAddr) -> Result<(), String> {
#[cfg(test)]
mod tests {
use super::*;
use crate::authority::Catalog;
use futures_util::future;
use std::net::{Ipv4Addr, SocketAddr, UdpSocket};
#[cfg(feature = "dns-over-rustls")]
use rustls::{Certificate, PrivateKey};
use std::net::SocketAddr;
use tokio::net::{TcpListener, UdpSocket};
use tokio::time::timeout;
#[test]
fn cleanup_after_shutdown() {
let runtime = tokio::runtime::Runtime::new().unwrap();
let random_port = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0))
.unwrap()
.local_addr()
.unwrap()
.port();
let bind_addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), random_port);
#[tokio::test]
async fn abort() {
let endpoints = Endpoints::new().await;
let (server_future, abort_handle) = future::abortable(async move {
let endpoints2 = endpoints.clone();
let (abortable, abort_handle) = future::abortable(async move {
let mut server_future = ServerFuture::new(Catalog::new());
let udp_socket = tokio::net::UdpSocket::bind(bind_addr).await.unwrap();
server_future.register_socket(udp_socket);
endpoints2.register(&mut server_future).await;
server_future.block_until_done().await
});
abort_handle.abort();
runtime.block_on(async move {
let _ = server_future.await;
});
abortable.await.expect_err("expected abort");
UdpSocket::bind(bind_addr).unwrap();
endpoints.rebind_all().await;
}
#[tokio::test]
async fn graceful_shutdown() {
let mut server_future = ServerFuture::new(Catalog::new());
let endpoints = Endpoints::new().await;
endpoints.register(&mut server_future).await;
timeout(Duration::from_secs(2), server_future.shutdown_gracefully())
.await
.expect("timed out waiting for the server to complete")
.expect("error while awaiting tasks");
endpoints.rebind_all().await;
}
#[test]
@@ -989,4 +1111,137 @@ mod tests {
sanitize_src_address(SocketAddr::from(([0x20, 0, 0, 0, 0, 0, 0, 0x1], 0))).is_err()
);
}
#[derive(Clone)]
struct Endpoints {
udp_addr: SocketAddr,
udp_std_addr: SocketAddr,
tcp_addr: SocketAddr,
tcp_std_addr: SocketAddr,
#[cfg(feature = "dns-over-rustls")]
rustls_addr: SocketAddr,
#[cfg(feature = "dns-over-https-rustls")]
https_rustls_addr: SocketAddr,
#[cfg(feature = "dns-over-quic")]
quic_addr: SocketAddr,
}
impl Endpoints {
async fn new() -> Self {
let udp = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let udp_std = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let tcp = TcpListener::bind("127.0.0.1:0").await.unwrap();
let tcp_std = TcpListener::bind("127.0.0.1:0").await.unwrap();
#[cfg(feature = "dns-over-rustls")]
let rustls = TcpListener::bind("127.0.0.1:0").await.unwrap();
#[cfg(feature = "dns-over-https-rustls")]
let https_rustls = TcpListener::bind("127.0.0.1:0").await.unwrap();
#[cfg(feature = "dns-over-quic")]
let quic = UdpSocket::bind("127.0.0.1:0").await.unwrap();
Self {
udp_addr: udp.local_addr().unwrap(),
udp_std_addr: udp_std.local_addr().unwrap(),
tcp_addr: tcp.local_addr().unwrap(),
tcp_std_addr: tcp_std.local_addr().unwrap(),
#[cfg(feature = "dns-over-rustls")]
rustls_addr: rustls.local_addr().unwrap(),
#[cfg(feature = "dns-over-https-rustls")]
https_rustls_addr: https_rustls.local_addr().unwrap(),
#[cfg(feature = "dns-over-quic")]
quic_addr: quic.local_addr().unwrap(),
}
}
async fn register<T: RequestHandler>(&self, server: &mut ServerFuture<T>) {
server.register_socket(UdpSocket::bind(self.udp_addr).await.unwrap());
server
.register_socket_std(std::net::UdpSocket::bind(self.udp_std_addr).unwrap())
.unwrap();
server.register_listener(
TcpListener::bind(self.tcp_addr).await.unwrap(),
Duration::from_secs(1),
);
server
.register_listener_std(
std::net::TcpListener::bind(self.tcp_std_addr).unwrap(),
Duration::from_secs(1),
)
.unwrap();
#[cfg(feature = "dns-over-rustls")]
{
let cert_key = rustls_cert_key();
server
.register_tls_listener(
TcpListener::bind(self.rustls_addr).await.unwrap(),
Duration::from_secs(30),
cert_key,
)
.unwrap();
}
#[cfg(feature = "dns-over-https-rustls")]
{
let cert_key = rustls_cert_key();
server
.register_https_listener(
TcpListener::bind(self.https_rustls_addr).await.unwrap(),
Duration::from_secs(1),
cert_key,
None,
)
.unwrap();
}
#[cfg(feature = "dns-over-quic")]
{
let cert_key = rustls_cert_key();
server
.register_quic_listener(
UdpSocket::bind(self.quic_addr).await.unwrap(),
Duration::from_secs(1),
cert_key,
None,
)
.unwrap();
}
}
async fn rebind_all(&self) {
UdpSocket::bind(self.udp_addr).await.unwrap();
UdpSocket::bind(self.udp_std_addr).await.unwrap();
TcpListener::bind(self.tcp_addr).await.unwrap();
TcpListener::bind(self.tcp_std_addr).await.unwrap();
#[cfg(feature = "dns-over-rustls")]
TcpListener::bind(self.rustls_addr).await.unwrap();
#[cfg(feature = "dns-over-https-rustls")]
TcpListener::bind(self.https_rustls_addr).await.unwrap();
#[cfg(feature = "dns-over-quic")]
UdpSocket::bind(self.quic_addr).await.unwrap();
}
}
#[cfg(feature = "dns-over-rustls")]
fn rustls_cert_key() -> (Vec<Certificate>, PrivateKey) {
use std::env;
use std::path::Path;
use trust_dns_proto::rustls::tls_server;
let server_path = env::var("TDNS_WORKSPACE_ROOT").unwrap_or_else(|_| "../..".to_owned());
let cert = tls_server::read_cert(Path::new(&format!(
"{}/tests/test-data/cert.pem",
server_path
)))
.map_err(|e| format!("error reading cert: {e}"))
.unwrap();
let key = tls_server::read_key_from_pem(Path::new(&format!(
"{}/tests/test-data/cert.key",
server_path
)))
.unwrap();
(cert, key)
}
}

View File

@@ -337,8 +337,9 @@ where
.query(&name, DNSClass::IN, RecordType::A)
.expect("error querying");
assert!(
response.response_code() == ResponseCode::NoError,
assert_eq!(
response.response_code(),
ResponseCode::NoError,
"got an error: {:?}",
response.response_code()
);
@@ -377,6 +378,7 @@ fn server_thread_udp(io_loop: Runtime, udp_socket: UdpSocket, server_continue: A
io_loop.block_on(future::lazy(|_| tokio::time::sleep(Duration::from_millis(10))).flatten());
}
_ = io_loop.block_on(server.shutdown_gracefully());
drop(io_loop);
}
@@ -394,6 +396,8 @@ fn server_thread_tcp(
while server_continue.load(Ordering::Relaxed) {
io_loop.block_on(future::lazy(|_| tokio::time::sleep(Duration::from_millis(10))).flatten());
}
_ = io_loop.block_on(server.shutdown_gracefully());
}
// TODO: need a rustls option
@@ -425,4 +429,6 @@ fn server_thread_tls(
while server_continue.load(Ordering::Relaxed) {
io_loop.block_on(future::lazy(|_| tokio::time::sleep(Duration::from_millis(10))).flatten());
}
_ = io_loop.block_on(server.shutdown_gracefully());
}

View File

@@ -27,7 +27,6 @@ async fn test_truncation() {
// Create and start the server.
let mut server = ServerFuture::new(new_large_catalog(128));
server.register_socket(udp_socket);
tokio::spawn(server.block_until_done());
// Create the UDP client.
let stream = UdpClientStream::<UdpSocket>::new(nameserver);
@@ -58,6 +57,8 @@ async fn test_truncation() {
assert!(result.truncated());
assert_eq!(max_payload, result.max_payload());
server.shutdown_gracefully().await.unwrap();
}
// TODO: should we do this for all of the integration tests?