Simplify encode and decode of fixed length types

This commit is contained in:
Benjamin Fry 2022-12-21 14:26:35 -08:00
parent 245b01750d
commit afe368b5bd
9 changed files with 256 additions and 201 deletions

View File

@ -22,7 +22,7 @@ use crate::{
rdata::{DNSSECRData, DNSKEY, KEY, SIG},
tbs, Algorithm, KeyPair, Private, TBS,
},
{DNSClass, Name, RData, RecordData, RecordType},
{DNSClass, Name, RData, RecordType},
},
serialize::binary::BinEncoder,
};

View File

@ -3,7 +3,7 @@
use super::rdata::{sig, DNSSECRData, SIG};
use crate::error::*;
use crate::rr::dnssec::Algorithm;
use crate::rr::{DNSClass, Name, RData, Record, RecordData, RecordType};
use crate::rr::{DNSClass, Name, RData, Record, RecordType};
use crate::serialize::binary::{BinEncodable, BinEncoder, EncodeMode};
/// Data To Be Signed.

View File

@ -34,7 +34,7 @@ pub mod type_bit_map;
use std::fmt;
use crate::error::ProtoResult;
use crate::serialize::binary::{BinDecoder, BinEncoder, Restrict};
use crate::serialize::binary::{BinDecodable, BinDecoder, BinEncodable, Restrict};
pub use self::dns_class::DNSClass;
pub use self::domain::{IntoName, Name, TryParseIp};
@ -49,31 +49,43 @@ pub use lower_name::LowerName;
pub use rr_key::RrKey;
/// RecordData that is stored in a DNS Record.
pub trait RecordData: Clone + Sized + PartialEq + Eq + fmt::Display {
pub trait RecordData: Clone + Sized + PartialEq + Eq + fmt::Display + BinEncodable {
/// Attempts to convert to this RecordData from the RData type, if it is not the correct type the original is returned
#[allow(clippy::result_large_err)]
fn try_from_rdata(data: RData) -> Result<Self, RData>;
/// Read the RecordData from the data stream.
///
/// * `decoder` - data stream from which the RData will be read
/// * `record_type` - specifies the RecordType that has already been read from the stream
/// * `length` - the data length that should be read from the stream for this RecordData
fn read(
decoder: &mut BinDecoder<'_>,
record_type: RecordType,
length: Restrict<u16>,
) -> ProtoResult<Self>;
/// Writes this type to the data stream
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()>;
/// Attempts to borrow this RecordData from the RData type, if it is not the correct type the original is returned
fn try_borrow(data: &RData) -> Result<&Self, &RData>;
// FIXME: make a new AnyRecordType trait
/// Get the associated RecordType for the RData
fn record_type(&self) -> RecordType;
/// Converts this RecordData into generic RData
fn into_rdata(self) -> RData;
}
trait RecordDataDecodable<'r>: Sized {
/// Read the RecordData from the data stream.
///
/// * `decoder` - data stream from which the RData will be read
/// * `record_type` - specifies the RecordType that has already been read from the stream
/// * `length` - the data length that should be read from the stream for this RecordData
fn read_data(
decoder: &mut BinDecoder<'r>,
record_type: RecordType,
length: Restrict<u16>,
) -> ProtoResult<Self>;
}
impl<'r, T> RecordDataDecodable<'r> for T
where
T: 'r + BinDecodable<'r> + Sized,
{
fn read_data(
decoder: &mut BinDecoder<'r>,
_record_type: RecordType,
_length: Restrict<u16>,
) -> ProtoResult<Self> {
T::read(decoder)
}
}

View File

@ -216,32 +216,7 @@ impl SOA {
}
}
impl RecordData for SOA {
fn try_from_rdata(data: RData) -> Result<Self, RData> {
match data {
RData::SOA(soa) => Ok(soa),
_ => Err(data),
}
}
fn read(
decoder: &mut BinDecoder<'_>,
record_type: RecordType,
_length: Restrict<u16>,
) -> ProtoResult<Self> {
assert_eq!(RecordType::SOA, record_type);
Ok(Self {
mname: Name::read(decoder)?,
rname: Name::read(decoder)?,
serial: decoder.read_u32()?.unverified(/*any u32 is valid*/),
refresh: decoder.read_i32()?.unverified(/*any i32 is valid*/),
retry: decoder.read_i32()?.unverified(/*any i32 is valid*/),
expire: decoder.read_i32()?.unverified(/*any i32 is valid*/),
minimum: decoder.read_u32()?.unverified(/*any u32 is valid*/),
})
}
impl BinEncodable for SOA {
/// [RFC 4034](https://tools.ietf.org/html/rfc4034#section-6), DNSSEC Resource Records, March 2005
///
/// This is accurate for all currently known name records.
@ -274,6 +249,29 @@ impl RecordData for SOA {
encoder.emit_u32(self.minimum)?;
Ok(())
}
}
impl<'r> BinDecodable<'r> for SOA {
fn read(decoder: &mut BinDecoder<'r>) -> ProtoResult<Self> {
Ok(Self {
mname: Name::read(decoder)?,
rname: Name::read(decoder)?,
serial: decoder.read_u32()?.unverified(/*any u32 is valid*/),
refresh: decoder.read_i32()?.unverified(/*any i32 is valid*/),
retry: decoder.read_i32()?.unverified(/*any i32 is valid*/),
expire: decoder.read_i32()?.unverified(/*any i32 is valid*/),
minimum: decoder.read_u32()?.unverified(/*any u32 is valid*/),
})
}
}
impl RecordData for SOA {
fn try_from_rdata(data: RData) -> Result<Self, RData> {
match data {
RData::SOA(soa) => Ok(soa),
_ => Err(data),
}
}
fn try_borrow(data: &RData) -> Result<&Self, &RData> {
match data {
@ -364,6 +362,8 @@ impl fmt::Display for SOA {
mod tests {
#![allow(clippy::dbg_macro, clippy::print_stdout)]
use crate::rr::RecordDataDecodable;
use super::*;
#[test]
@ -389,8 +389,8 @@ mod tests {
println!("bytes: {bytes:?}");
let mut decoder: BinDecoder<'_> = BinDecoder::new(bytes);
let read_rdata =
SOA::read(&mut decoder, RecordType::SOA, Restrict::new(len)).expect("Decoding error");
let read_rdata = SOA::read_data(&mut decoder, RecordType::SOA, Restrict::new(len))
.expect("Decoding error");
assert_eq!(rdata, read_rdata);
}
}

View File

@ -23,7 +23,7 @@ use super::rdata::{
CAA, CSYNC, HINFO, MX, NAPTR, NULL, OPENPGPKEY, OPT, SOA, SRV, SSHFP, SVCB, TLSA, TXT,
};
use super::record_type::RecordType;
use super::{rdata, RecordData};
use super::{rdata, RecordData, RecordDataDecodable};
use crate::error::*;
use crate::serialize::binary::*;
@ -748,142 +748,7 @@ impl RData {
}
}
impl RecordData for RData {
fn try_from_rdata(data: RData) -> Result<Self, RData> {
Ok(data)
}
/// Read the RData from the given Decoder
fn read(
decoder: &mut BinDecoder<'_>,
record_type: RecordType,
rdata_length: Restrict<u16>,
) -> ProtoResult<Self> {
let start_idx = decoder.index();
let result = match record_type {
RecordType::A => {
trace!("reading A");
rdata::a::read(decoder).map(Self::A)
}
RecordType::AAAA => {
trace!("reading AAAA");
rdata::aaaa::read(decoder).map(Self::AAAA)
}
RecordType::ANAME => {
trace!("reading ANAME");
rdata::name::read(decoder).map(Self::ANAME)
}
rt @ RecordType::ANY | rt @ RecordType::AXFR | rt @ RecordType::IXFR => {
return Err(ProtoErrorKind::UnknownRecordTypeValue(rt.into()).into());
}
RecordType::CAA => {
trace!("reading CAA");
rdata::caa::read(decoder, rdata_length).map(Self::CAA)
}
RecordType::CNAME => {
trace!("reading CNAME");
rdata::name::read(decoder).map(Self::CNAME)
}
RecordType::CSYNC => {
trace!("reading CSYNC");
rdata::csync::read(decoder, rdata_length).map(Self::CSYNC)
}
RecordType::HINFO => {
trace!("reading HINFO");
rdata::hinfo::read(decoder).map(Self::HINFO)
}
RecordType::HTTPS => {
trace!("reading HTTPS");
rdata::svcb::read(decoder, rdata_length).map(Self::HTTPS)
}
RecordType::ZERO => {
trace!("reading EMPTY");
// we should never get here, since ZERO should be 0 length, and None in the Record.
// this invariant is verified below, and the decoding will fail with an err.
#[allow(deprecated)]
Ok(Self::ZERO)
}
RecordType::MX => {
trace!("reading MX");
rdata::mx::read(decoder).map(Self::MX)
}
RecordType::NAPTR => {
trace!("reading NAPTR");
rdata::naptr::read(decoder).map(Self::NAPTR)
}
RecordType::NULL => {
trace!("reading NULL");
rdata::null::read(decoder, rdata_length).map(Self::NULL)
}
RecordType::NS => {
trace!("reading NS");
rdata::name::read(decoder).map(Self::NS)
}
RecordType::OPENPGPKEY => {
trace!("reading OPENPGPKEY");
rdata::openpgpkey::read(decoder, rdata_length).map(Self::OPENPGPKEY)
}
RecordType::OPT => {
trace!("reading OPT");
rdata::opt::read(decoder, rdata_length).map(Self::OPT)
}
RecordType::PTR => {
trace!("reading PTR");
rdata::name::read(decoder).map(Self::PTR)
}
RecordType::SOA => {
trace!("reading SOA");
SOA::read(decoder, record_type, rdata_length).map(Self::SOA)
}
RecordType::SRV => {
trace!("reading SRV");
rdata::srv::read(decoder).map(Self::SRV)
}
RecordType::SSHFP => {
trace!("reading SSHFP");
rdata::sshfp::read(decoder, rdata_length).map(Self::SSHFP)
}
RecordType::SVCB => {
trace!("reading SVCB");
rdata::svcb::read(decoder, rdata_length).map(Self::SVCB)
}
RecordType::TLSA => {
trace!("reading TLSA");
rdata::tlsa::read(decoder, rdata_length).map(Self::TLSA)
}
RecordType::TXT => {
trace!("reading TXT");
rdata::txt::read(decoder, rdata_length).map(Self::TXT)
}
#[cfg(feature = "dnssec")]
r if r.is_dnssec() => {
DNSSECRData::read(decoder, record_type, rdata_length).map(Self::DNSSEC)
}
record_type => {
trace!("reading Unknown record: {}", record_type);
rdata::null::read(decoder, rdata_length).map(|rdata| Self::Unknown {
code: record_type.into(),
rdata,
})
}
};
// we should have read rdata_length, but we did not
let read = decoder.index() - start_idx;
rdata_length
.map(|u| u as usize)
.verify_unwrap(|rdata_length| read == *rdata_length)
.map_err(|rdata_length| {
ProtoError::from(ProtoErrorKind::IncorrectRDataLengthRead {
read,
len: rdata_length,
})
})?;
result
}
impl BinEncodable for RData {
/// [RFC 4034](https://tools.ietf.org/html/rfc4034#section-6), DNSSEC Resource Records, March 2005
///
/// ```text
@ -953,8 +818,8 @@ impl RecordData for RData {
/// ```
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
match *self {
Self::A(address) => rdata::a::emit(encoder, address),
Self::AAAA(ref address) => rdata::aaaa::emit(encoder, address),
Self::A(ref address) => Ipv4Addr::emit(address, encoder),
Self::AAAA(ref address) => Ipv6Addr::emit(address, encoder),
Self::ANAME(ref name) => {
encoder.with_canonical_names(|encoder| rdata::name::emit(encoder, name))
}
@ -965,7 +830,7 @@ impl RecordData for RData {
Self::CNAME(ref name) | RData::NS(ref name) | RData::PTR(ref name) => {
rdata::name::emit(encoder, name)
}
Self::CSYNC(ref csync) => rdata::csync::emit(encoder, csync),
Self::CSYNC(ref csync) => csync.emit(encoder),
Self::HINFO(ref hinfo) => rdata::hinfo::emit(encoder, hinfo),
Self::HTTPS(ref svcb) => rdata::svcb::emit(encoder, svcb),
Self::ZERO => Ok(()),
@ -998,6 +863,142 @@ impl RecordData for RData {
Self::Unknown { ref rdata, .. } => rdata::null::emit(encoder, rdata),
}
}
}
impl<'r> RecordDataDecodable<'r> for RData {
fn read_data(
decoder: &mut BinDecoder<'r>,
record_type: RecordType,
length: Restrict<u16>,
) -> ProtoResult<Self> {
let start_idx = decoder.index();
let result = match record_type {
RecordType::A => {
trace!("reading A");
rdata::a::read(decoder).map(Self::A)
}
RecordType::AAAA => {
trace!("reading AAAA");
rdata::aaaa::read(decoder).map(Self::AAAA)
}
RecordType::ANAME => {
trace!("reading ANAME");
rdata::name::read(decoder).map(Self::ANAME)
}
rt @ RecordType::ANY | rt @ RecordType::AXFR | rt @ RecordType::IXFR => {
return Err(ProtoErrorKind::UnknownRecordTypeValue(rt.into()).into());
}
RecordType::CAA => {
trace!("reading CAA");
rdata::caa::read(decoder, length).map(Self::CAA)
}
RecordType::CNAME => {
trace!("reading CNAME");
rdata::name::read(decoder).map(Self::CNAME)
}
RecordType::CSYNC => {
trace!("reading CSYNC");
CSYNC::read_data(decoder, record_type, length).map(Self::CSYNC)
}
RecordType::HINFO => {
trace!("reading HINFO");
rdata::hinfo::read(decoder).map(Self::HINFO)
}
RecordType::HTTPS => {
trace!("reading HTTPS");
rdata::svcb::read(decoder, length).map(Self::HTTPS)
}
RecordType::ZERO => {
trace!("reading EMPTY");
// we should never get here, since ZERO should be 0 length, and None in the Record.
// this invariant is verified below, and the decoding will fail with an err.
#[allow(deprecated)]
Ok(Self::ZERO)
}
RecordType::MX => {
trace!("reading MX");
rdata::mx::read(decoder).map(Self::MX)
}
RecordType::NAPTR => {
trace!("reading NAPTR");
rdata::naptr::read(decoder).map(Self::NAPTR)
}
RecordType::NULL => {
trace!("reading NULL");
rdata::null::read(decoder, length).map(Self::NULL)
}
RecordType::NS => {
trace!("reading NS");
rdata::name::read(decoder).map(Self::NS)
}
RecordType::OPENPGPKEY => {
trace!("reading OPENPGPKEY");
rdata::openpgpkey::read(decoder, length).map(Self::OPENPGPKEY)
}
RecordType::OPT => {
trace!("reading OPT");
rdata::opt::read(decoder, length).map(Self::OPT)
}
RecordType::PTR => {
trace!("reading PTR");
rdata::name::read(decoder).map(Self::PTR)
}
RecordType::SOA => {
trace!("reading SOA");
SOA::read_data(decoder, record_type, length).map(Self::SOA)
}
RecordType::SRV => {
trace!("reading SRV");
rdata::srv::read(decoder).map(Self::SRV)
}
RecordType::SSHFP => {
trace!("reading SSHFP");
rdata::sshfp::read(decoder, length).map(Self::SSHFP)
}
RecordType::SVCB => {
trace!("reading SVCB");
rdata::svcb::read(decoder, length).map(Self::SVCB)
}
RecordType::TLSA => {
trace!("reading TLSA");
rdata::tlsa::read(decoder, length).map(Self::TLSA)
}
RecordType::TXT => {
trace!("reading TXT");
rdata::txt::read(decoder, length).map(Self::TXT)
}
#[cfg(feature = "dnssec")]
r if r.is_dnssec() => DNSSECRData::read(decoder, record_type, length).map(Self::DNSSEC),
record_type => {
trace!("reading Unknown record: {}", record_type);
rdata::null::read(decoder, length).map(|rdata| Self::Unknown {
code: record_type.into(),
rdata,
})
}
};
// we should have read rdata_length, but we did not
let read = decoder.index() - start_idx;
length
.map(|u| u as usize)
.verify_unwrap(|rdata_length| read == *rdata_length)
.map_err(|rdata_length| {
ProtoError::from(ProtoErrorKind::IncorrectRDataLengthRead {
read,
len: rdata_length,
})
})?;
result
}
}
impl RecordData for RData {
fn try_from_rdata(data: RData) -> Result<Self, RData> {
Ok(data)
}
fn try_borrow(data: &RData) -> Result<&Self, &RData> {
Ok(data)
@ -1255,7 +1256,7 @@ mod tests {
let mut decoder = BinDecoder::new(&binary);
assert_eq!(
RData::read(
RData::read_data(
&mut decoder,
record_type_from_rdata(&expect),
Restrict::new(length)

View File

@ -34,6 +34,8 @@ use crate::rr::RecordSet;
use crate::rr::RecordType;
use crate::serialize::binary::*;
use super::RecordDataDecodable;
#[cfg(feature = "mdns")]
/// From [RFC 6762](https://tools.ietf.org/html/rfc6762#section-10.2)
/// ```text
@ -79,6 +81,7 @@ const MDNS_ENABLE_CACHE_FLUSH: u16 = 1 << 15;
/// ```
#[cfg_attr(feature = "serde-config", derive(Deserialize, Serialize))]
#[derive(Eq, Debug, Clone)]
// TODO: make Record carry a lifetime for more efficient storage options in the future
pub struct Record<R: RecordData = RData> {
name_labels: Name,
rr_type: RecordType,
@ -152,6 +155,7 @@ impl<R: RecordData> Record<R> {
}
/// Attempts to convert the generic `RData` based Record into this one with the interior `R`
#[allow(clippy::result_large_err)]
pub fn try_from(record: Record<RData>) -> Result<Self, Record<RData>> {
let Record {
name_labels,
@ -449,7 +453,7 @@ impl<R: RecordData> BinEncodable for Record<R> {
}
}
impl<'r, R: RecordData> BinDecodable<'r> for Record<R> {
impl<'r, R: RecordData + RecordDataDecodable<'r>> BinDecodable<'r> for Record<R> {
/// parse a resource record line example:
/// WARNING: the record_bytes is 100% consumed and destroyed in this parsing process
fn read(decoder: &mut BinDecoder<'r>) -> ProtoResult<Self> {
@ -527,7 +531,11 @@ impl<'r, R: RecordData> BinDecodable<'r> for Record<R> {
// according to the TYPE and CLASS of the resource record.
// Adding restrict to the rdata length because it's used for many calculations later
// and must be validated before hand
Some(R::read(decoder, record_type, Restrict::new(rd_length))?)
Some(R::read_data(
decoder,
record_type,
Restrict::new(rd_length),
)?)
};
Ok(Self {

View File

@ -83,7 +83,7 @@ impl BinEncodable for i32 {
}
impl<'r> BinDecodable<'r> for i32 {
fn read(decoder: &mut BinDecoder<'_>) -> ProtoResult<Self> {
fn read(decoder: &mut BinDecoder<'r>) -> ProtoResult<Self> {
decoder
.read_i32()
.map(Restrict::unverified)
@ -114,24 +114,57 @@ impl BinEncodable for Vec<u8> {
impl BinEncodable for Ipv4Addr {
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
crate::rr::rdata::a::emit(encoder, *self)
let segments = self.octets();
encoder.emit(segments[0])?;
encoder.emit(segments[1])?;
encoder.emit(segments[2])?;
encoder.emit(segments[3])?;
Ok(())
}
}
impl<'r> BinDecodable<'r> for Ipv4Addr {
fn read(decoder: &mut BinDecoder<'_>) -> ProtoResult<Self> {
crate::rr::rdata::a::read(decoder)
fn read(decoder: &mut BinDecoder<'r>) -> ProtoResult<Self> {
// TODO: would this be more efficient as a single u32 read?
Ok(Self::new(
decoder.pop()?.unverified(/*valid as any u8*/),
decoder.pop()?.unverified(/*valid as any u8*/),
decoder.pop()?.unverified(/*valid as any u8*/),
decoder.pop()?.unverified(/*valid as any u8*/),
))
}
}
impl BinEncodable for Ipv6Addr {
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
crate::rr::rdata::aaaa::emit(encoder, self)
let segments = self.segments();
// TODO: this might be more efficient as a single write of the array
encoder.emit_u16(segments[0])?;
encoder.emit_u16(segments[1])?;
encoder.emit_u16(segments[2])?;
encoder.emit_u16(segments[3])?;
encoder.emit_u16(segments[4])?;
encoder.emit_u16(segments[5])?;
encoder.emit_u16(segments[6])?;
encoder.emit_u16(segments[7])?;
Ok(())
}
}
impl<'r> BinDecodable<'r> for Ipv6Addr {
fn read(decoder: &mut BinDecoder<'_>) -> ProtoResult<Self> {
crate::rr::rdata::aaaa::read(decoder)
fn read(decoder: &mut BinDecoder<'r>) -> ProtoResult<Self> {
// TODO: would this be more efficient as two u64 reads?
let a: u16 = decoder.read_u16()?.unverified(/*valid as any u16*/);
let b: u16 = decoder.read_u16()?.unverified(/*valid as any u16*/);
let c: u16 = decoder.read_u16()?.unverified(/*valid as any u16*/);
let d: u16 = decoder.read_u16()?.unverified(/*valid as any u16*/);
let e: u16 = decoder.read_u16()?.unverified(/*valid as any u16*/);
let f: u16 = decoder.read_u16()?.unverified(/*valid as any u16*/);
let g: u16 = decoder.read_u16()?.unverified(/*valid as any u16*/);
let h: u16 = decoder.read_u16()?.unverified(/*valid as any u16*/);
Ok(Self::new(a, b, c, d, e, f, g, h))
}
}

View File

@ -96,6 +96,6 @@ async fn test_ttl_wilcard() {
.next()
.expect("A record not found in authority");
assert_eq!(data.rr_type(), RecordType::A);
assert_eq!(data.record_type(), RecordType::A);
assert_eq!(data.ttl(), 120);
}

View File

@ -23,6 +23,7 @@
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::net::{IpAddr, SocketAddr};
use std::ops::Deref;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
@ -32,7 +33,7 @@ use console::style;
use tokio::task::JoinSet;
use tokio::time::MissedTickBehavior;
use trust_dns_client::rr::Record;
use trust_dns_client::rr::{Record, RecordData};
use trust_dns_resolver::config::{
NameServerConfig, NameServerConfigGroup, Protocol, ResolverConfig, ResolverOpts,
};
@ -142,7 +143,7 @@ struct Opts {
interval: f32,
}
fn print_record(r: &Record) {
fn print_record<D: RecordData, R: Deref<Target = Record<D>>>(r: &R) {
print!(
"\t{name} {ttl} {class} {ty}",
name = style(r.name()).blue(),
@ -165,7 +166,7 @@ fn print_ok(lookup: Lookup) {
);
for r in lookup.record_iter() {
print_record(r);
print_record(&r);
}
}