Make H3ClientStream Clonable
This commit is contained in:
parent
c1f2e9b4de
commit
cffc3fac2a
@ -6,7 +6,7 @@
|
|||||||
// copied, modified, or distributed except according to those terms.
|
// copied, modified, or distributed except according to those terms.
|
||||||
|
|
||||||
use std::fmt::{self, Display};
|
use std::fmt::{self, Display};
|
||||||
use std::future::Future;
|
use std::future::{self, Future};
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
@ -16,12 +16,13 @@ use std::task::{Context, Poll};
|
|||||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||||
use futures_util::future::FutureExt;
|
use futures_util::future::FutureExt;
|
||||||
use futures_util::stream::Stream;
|
use futures_util::stream::Stream;
|
||||||
use h3::client::{Connection, SendRequest};
|
use h3::client::SendRequest;
|
||||||
use h3_quinn::OpenStreams;
|
use h3_quinn::OpenStreams;
|
||||||
use http::header::{self, CONTENT_LENGTH};
|
use http::header::{self, CONTENT_LENGTH};
|
||||||
use quinn::{ClientConfig, Endpoint, EndpointConfig, TransportConfig};
|
use quinn::{ClientConfig, Endpoint, EndpointConfig, TransportConfig};
|
||||||
use rustls::ClientConfig as TlsClientConfig;
|
use rustls::ClientConfig as TlsClientConfig;
|
||||||
use tracing::debug;
|
use tokio::sync::mpsc;
|
||||||
|
use tracing::{debug, warn};
|
||||||
|
|
||||||
use crate::error::ProtoError;
|
use crate::error::ProtoError;
|
||||||
use crate::http::Version;
|
use crate::http::Version;
|
||||||
@ -34,13 +35,14 @@ use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream};
|
|||||||
use super::ALPN_H3;
|
use super::ALPN_H3;
|
||||||
|
|
||||||
/// A DNS client connection for DNS-over-HTTP/3
|
/// A DNS client connection for DNS-over-HTTP/3
|
||||||
|
#[derive(Clone)]
|
||||||
#[must_use = "futures do nothing unless polled"]
|
#[must_use = "futures do nothing unless polled"]
|
||||||
pub struct H3ClientStream {
|
pub struct H3ClientStream {
|
||||||
// Corresponds to the dns-name of the HTTP/3 server
|
// Corresponds to the dns-name of the HTTP/3 server
|
||||||
name_server_name: Arc<str>,
|
name_server_name: Arc<str>,
|
||||||
name_server: SocketAddr,
|
name_server: SocketAddr,
|
||||||
driver: Connection<h3_quinn::Connection, Bytes>,
|
|
||||||
send_request: SendRequest<OpenStreams, Bytes>,
|
send_request: SendRequest<OpenStreams, Bytes>,
|
||||||
|
shutdown_tx: mpsc::Sender<()>,
|
||||||
is_shutdown: bool,
|
is_shutdown: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -264,19 +266,19 @@ impl DnsRequestSender for H3ClientStream {
|
|||||||
impl Stream for H3ClientStream {
|
impl Stream for H3ClientStream {
|
||||||
type Item = Result<(), ProtoError>;
|
type Item = Result<(), ProtoError>;
|
||||||
|
|
||||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
if self.is_shutdown {
|
if self.is_shutdown {
|
||||||
return Poll::Ready(None);
|
return Poll::Ready(None);
|
||||||
}
|
}
|
||||||
|
|
||||||
// just checking if the connection is ok
|
// just checking if the connection is ok
|
||||||
match self.driver.poll_close(cx) {
|
if self.shutdown_tx.is_closed() {
|
||||||
Poll::Ready(Ok(())) => Poll::Ready(None),
|
return Poll::Ready(Some(Err(ProtoError::from(
|
||||||
Poll::Pending => Poll::Pending,
|
"h3 connection is already shutdown",
|
||||||
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(ProtoError::from(format!(
|
))));
|
||||||
"h3 stream errored: {e}",
|
|
||||||
))))),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Poll::Ready(Some(Ok(())))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -398,15 +400,31 @@ impl H3ClientStreamBuilder {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let h3_connection = h3_quinn::Connection::new(quic_connection);
|
let h3_connection = h3_quinn::Connection::new(quic_connection);
|
||||||
let (driver, send_request) = h3::client::new(h3_connection)
|
let (mut driver, send_request) = h3::client::new(h3_connection)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| ProtoError::from(format!("h3 connection failed: {e}")))?;
|
.map_err(|e| ProtoError::from(format!("h3 connection failed: {e}")))?;
|
||||||
|
|
||||||
|
let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
|
||||||
|
|
||||||
|
// TODO: hand this back for others to run rather than spawning here?
|
||||||
|
debug!("h3 connection is ready: {}", name_server);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
tokio::select! {
|
||||||
|
res = future::poll_fn(|cx| driver.poll_close(cx)) => {
|
||||||
|
res.map_err(|e| warn!("h3 connection failed: {e}"))
|
||||||
|
}
|
||||||
|
_ = shutdown_rx.recv() => {
|
||||||
|
debug!("h3 connection is shutting down: {}", name_server);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
Ok(H3ClientStream {
|
Ok(H3ClientStream {
|
||||||
name_server_name: Arc::from(dns_name),
|
name_server_name: Arc::from(dns_name),
|
||||||
name_server,
|
name_server,
|
||||||
driver,
|
|
||||||
send_request,
|
send_request,
|
||||||
|
shutdown_tx,
|
||||||
is_shutdown: false,
|
is_shutdown: false,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -453,6 +471,7 @@ mod tests {
|
|||||||
|
|
||||||
use rustls::KeyLogFile;
|
use rustls::KeyLogFile;
|
||||||
use tokio::runtime::Runtime;
|
use tokio::runtime::Runtime;
|
||||||
|
use tokio::task::JoinSet;
|
||||||
|
|
||||||
use crate::op::{Message, Query, ResponseCode};
|
use crate::op::{Message, Query, ResponseCode};
|
||||||
use crate::rr::rdata::{A, AAAA};
|
use crate::rr::rdata::{A, AAAA};
|
||||||
@ -652,4 +671,56 @@ mod tests {
|
|||||||
&AAAA::new(0x2606, 0x2800, 0x21f, 0xcb07, 0x6820, 0x80da, 0xaf6b, 0x8b2c)
|
&AAAA::new(0x2606, 0x2800, 0x21f, 0xcb07, 0x6820, 0x80da, 0xaf6b, 0x8b2c)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[allow(clippy::print_stdout)]
|
||||||
|
fn test_h3_client_stream_clonable() {
|
||||||
|
// use google
|
||||||
|
let google = SocketAddr::from(([8, 8, 8, 8], 443));
|
||||||
|
|
||||||
|
let mut client_config = super::super::client_config_tls13().unwrap();
|
||||||
|
client_config.key_log = Arc::new(KeyLogFile::new());
|
||||||
|
|
||||||
|
let mut h3_builder = H3ClientStream::builder();
|
||||||
|
h3_builder.crypto_config(client_config);
|
||||||
|
let connect = h3_builder.build(google, "dns.google".to_string());
|
||||||
|
|
||||||
|
// tokio runtime stuff...
|
||||||
|
let runtime = Runtime::new().expect("could not start runtime");
|
||||||
|
let h3 = runtime.block_on(connect).expect("h3 connect failed");
|
||||||
|
|
||||||
|
// prepare request
|
||||||
|
let mut request = Message::new();
|
||||||
|
let query = Query::query(
|
||||||
|
Name::from_str("www.example.com.").unwrap(),
|
||||||
|
RecordType::AAAA,
|
||||||
|
);
|
||||||
|
request.add_query(query);
|
||||||
|
let request = DnsRequest::new(request, DnsRequestOptions::default());
|
||||||
|
|
||||||
|
runtime.block_on(async move {
|
||||||
|
let mut join_set = JoinSet::new();
|
||||||
|
|
||||||
|
for i in 0..50 {
|
||||||
|
let mut h3 = h3.clone();
|
||||||
|
let request = request.clone();
|
||||||
|
|
||||||
|
join_set.spawn(async move {
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
h3.send_message(request)
|
||||||
|
.first_answer()
|
||||||
|
.await
|
||||||
|
.expect("send_message failed");
|
||||||
|
println!("request[{i}] completed: {:?}", start.elapsed());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let total = join_set.len();
|
||||||
|
let mut idx = 0usize;
|
||||||
|
while join_set.join_next().await.is_some() {
|
||||||
|
println!("join_set completed {idx}/{total}");
|
||||||
|
idx += 1;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user