Proto client implementation

This commit is contained in:
dAxpeDDa 2023-07-07 17:42:31 +02:00 committed by Benjamin Fry
parent 7c78e740ef
commit e34770c113
19 changed files with 975 additions and 142 deletions

View File

@ -85,6 +85,7 @@ jobs:
dns-over-rustls,
dns-over-https-rustls,
dns-over-quic,
dns-over-h3,
dns-over-native-tls,
dnssec-openssl,
dnssec-ring,

42
Cargo.lock generated
View File

@ -635,6 +635,17 @@ dependencies = [
"waker-fn",
]
[[package]]
name = "futures-macro"
version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.37",
]
[[package]]
name = "futures-sink"
version = "0.3.28"
@ -656,6 +667,7 @@ dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
@ -712,6 +724,34 @@ dependencies = [
"tracing",
]
[[package]]
name = "h3"
version = "0.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6de6ca43eed186fd055214af06967b0a7a68336cefec7e8a4004e96efeaccb9e"
dependencies = [
"bytes",
"fastrand 1.9.0",
"futures-util",
"http",
"tokio",
"tracing",
]
[[package]]
name = "h3-quinn"
version = "0.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d4a1a1763e4f3e82ee9f1ecf2cf862b22cc7316ebe14684e42f94532b5ec64d"
dependencies = [
"bytes",
"futures",
"h3",
"quinn",
"quinn-proto",
"tokio-util",
]
[[package]]
name = "hashbrown"
version = "0.12.3"
@ -1937,6 +1977,8 @@ dependencies = [
"futures-io",
"futures-util",
"h2",
"h3",
"h3-quinn",
"http",
"idna",
"ipnet",

View File

@ -71,6 +71,8 @@ ring = "0.16"
# net proto
quinn = { version = "0.10", default-features = false }
h2 = "0.3.0"
h3 = "0.0.2"
h3-quinn = "0.0.3"
http = "0.2"

View File

@ -42,6 +42,8 @@ dns-over-quic = ["quinn", "rustls/quic", "dns-over-rustls", "bytes", "tokio-runt
native-certs = ["dep:rustls-native-certs"]
dns-over-h3 = ["h3", "h3-quinn", "quinn", "http", "dns-over-quic"]
dnssec-openssl = ["dnssec", "openssl"]
dnssec-ring = ["dnssec", "ring"]
dnssec = []
@ -79,6 +81,8 @@ futures-channel = { workspace = true, default-features = false, features = ["std
futures-io = { workspace = true, default-features = false, features = ["std"] }
futures-util = { workspace = true, default-features = false, features = ["std"] }
h2 = { workspace = true, features = ["stream"], optional = true }
h3 = { workspace = true, optional = true }
h3-quinn = { workspace = true, optional = true }
http = { workspace = true, optional = true }
idna.workspace = true
ipnet.workspace = true

View File

@ -0,0 +1,655 @@
// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
//
// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.
use std::fmt::{self, Display};
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures_util::future::FutureExt;
use futures_util::stream::Stream;
use h3::client::{Connection, SendRequest};
use h3_quinn::OpenStreams;
use http::header::{self, CONTENT_LENGTH};
use quinn::{ClientConfig, Endpoint, EndpointConfig, TransportConfig};
use rustls::ClientConfig as TlsClientConfig;
use tracing::debug;
use crate::error::ProtoError;
use crate::http::Version;
use crate::op::Message;
use crate::quic::quic_socket::QuinnAsyncUdpSocketAdapter;
use crate::quic::QuicLocalAddr;
use crate::udp::{DnsUdpSocket, UdpSocket};
use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream};
use super::ALPN_H3;
/// A DNS client connection for DNS-over-HTTP/3
#[must_use = "futures do nothing unless polled"]
pub struct H3ClientStream {
// Corresponds to the dns-name of the HTTP/3 server
name_server_name: Arc<str>,
name_server: SocketAddr,
driver: Connection<h3_quinn::Connection, Bytes>,
send_request: SendRequest<OpenStreams, Bytes>,
is_shutdown: bool,
}
impl Display for H3ClientStream {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(
formatter,
"H3({},{})",
self.name_server, self.name_server_name
)
}
}
impl H3ClientStream {
/// Builder for H3ClientStream
pub fn builder() -> H3ClientStreamBuilder {
H3ClientStreamBuilder::default()
}
async fn inner_send(
mut h3: SendRequest<OpenStreams, Bytes>,
message: Bytes,
name_server_name: Arc<str>,
) -> Result<DnsResponse, ProtoError> {
// build up the http request
let request =
crate::http::request::new(Version::Http3, &name_server_name, message.remaining());
let request =
request.map_err(|err| ProtoError::from(format!("bad http request: {err}")))?;
debug!("request: {:#?}", request);
// Send the request
let mut stream = h3
.send_request(request)
.await
.map_err(|err| ProtoError::from(format!("h3 send_request error: {err}")))?;
stream
.send_data(message)
.await
.map_err(|e| ProtoError::from(format!("h3 send_data error: {e}")))?;
stream
.finish()
.await
.map_err(|err| ProtoError::from(format!("received a stream error: {err}")))?;
let response = stream
.recv_response()
.await
.map_err(|err| ProtoError::from(format!("h3 recv_response error: {err}")))?;
debug!("got response: {:#?}", response);
// get the length of packet
let content_length = response
.headers()
.get(CONTENT_LENGTH)
.map(|v| v.to_str())
.transpose()
.map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?
.map(usize::from_str)
.transpose()
.map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?;
// TODO: what is a good max here?
// clamp(512, 4096) says make sure it is at least 512 bytes, and min 4096 says it is at most 4k
// just a little protection from malicious actors.
let mut response_bytes =
BytesMut::with_capacity(content_length.unwrap_or(512).clamp(512, 4096));
while let Some(partial_bytes) = stream
.recv_data()
.await
.map_err(|e| ProtoError::from(format!("h3 recv_data error: {e}")))?
{
debug!("got bytes: {}", partial_bytes.remaining());
response_bytes.put(partial_bytes);
// assert the length
if let Some(content_length) = content_length {
if response_bytes.len() >= content_length {
break;
}
}
}
// assert the length
if let Some(content_length) = content_length {
if response_bytes.len() != content_length {
// TODO: make explicit error type
return Err(ProtoError::from(format!(
"expected byte length: {}, got: {}",
content_length,
response_bytes.len()
)));
}
}
// Was it a successful request?
if !response.status().is_success() {
let error_string = String::from_utf8_lossy(response_bytes.as_ref());
// TODO: make explicit error type
return Err(ProtoError::from(format!(
"http unsuccessful code: {}, message: {}",
response.status(),
error_string
)));
} else {
// verify content type
{
// in the case that the ContentType is not specified, we assume it's the standard DNS format
let content_type = response
.headers()
.get(header::CONTENT_TYPE)
.map(|h| {
h.to_str().map_err(|err| {
// TODO: make explicit error type
ProtoError::from(format!("ContentType header not a string: {err}"))
})
})
.unwrap_or(Ok(crate::http::MIME_APPLICATION_DNS))?;
if content_type != crate::http::MIME_APPLICATION_DNS {
return Err(ProtoError::from(format!(
"ContentType unsupported (must be '{}'): '{}'",
crate::http::MIME_APPLICATION_DNS,
content_type
)));
}
}
};
// and finally convert the bytes into a DNS message
let message = Message::from_vec(&response_bytes)?;
Ok(DnsResponse::new(message, response_bytes.to_vec()))
}
}
impl DnsRequestSender for H3ClientStream {
/// This indicates that the HTTP message was successfully sent, and we now have the response.RecvStream
///
/// If the request fails, this will return the error, and it should be assumed that the Stream portion of
/// this will have no date.
///
/// ```text
/// 5.2. The HTTP Response
///
/// An HTTP response with a 2xx status code ([RFC7231] Section 6.3)
/// indicates a valid DNS response to the query made in the HTTP request.
/// A valid DNS response includes both success and failure responses.
/// For example, a DNS failure response such as SERVFAIL or NXDOMAIN will
/// be the message in a successful 2xx HTTP response even though there
/// was a failure at the DNS layer. Responses with non-successful HTTP
/// status codes do not contain DNS answers to the question in the
/// corresponding request. Some of these non-successful HTTP responses
/// (e.g., redirects or authentication failures) could mean that clients
/// need to make new requests to satisfy the original question.
///
/// Different response media types will provide more or less information
/// from a DNS response. For example, one response type might include
/// the information from the DNS header bytes while another might omit
/// it. The amount and type of information that a media type gives is
/// solely up to the format, and not defined in this protocol.
///
/// The only response type defined in this document is "application/dns-
/// message", but it is possible that other response formats will be
/// defined in the future.
///
/// The DNS response for "application/dns-message" in Section 7 MAY have
/// one or more EDNS options [RFC6891], depending on the extension
/// definition of the extensions given in the DNS request.
///
/// Each DNS request-response pair is matched to one HTTP exchange. The
/// responses may be processed and transported in any order using HTTP's
/// multi-streaming functionality ([RFC7540] Section 5).
///
/// Section 6.1 discusses the relationship between DNS and HTTP response
/// caching.
///
/// A DNS API server MUST be able to process application/dns-message
/// request messages.
///
/// A DNS API server SHOULD respond with HTTP status code 415
/// (Unsupported Media Type) upon receiving a media type it is unable to
/// process.
/// ```
fn send_message(&mut self, mut message: DnsRequest) -> DnsResponseStream {
if self.is_shutdown {
panic!("can not send messages after stream is shutdown")
}
// per the RFC, a zero id allows for the HTTP packet to be cached better
message.set_id(0);
let bytes = match message.to_vec() {
Ok(bytes) => bytes,
Err(err) => return err.into(),
};
Box::pin(Self::inner_send(
self.send_request.clone(),
Bytes::from(bytes),
Arc::clone(&self.name_server_name),
))
.into()
}
fn shutdown(&mut self) {
self.is_shutdown = true;
}
fn is_shutdown(&self) -> bool {
self.is_shutdown
}
}
impl Stream for H3ClientStream {
type Item = Result<(), ProtoError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.is_shutdown {
return Poll::Ready(None);
}
// just checking if the connection is ok
match self.driver.poll_close(cx) {
Poll::Ready(Ok(())) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(ProtoError::from(format!(
"h3 stream errored: {e}",
))))),
}
}
}
/// A H3 connection builder for DNS-over-HTTP/3
#[derive(Clone)]
pub struct H3ClientStreamBuilder {
crypto_config: TlsClientConfig,
transport_config: Arc<TransportConfig>,
bind_addr: Option<SocketAddr>,
}
impl H3ClientStreamBuilder {
/// Constructs a new H3ClientStreamBuilder with the associated ClientConfig
pub fn crypto_config(&mut self, crypto_config: TlsClientConfig) -> &mut Self {
self.crypto_config = crypto_config;
self
}
/// Sets the address to connect from.
pub fn bind_addr(&mut self, bind_addr: SocketAddr) {
self.bind_addr = Some(bind_addr);
}
/// Creates a new H3Stream to the specified name_server
///
/// # Arguments
///
/// * `name_server` - IP and Port for the remote DNS resolver
/// * `dns_name` - The DNS name, Subject Public Key Info (SPKI) name, as associated to a certificate
pub fn build(self, name_server: SocketAddr, dns_name: String) -> H3ClientConnect {
H3ClientConnect(Box::pin(self.connect(name_server, dns_name)) as _)
}
/// Creates a new H3Stream with existing connection
pub fn build_with_future<S, F>(
self,
future: F,
name_server: SocketAddr,
dns_name: String,
) -> H3ClientConnect
where
S: DnsUdpSocket + QuicLocalAddr + 'static,
F: Future<Output = std::io::Result<S>> + Send + Unpin + 'static,
{
H3ClientConnect(Box::pin(self.connect_with_future(future, name_server, dns_name)) as _)
}
async fn connect_with_future<S, F>(
self,
future: F,
name_server: SocketAddr,
dns_name: String,
) -> Result<H3ClientStream, ProtoError>
where
S: DnsUdpSocket + QuicLocalAddr + 'static,
F: Future<Output = std::io::Result<S>> + Send,
{
let socket = future.await?;
let wrapper = QuinnAsyncUdpSocketAdapter { io: socket };
let endpoint = Endpoint::new_with_abstract_socket(
EndpointConfig::default(),
None,
wrapper,
Arc::new(quinn::TokioRuntime),
)?;
self.connect_inner(endpoint, name_server, dns_name).await
}
async fn connect(
self,
name_server: SocketAddr,
dns_name: String,
) -> Result<H3ClientStream, ProtoError> {
let connect = if let Some(bind_addr) = self.bind_addr {
<tokio::net::UdpSocket as UdpSocket>::connect_with_bind(name_server, bind_addr)
} else {
<tokio::net::UdpSocket as UdpSocket>::connect(name_server)
};
let socket = connect.await?;
let socket = socket.into_std()?;
let endpoint = Endpoint::new(
EndpointConfig::default(),
None,
socket,
Arc::new(quinn::TokioRuntime),
)?;
self.connect_inner(endpoint, name_server, dns_name).await
}
async fn connect_inner(
self,
mut endpoint: Endpoint,
name_server: SocketAddr,
dns_name: String,
) -> Result<H3ClientStream, ProtoError> {
let mut crypto_config = self.crypto_config;
// ensure the ALPN protocol is set correctly
if crypto_config.alpn_protocols.is_empty() {
crypto_config.alpn_protocols = vec![ALPN_H3.to_vec()];
}
let early_data_enabled = crypto_config.enable_early_data;
let mut client_config = ClientConfig::new(Arc::new(crypto_config));
client_config.transport_config(self.transport_config.clone());
endpoint.set_default_client_config(client_config);
let connecting = endpoint.connect(name_server, &dns_name)?;
// TODO: for Client/Dynamic update, don't use RTT, for queries, do use it.
let quic_connection = if early_data_enabled {
match connecting.into_0rtt() {
Ok((new_connection, _)) => new_connection,
Err(connecting) => connecting.await?,
}
} else {
connecting.await?
};
let h3_connection = h3_quinn::Connection::new(quic_connection);
let (driver, send_request) = h3::client::new(h3_connection)
.await
.map_err(|e| ProtoError::from(format!("h3 connection failed: {e}")))?;
Ok(H3ClientStream {
name_server_name: Arc::from(dns_name),
name_server,
driver,
send_request,
is_shutdown: false,
})
}
}
impl Default for H3ClientStreamBuilder {
fn default() -> Self {
Self {
crypto_config: super::client_config_tls13().unwrap(),
transport_config: Arc::new(super::transport()),
bind_addr: None,
}
}
}
/// A future that resolves to an H3ClientStream
pub struct H3ClientConnect(
Pin<Box<dyn Future<Output = Result<H3ClientStream, ProtoError>> + Send>>,
);
impl Future for H3ClientConnect {
type Output = Result<H3ClientStream, ProtoError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.poll_unpin(cx)
}
}
/// A future that resolves to
pub struct H3ClientResponse(Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>);
impl Future for H3ClientResponse {
type Output = Result<DnsResponse, ProtoError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.as_mut().poll(cx).map_err(ProtoError::from)
}
}
#[cfg(all(test, any(feature = "native-certs", feature = "webpki-roots")))]
mod tests {
use std::net::SocketAddr;
use std::str::FromStr;
use rustls::KeyLogFile;
use tokio::runtime::Runtime;
use crate::op::{Message, Query, ResponseCode};
use crate::rr::rdata::{A, AAAA};
use crate::rr::{Name, RData, RecordType};
use crate::xfer::{DnsRequestOptions, FirstAnswer};
use super::*;
#[test]
fn test_h3_google() {
//env_logger::try_init().ok();
let google = SocketAddr::from(([8, 8, 8, 8], 443));
let mut request = Message::new();
let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
request.add_query(query);
let request = DnsRequest::new(request, DnsRequestOptions::default());
let mut client_config = super::super::client_config_tls13().unwrap();
client_config.key_log = Arc::new(KeyLogFile::new());
let mut h3_builder = H3ClientStream::builder();
h3_builder.crypto_config(client_config);
let connect = h3_builder.build(google, "dns.google".to_string());
// tokio runtime stuff...
let runtime = Runtime::new().expect("could not start runtime");
let mut h3 = runtime.block_on(connect).expect("h3 connect failed");
let response = runtime
.block_on(h3.send_message(request).first_answer())
.expect("send_message failed");
let record = &response.answers()[0];
let addr = record
.data()
.and_then(RData::as_a)
.expect("Expected A record");
assert_eq!(addr, &A::new(93, 184, 216, 34));
//
// assert that the connection works for a second query
let mut request = Message::new();
let query = Query::query(
Name::from_str("www.example.com.").unwrap(),
RecordType::AAAA,
);
request.add_query(query);
let request = DnsRequest::new(request, DnsRequestOptions::default());
for _ in 0..3 {
let response = runtime
.block_on(h3.send_message(request.clone()).first_answer())
.expect("send_message failed");
if response.response_code() == ResponseCode::ServFail {
continue;
}
let record = &response.answers()[0];
let addr = record
.data()
.and_then(RData::as_aaaa)
.expect("invalid response, expected A record");
assert_eq!(
addr,
&AAAA::new(0x2606, 0x2800, 0x0220, 0x0001, 0x0248, 0x1893, 0x25c8, 0x1946)
);
}
}
#[test]
fn test_h3_google_with_pure_ip_address_server() {
//env_logger::try_init().ok();
let google = SocketAddr::from(([8, 8, 8, 8], 443));
let mut request = Message::new();
let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
request.add_query(query);
let request = DnsRequest::new(request, DnsRequestOptions::default());
let mut client_config = super::super::client_config_tls13().unwrap();
client_config.key_log = Arc::new(KeyLogFile::new());
let mut h3_builder = H3ClientStream::builder();
h3_builder.crypto_config(client_config);
let connect = h3_builder.build(google, google.ip().to_string());
// tokio runtime stuff...
let runtime = Runtime::new().expect("could not start runtime");
let mut h3 = runtime.block_on(connect).expect("h3 connect failed");
let response = runtime
.block_on(h3.send_message(request).first_answer())
.expect("send_message failed");
let record = &response.answers()[0];
let addr = record
.data()
.and_then(RData::as_a)
.expect("Expected A record");
assert_eq!(addr, &A::new(93, 184, 216, 34));
//
// assert that the connection works for a second query
let mut request = Message::new();
let query = Query::query(
Name::from_str("www.example.com.").unwrap(),
RecordType::AAAA,
);
request.add_query(query);
let request = DnsRequest::new(request, DnsRequestOptions::default());
for _ in 0..3 {
let response = runtime
.block_on(h3.send_message(request.clone()).first_answer())
.expect("send_message failed");
if response.response_code() == ResponseCode::ServFail {
continue;
}
let record = &response.answers()[0];
let addr = record
.data()
.and_then(RData::as_aaaa)
.expect("invalid response, expected A record");
assert_eq!(
addr,
&AAAA::new(0x2606, 0x2800, 0x0220, 0x0001, 0x0248, 0x1893, 0x25c8, 0x1946)
);
}
}
/// Currently fails, see <https://github.com/hyperium/h3/issues/206>.
#[test]
#[ignore] // cloudflare has been unreliable as a public test service.
fn test_h3_cloudflare() {
// self::env_logger::try_init().ok();
let cloudflare = SocketAddr::from(([1, 1, 1, 1], 443));
let mut request = Message::new();
let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
request.add_query(query);
let request = DnsRequest::new(request, DnsRequestOptions::default());
let mut client_config = super::super::client_config_tls13().unwrap();
client_config.key_log = Arc::new(KeyLogFile::new());
let mut h3_builder = H3ClientStream::builder();
h3_builder.crypto_config(client_config);
let connect = h3_builder.build(cloudflare, "cloudflare-dns.com".to_string());
// tokio runtime stuff...
let runtime = Runtime::new().expect("could not start runtime");
let mut h3 = runtime.block_on(connect).expect("h3 connect failed");
let response = runtime
.block_on(h3.send_message(request).first_answer())
.expect("send_message failed");
let record = &response.answers()[0];
let addr = record
.data()
.and_then(RData::as_a)
.expect("invalid response, expected A record");
assert_eq!(addr, &A::new(93, 184, 216, 34));
//
// assert that the connection works for a second query
let mut request = Message::new();
let query = Query::query(
Name::from_str("www.example.com.").unwrap(),
RecordType::AAAA,
);
request.add_query(query);
let request = DnsRequest::new(request, DnsRequestOptions::default());
let response = runtime
.block_on(h3.send_message(request).first_answer())
.expect("send_message failed");
let record = &response.answers()[0];
let addr = record
.data()
.and_then(RData::as_aaaa)
.expect("invalid response, expected A record");
assert_eq!(
addr,
&AAAA::new(0x2606, 0x2800, 0x0220, 0x0001, 0x0248, 0x1893, 0x25c8, 0x1946)
);
}
}

View File

@ -0,0 +1,38 @@
// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
//
// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.
//! TLS protocol related components for DNS over HTTP/3 (DoH3)
mod h3_client_stream;
use quinn::{TransportConfig, VarInt};
pub use crate::http::error::{Error as H3Error, Result as H3Result};
pub use crate::quic::client_config_tls13;
pub use self::h3_client_stream::{
H3ClientConnect, H3ClientResponse, H3ClientStream, H3ClientStreamBuilder,
};
const ALPN_H3: &[u8] = b"h3";
/// Returns a default endpoint configuration for DNS-over-QUIC
fn transport() -> TransportConfig {
let mut transport_config = TransportConfig::default();
transport_config.datagram_receive_buffer_size(None);
transport_config.datagram_send_buffer_size(0);
// clients never accept new bidirectional streams
transport_config.max_concurrent_bidi_streams(VarInt::from_u32(3));
// - SETTINGS
// - QPACK encoder
// - QPACK decoder
// - RESERVED (GREASE)
transport_config.max_concurrent_uni_streams(VarInt::from_u32(4));
transport_config
}

View File

@ -9,7 +9,6 @@ use std::num::ParseIntError;
use std::{fmt, io};
use crate::error::ProtoError;
use h2;
use http::header::ToStrError;
use thiserror::Error;
@ -43,7 +42,12 @@ pub enum ErrorKind {
ProtoError(#[from] ProtoError),
#[error("h2: {0}")]
#[cfg(feature = "dns-over-https")]
H2(#[from] h2::Error),
#[error("h3: {0}")]
#[cfg(feature = "dns-over-h3")]
H3(#[from] h3::Error),
}
/// The error type for errors that get returned in the crate
@ -118,12 +122,20 @@ impl From<ProtoError> for Error {
}
}
#[cfg(feature = "dns-over-https")]
impl From<h2::Error> for Error {
fn from(msg: h2::Error) -> Self {
ErrorKind::H2(msg).into()
}
}
#[cfg(feature = "dns-over-h3")]
impl From<h3::Error> for Error {
fn from(msg: h3::Error) -> Self {
ErrorKind::H3(msg).into()
}
}
impl From<Error> for io::Error {
fn from(err: Error) -> Self {
Self::new(io::ErrorKind::Other, format!("https: {err}"))

View File

@ -0,0 +1,37 @@
// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
//
// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.
//! HTTP protocol related components for DNS over HTTP/2 (DoH) and HTTP/3 (DoH3)
pub(crate) const MIME_APPLICATION_DNS: &str = "application/dns-message";
pub(crate) const DNS_QUERY_PATH: &str = "/dns-query";
pub(crate) mod error;
pub mod request;
pub mod response;
/// Represents a version of the HTTP spec.
#[derive(Clone, Copy, Debug)]
pub enum Version {
/// HTTP/2 for DoH.
#[cfg(feature = "dns-over-https")]
Http2,
/// HTTP/3 for DoH3.
#[cfg(feature = "dns-over-h3")]
Http3,
}
impl Version {
fn to_http(self) -> http::Version {
match self {
#[cfg(feature = "dns-over-https")]
Self::Http2 => http::Version::HTTP_2,
#[cfg(feature = "dns-over-h3")]
Self::Http3 => http::Version::HTTP_3,
}
}
}

View File

@ -10,13 +10,14 @@
use std::str::FromStr;
use http::header::{ACCEPT, CONTENT_LENGTH, CONTENT_TYPE};
use http::{header, uri, Request, Uri, Version};
use http::{header, uri, Request, Uri};
use tracing::debug;
use crate::error::ProtoError;
use crate::https::HttpsResult;
use crate::http::error::Result;
use crate::http::Version;
/// Create a new Request for an http/2 dns-message request
/// Create a new Request for an http dns-message request
///
/// ```text
/// https://tools.ietf.org/html/draft-ietf-doh-dns-over-https-10#section-5.1
@ -27,7 +28,7 @@ use crate::https::HttpsResult;
/// [RFC4648].
/// ```
#[allow(clippy::field_reassign_with_default)] // https://github.com/rust-lang/rust-clippy/issues/6527
pub fn new(name_server_name: &str, message_len: usize) -> HttpsResult<Request<()>> {
pub fn new(version: Version, name_server_name: &str, message_len: usize) -> Result<Request<()>> {
// TODO: this is basically the GET version, but it is more expensive than POST
// perhaps add an option if people want better HTTP caching options.
@ -41,7 +42,7 @@ pub fn new(name_server_name: &str, message_len: usize) -> HttpsResult<Request<()
// .body(());
let mut parts = uri::Parts::default();
parts.path_and_query = Some(uri::PathAndQuery::from_static(crate::https::DNS_QUERY_PATH));
parts.path_and_query = Some(uri::PathAndQuery::from_static(crate::http::DNS_QUERY_PATH));
parts.scheme = Some(uri::Scheme::HTTPS);
parts.authority = Some(
uri::Authority::from_str(name_server_name)
@ -55,27 +56,27 @@ pub fn new(name_server_name: &str, message_len: usize) -> HttpsResult<Request<()
let request = Request::builder()
.method("POST")
.uri(url)
.version(Version::HTTP_2)
.header(CONTENT_TYPE, crate::https::MIME_APPLICATION_DNS)
.header(ACCEPT, crate::https::MIME_APPLICATION_DNS)
.version(version.to_http())
.header(CONTENT_TYPE, crate::http::MIME_APPLICATION_DNS)
.header(ACCEPT, crate::http::MIME_APPLICATION_DNS)
.header(CONTENT_LENGTH, message_len)
.body(())
.map_err(|e| ProtoError::from(format!("h2 stream errored: {e}")))?;
.map_err(|e| ProtoError::from(format!("http stream errored: {e}")))?;
Ok(request)
}
/// Verifies the request is something we know what to deal with
pub fn verify<T>(name_server: Option<&str>, request: &Request<T>) -> HttpsResult<()> {
pub fn verify<T>(version: Version, name_server: Option<&str>, request: &Request<T>) -> Result<()> {
// Verify all HTTP parameters
let uri = request.uri();
// validate path
if uri.path() != crate::https::DNS_QUERY_PATH {
if uri.path() != crate::http::DNS_QUERY_PATH {
return Err(format!(
"bad path: {}, expected: {}",
uri.path(),
crate::https::DNS_QUERY_PATH
crate::http::DNS_QUERY_PATH
)
.into());
}
@ -98,7 +99,7 @@ pub fn verify<T>(name_server: Option<&str>, request: &Request<T>) -> HttpsResult
// TODO: switch to mime::APPLICATION_DNS when that stabilizes
match request.headers().get(CONTENT_TYPE).map(|v| v.to_str()) {
Some(Ok(ctype)) if ctype == crate::https::MIME_APPLICATION_DNS => {}
Some(Ok(ctype)) if ctype == crate::http::MIME_APPLICATION_DNS => {}
_ => return Err("unsupported content type".into()),
};
@ -109,7 +110,7 @@ pub fn verify<T>(name_server: Option<&str>, request: &Request<T>) -> HttpsResult
for mime_and_quality in ctype.split(',') {
let mut parts = mime_and_quality.splitn(2, ';');
match parts.next() {
Some(mime) if mime.trim() == crate::https::MIME_APPLICATION_DNS => {
Some(mime) if mime.trim() == crate::http::MIME_APPLICATION_DNS => {
found = true;
break;
}
@ -129,8 +130,14 @@ pub fn verify<T>(name_server: Option<&str>, request: &Request<T>) -> HttpsResult
None => return Err("Accept is unspecified".into()),
};
if request.version() != Version::HTTP_2 {
return Err("only HTTP/2 supported".into());
if request.version() != version.to_http() {
let message = match version {
#[cfg(feature = "dns-over-https")]
Version::Http2 => "only HTTP/2 supported",
#[cfg(feature = "dns-over-h3")]
Version::Http3 => "only HTTP/3 supported",
};
return Err(message.into());
}
debug!(
@ -150,8 +157,16 @@ mod tests {
use super::*;
#[test]
fn test_new_verify() {
let request = new("ns.example.com", 512).expect("error converting to http");
assert!(verify(Some("ns.example.com"), &request).is_ok());
#[cfg(feature = "dns-over-https")]
fn test_new_verify_h2() {
let request = new(Version::Http2, "ns.example.com", 512).expect("error converting to http");
assert!(verify(Version::Http2, Some("ns.example.com"), &request).is_ok());
}
#[test]
#[cfg(feature = "dns-over-h3")]
fn test_new_verify_h3() {
let request = new(Version::Http3, "ns.example.com", 512).expect("error converting to http");
assert!(verify(Version::Http3, Some("ns.example.com"), &request).is_ok());
}
}

View File

@ -8,12 +8,13 @@
//! HTTP request creation and validation
use http::header::{CONTENT_LENGTH, CONTENT_TYPE};
use http::{Response, StatusCode, Version};
use http::{Response, StatusCode};
use crate::error::ProtoError;
use crate::https::HttpsResult;
use crate::http::error::Result;
use crate::http::Version;
/// Create a new Response for an http/2 dns-message request
/// Create a new Response for an http dns-message request
///
/// ```text
/// 4.2.1. Handling DNS and HTTP Errors
@ -38,11 +39,11 @@ use crate::https::HttpsResult;
/// cannot generate a representation suitable for the client (HTTP status
/// code 406, [RFC7231] Section 6.5.6), and so on.
/// ```
pub fn new(message_len: usize) -> HttpsResult<Response<()>> {
pub fn new(version: Version, message_len: usize) -> Result<Response<()>> {
Response::builder()
.status(StatusCode::OK)
.version(Version::HTTP_2)
.header(CONTENT_TYPE, crate::https::MIME_APPLICATION_DNS)
.version(version.to_http())
.header(CONTENT_TYPE, crate::http::MIME_APPLICATION_DNS)
.header(CONTENT_LENGTH, message_len)
.body(())
.map_err(|e| ProtoError::from(format!("invalid response: {e}")).into())

View File

@ -28,6 +28,7 @@ use tokio_rustls::{
use tracing::{debug, warn};
use crate::error::ProtoError;
use crate::http::Version;
use crate::iocompat::AsyncIoStdAsTokio;
use crate::op::Message;
use crate::tcp::{Connect, DnsTcpStream};
@ -71,7 +72,8 @@ impl HttpsClientStream {
};
// build up the http request
let request = crate::https::request::new(&name_server_name, message.remaining());
let request =
crate::http::request::new(Version::Http2, &name_server_name, message.remaining());
let request =
request.map_err(|err| ProtoError::from(format!("bad http request: {err}")))?;
@ -160,12 +162,12 @@ impl HttpsClientStream {
ProtoError::from(format!("ContentType header not a string: {err}"))
})
})
.unwrap_or(Ok(crate::https::MIME_APPLICATION_DNS))?;
.unwrap_or(Ok(crate::http::MIME_APPLICATION_DNS))?;
if content_type != crate::https::MIME_APPLICATION_DNS {
if content_type != crate::http::MIME_APPLICATION_DNS {
return Err(ProtoError::from(format!(
"ContentType unsupported (must be '{}'): '{}'",
crate::https::MIME_APPLICATION_DNS,
crate::http::MIME_APPLICATION_DNS,
content_type
)));
}

View File

@ -18,6 +18,7 @@ use http::header::CONTENT_LENGTH;
use http::{Method, Request};
use tracing::debug;
use crate::http::Version;
use crate::https::HttpsError;
/// Given an HTTP request, return a future that will result in the next sequence of bytes.
@ -34,7 +35,7 @@ where
debug!("Received request: {:#?}", request);
let this_server_name = this_server_name.as_deref();
match crate::https::request::verify(this_server_name, &request) {
match crate::http::request::verify(Version::Http2, this_server_name, &request) {
Ok(_) => (),
Err(err) => return Err(err),
}
@ -97,7 +98,7 @@ mod tests {
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::https::request;
use crate::http::request;
use crate::op::Message;
use super::*;
@ -123,7 +124,7 @@ mod tests {
let msg_bytes = message.to_vec().unwrap();
let len = msg_bytes.len();
let stream = TestBytesStream(vec![Ok(Bytes::from(msg_bytes))]);
let request = request::new("ns.example.com", len).unwrap();
let request = request::new(Version::Http2, "ns.example.com", len).unwrap();
let request = request.map(|()| stream);
let from_post = message_from(Some(Arc::from("ns.example.com")), request);

View File

@ -7,16 +7,10 @@
//! TLS protocol related components for DNS over HTTPS (DoH)
const MIME_APPLICATION_DNS: &str = "application/dns-message";
const DNS_QUERY_PATH: &str = "/dns-query";
mod error;
mod https_client_stream;
pub mod https_server;
pub mod request;
pub mod response;
pub use self::error::{Error as HttpsError, Result as HttpsResult};
pub use crate::http::error::{Error as HttpsError, Result as HttpsResult};
pub use self::https_client_stream::{
HttpsClientConnect, HttpsClientResponse, HttpsClientStream, HttpsClientStreamBuilder,

View File

@ -62,6 +62,15 @@ pub fn spawn_bg<F: Future<Output = R> + Send + 'static, R: Send + 'static>(
}
pub mod error;
#[cfg(feature = "dns-over-h3")]
#[cfg_attr(docsrs, doc(cfg(feature = "dns-over-h3")))]
pub mod h3;
#[cfg(any(feature = "dns-over-https", feature = "dns-over-h3"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "dns-over-https", feature = "dns-over-h3")))
)]
pub mod http;
#[cfg(feature = "dns-over-https")]
#[cfg_attr(docsrs, doc(cfg(feature = "dns-over-https")))]
pub mod https;

View File

@ -10,6 +10,7 @@
mod quic_client_stream;
mod quic_config;
mod quic_server;
pub(crate) mod quic_socket;
mod quic_stream;
pub use self::quic_client_stream::{

View File

@ -5,7 +5,6 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.
use std::fmt::{Debug, Formatter};
use std::{
fmt::{self, Display},
future::Future,
@ -16,12 +15,13 @@ use std::{
};
use futures_util::{future::FutureExt, stream::Stream};
use quinn::{AsyncUdpSocket, ClientConfig, Connection, Endpoint, TransportConfig, VarInt};
use quinn::{ClientConfig, Connection, Endpoint, TransportConfig, VarInt};
use rustls::{version::TLS13, ClientConfig as TlsClientConfig};
use crate::udp::{DnsUdpSocket, QuicLocalAddr};
use crate::{
error::ProtoError,
quic::quic_socket::QuinnAsyncUdpSocketAdapter,
quic::quic_stream::{DoqErrorCode, QuicStream},
udp::UdpSocket,
xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream},
@ -356,100 +356,3 @@ impl Future for QuicClientResponse {
self.0.as_mut().poll(cx).map_err(ProtoError::from)
}
}
/// Wrapper used for quinn::Endpoint::new_with_abstract_socket
struct QuinnAsyncUdpSocketAdapter<S: DnsUdpSocket + QuicLocalAddr> {
io: S,
}
impl<S: DnsUdpSocket + QuicLocalAddr> Debug for QuinnAsyncUdpSocketAdapter<S> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str("Wrapper for quinn::AsyncUdpSocket")
}
}
/// TODO: Naive implementation. Look forward to future improvements.
impl<S: DnsUdpSocket + QuicLocalAddr + 'static> AsyncUdpSocket for QuinnAsyncUdpSocketAdapter<S> {
fn poll_send(
&self,
_state: &quinn::udp::UdpState,
cx: &mut Context<'_>,
transmits: &[quinn::udp::Transmit],
) -> Poll<std::io::Result<usize>> {
// logics from quinn-udp::fallback.rs
let io = &self.io;
let mut sent = 0;
for transmit in transmits {
match io.poll_send_to(cx, &transmit.contents, transmit.destination) {
Poll::Ready(ready) => match ready {
Ok(_) => {
sent += 1;
}
// We need to report that some packets were sent in this case, so we rely on
// errors being either harmlessly transient (in the case of WouldBlock) or
// recurring on the next call.
Err(_) if sent != 0 => return Poll::Ready(Ok(sent)),
Err(e) => {
if e.kind() == std::io::ErrorKind::WouldBlock {
return Poll::Ready(Err(e));
}
// Other errors are ignored, since they will ususally be handled
// by higher level retransmits and timeouts.
// - PermissionDenied errors have been observed due to iptable rules.
// Those are not fatal errors, since the
// configuration can be dynamically changed.
// - Destination unreachable errors have been observed for other
// log_sendmsg_error(&mut self.last_send_error, e, transmit);
sent += 1;
}
},
Poll::Pending => {
return if sent == 0 {
Poll::Pending
} else {
Poll::Ready(Ok(sent))
}
}
}
}
Poll::Ready(Ok(sent))
}
fn poll_recv(
&self,
cx: &mut Context<'_>,
bufs: &mut [std::io::IoSliceMut<'_>],
meta: &mut [quinn::udp::RecvMeta],
) -> Poll<std::io::Result<usize>> {
// logics from quinn-udp::fallback.rs
let io = &self.io;
let Some(buf) = bufs.get_mut(0) else {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"no buf",
)));
};
match io.poll_recv_from(cx, buf.as_mut()) {
Poll::Ready(res) => match res {
Ok((len, addr)) => {
meta[0] = quinn::udp::RecvMeta {
len,
stride: len,
addr,
ecn: None,
dst_ip: None,
};
Poll::Ready(Ok(1))
}
Err(err) => Poll::Ready(Err(err)),
},
Poll::Pending => Poll::Pending,
}
}
fn local_addr(&self) -> std::io::Result<std::net::SocketAddr> {
self.io.local_addr()
}
}

View File

@ -0,0 +1,113 @@
// Copyright 2015-2022 Benjamin Fry <benjaminfry@me.com>
//
// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.
use std::fmt::{Debug, Formatter};
use std::{
fmt,
task::{Context, Poll},
};
use quinn::AsyncUdpSocket;
use crate::udp::{DnsUdpSocket, QuicLocalAddr};
/// Wrapper used for quinn::Endpoint::new_with_abstract_socket
pub(crate) struct QuinnAsyncUdpSocketAdapter<S: DnsUdpSocket + QuicLocalAddr> {
pub(crate) io: S,
}
impl<S: DnsUdpSocket + QuicLocalAddr> Debug for QuinnAsyncUdpSocketAdapter<S> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str("Wrapper for quinn::AsyncUdpSocket")
}
}
/// TODO: Naive implementation. Look forward to future improvements.
impl<S: DnsUdpSocket + QuicLocalAddr + 'static> AsyncUdpSocket for QuinnAsyncUdpSocketAdapter<S> {
fn poll_send(
&self,
_state: &quinn::udp::UdpState,
cx: &mut Context<'_>,
transmits: &[quinn::udp::Transmit],
) -> Poll<std::io::Result<usize>> {
// logics from quinn-udp::fallback.rs
let io = &self.io;
let mut sent = 0;
for transmit in transmits {
match io.poll_send_to(cx, &transmit.contents, transmit.destination) {
Poll::Ready(ready) => match ready {
Ok(_) => {
sent += 1;
}
// We need to report that some packets were sent in this case, so we rely on
// errors being either harmlessly transient (in the case of WouldBlock) or
// recurring on the next call.
Err(_) if sent != 0 => return Poll::Ready(Ok(sent)),
Err(e) => {
if e.kind() == std::io::ErrorKind::WouldBlock {
return Poll::Ready(Err(e));
}
// Other errors are ignored, since they will ususally be handled
// by higher level retransmits and timeouts.
// - PermissionDenied errors have been observed due to iptable rules.
// Those are not fatal errors, since the
// configuration can be dynamically changed.
// - Destination unreachable errors have been observed for other
// log_sendmsg_error(&mut self.last_send_error, e, transmit);
sent += 1;
}
},
Poll::Pending => {
return if sent == 0 {
Poll::Pending
} else {
Poll::Ready(Ok(sent))
}
}
}
}
Poll::Ready(Ok(sent))
}
fn poll_recv(
&self,
cx: &mut Context<'_>,
bufs: &mut [std::io::IoSliceMut<'_>],
meta: &mut [quinn::udp::RecvMeta],
) -> Poll<std::io::Result<usize>> {
// logics from quinn-udp::fallback.rs
let io = &self.io;
let Some(buf) = bufs.get_mut(0) else {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"no buf",
)));
};
match io.poll_recv_from(cx, buf.as_mut()) {
Poll::Ready(res) => match res {
Ok((len, addr)) => {
meta[0] = quinn::udp::RecvMeta {
len,
stride: len,
addr,
ecn: None,
dst_ip: None,
};
Poll::Ready(Ok(1))
}
Err(err) => Poll::Ready(Err(err)),
},
Poll::Pending => Poll::Pending,
}
}
fn local_addr(&self) -> std::io::Result<std::net::SocketAddr> {
self.io.local_addr()
}
}

View File

@ -13,7 +13,7 @@ use h2::server;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken;
use tracing::{debug, warn};
use trust_dns_proto::rr::Record;
use trust_dns_proto::{http::Version, rr::Record};
use crate::{
authority::MessageResponse,
@ -108,7 +108,7 @@ impl ResponseHandler for HttpsResponseHandle {
impl Iterator<Item = &'a Record> + Send + 'a,
>,
) -> io::Result<ResponseInfo> {
use crate::proto::https::response;
use crate::proto::http::response;
use crate::proto::https::HttpsError;
use crate::proto::serialize::binary::BinEncoder;
@ -119,7 +119,7 @@ impl ResponseHandler for HttpsResponseHandle {
response.destructive_emit(&mut encoder)?
};
let bytes = Bytes::from(bytes);
let response = response::new(bytes.len())?;
let response = response::new(Version::Http2, bytes.len())?;
debug!("sending response: {:#?}", response);
let mut stream = self

View File

@ -33,6 +33,9 @@ dns-over-https-rustls: (default "--features=dns-over-https-rustls" "--ignore=\\{
# Check, build, and test all crates with dns-over-quic enabled
dns-over-quic: (default "--features=dns-over-quic" "--ignore=\\{async-std-resolver,trust-dns-compatibility\\}")
# Check, build, and test all crates with dns-over-h3 enabled
dns-over-h3: (default "--features=dns-over-h3" "--ignore=\\{async-std-resolver,trust-dns-compatibility,trust-dns-client\\}")
# Check, build, and test all crates with dns-over-native-tls enabled
dns-over-native-tls: (default "--features=dns-over-native-tls" "--ignore=\\{async-std-resolver,trust-dns-compatibility,trust-dns-server,trust-dns,trust-dns-util,trust-dns-integration\\}")