Fix truncation for UDP
This fixes a couple of issues for UDP on both the client and server: * Previously, the UdpClientStream was using a fixed `2048` for the size of the receive buffer. This can cause problems on interfaces with a larger MTU. #1096 adjusted this value on the server side to 4096 (the maximum as recommended by RFC6891). This PR sets a constant that is shared by the UDP client and server. Additionally, the client uses EDNS in the request to further trim down the buffer size. * The Server previously was not setting a maximum for the `BinEncoder`, which defaults to `u16::MAX` (i.e. effectively no truncation for UDP). This PR sets an appropriate maximum for the `BinEncoder` based on the response EDNS and protocol being used. Fixes: #1973
This commit is contained in:
parent
16c7f987d7
commit
0a1306ba8f
@ -132,7 +132,7 @@ impl<'a> BinEncoder<'a> {
|
||||
BinEncoder {
|
||||
offset: offset as usize,
|
||||
// TODO: add max_size to signature
|
||||
buffer: private::MaximalBuf::new(u16::max_value(), buf),
|
||||
buffer: private::MaximalBuf::new(u16::MAX, buf),
|
||||
name_pointers: Vec::new(),
|
||||
mode,
|
||||
canonical_names: false,
|
||||
@ -240,8 +240,8 @@ impl<'a> BinEncoder<'a> {
|
||||
/// The location is the current position in the buffer
|
||||
/// implicitly, it is expected that the name will be written to the stream after the current index.
|
||||
pub fn store_label_pointer(&mut self, start: usize, end: usize) {
|
||||
assert!(start <= (u16::max_value() as usize));
|
||||
assert!(end <= (u16::max_value() as usize));
|
||||
assert!(start <= (u16::MAX as usize));
|
||||
assert!(end <= (u16::MAX as usize));
|
||||
assert!(start <= end);
|
||||
if self.offset < 0x3FFF_usize {
|
||||
self.name_pointers
|
||||
@ -255,7 +255,7 @@ impl<'a> BinEncoder<'a> {
|
||||
|
||||
for (match_start, matcher) in &self.name_pointers {
|
||||
if matcher.as_slice() == search {
|
||||
assert!(match_start <= &(u16::max_value() as usize));
|
||||
assert!(match_start <= &(u16::MAX as usize));
|
||||
return Some(*match_start as u16);
|
||||
}
|
||||
}
|
||||
|
@ -21,3 +21,7 @@ mod udp_stream;
|
||||
|
||||
pub use self::udp_client_stream::{UdpClientConnect, UdpClientStream};
|
||||
pub use self::udp_stream::{DnsUdpSocket, QuicLocalAddr, UdpSocket, UdpStream};
|
||||
|
||||
/// Max size for the UDP receive buffer as recommended by
|
||||
/// [RFC6891](https://datatracker.ietf.org/doc/html/rfc6891#section-6.2.5).
|
||||
pub const MAX_RECEIVE_BUFFER_SIZE: usize = 4096;
|
||||
|
@ -15,13 +15,13 @@ use std::task::{Context, Poll};
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use futures_util::{future::Future, stream::Stream};
|
||||
use tracing::{debug, warn};
|
||||
use tracing::{debug, trace, warn};
|
||||
|
||||
use crate::error::ProtoError;
|
||||
use crate::op::message::NoopMessageFinalizer;
|
||||
use crate::op::{Message, MessageFinalizer, MessageVerifier};
|
||||
use crate::udp::udp_stream::{NextRandomUdpSocket, UdpCreator, UdpSocket};
|
||||
use crate::udp::DnsUdpSocket;
|
||||
use crate::udp::{DnsUdpSocket, MAX_RECEIVE_BUFFER_SIZE};
|
||||
use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream, SerialMessage};
|
||||
use crate::Time;
|
||||
|
||||
@ -212,6 +212,9 @@ impl<S: DnsUdpSocket + Send + 'static, MF: MessageFinalizer> DnsRequestSender
|
||||
}
|
||||
}
|
||||
|
||||
// Get an appropriate read buffer size.
|
||||
let recv_buf_size = MAX_RECEIVE_BUFFER_SIZE.min(message.max_payload() as usize);
|
||||
|
||||
let bytes = match message.to_vec() {
|
||||
Ok(bytes) => bytes,
|
||||
Err(err) => {
|
||||
@ -235,7 +238,8 @@ impl<S: DnsUdpSocket + Send + 'static, MF: MessageFinalizer> DnsRequestSender
|
||||
self.timeout,
|
||||
Box::pin(async move {
|
||||
let socket: S = NextRandomUdpSocket::new_with_closure(&addr, creator).await?;
|
||||
send_serial_message_inner(message, message_id, verifier, socket).await
|
||||
send_serial_message_inner(message, message_id, verifier, socket, recv_buf_size)
|
||||
.await
|
||||
}),
|
||||
)
|
||||
.into()
|
||||
@ -298,6 +302,7 @@ async fn send_serial_message_inner<S: DnsUdpSocket + Send>(
|
||||
msg_id: u16,
|
||||
verifier: Option<MessageVerifier>,
|
||||
socket: S,
|
||||
recv_buf_size: usize,
|
||||
) -> Result<DnsResponse, ProtoError> {
|
||||
let bytes = msg.bytes();
|
||||
let addr = msg.addr();
|
||||
@ -311,13 +316,16 @@ async fn send_serial_message_inner<S: DnsUdpSocket + Send>(
|
||||
)));
|
||||
}
|
||||
|
||||
// Create the receive buffer.
|
||||
trace!("creating UDP receive buffer with size {recv_buf_size}");
|
||||
let mut recv_buf = vec![0; recv_buf_size];
|
||||
|
||||
// TODO: limit the max number of attempted messages? this relies on a timeout to die...
|
||||
loop {
|
||||
// TODO: consider making this heap based? need to verify it matches EDNS settings
|
||||
let mut recv_buf = [0u8; 2048];
|
||||
|
||||
let (len, src) = socket.recv_from(&mut recv_buf).await?;
|
||||
let buffer: Vec<_> = recv_buf.iter().take(len).cloned().collect();
|
||||
|
||||
// Copy the slice of read bytes.
|
||||
let buffer: Vec<_> = Vec::from(&recv_buf[0..len]);
|
||||
|
||||
// compare expected src to received packet
|
||||
let request_target = msg.addr();
|
||||
|
@ -19,6 +19,7 @@ use rand;
|
||||
use rand::distributions::{uniform::Uniform, Distribution};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::udp::MAX_RECEIVE_BUFFER_SIZE;
|
||||
use crate::xfer::{BufDnsStreamHandle, SerialMessage, StreamReceiver};
|
||||
use crate::Time;
|
||||
|
||||
@ -220,7 +221,7 @@ impl<S: DnsUdpSocket + Send + 'static> Stream for UdpStream<S> {
|
||||
// receive all inbound messages
|
||||
|
||||
// TODO: this should match edns settings
|
||||
let mut buf = [0u8; 4096];
|
||||
let mut buf = [0u8; MAX_RECEIVE_BUFFER_SIZE];
|
||||
let (len, src) = ready!(socket.poll_recv_from(cx, &mut buf))?;
|
||||
|
||||
let serial_message = SerialMessage::new(buf.iter().take(len).cloned().collect(), src);
|
||||
|
@ -47,22 +47,23 @@ async fn send_response<'a, R: ResponseHandler>(
|
||||
>,
|
||||
mut response_handle: R,
|
||||
) -> io::Result<ResponseInfo> {
|
||||
#[cfg(feature = "dnssec")]
|
||||
if let Some(mut resp_edns) = response_edns {
|
||||
// set edns DAU and DHU
|
||||
// send along the algorithms which are supported by this authority
|
||||
let mut algorithms = SupportedAlgorithms::default();
|
||||
algorithms.set(Algorithm::RSASHA256);
|
||||
algorithms.set(Algorithm::ECDSAP256SHA256);
|
||||
algorithms.set(Algorithm::ECDSAP384SHA384);
|
||||
algorithms.set(Algorithm::ED25519);
|
||||
#[cfg(feature = "dnssec")]
|
||||
{
|
||||
// set edns DAU and DHU
|
||||
// send along the algorithms which are supported by this authority
|
||||
let mut algorithms = SupportedAlgorithms::default();
|
||||
algorithms.set(Algorithm::RSASHA256);
|
||||
algorithms.set(Algorithm::ECDSAP256SHA256);
|
||||
algorithms.set(Algorithm::ECDSAP384SHA384);
|
||||
algorithms.set(Algorithm::ED25519);
|
||||
|
||||
let dau = EdnsOption::DAU(algorithms);
|
||||
let dhu = EdnsOption::DHU(algorithms);
|
||||
|
||||
resp_edns.options_mut().insert(dau);
|
||||
resp_edns.options_mut().insert(dhu);
|
||||
let dau = EdnsOption::DAU(algorithms);
|
||||
let dhu = EdnsOption::DHU(algorithms);
|
||||
|
||||
resp_edns.options_mut().insert(dau);
|
||||
resp_edns.options_mut().insert(dhu);
|
||||
}
|
||||
response.set_edns(resp_edns);
|
||||
}
|
||||
|
||||
|
@ -96,6 +96,11 @@ where
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets a reference to the EDNS options for the Response.
|
||||
pub fn get_edns(&self) -> &Option<Edns> {
|
||||
&self.edns
|
||||
}
|
||||
|
||||
/// Consumes self, and emits to the encoder.
|
||||
pub fn destructive_emit(mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<ResponseInfo> {
|
||||
// soa records are part of the nameserver section
|
||||
|
@ -7,9 +7,10 @@
|
||||
|
||||
use std::{io, net::SocketAddr};
|
||||
|
||||
use tracing::debug;
|
||||
use tracing::{debug, trace};
|
||||
use trust_dns_proto::rr::Record;
|
||||
|
||||
use crate::server::Protocol;
|
||||
use crate::{
|
||||
authority::MessageResponse,
|
||||
proto::{
|
||||
@ -46,12 +47,43 @@ pub trait ResponseHandler: Clone + Send + Sync + Unpin + 'static {
|
||||
pub struct ResponseHandle {
|
||||
dst: SocketAddr,
|
||||
stream_handle: BufDnsStreamHandle,
|
||||
protocol: Protocol,
|
||||
}
|
||||
|
||||
impl ResponseHandle {
|
||||
/// Returns a new `ResponseHandle` for sending a response message
|
||||
pub fn new(dst: SocketAddr, stream_handle: BufDnsStreamHandle) -> Self {
|
||||
Self { dst, stream_handle }
|
||||
pub fn new(dst: SocketAddr, stream_handle: BufDnsStreamHandle, protocol: Protocol) -> Self {
|
||||
Self {
|
||||
dst,
|
||||
stream_handle,
|
||||
protocol,
|
||||
}
|
||||
}
|
||||
|
||||
/// Selects an appropriate maximum serialized size for the given response.
|
||||
fn max_size_for_response<'a>(
|
||||
&self,
|
||||
response: &MessageResponse<
|
||||
'_,
|
||||
'a,
|
||||
impl Iterator<Item = &'a Record> + Send + 'a,
|
||||
impl Iterator<Item = &'a Record> + Send + 'a,
|
||||
impl Iterator<Item = &'a Record> + Send + 'a,
|
||||
impl Iterator<Item = &'a Record> + Send + 'a,
|
||||
>,
|
||||
) -> u16 {
|
||||
match self.protocol {
|
||||
Protocol::Udp => {
|
||||
// Use EDNS, if available.
|
||||
if let Some(edns) = response.get_edns() {
|
||||
edns.max_payload()
|
||||
} else {
|
||||
// No EDNS, use the recommended max from RFC6891.
|
||||
trust_dns_proto::udp::MAX_RECEIVE_BUFFER_SIZE as u16
|
||||
}
|
||||
}
|
||||
_ => u16::MAX,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -79,6 +111,12 @@ impl ResponseHandler for ResponseHandle {
|
||||
let mut buffer = Vec::with_capacity(512);
|
||||
let encode_result = {
|
||||
let mut encoder = BinEncoder::new(&mut buffer);
|
||||
|
||||
// Set an appropriate maximum on the encoder.
|
||||
let max_size = self.max_size_for_response(&response);
|
||||
trace!("setting response max size: {max_size} for protocol: {:?}", self.protocol);
|
||||
encoder.set_max_size(max_size);
|
||||
|
||||
response.destructive_emit(&mut encoder)
|
||||
};
|
||||
|
||||
|
@ -716,7 +716,7 @@ pub(crate) async fn handle_raw_request<T: RequestHandler>(
|
||||
response_handler: BufDnsStreamHandle,
|
||||
) {
|
||||
let src_addr = message.addr();
|
||||
let response_handler = ResponseHandle::new(message.addr(), response_handler);
|
||||
let response_handler = ResponseHandle::new(message.addr(), response_handler, protocol);
|
||||
|
||||
self::handle_request(
|
||||
message.bytes(),
|
||||
|
118
tests/integration-tests/tests/truncation_tests.rs
Normal file
118
tests/integration-tests/tests/truncation_tests.rs
Normal file
@ -0,0 +1,118 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::UdpSocket;
|
||||
use trust_dns_client::client::AsyncClient;
|
||||
use trust_dns_proto::op::{Edns, Message, MessageType, OpCode, Query};
|
||||
use trust_dns_proto::rr::rdata::{A, SOA};
|
||||
use trust_dns_proto::rr::{DNSClass, Name, RData, Record, RecordSet, RecordType, RrKey};
|
||||
use trust_dns_proto::udp::UdpClientStream;
|
||||
use trust_dns_proto::xfer::FirstAnswer;
|
||||
use trust_dns_proto::DnsHandle;
|
||||
use trust_dns_server::authority::{Catalog, ZoneType};
|
||||
use trust_dns_server::store::in_memory::InMemoryAuthority;
|
||||
use trust_dns_server::ServerFuture;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_truncation() {
|
||||
let _guard = subscribe();
|
||||
|
||||
let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 0));
|
||||
let udp_socket = UdpSocket::bind(&addr).await.unwrap();
|
||||
|
||||
let nameserver = udp_socket.local_addr().unwrap();
|
||||
println!("udp_socket on port: {nameserver}");
|
||||
|
||||
// Create and start the server.
|
||||
let mut server = ServerFuture::new(new_large_catalog(128));
|
||||
server.register_socket(udp_socket);
|
||||
tokio::spawn(server.block_until_done());
|
||||
|
||||
// Create the UDP client.
|
||||
let stream = UdpClientStream::<UdpSocket>::new(nameserver);
|
||||
let (mut client, bg) = AsyncClient::connect(stream).await.unwrap();
|
||||
|
||||
// Run the client exchange in the background.
|
||||
tokio::spawn(bg);
|
||||
|
||||
// Build the query.
|
||||
let max_payload = 512;
|
||||
let mut msg = Message::new();
|
||||
msg.add_query({
|
||||
let mut query = Query::query(large_name(), RecordType::A);
|
||||
query.set_query_class(DNSClass::IN);
|
||||
query
|
||||
})
|
||||
.set_id(rand::random::<u16>())
|
||||
.set_message_type(MessageType::Query)
|
||||
.set_op_code(OpCode::Query)
|
||||
.set_recursion_desired(true)
|
||||
.set_edns({
|
||||
let mut edns = Edns::new();
|
||||
edns.set_max_payload(max_payload).set_version(0);
|
||||
edns
|
||||
});
|
||||
|
||||
let result = client.send(msg).first_answer().await.expect("query failed");
|
||||
|
||||
assert!(result.truncated());
|
||||
assert_eq!(max_payload, result.max_payload());
|
||||
}
|
||||
|
||||
// TODO: should we do this for all of the integration tests?
|
||||
fn subscribe() -> tracing::subscriber::DefaultGuard {
|
||||
let sub = tracing_subscriber::FmtSubscriber::builder()
|
||||
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.finish();
|
||||
tracing::subscriber::set_default(sub)
|
||||
}
|
||||
|
||||
pub fn new_large_catalog(num_records: u32) -> Catalog {
|
||||
// Create a large record set.
|
||||
let name = large_name();
|
||||
let mut record_set = RecordSet::new(&name, RecordType::A, 0);
|
||||
for i in 1..num_records + 1 {
|
||||
let ip = Ipv4Addr::from(i);
|
||||
let rdata = RData::A(A(ip));
|
||||
record_set.insert(Record::from_rdata(name.clone(), 86400, rdata), 0);
|
||||
}
|
||||
|
||||
let mut soa_record_set = RecordSet::new(&name, RecordType::SOA, 0);
|
||||
soa_record_set.insert(
|
||||
Record::from_rdata(
|
||||
name.clone(),
|
||||
86400,
|
||||
RData::SOA(SOA::new(
|
||||
n("sns.dns.icann.org."),
|
||||
n("noc.dns.icann.org."),
|
||||
2015082403,
|
||||
7200,
|
||||
3600,
|
||||
1209600,
|
||||
3600,
|
||||
)),
|
||||
),
|
||||
0,
|
||||
);
|
||||
|
||||
let mut records = BTreeMap::new();
|
||||
records.insert(RrKey::new(name.clone().into(), RecordType::A), record_set);
|
||||
records.insert(RrKey::new(name.into(), RecordType::SOA), soa_record_set);
|
||||
let authority =
|
||||
InMemoryAuthority::new(Name::root(), records, ZoneType::Primary, false).unwrap();
|
||||
|
||||
let mut catalog: Catalog = Catalog::new();
|
||||
catalog.upsert(Name::root().into(), Box::new(Arc::new(authority)));
|
||||
catalog
|
||||
}
|
||||
|
||||
const LARGE_NAME: &str = "large.com";
|
||||
|
||||
fn large_name() -> Name {
|
||||
n(LARGE_NAME)
|
||||
}
|
||||
|
||||
pub fn n<S: AsRef<str>>(name: S) -> Name {
|
||||
Name::from_str(name.as_ref()).unwrap()
|
||||
}
|
Loading…
Reference in New Issue
Block a user