From cffc3fac2a7dc75e6a56b2a348033b5be903fdde Mon Sep 17 00:00:00 2001 From: 0xffffharry <95022881+0xffffharry@users.noreply.github.com> Date: Thu, 2 May 2024 12:10:58 +0000 Subject: [PATCH] Make H3ClientStream Clonable --- crates/proto/src/h3/h3_client_stream.rs | 97 +++++++++++++++++++++---- 1 file changed, 84 insertions(+), 13 deletions(-) diff --git a/crates/proto/src/h3/h3_client_stream.rs b/crates/proto/src/h3/h3_client_stream.rs index 0b6cd394..86c0c2a0 100644 --- a/crates/proto/src/h3/h3_client_stream.rs +++ b/crates/proto/src/h3/h3_client_stream.rs @@ -6,7 +6,7 @@ // copied, modified, or distributed except according to those terms. use std::fmt::{self, Display}; -use std::future::Future; +use std::future::{self, Future}; use std::net::SocketAddr; use std::pin::Pin; use std::str::FromStr; @@ -16,12 +16,13 @@ use std::task::{Context, Poll}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use futures_util::future::FutureExt; use futures_util::stream::Stream; -use h3::client::{Connection, SendRequest}; +use h3::client::SendRequest; use h3_quinn::OpenStreams; use http::header::{self, CONTENT_LENGTH}; use quinn::{ClientConfig, Endpoint, EndpointConfig, TransportConfig}; use rustls::ClientConfig as TlsClientConfig; -use tracing::debug; +use tokio::sync::mpsc; +use tracing::{debug, warn}; use crate::error::ProtoError; use crate::http::Version; @@ -34,13 +35,14 @@ use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream}; use super::ALPN_H3; /// A DNS client connection for DNS-over-HTTP/3 +#[derive(Clone)] #[must_use = "futures do nothing unless polled"] pub struct H3ClientStream { // Corresponds to the dns-name of the HTTP/3 server name_server_name: Arc, name_server: SocketAddr, - driver: Connection, send_request: SendRequest, + shutdown_tx: mpsc::Sender<()>, is_shutdown: bool, } @@ -264,19 +266,19 @@ impl DnsRequestSender for H3ClientStream { impl Stream for H3ClientStream { type Item = Result<(), ProtoError>; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { if self.is_shutdown { return Poll::Ready(None); } // just checking if the connection is ok - match self.driver.poll_close(cx) { - Poll::Ready(Ok(())) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => Poll::Ready(Some(Err(ProtoError::from(format!( - "h3 stream errored: {e}", - ))))), + if self.shutdown_tx.is_closed() { + return Poll::Ready(Some(Err(ProtoError::from( + "h3 connection is already shutdown", + )))); } + + Poll::Ready(Some(Ok(()))) } } @@ -398,15 +400,31 @@ impl H3ClientStreamBuilder { }; let h3_connection = h3_quinn::Connection::new(quic_connection); - let (driver, send_request) = h3::client::new(h3_connection) + let (mut driver, send_request) = h3::client::new(h3_connection) .await .map_err(|e| ProtoError::from(format!("h3 connection failed: {e}")))?; + let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); + + // TODO: hand this back for others to run rather than spawning here? + debug!("h3 connection is ready: {}", name_server); + tokio::spawn(async move { + tokio::select! { + res = future::poll_fn(|cx| driver.poll_close(cx)) => { + res.map_err(|e| warn!("h3 connection failed: {e}")) + } + _ = shutdown_rx.recv() => { + debug!("h3 connection is shutting down: {}", name_server); + Ok(()) + } + } + }); + Ok(H3ClientStream { name_server_name: Arc::from(dns_name), name_server, - driver, send_request, + shutdown_tx, is_shutdown: false, }) } @@ -453,6 +471,7 @@ mod tests { use rustls::KeyLogFile; use tokio::runtime::Runtime; + use tokio::task::JoinSet; use crate::op::{Message, Query, ResponseCode}; use crate::rr::rdata::{A, AAAA}; @@ -652,4 +671,56 @@ mod tests { &AAAA::new(0x2606, 0x2800, 0x21f, 0xcb07, 0x6820, 0x80da, 0xaf6b, 0x8b2c) ); } + + #[test] + #[allow(clippy::print_stdout)] + fn test_h3_client_stream_clonable() { + // use google + let google = SocketAddr::from(([8, 8, 8, 8], 443)); + + let mut client_config = super::super::client_config_tls13().unwrap(); + client_config.key_log = Arc::new(KeyLogFile::new()); + + let mut h3_builder = H3ClientStream::builder(); + h3_builder.crypto_config(client_config); + let connect = h3_builder.build(google, "dns.google".to_string()); + + // tokio runtime stuff... + let runtime = Runtime::new().expect("could not start runtime"); + let h3 = runtime.block_on(connect).expect("h3 connect failed"); + + // prepare request + let mut request = Message::new(); + let query = Query::query( + Name::from_str("www.example.com.").unwrap(), + RecordType::AAAA, + ); + request.add_query(query); + let request = DnsRequest::new(request, DnsRequestOptions::default()); + + runtime.block_on(async move { + let mut join_set = JoinSet::new(); + + for i in 0..50 { + let mut h3 = h3.clone(); + let request = request.clone(); + + join_set.spawn(async move { + let start = std::time::Instant::now(); + h3.send_message(request) + .first_answer() + .await + .expect("send_message failed"); + println!("request[{i}] completed: {:?}", start.elapsed()); + }); + } + + let total = join_set.len(); + let mut idx = 0usize; + while join_set.join_next().await.is_some() { + println!("join_set completed {idx}/{total}"); + idx += 1; + } + }); + } }