reformat all code

This commit is contained in:
Benjamin Fry
2019-10-14 11:47:52 -07:00
parent 735cf45e77
commit 31fbf66316
134 changed files with 2257 additions and 1816 deletions

View File

@@ -47,7 +47,6 @@ fn name_cmp_medium(b: &mut Bencher) {
}); });
} }
#[bench] #[bench]
fn name_cmp_medium_case(b: &mut Bencher) { fn name_cmp_medium_case(b: &mut Bencher) {
let name1 = LowerName::new(&Name::from_str("www.example.com").unwrap()); let name1 = LowerName::new(&Name::from_str("www.example.com").unwrap());

View File

@@ -5,10 +5,10 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be // http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms. // copied, modified, or distributed except according to those terms.
use std::sync::Arc;
use std::time::Duration;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use std::task::Context; use std::task::Context;
use std::time::Duration;
use futures::{Future, FutureExt, Poll}; use futures::{Future, FutureExt, Poll};
use proto::error::ProtoError; use proto::error::ProtoError;
@@ -20,7 +20,7 @@ use proto::xfer::{
use rand; use rand;
use crate::error::*; use crate::error::*;
use crate::op::{Message, MessageType, OpCode, Query, update_message}; use crate::op::{update_message, Message, MessageType, OpCode, Query};
use crate::rr::dnssec::Signer; use crate::rr::dnssec::Signer;
use crate::rr::{DNSClass, Name, Record, RecordSet, RecordType}; use crate::rr::{DNSClass, Name, Record, RecordSet, RecordType};
@@ -155,7 +155,9 @@ where
loop { loop {
// we're either awaiting the connection, or we're always returning the exchange's result // we're either awaiting the connection, or we're always returning the exchange's result
let next = match *self { let next = match *self {
InnerClientFuture::DnsExchangeConnect(ref mut connect) => ready!(connect.poll_unpin(cx))?, InnerClientFuture::DnsExchangeConnect(ref mut connect) => {
ready!(connect.poll_unpin(cx))?
}
InnerClientFuture::DnsExchange(ref mut exchange) => return exchange.poll_unpin(cx), InnerClientFuture::DnsExchange(ref mut exchange) => return exchange.poll_unpin(cx),
}; };
@@ -298,15 +300,16 @@ pub trait ClientHandle: 'static + Clone + DnsHandle + Send {
// build the message // build the message
let mut message: Message = Message::new(); let mut message: Message = Message::new();
let id: u16 = rand::random(); let id: u16 = rand::random();
message.set_id(id) message
// 3.3. NOTIFY is similar to QUERY in that it has a request message with .set_id(id)
// the header QR flag "clear" and a response message with QR "set". The // 3.3. NOTIFY is similar to QUERY in that it has a request message with
// response message contains no useful information, but its reception by // the header QR flag "clear" and a response message with QR "set". The
// the master is an indication that the slave has received the NOTIFY // response message contains no useful information, but its reception by
// and that the master can remove the slave from any retry queue for // the master is an indication that the slave has received the NOTIFY
// this NOTIFY event. // and that the master can remove the slave from any retry queue for
.set_message_type(MessageType::Query) // this NOTIFY event.
.set_op_code(OpCode::Notify); .set_message_type(MessageType::Query)
.set_op_code(OpCode::Notify);
// Extended dns // Extended dns
{ {
@@ -374,7 +377,7 @@ pub trait ClientHandle: 'static + Clone + DnsHandle + Send {
{ {
let rrset = rrset.into(); let rrset = rrset.into();
let message = update_message::create(rrset, zone_origin); let message = update_message::create(rrset, zone_origin);
ClientResponse(self.send(message)) ClientResponse(self.send(message))
} }

View File

@@ -5,11 +5,11 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be // http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms. // copied, modified, or distributed except according to those terms.
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use futures::Future;
use futures::lock::Mutex; use futures::lock::Mutex;
use futures::Future;
use proto::error::ProtoError; use proto::error::ProtoError;
use proto::xfer::{DnsHandle, DnsRequest, DnsResponse}; use proto::xfer::{DnsHandle, DnsRequest, DnsResponse};
@@ -42,10 +42,11 @@ where
} }
} }
async fn inner_send(request: DnsRequest, async fn inner_send(
active_queries: Arc<Mutex<HashMap<Query, RcFuture<<H as DnsHandle>::Response>>>>, request: DnsRequest,
mut client: H) active_queries: Arc<Mutex<HashMap<Query, RcFuture<<H as DnsHandle>::Response>>>>,
-> Result<DnsResponse, ProtoError> { mut client: H,
) -> Result<DnsResponse, ProtoError> {
// TODO: what if we want to support multiple queries (non-standard)? // TODO: what if we want to support multiple queries (non-standard)?
let query = request.queries().first().expect("no query!").clone(); let query = request.queries().first().expect("no query!").clone();
@@ -55,13 +56,14 @@ where
// TODO: we need to consider TTL on the records here at some point // TODO: we need to consider TTL on the records here at some point
// If the query is running, grab that existing one... // If the query is running, grab that existing one...
if let Some(rc_future) = active_queries.get(&query) { if let Some(rc_future) = active_queries.get(&query) {
return rc_future.clone().await return rc_future.clone().await;
}; };
// Otherwise issue a new query and store in the map // Otherwise issue a new query and store in the map
active_queries.entry(query).or_insert_with(|| { active_queries
rc_future(client.send(request)) .entry(query)
}).await .or_insert_with(|| rc_future(client.send(request)))
.await
} }
} }
@@ -73,18 +75,22 @@ where
fn send<R: Into<DnsRequest>>(&mut self, request: R) -> Self::Response { fn send<R: Into<DnsRequest>>(&mut self, request: R) -> Self::Response {
let request = request.into(); let request = request.into();
Box::pin(Self::inner_send(request, Arc::clone(&self.active_queries), self.client.clone())) Box::pin(Self::inner_send(
request,
Arc::clone(&self.active_queries),
self.client.clone(),
))
} }
} }
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use std::sync::Arc;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use futures::*;
use futures::lock::Mutex; use futures::lock::Mutex;
use futures::*;
use proto::error::ProtoError; use proto::error::ProtoError;
use proto::xfer::{DnsHandle, DnsRequest, DnsResponse}; use proto::xfer::{DnsHandle, DnsRequest, DnsResponse};
@@ -110,7 +116,11 @@ mod test {
let mut i = i.lock().await; let mut i = i.lock().await;
message.set_id(*i); message.set_id(*i);
println!("sending {}: {}", *i, request.into().queries().first().expect("no query!").clone()); println!(
"sending {}: {}",
*i,
request.into().queries().first().expect("no query!").clone()
);
*i += 1; *i += 1;
@@ -125,7 +135,9 @@ mod test {
fn test_memoized() { fn test_memoized() {
use futures::executor::block_on; use futures::executor::block_on;
let mut client = MemoizeClientHandle::new(TestClient { i: Arc::new(Mutex::new(0)) }); let mut client = MemoizeClientHandle::new(TestClient {
i: Arc::new(Mutex::new(0)),
});
let mut test1 = Message::new(); let mut test1 = Message::new();
test1.add_query(Query::new().set_query_type(RecordType::A).clone()); test1.add_query(Query::new().set_query_type(RecordType::A).clone());
@@ -146,5 +158,4 @@ mod test {
let result = block_on(client.send(test2)).ok().unwrap(); let result = block_on(client.send(test2)).ok().unwrap();
assert_eq!(result.id(), 1); assert_eq!(result.id(), 1);
} }
} }

View File

@@ -5,12 +5,12 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be // http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms. // copied, modified, or distributed except according to those terms.
use std::sync::Arc;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use std::task::Context; use std::task::Context;
use futures::{future::Fuse, Future, FutureExt, Poll};
use futures::lock::Mutex; use futures::lock::Mutex;
use futures::{future::Fuse, Future, FutureExt, Poll};
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
pub struct RcFuture<F: Future> pub struct RcFuture<F: Future>
@@ -29,9 +29,7 @@ where
{ {
let future_and_result = Arc::new(Mutex::new((future.fuse(), None))); let future_and_result = Arc::new(Mutex::new((future.fuse(), None)));
RcFuture { RcFuture { future_and_result }
future_and_result,
}
} }
impl<F> Future for RcFuture<F> impl<F> Future for RcFuture<F>
@@ -47,7 +45,7 @@ where
// wait for it to complete. // wait for it to complete.
if let Some(mut future_and_result) = self.future_and_result.try_lock() { if let Some(mut future_and_result) = self.future_and_result.try_lock() {
let (ref mut future, ref mut stored_result) = *future_and_result; let (ref mut future, ref mut stored_result) = *future_and_result;
// if pending it's either done, or it's actually pending // if pending it's either done, or it's actually pending
match future.poll_unpin(cx) { match future.poll_unpin(cx) {
Poll::Pending => (), Poll::Pending => (),
@@ -62,7 +60,7 @@ where
return Poll::Ready(result.clone()); return Poll::Ready(result.clone());
} else { } else {
// the poll on the future should wake this thread // the poll on the future should wake this thread
return Poll::Pending return Poll::Pending;
} }
} else { } else {
// TODO: track wakers in a queue instead... // TODO: track wakers in a queue instead...
@@ -116,4 +114,4 @@ mod tests {
let i = block_on(rc).err().unwrap(); let i = block_on(rc).err().unwrap();
assert_eq!(i, 2); assert_eq!(i, 2);
} }
} }

View File

@@ -26,9 +26,9 @@ use self::not_openssl::SslErrorStack;
use self::not_ring::{KeyRejected, Unspecified}; use self::not_ring::{KeyRejected, Unspecified};
#[cfg(feature = "openssl")] #[cfg(feature = "openssl")]
use openssl::error::ErrorStack as SslErrorStack; use openssl::error::ErrorStack as SslErrorStack;
use proto::error::{ProtoError, ProtoErrorKind};
#[cfg(feature = "ring")] #[cfg(feature = "ring")]
use ring::error::{KeyRejected, Unspecified}; use ring::error::{KeyRejected, Unspecified};
use proto::error::{ProtoError, ProtoErrorKind};
/// An alias for dnssec results returned by functions of this crate /// An alias for dnssec results returned by functions of this crate
pub type Result<T> = ::std::result::Result<T, Error>; pub type Result<T> = ::std::result::Result<T, Error>;

View File

@@ -10,9 +10,7 @@
use std::net::{Ipv4Addr, SocketAddr}; use std::net::{Ipv4Addr, SocketAddr};
use std::sync::Arc; use std::sync::Arc;
use proto::multicast::{ use proto::multicast::{MdnsClientConnect, MdnsClientStream, MdnsQueryType, MDNS_IPV4, MDNS_IPV6};
MdnsClientConnect, MdnsClientStream, MdnsQueryType, MDNS_IPV4, MDNS_IPV6,
};
use proto::xfer::{DnsMultiplexer, DnsMultiplexerConnect, DnsRequestSender}; use proto::xfer::{DnsMultiplexer, DnsMultiplexerConnect, DnsRequestSender};
use crate::client::ClientConnection; use crate::client::ClientConnection;

View File

@@ -11,4 +11,4 @@ mod mdns_client_connection;
use proto::multicast; use proto::multicast;
pub use self::mdns_client_connection::MdnsClientConnection; pub use self::mdns_client_connection::MdnsClientConnection;
pub use self::multicast::{MdnsClientStream, MdnsStream, MdnsQueryType, MDNS_IPV4, MDNS_IPV6}; pub use self::multicast::{MdnsClientStream, MdnsQueryType, MdnsStream, MDNS_IPV4, MDNS_IPV6};

View File

@@ -13,6 +13,7 @@ pub mod update_message;
pub use self::lower_query::LowerQuery; pub use self::lower_query::LowerQuery;
pub use self::update_message::UpdateMessage; pub use self::update_message::UpdateMessage;
pub use proto::op::{Edns, Header, Message, MessageFinalizer, MessageType, OpCode, Query, pub use proto::op::{
ResponseCode}; Edns, Header, Message, MessageFinalizer, MessageType, OpCode, Query, ResponseCode,
};
pub use proto::xfer::DnsResponse; pub use proto::xfer::DnsResponse;

View File

@@ -9,10 +9,10 @@
use std::fmt::Debug; use std::fmt::Debug;
use crate::op::{Message, Query, MessageType, OpCode};
use crate::rr::{Record, RecordSet, DNSClass, Name, RecordType, RData};
use crate::rr::rdata::NULL;
use crate::client::client_future::MAX_PAYLOAD_LEN; use crate::client::client_future::MAX_PAYLOAD_LEN;
use crate::op::{Message, MessageType, OpCode, Query};
use crate::rr::rdata::NULL;
use crate::rr::{DNSClass, Name, RData, Record, RecordSet, RecordType};
/// To reduce errors in using the Message struct as an Update, this will do the call throughs /// To reduce errors in using the Message struct as an Update, this will do the call throughs
/// to properly do that. /// to properly do that.

View File

@@ -5,7 +5,9 @@ use openssl::rsa::Rsa;
#[cfg(feature = "openssl")] #[cfg(feature = "openssl")]
use openssl::symm::Cipher; use openssl::symm::Cipher;
#[cfg(feature = "ring")] #[cfg(feature = "ring")]
use ring::signature::{EcdsaKeyPair, Ed25519KeyPair, ECDSA_P256_SHA256_FIXED_SIGNING, ECDSA_P384_SHA384_FIXED_SIGNING}; use ring::signature::{
EcdsaKeyPair, Ed25519KeyPair, ECDSA_P256_SHA256_FIXED_SIGNING, ECDSA_P384_SHA384_FIXED_SIGNING,
};
use crate::error::*; use crate::error::*;
use crate::rr::dnssec::{Algorithm, KeyPair, Private}; use crate::rr::dnssec::{Algorithm, KeyPair, Private};
@@ -35,9 +37,7 @@ impl KeyFormat {
let password = password.as_bytes(); let password = password.as_bytes();
match algorithm { match algorithm {
Algorithm::Unknown(v) => { Algorithm::Unknown(v) => Err(format!("unknown algorithm: {}", v).into()),
Err(format!("unknown algorithm: {}", v).into())
}
#[cfg(feature = "openssl")] #[cfg(feature = "openssl")]
e @ Algorithm::RSASHA1 | e @ Algorithm::RSASHA1NSEC3SHA1 => { e @ Algorithm::RSASHA1 | e @ Algorithm::RSASHA1NSEC3SHA1 => {
Err(format!("unsupported Algorithm (insecure): {:?}", e).into()) Err(format!("unsupported Algorithm (insecure): {:?}", e).into())
@@ -59,7 +59,8 @@ impl KeyFormat {
"unsupported key format with RSA (DER or PEM only): \ "unsupported key format with RSA (DER or PEM only): \
{:?}", {:?}",
e e
).into()) )
.into())
} }
}; };
@@ -77,10 +78,10 @@ impl KeyFormat {
} }
#[cfg(feature = "openssl")] #[cfg(feature = "openssl")]
KeyFormat::Pem => { KeyFormat::Pem => {
let key = EcKey::private_key_from_pem_passphrase(bytes, password) let key =
.map_err(|e| { EcKey::private_key_from_pem_passphrase(bytes, password).map_err(|e| {
format!("could not decode EC from PEM, bad password?: {}", e) format!("could not decode EC from PEM, bad password?: {}", e)
})?; })?;
Ok(KeyPair::from_ec_key(key) Ok(KeyPair::from_ec_key(key)
.map_err(|e| format!("could not tranlate RSA to KeyPair: {}", e))?) .map_err(|e| format!("could not tranlate RSA to KeyPair: {}", e))?)
@@ -96,11 +97,8 @@ impl KeyFormat {
Ok(KeyPair::from_ecdsa(key)) Ok(KeyPair::from_ecdsa(key))
} }
e => Err(format!( e => Err(format!("unsupported key format with EC: {:?}", e).into()),
"unsupported key format with EC: {:?}", },
e
).into()),
}
Algorithm::ED25519 => match self { Algorithm::ED25519 => match self {
#[cfg(feature = "ring")] #[cfg(feature = "ring")]
KeyFormat::Pkcs8 => { KeyFormat::Pkcs8 => {
@@ -111,13 +109,15 @@ impl KeyFormat {
e => Err(format!( e => Err(format!(
"unsupported key format with ED25519 (only Pkcs8 supported): {:?}", "unsupported key format with ED25519 (only Pkcs8 supported): {:?}",
e e
).into()), )
.into()),
}, },
#[cfg(not(all(feature = "openssl", feature = "ring")))] #[cfg(not(all(feature = "openssl", feature = "ring")))]
e => Err(format!( e => Err(format!(
"unsupported Algorithm, enable openssl or ring feature: {:?}", "unsupported Algorithm, enable openssl or ring feature: {:?}",
e e
).into()), )
.into()),
} }
} }
@@ -138,9 +138,7 @@ impl KeyFormat {
// generate the key // generate the key
#[allow(unused)] #[allow(unused)]
let key_pair: KeyPair<Private> = match algorithm { let key_pair: KeyPair<Private> = match algorithm {
Algorithm::Unknown(v) => { Algorithm::Unknown(v) => return Err(format!("unknown algorithm: {}", v).into()),
return Err(format!("unknown algorithm: {}", v).into())
}
#[cfg(feature = "openssl")] #[cfg(feature = "openssl")]
e @ Algorithm::RSASHA1 | e @ Algorithm::RSASHA1NSEC3SHA1 => { e @ Algorithm::RSASHA1 | e @ Algorithm::RSASHA1NSEC3SHA1 => {
return Err(format!("unsupported Algorithm (insecure): {:?}", e).into()) return Err(format!("unsupported Algorithm (insecure): {:?}", e).into())
@@ -153,7 +151,7 @@ impl KeyFormat {
#[cfg(feature = "ring")] #[cfg(feature = "ring")]
KeyFormat::Pkcs8 => return KeyPair::generate_pkcs8(algorithm), KeyFormat::Pkcs8 => return KeyPair::generate_pkcs8(algorithm),
e => return Err(format!("unsupported key format with EC: {:?}", e).into()), e => return Err(format!("unsupported key format with EC: {:?}", e).into()),
} },
#[cfg(feature = "ring")] #[cfg(feature = "ring")]
Algorithm::ED25519 => return KeyPair::generate_pkcs8(algorithm), Algorithm::ED25519 => return KeyPair::generate_pkcs8(algorithm),
#[cfg(not(all(feature = "openssl", feature = "ring")))] #[cfg(not(all(feature = "openssl", feature = "ring")))]
@@ -161,7 +159,8 @@ impl KeyFormat {
return Err(format!( return Err(format!(
"unsupported Algorithm, enable openssl or ring feature: {:?}", "unsupported Algorithm, enable openssl or ring feature: {:?}",
e e
).into()) )
.into())
} }
}; };
@@ -195,7 +194,8 @@ impl KeyFormat {
"unsupported key format with RSA or EC (DER or PEM \ "unsupported key format with RSA or EC (DER or PEM \
only): {:?}", only): {:?}",
e e
).into()), )
.into()),
} }
} }
#[cfg(feature = "ring")] #[cfg(feature = "ring")]
@@ -205,7 +205,8 @@ impl KeyFormat {
#[cfg(not(any(feature = "openssl", feature = "ring")))] #[cfg(not(any(feature = "openssl", feature = "ring")))]
_ => Err(format!( _ => Err(format!(
"unsupported Algorithm, enable openssl feature (encode not supported with ring)" "unsupported Algorithm, enable openssl feature (encode not supported with ring)"
).into()), )
.into()),
} }
} }
@@ -252,7 +253,8 @@ impl KeyFormat {
"unsupported key format with RSA or EC (DER or PEM \ "unsupported key format with RSA or EC (DER or PEM \
only): {:?}", only): {:?}",
e e
).into()), )
.into()),
} }
} }
#[cfg(any(feature = "ring", not(feature = "openssl")))] #[cfg(any(feature = "ring", not(feature = "openssl")))]

View File

@@ -22,9 +22,13 @@ use openssl::rsa::Rsa as OpenSslRsa;
use openssl::sign::Signer; use openssl::sign::Signer;
#[cfg(feature = "ring")] #[cfg(feature = "ring")]
use ring::{rand, use ring::{
signature::{EcdsaKeyPair, Ed25519KeyPair, KeyPair as RingKeyPair, rand,
ECDSA_P256_SHA256_FIXED_SIGNING, ECDSA_P384_SHA384_FIXED_SIGNING}}; signature::{
EcdsaKeyPair, Ed25519KeyPair, KeyPair as RingKeyPair, ECDSA_P256_SHA256_FIXED_SIGNING,
ECDSA_P384_SHA384_FIXED_SIGNING,
},
};
use crate::error::*; use crate::error::*;
#[cfg(any(feature = "openssl", feature = "ring"))] #[cfg(any(feature = "openssl", feature = "ring"))]
@@ -419,11 +423,7 @@ impl<K: HasPrivate> KeyPair<K> {
#[cfg(feature = "ring")] #[cfg(feature = "ring")]
KeyPair::ECDSA(ref ec_key) => { KeyPair::ECDSA(ref ec_key) => {
let rng = rand::SystemRandom::new(); let rng = rand::SystemRandom::new();
Ok(ec_key Ok(ec_key.sign(&rng, tbs.as_ref()).unwrap().as_ref().to_vec())
.sign(&rng, tbs.as_ref())
.unwrap()
.as_ref()
.to_vec())
} }
#[cfg(feature = "ring")] #[cfg(feature = "ring")]
KeyPair::ED25519(ref ed_key) => Ok(ed_key.sign(tbs.as_ref()).as_ref().to_vec()), KeyPair::ED25519(ref ed_key) => Ok(ed_key.sign(tbs.as_ref()).as_ref().to_vec()),
@@ -441,9 +441,7 @@ impl KeyPair<Private> {
/// RSA keys are hardcoded to 2048bits at the moment. Other keys have predefined sizes. /// RSA keys are hardcoded to 2048bits at the moment. Other keys have predefined sizes.
pub fn generate(algorithm: Algorithm) -> DnsSecResult<Self> { pub fn generate(algorithm: Algorithm) -> DnsSecResult<Self> {
match algorithm { match algorithm {
Algorithm::Unknown(_) => { Algorithm::Unknown(_) => Err(DnsSecErrorKind::Message("unknown algorithm").into()),
Err(DnsSecErrorKind::Message("unknown algorithm").into())
}
#[cfg(feature = "openssl")] #[cfg(feature = "openssl")]
Algorithm::RSASHA1 Algorithm::RSASHA1
| Algorithm::RSASHA1NSEC3SHA1 | Algorithm::RSASHA1NSEC3SHA1
@@ -478,9 +476,7 @@ impl KeyPair<Private> {
#[cfg(feature = "ring")] #[cfg(feature = "ring")]
pub fn generate_pkcs8(algorithm: Algorithm) -> DnsSecResult<Vec<u8>> { pub fn generate_pkcs8(algorithm: Algorithm) -> DnsSecResult<Vec<u8>> {
match algorithm { match algorithm {
Algorithm::Unknown(_) => { Algorithm::Unknown(_) => Err(DnsSecErrorKind::Message("unknown algorithm").into()),
Err(DnsSecErrorKind::Message("unknown algorithm").into())
}
#[cfg(feature = "openssl")] #[cfg(feature = "openssl")]
Algorithm::RSASHA1 Algorithm::RSASHA1
| Algorithm::RSASHA1NSEC3SHA1 | Algorithm::RSASHA1NSEC3SHA1

View File

@@ -23,21 +23,21 @@ mod signer;
use proto::rr::dnssec; use proto::rr::dnssec;
pub use self::dnssec::tbs;
pub use self::dnssec::Algorithm; pub use self::dnssec::Algorithm;
pub use self::dnssec::DigestType; pub use self::dnssec::DigestType;
#[cfg(any(feature = "openssl", feature = "ring"))]
pub use self::key_format::KeyFormat;
pub use self::keypair::KeyPair;
pub use self::dnssec::Nsec3HashAlgorithm; pub use self::dnssec::Nsec3HashAlgorithm;
pub use self::dnssec::PublicKey; pub use self::dnssec::PublicKey;
pub use self::dnssec::PublicKeyBuf; pub use self::dnssec::PublicKeyBuf;
pub use self::dnssec::PublicKeyEnum; pub use self::dnssec::PublicKeyEnum;
pub use self::signer::Signer;
pub use self::dnssec::SupportedAlgorithms; pub use self::dnssec::SupportedAlgorithms;
pub use self::dnssec::TrustAnchor; pub use self::dnssec::TrustAnchor;
pub use self::dnssec::tbs;
pub use self::dnssec::TBS;
pub use self::dnssec::Verifier; pub use self::dnssec::Verifier;
pub use self::dnssec::TBS;
#[cfg(any(feature = "openssl", feature = "ring"))]
pub use self::key_format::KeyFormat;
pub use self::keypair::KeyPair;
pub use self::signer::Signer;
pub use crate::error::DnsSecError; pub use crate::error::DnsSecError;
pub use crate::error::DnsSecErrorKind; pub use crate::error::DnsSecErrorKind;
@@ -63,15 +63,15 @@ mod faux_key_type {
/// A key that contains private key material /// A key that contains private key material
pub trait HasPrivate {} pub trait HasPrivate {}
impl <K: HasPrivate> HasPublic for K {} impl<K: HasPrivate> HasPublic for K {}
/// Faux implementation of the Openssl Public key types /// Faux implementation of the Openssl Public key types
pub enum Public{} pub enum Public {}
impl HasPublic for Public {} impl HasPublic for Public {}
/// Faux implementation of the Openssl Public key types /// Faux implementation of the Openssl Public key types
pub enum Private{} pub enum Private {}
impl HasPrivate for Private {} impl HasPrivate for Private {}
} }

View File

@@ -28,12 +28,12 @@ pub use proto::rr::record_data;
pub use proto::rr::record_type; pub use proto::rr::record_type;
pub use proto::rr::resource; pub use proto::rr::resource;
pub use self::rr::domain::{IntoName, Name, Label};
pub use self::dns_class::DNSClass; pub use self::dns_class::DNSClass;
pub use self::lower_name::LowerName; pub use self::lower_name::LowerName;
pub use self::record_data::RData; pub use self::record_data::RData;
pub use self::record_type::RecordType; pub use self::record_type::RecordType;
pub use self::resource::Record; pub use self::resource::Record;
pub use self::rr::domain::{IntoName, Label, Name};
#[allow(deprecated)] #[allow(deprecated)]
pub use self::rr::IntoRecordSet; pub use self::rr::IntoRecordSet;
pub use self::rr::RecordSet; pub use self::rr::RecordSet;
@@ -41,6 +41,6 @@ pub use self::rr_key::RrKey;
/// All record data structures and related serialization methods /// All record data structures and related serialization methods
pub mod rdata { pub mod rdata {
pub use proto::rr::rdata::*;
pub use proto::rr::dnssec::rdata::*; pub use proto::rr::dnssec::rdata::*;
pub use proto::rr::rdata::*;
} }

View File

@@ -1,7 +1,7 @@
//! Reserved Zone and related information //! Reserved Zone and related information
use proto::rr::domain::{Label, Name};
pub use proto::rr::domain::usage::*; pub use proto::rr::domain::usage::*;
use proto::rr::domain::{Label, Name};
use proto::serialize::binary::BinEncodable; use proto::serialize::binary::BinEncodable;
use radix_trie::{Trie, TrieKey}; use radix_trie::{Trie, TrieKey};
@@ -12,11 +12,11 @@ use radix_trie::{Trie, TrieKey};
// //
// ```text // ```text
// 6.1. Domain Name Reservation Considerations for Private Addresses // 6.1. Domain Name Reservation Considerations for Private Addresses
// //
// The private-address [RFC1918] reverse-mapping domains listed below, // The private-address [RFC1918] reverse-mapping domains listed below,
// and any names falling within those domains, are Special-Use Domain // and any names falling within those domains, are Special-Use Domain
// Names: // Names:
// //
// 10.in-addr.arpa. 21.172.in-addr.arpa. 26.172.in-addr.arpa. // 10.in-addr.arpa. 21.172.in-addr.arpa. 26.172.in-addr.arpa.
// 16.172.in-addr.arpa. 22.172.in-addr.arpa. 27.172.in-addr.arpa. // 16.172.in-addr.arpa. 22.172.in-addr.arpa. 27.172.in-addr.arpa.
// 17.172.in-addr.arpa. 30.172.in-addr.arpa. 28.172.in-addr.arpa. // 17.172.in-addr.arpa. 30.172.in-addr.arpa. 28.172.in-addr.arpa.
@@ -29,7 +29,7 @@ lazy_static! {
pub static ref IN_ADDR_ARPA_10: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("10").unwrap().append_domain(&*IN_ADDR_ARPA)); pub static ref IN_ADDR_ARPA_10: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("10").unwrap().append_domain(&*IN_ADDR_ARPA));
static ref IN_ADDR_ARPA_172: Name = Name::from_ascii("172").unwrap().append_domain(&*IN_ADDR_ARPA); static ref IN_ADDR_ARPA_172: Name = Name::from_ascii("172").unwrap().append_domain(&*IN_ADDR_ARPA);
/// 16.172.in-addr.arpa. usage /// 16.172.in-addr.arpa. usage
pub static ref IN_ADDR_ARPA_172_16: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("16").unwrap().append_domain(&*IN_ADDR_ARPA_172)); pub static ref IN_ADDR_ARPA_172_16: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("16").unwrap().append_domain(&*IN_ADDR_ARPA_172));
/// 17.172.in-addr.arpa. usage /// 17.172.in-addr.arpa. usage
@@ -42,9 +42,9 @@ lazy_static! {
pub static ref IN_ADDR_ARPA_172_20: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("20").unwrap().append_domain(&*IN_ADDR_ARPA_172)); pub static ref IN_ADDR_ARPA_172_20: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("20").unwrap().append_domain(&*IN_ADDR_ARPA_172));
/// 21.172.in-addr.arpa. usage /// 21.172.in-addr.arpa. usage
pub static ref IN_ADDR_ARPA_172_21: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("21").unwrap().append_domain(&*IN_ADDR_ARPA_172)); pub static ref IN_ADDR_ARPA_172_21: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("21").unwrap().append_domain(&*IN_ADDR_ARPA_172));
/// 22.172.in-addr.arpa. usage /// 22.172.in-addr.arpa. usage
pub static ref IN_ADDR_ARPA_172_22: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("22").unwrap().append_domain(&*IN_ADDR_ARPA_172)); pub static ref IN_ADDR_ARPA_172_22: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("22").unwrap().append_domain(&*IN_ADDR_ARPA_172));
/// 23.172.in-addr.arpa. usage /// 23.172.in-addr.arpa. usage
pub static ref IN_ADDR_ARPA_172_23: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("23").unwrap().append_domain(&*IN_ADDR_ARPA_172)); pub static ref IN_ADDR_ARPA_172_23: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("23").unwrap().append_domain(&*IN_ADDR_ARPA_172));
/// 24.172.in-addr.arpa. usage /// 24.172.in-addr.arpa. usage
pub static ref IN_ADDR_ARPA_172_24: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("24").unwrap().append_domain(&*IN_ADDR_ARPA_172)); pub static ref IN_ADDR_ARPA_172_24: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("24").unwrap().append_domain(&*IN_ADDR_ARPA_172));
@@ -54,7 +54,7 @@ lazy_static! {
pub static ref IN_ADDR_ARPA_172_26: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("26").unwrap().append_domain(&*IN_ADDR_ARPA_172)); pub static ref IN_ADDR_ARPA_172_26: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("26").unwrap().append_domain(&*IN_ADDR_ARPA_172));
/// 27.172.in-addr.arpa. usage /// 27.172.in-addr.arpa. usage
pub static ref IN_ADDR_ARPA_172_27: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("27").unwrap().append_domain(&*IN_ADDR_ARPA_172)); pub static ref IN_ADDR_ARPA_172_27: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("27").unwrap().append_domain(&*IN_ADDR_ARPA_172));
/// 28.172.in-addr.arpa. usage /// 28.172.in-addr.arpa. usage
pub static ref IN_ADDR_ARPA_172_28: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("28").unwrap().append_domain(&*IN_ADDR_ARPA_172)); pub static ref IN_ADDR_ARPA_172_28: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("28").unwrap().append_domain(&*IN_ADDR_ARPA_172));
/// 29.172.in-addr.arpa. usage /// 29.172.in-addr.arpa. usage
pub static ref IN_ADDR_ARPA_172_29: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("29").unwrap().append_domain(&*IN_ADDR_ARPA_172)); pub static ref IN_ADDR_ARPA_172_29: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("29").unwrap().append_domain(&*IN_ADDR_ARPA_172));
@@ -67,13 +67,13 @@ lazy_static! {
pub static ref IN_ADDR_ARPA_192_168: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("168.192").unwrap().append_domain(&*IN_ADDR_ARPA)); pub static ref IN_ADDR_ARPA_192_168: ZoneUsage = ZoneUsage::reverse(Name::from_ascii("168.192").unwrap().append_domain(&*IN_ADDR_ARPA));
} }
// example., example.com., example.net., and example.org. // example., example.com., example.net., and example.org.
// //
// [Special-Use Domain Names](https://tools.ietf.org/html/rfc6761), RFC 6761 February, 2013 // [Special-Use Domain Names](https://tools.ietf.org/html/rfc6761), RFC 6761 February, 2013
// //
// ```text // ```text
// 6.5. Domain Name Reservation Considerations for Example Domains // 6.5. Domain Name Reservation Considerations for Example Domains
// //
// The domains "example.", "example.com.", "example.net.", // The domains "example.", "example.com.", "example.net.",
// "example.org.", and any names falling within those domains, are // "example.org.", and any names falling within those domains, are
// special in the following ways: // special in the following ways:
@@ -83,7 +83,7 @@ lazy_static! {
static ref NET: Label = Label::from_ascii("net").unwrap(); static ref NET: Label = Label::from_ascii("net").unwrap();
static ref ORG: Label = Label::from_ascii("org").unwrap(); static ref ORG: Label = Label::from_ascii("org").unwrap();
static ref EXAMPLE_L: Label = Label::from_ascii("example").unwrap(); static ref EXAMPLE_L: Label = Label::from_ascii("example").unwrap();
/// example. usage /// example. usage
pub static ref EXAMPLE: ZoneUsage = ZoneUsage::example(Name::from_labels(vec![EXAMPLE_L.clone()]).unwrap()); pub static ref EXAMPLE: ZoneUsage = ZoneUsage::example(Name::from_labels(vec![EXAMPLE_L.clone()]).unwrap());
/// example.com. usage /// example.com. usage
@@ -100,7 +100,7 @@ lazy_static! {
// //
// ```text // ```text
// 6.2. Domain Name Reservation Considerations for "test." // 6.2. Domain Name Reservation Considerations for "test."
// //
// The domain "test.", and any names falling within ".test.", are // The domain "test.", and any names falling within ".test.", are
// special in the following ways: // special in the following ways:
// ``` // ```
@@ -122,9 +122,9 @@ impl TrieKey for TrieName {
/// Returns this name in byte form, reversed for searching from zone to local label /// Returns this name in byte form, reversed for searching from zone to local label
/// ///
/// # Panics /// # Panics
/// ///
/// This will panic on bad names /// This will panic on bad names
fn encode_bytes(&self) -> Vec<u8> { fn encode_bytes(&self) -> Vec<u8> {
let mut bytes = self.0.to_bytes().expect("bad name for trie"); let mut bytes = self.0.to_bytes().expect("bad name for trie");
bytes.reverse(); bytes.reverse();
bytes bytes
@@ -144,9 +144,9 @@ impl<'n> TrieKey for TrieNameRef<'n> {
/// Returns this name in byte form, reversed for searching from zone to local label /// Returns this name in byte form, reversed for searching from zone to local label
/// ///
/// # Panics /// # Panics
/// ///
/// This will panic on bad names /// This will panic on bad names
fn encode_bytes(&self) -> Vec<u8> { fn encode_bytes(&self) -> Vec<u8> {
let mut bytes = self.0.to_bytes().expect("bad name for trie"); let mut bytes = self.0.to_bytes().expect("bad name for trie");
bytes.reverse(); bytes.reverse();
bytes bytes
@@ -162,39 +162,85 @@ impl UsageTrie {
let mut trie: Trie<TrieName, &'static ZoneUsage> = Trie::new(); let mut trie: Trie<TrieName, &'static ZoneUsage> = Trie::new();
assert!(trie.insert(DEFAULT.clone().into(), &DEFAULT).is_none()); assert!(trie.insert(DEFAULT.clone().into(), &DEFAULT).is_none());
assert!(trie.insert(IN_ADDR_ARPA_10.clone().into(), &IN_ADDR_ARPA_10).is_none()); assert!(trie
assert!(trie.insert(IN_ADDR_ARPA_172_16.clone().into(), &IN_ADDR_ARPA_172_16).is_none()); .insert(IN_ADDR_ARPA_10.clone().into(), &IN_ADDR_ARPA_10)
assert!(trie.insert(IN_ADDR_ARPA_172_17.clone().into(), &IN_ADDR_ARPA_172_17).is_none()); .is_none());
assert!(trie.insert(IN_ADDR_ARPA_172_18.clone().into(), &IN_ADDR_ARPA_172_18).is_none()); assert!(trie
assert!(trie.insert(IN_ADDR_ARPA_172_19.clone().into(), &IN_ADDR_ARPA_172_19).is_none()); .insert(IN_ADDR_ARPA_172_16.clone().into(), &IN_ADDR_ARPA_172_16)
assert!(trie.insert(IN_ADDR_ARPA_172_20.clone().into(), &IN_ADDR_ARPA_172_20).is_none()); .is_none());
assert!(trie.insert(IN_ADDR_ARPA_172_21.clone().into(), &IN_ADDR_ARPA_172_21).is_none()); assert!(trie
assert!(trie.insert(IN_ADDR_ARPA_172_22.clone().into(), &IN_ADDR_ARPA_172_22).is_none()); .insert(IN_ADDR_ARPA_172_17.clone().into(), &IN_ADDR_ARPA_172_17)
assert!(trie.insert(IN_ADDR_ARPA_172_23.clone().into(), &IN_ADDR_ARPA_172_23).is_none()); .is_none());
assert!(trie.insert(IN_ADDR_ARPA_172_24.clone().into(), &IN_ADDR_ARPA_172_24).is_none()); assert!(trie
assert!(trie.insert(IN_ADDR_ARPA_172_25.clone().into(), &IN_ADDR_ARPA_172_25).is_none()); .insert(IN_ADDR_ARPA_172_18.clone().into(), &IN_ADDR_ARPA_172_18)
assert!(trie.insert(IN_ADDR_ARPA_172_26.clone().into(), &IN_ADDR_ARPA_172_26).is_none()); .is_none());
assert!(trie.insert(IN_ADDR_ARPA_172_27.clone().into(), &IN_ADDR_ARPA_172_27).is_none()); assert!(trie
assert!(trie.insert(IN_ADDR_ARPA_172_28.clone().into(), &IN_ADDR_ARPA_172_28).is_none()); .insert(IN_ADDR_ARPA_172_19.clone().into(), &IN_ADDR_ARPA_172_19)
assert!(trie.insert(IN_ADDR_ARPA_172_29.clone().into(), &IN_ADDR_ARPA_172_29).is_none()); .is_none());
assert!(trie.insert(IN_ADDR_ARPA_172_30.clone().into(), &IN_ADDR_ARPA_172_30).is_none()); assert!(trie
assert!(trie.insert(IN_ADDR_ARPA_172_31.clone().into(), &IN_ADDR_ARPA_172_31).is_none()); .insert(IN_ADDR_ARPA_172_20.clone().into(), &IN_ADDR_ARPA_172_20)
assert!(trie.insert(IN_ADDR_ARPA_192_168.clone().into(), &IN_ADDR_ARPA_192_168).is_none()); .is_none());
assert!(trie
.insert(IN_ADDR_ARPA_172_21.clone().into(), &IN_ADDR_ARPA_172_21)
.is_none());
assert!(trie
.insert(IN_ADDR_ARPA_172_22.clone().into(), &IN_ADDR_ARPA_172_22)
.is_none());
assert!(trie
.insert(IN_ADDR_ARPA_172_23.clone().into(), &IN_ADDR_ARPA_172_23)
.is_none());
assert!(trie
.insert(IN_ADDR_ARPA_172_24.clone().into(), &IN_ADDR_ARPA_172_24)
.is_none());
assert!(trie
.insert(IN_ADDR_ARPA_172_25.clone().into(), &IN_ADDR_ARPA_172_25)
.is_none());
assert!(trie
.insert(IN_ADDR_ARPA_172_26.clone().into(), &IN_ADDR_ARPA_172_26)
.is_none());
assert!(trie
.insert(IN_ADDR_ARPA_172_27.clone().into(), &IN_ADDR_ARPA_172_27)
.is_none());
assert!(trie
.insert(IN_ADDR_ARPA_172_28.clone().into(), &IN_ADDR_ARPA_172_28)
.is_none());
assert!(trie
.insert(IN_ADDR_ARPA_172_29.clone().into(), &IN_ADDR_ARPA_172_29)
.is_none());
assert!(trie
.insert(IN_ADDR_ARPA_172_30.clone().into(), &IN_ADDR_ARPA_172_30)
.is_none());
assert!(trie
.insert(IN_ADDR_ARPA_172_31.clone().into(), &IN_ADDR_ARPA_172_31)
.is_none());
assert!(trie
.insert(IN_ADDR_ARPA_192_168.clone().into(), &IN_ADDR_ARPA_192_168)
.is_none());
assert!(trie.insert(TEST.clone().into(), &TEST).is_none()); assert!(trie.insert(TEST.clone().into(), &TEST).is_none());
assert!(trie.insert(LOCALHOST.clone().into(), &LOCALHOST).is_none()); assert!(trie.insert(LOCALHOST.clone().into(), &LOCALHOST).is_none());
assert!(trie.insert(IN_ADDR_ARPA_127.clone().into(), &IN_ADDR_ARPA_127).is_none()); assert!(trie
assert!(trie.insert(IP6_ARPA_1.clone().into(), &IP6_ARPA_1).is_none()); .insert(IN_ADDR_ARPA_127.clone().into(), &IN_ADDR_ARPA_127)
.is_none());
assert!(trie
.insert(IP6_ARPA_1.clone().into(), &IP6_ARPA_1)
.is_none());
assert!(trie.insert(INVALID.clone().into(), &INVALID).is_none()); assert!(trie.insert(INVALID.clone().into(), &INVALID).is_none());
assert!(trie.insert(EXAMPLE.clone().into(), &EXAMPLE).is_none()); assert!(trie.insert(EXAMPLE.clone().into(), &EXAMPLE).is_none());
assert!(trie.insert(EXAMPLE_COM.clone().into(), &EXAMPLE_COM).is_none()); assert!(trie
assert!(trie.insert(EXAMPLE_NET.clone().into(), &EXAMPLE_NET).is_none()); .insert(EXAMPLE_COM.clone().into(), &EXAMPLE_COM)
assert!(trie.insert(EXAMPLE_ORG.clone().into(), &EXAMPLE_ORG).is_none()); .is_none());
assert!(trie
.insert(EXAMPLE_NET.clone().into(), &EXAMPLE_NET)
.is_none());
assert!(trie
.insert(EXAMPLE_ORG.clone().into(), &EXAMPLE_ORG)
.is_none());
UsageTrie(trie) UsageTrie(trie)
} }
@@ -204,11 +250,13 @@ impl UsageTrie {
/// ///
/// Matches the closest zone encapsulating `name`, at a minimum the default root zone usage will be returned /// Matches the closest zone encapsulating `name`, at a minimum the default root zone usage will be returned
pub fn get(&self, name: &Name) -> &'static ZoneUsage { pub fn get(&self, name: &Name) -> &'static ZoneUsage {
self.0.get_ancestor_value(&TrieName::from(name.clone())).expect("DEFAULT root ZoneUsage should have been returned") self.0
.get_ancestor_value(&TrieName::from(name.clone()))
.expect("DEFAULT root ZoneUsage should have been returned")
} }
} }
lazy_static!{ lazy_static! {
/// All default usage mappings /// All default usage mappings
pub static ref USAGE: UsageTrie = UsageTrie::default(); pub static ref USAGE: UsageTrie = UsageTrie::default();
} }
@@ -229,33 +277,107 @@ mod tests {
#[test] #[test]
fn test_local_networks() { fn test_local_networks() {
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(9,0,0,1))).name(), DEFAULT.name()); assert_eq!(
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(10,0,0,1))).name(), IN_ADDR_ARPA_10.name()); USAGE.get(&Name::from(Ipv4Addr::new(9, 0, 0, 1))).name(),
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(11,0,0,1))).name(), DEFAULT.name()); DEFAULT.name()
);
assert_eq!(
USAGE.get(&Name::from(Ipv4Addr::new(10, 0, 0, 1))).name(),
IN_ADDR_ARPA_10.name()
);
assert_eq!(
USAGE.get(&Name::from(Ipv4Addr::new(11, 0, 0, 1))).name(),
DEFAULT.name()
);
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(172,16,0,0))).name(), IN_ADDR_ARPA_172_16.name()); assert_eq!(
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(172,17,0,0))).name(), IN_ADDR_ARPA_172_17.name()); USAGE.get(&Name::from(Ipv4Addr::new(172, 16, 0, 0))).name(),
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(172,18,0,0))).name(), IN_ADDR_ARPA_172_18.name()); IN_ADDR_ARPA_172_16.name()
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(172,19,0,0))).name(), IN_ADDR_ARPA_172_19.name()); );
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(172,20,0,0))).name(), IN_ADDR_ARPA_172_20.name()); assert_eq!(
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(172,21,0,0))).name(), IN_ADDR_ARPA_172_21.name()); USAGE.get(&Name::from(Ipv4Addr::new(172, 17, 0, 0))).name(),
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(172,22,0,0))).name(), IN_ADDR_ARPA_172_22.name()); IN_ADDR_ARPA_172_17.name()
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(172,23,0,0))).name(), IN_ADDR_ARPA_172_23.name()); );
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(172,24,0,0))).name(), IN_ADDR_ARPA_172_24.name()); assert_eq!(
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(172,25,0,0))).name(), IN_ADDR_ARPA_172_25.name()); USAGE.get(&Name::from(Ipv4Addr::new(172, 18, 0, 0))).name(),
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(172,26,0,0))).name(), IN_ADDR_ARPA_172_26.name()); IN_ADDR_ARPA_172_18.name()
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(172,27,0,0))).name(), IN_ADDR_ARPA_172_27.name()); );
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(172,28,0,0))).name(), IN_ADDR_ARPA_172_28.name()); assert_eq!(
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(172,29,0,0))).name(), IN_ADDR_ARPA_172_29.name()); USAGE.get(&Name::from(Ipv4Addr::new(172, 19, 0, 0))).name(),
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(172,30,0,0))).name(), IN_ADDR_ARPA_172_30.name()); IN_ADDR_ARPA_172_19.name()
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(172,31,0,0))).name(), IN_ADDR_ARPA_172_31.name()); );
assert_eq!(
USAGE.get(&Name::from(Ipv4Addr::new(172, 20, 0, 0))).name(),
IN_ADDR_ARPA_172_20.name()
);
assert_eq!(
USAGE.get(&Name::from(Ipv4Addr::new(172, 21, 0, 0))).name(),
IN_ADDR_ARPA_172_21.name()
);
assert_eq!(
USAGE.get(&Name::from(Ipv4Addr::new(172, 22, 0, 0))).name(),
IN_ADDR_ARPA_172_22.name()
);
assert_eq!(
USAGE.get(&Name::from(Ipv4Addr::new(172, 23, 0, 0))).name(),
IN_ADDR_ARPA_172_23.name()
);
assert_eq!(
USAGE.get(&Name::from(Ipv4Addr::new(172, 24, 0, 0))).name(),
IN_ADDR_ARPA_172_24.name()
);
assert_eq!(
USAGE.get(&Name::from(Ipv4Addr::new(172, 25, 0, 0))).name(),
IN_ADDR_ARPA_172_25.name()
);
assert_eq!(
USAGE.get(&Name::from(Ipv4Addr::new(172, 26, 0, 0))).name(),
IN_ADDR_ARPA_172_26.name()
);
assert_eq!(
USAGE.get(&Name::from(Ipv4Addr::new(172, 27, 0, 0))).name(),
IN_ADDR_ARPA_172_27.name()
);
assert_eq!(
USAGE.get(&Name::from(Ipv4Addr::new(172, 28, 0, 0))).name(),
IN_ADDR_ARPA_172_28.name()
);
assert_eq!(
USAGE.get(&Name::from(Ipv4Addr::new(172, 29, 0, 0))).name(),
IN_ADDR_ARPA_172_29.name()
);
assert_eq!(
USAGE.get(&Name::from(Ipv4Addr::new(172, 30, 0, 0))).name(),
IN_ADDR_ARPA_172_30.name()
);
assert_eq!(
USAGE.get(&Name::from(Ipv4Addr::new(172, 31, 0, 0))).name(),
IN_ADDR_ARPA_172_31.name()
);
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(172,15,0,0))).name(), DEFAULT.name()); assert_eq!(
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(172,32,0,0))).name(), DEFAULT.name()); USAGE.get(&Name::from(Ipv4Addr::new(172, 15, 0, 0))).name(),
DEFAULT.name()
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(192,167,255,255))).name(), DEFAULT.name()); );
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(192,168,2,3))).name(), IN_ADDR_ARPA_192_168.name()); assert_eq!(
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(192,169,0,0))).name(), DEFAULT.name()); USAGE.get(&Name::from(Ipv4Addr::new(172, 32, 0, 0))).name(),
DEFAULT.name()
);
assert_eq!(
USAGE
.get(&Name::from(Ipv4Addr::new(192, 167, 255, 255)))
.name(),
DEFAULT.name()
);
assert_eq!(
USAGE.get(&Name::from(Ipv4Addr::new(192, 168, 2, 3))).name(),
IN_ADDR_ARPA_192_168.name()
);
assert_eq!(
USAGE.get(&Name::from(Ipv4Addr::new(192, 169, 0, 0))).name(),
DEFAULT.name()
);
} }
#[test] #[test]
@@ -297,11 +419,25 @@ mod tests {
let usage = USAGE.get(&name); let usage = USAGE.get(&name);
assert_eq!(usage.name(), LOCALHOST.name()); assert_eq!(usage.name(), LOCALHOST.name());
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(127,0,0,1))).name(), IN_ADDR_ARPA_127.name()); assert_eq!(
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(127,0,0,2))).name(), IN_ADDR_ARPA_127.name()); USAGE.get(&Name::from(Ipv4Addr::new(127, 0, 0, 1))).name(),
assert_eq!(USAGE.get(&Name::from(Ipv4Addr::new(127,255,0,0))).name(), IN_ADDR_ARPA_127.name()); IN_ADDR_ARPA_127.name()
assert_eq!(USAGE.get(&Name::from(Ipv6Addr::new(0,0,0,0,0,0,0,1))).name(), IP6_ARPA_1.name()); );
assert_eq!(
USAGE.get(&Name::from(Ipv4Addr::new(127, 0, 0, 2))).name(),
IN_ADDR_ARPA_127.name()
);
assert_eq!(
USAGE.get(&Name::from(Ipv4Addr::new(127, 255, 0, 0))).name(),
IN_ADDR_ARPA_127.name()
);
assert_eq!(
USAGE
.get(&Name::from(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)))
.name(),
IP6_ARPA_1.name()
);
} }
#[test] #[test]
@@ -329,4 +465,4 @@ mod tests {
let usage = USAGE.get(&name); let usage = USAGE.get(&name);
assert_eq!(usage.name(), TEST.name()); assert_eq!(usage.name(), TEST.name());
} }
} }

View File

@@ -104,11 +104,13 @@ impl<'a> Lexer<'a> {
} }
State::Comment { is_list } => { State::Comment { is_list } => {
match ch { match ch {
Some('\r') | Some('\n') => if is_list { Some('\r') | Some('\n') => {
self.state = State::List; if is_list {
} else { self.state = State::List;
self.state = State::EOL; } else {
}, // out of the comment self.state = State::EOL;
}
} // out of the comment
Some(_) => { Some(_) => {
self.txt.next(); self.txt.next();
} // advance the token by default and maintain state } // advance the token by default and maintain state
@@ -163,7 +165,8 @@ impl<'a> Lexer<'a> {
} else { } else {
return Err(LexerErrorKind::UnrecognizedDollar( return Err(LexerErrorKind::UnrecognizedDollar(
char_data.take().unwrap_or_else(|| "".into()), char_data.take().unwrap_or_else(|| "".into()),
).into()); )
.into());
} }
} }
} }
@@ -180,7 +183,8 @@ impl<'a> Lexer<'a> {
.take() .take()
.ok_or_else(|| { .ok_or_else(|| {
LexerErrorKind::IllegalState("char_data_vec is None").into() LexerErrorKind::IllegalState("char_data_vec is None").into()
}).map(|v| Some(Token::List(v))); })
.map(|v| Some(Token::List(v)));
} }
Some(ch) if ch.is_whitespace() => { Some(ch) if ch.is_whitespace() => {
self.txt.next(); self.txt.next();
@@ -197,30 +201,33 @@ impl<'a> Lexer<'a> {
Some(')') if !is_list => { Some(')') if !is_list => {
return Err(LexerErrorKind::IllegalCharacter(ch.unwrap_or(')')).into()) return Err(LexerErrorKind::IllegalCharacter(ch.unwrap_or(')')).into())
} }
Some(ch) if ch.is_whitespace() || ch == ')' || ch == ';' => if is_list { Some(ch) if ch.is_whitespace() || ch == ')' || ch == ';' => {
char_data_vec if is_list {
.as_mut() char_data_vec
.ok_or_else(|| { .as_mut()
LexerError::from(LexerErrorKind::IllegalState( .ok_or_else(|| {
"char_data_vec is None", LexerError::from(LexerErrorKind::IllegalState(
)) "char_data_vec is None",
}).and_then(|v| { ))
let char_data = char_data.take().ok_or_else(|| { })
LexerErrorKind::IllegalState("char_data is None") .and_then(|v| {
})?; let char_data = char_data.take().ok_or_else(|| {
LexerErrorKind::IllegalState("char_data is None")
})?;
v.push(char_data); v.push(char_data);
Ok(()) Ok(())
})?; })?;
self.state = State::List; self.state = State::List;
} else { } else {
self.state = State::RestOfLine; self.state = State::RestOfLine;
let result = char_data.take().ok_or_else(|| { let result = char_data.take().ok_or_else(|| {
LexerErrorKind::IllegalState("char_data is None").into() LexerErrorKind::IllegalState("char_data is None").into()
}); });
let opt = result.map(|s| Some(Token::CharData(s))); let opt = result.map(|s| Some(Token::CharData(s)));
return opt; return opt;
}, }
}
// TODO: this next one can be removed, but will keep unescaping for quoted strings // TODO: this next one can be removed, but will keep unescaping for quoted strings
//Some('\\') => { try!(Self::push_to_str(&mut char_data, try!(self.escape_seq()))); }, //Some('\\') => { try!(Self::push_to_str(&mut char_data, try!(self.escape_seq()))); },
Some(ch) if !ch.is_control() && !ch.is_whitespace() => { Some(ch) if !ch.is_control() && !ch.is_whitespace() => {
@@ -234,7 +241,8 @@ impl<'a> Lexer<'a> {
.take() .take()
.ok_or_else(|| { .ok_or_else(|| {
LexerErrorKind::IllegalState("char_data is None").into() LexerErrorKind::IllegalState("char_data is None").into()
}).map(|s| Some(Token::CharData(s))); })
.map(|s| Some(Token::CharData(s)));
} }
} }
} }

View File

@@ -16,8 +16,8 @@
//! Text serialization types //! Text serialization types
mod master_lex;
mod master; mod master;
mod master_lex;
mod parse_rdata; mod parse_rdata;
mod rdata_parsers; mod rdata_parsers;

View File

@@ -25,9 +25,7 @@ use crate::error::*;
pub fn parse<'i, I: Iterator<Item = &'i str>>(mut tokens: I) -> ParseResult<Ipv4Addr> { pub fn parse<'i, I: Iterator<Item = &'i str>>(mut tokens: I) -> ParseResult<Ipv4Addr> {
let address: Ipv4Addr = tokens let address: Ipv4Addr = tokens
.next() .next()
.ok_or_else(|| { .ok_or_else(|| ParseError::from(ParseErrorKind::MissingToken("ipv4 address".to_string())))
ParseError::from(ParseErrorKind::MissingToken("ipv4 address".to_string()))
})
.and_then(|s| Ipv4Addr::from_str(s).map_err(Into::into))?; .and_then(|s| Ipv4Addr::from_str(s).map_err(Into::into))?;
Ok(address) Ok(address)
} }

View File

@@ -21,14 +21,11 @@ use std::str::FromStr;
use crate::error::*; use crate::error::*;
/// Parse the RData from a set of Tokens /// Parse the RData from a set of Tokens
pub fn parse<'i, I: Iterator<Item = &'i str>>(mut tokens: I) -> ParseResult<Ipv6Addr> { pub fn parse<'i, I: Iterator<Item = &'i str>>(mut tokens: I) -> ParseResult<Ipv6Addr> {
let address: Ipv6Addr = tokens let address: Ipv6Addr = tokens
.next() .next()
.ok_or_else(|| { .ok_or_else(|| ParseError::from(ParseErrorKind::MissingToken("ipv6 address".to_string())))
ParseError::from(ParseErrorKind::MissingToken("ipv6 address".to_string()))
})
.and_then(|s| Ipv6Addr::from_str(s).map_err(Into::into))?; .and_then(|s| Ipv6Addr::from_str(s).map_err(Into::into))?;
Ok(address) Ok(address)
} }

View File

@@ -44,15 +44,15 @@ use crate::rr::rdata::CAA;
/// specified in [RFC1035], Section 5.1. /// specified in [RFC1035], Section 5.1.
/// ``` /// ```
pub fn parse<'i, I: Iterator<Item = &'i str>>(mut tokens: I) -> ParseResult<CAA> { pub fn parse<'i, I: Iterator<Item = &'i str>>(mut tokens: I) -> ParseResult<CAA> {
let flags_str: &str = tokens.next().ok_or_else(|| { let flags_str: &str = tokens
ParseError::from(ParseErrorKind::Message("caa flags not present")) .next()
})?; .ok_or_else(|| ParseError::from(ParseErrorKind::Message("caa flags not present")))?;
let tag_str: &str = tokens.next().ok_or_else(|| { let tag_str: &str = tokens
ParseError::from(ParseErrorKind::Message("caa tag not present")) .next()
})?; .ok_or_else(|| ParseError::from(ParseErrorKind::Message("caa tag not present")))?;
let value_str: &str = tokens.next().ok_or_else(|| { let value_str: &str = tokens
ParseError::from(ParseErrorKind::Message("caa value not present")) .next()
})?; .ok_or_else(|| ParseError::from(ParseErrorKind::Message("caa value not present")))?;
// parse the flags // parse the flags
let issuer_critical = { let issuer_critical = {

View File

@@ -27,15 +27,11 @@ pub fn parse<'i, I: Iterator<Item = &'i str>>(
) -> ParseResult<MX> { ) -> ParseResult<MX> {
let preference: u16 = tokens let preference: u16 = tokens
.next() .next()
.ok_or_else(|| { .ok_or_else(|| ParseError::from(ParseErrorKind::MissingToken("preference".to_string())))
ParseError::from(ParseErrorKind::MissingToken("preference".to_string()))
})
.and_then(|s| s.parse().map_err(Into::into))?; .and_then(|s| s.parse().map_err(Into::into))?;
let exchange: Name = tokens let exchange: Name = tokens
.next() .next()
.ok_or_else(|| { .ok_or_else(|| ParseErrorKind::MissingToken("exchange".to_string()).into())
ParseErrorKind::MissingToken("exchange".to_string()).into()
})
.and_then(|s| Name::parse(s, origin).map_err(ParseError::from))?; .and_then(|s| Name::parse(s, origin).map_err(ParseError::from))?;
Ok(MX::new(preference, exchange)) Ok(MX::new(preference, exchange))

View File

@@ -38,46 +38,30 @@ pub fn parse<'i, I: Iterator<Item = &'i str>>(
let serial: u32 = tokens let serial: u32 = tokens
.next() .next()
.ok_or_else(|| { .ok_or_else(|| ParseError::from(ParseErrorKind::MissingToken("serial".to_string())))
ParseError::from(ParseErrorKind::MissingToken("serial".to_string()))
})
.and_then(|s| u32::from_str(s).map_err(Into::into))?; .and_then(|s| u32::from_str(s).map_err(Into::into))?;
let refresh: i32 = tokens let refresh: i32 = tokens
.next() .next()
.ok_or_else(|| { .ok_or_else(|| ParseError::from(ParseErrorKind::MissingToken("refresh".to_string())))
ParseError::from(ParseErrorKind::MissingToken("refresh".to_string()))
})
.and_then(|s| i32::from_str(s).map_err(Into::into))?; .and_then(|s| i32::from_str(s).map_err(Into::into))?;
let retry: i32 = tokens let retry: i32 = tokens
.next() .next()
.ok_or_else(|| { .ok_or_else(|| ParseError::from(ParseErrorKind::MissingToken("retry".to_string())))
ParseError::from(ParseErrorKind::MissingToken("retry".to_string()))
})
.and_then(|s| i32::from_str(s).map_err(Into::into))?; .and_then(|s| i32::from_str(s).map_err(Into::into))?;
let expire: i32 = tokens let expire: i32 = tokens
.next() .next()
.ok_or_else(|| { .ok_or_else(|| ParseError::from(ParseErrorKind::MissingToken("expire".to_string())))
ParseError::from(ParseErrorKind::MissingToken("expire".to_string()))
})
.and_then(|s| i32::from_str(s).map_err(Into::into))?; .and_then(|s| i32::from_str(s).map_err(Into::into))?;
let minimum: u32 = tokens let minimum: u32 = tokens
.next() .next()
.ok_or_else(|| { .ok_or_else(|| ParseError::from(ParseErrorKind::MissingToken("minimum".to_string())))
ParseError::from(ParseErrorKind::MissingToken("minimum".to_string()))
})
.and_then(|s| u32::from_str(s).map_err(Into::into))?; .and_then(|s| u32::from_str(s).map_err(Into::into))?;
Ok(SOA::new( Ok(SOA::new(
mname, mname, rname, serial, refresh, retry, expire, minimum,
rname,
serial,
refresh,
retry,
expire,
minimum,
)) ))
} }

View File

@@ -28,30 +28,22 @@ pub fn parse<'i, I: Iterator<Item = &'i str>>(
) -> ParseResult<SRV> { ) -> ParseResult<SRV> {
let priority: u16 = tokens let priority: u16 = tokens
.next() .next()
.ok_or_else(|| { .ok_or_else(|| ParseError::from(ParseErrorKind::MissingToken("priority".to_string())))
ParseError::from(ParseErrorKind::MissingToken("priority".to_string()))
})
.and_then(|s| u16::from_str(s).map_err(Into::into))?; .and_then(|s| u16::from_str(s).map_err(Into::into))?;
let weight: u16 = tokens let weight: u16 = tokens
.next() .next()
.ok_or_else(|| { .ok_or_else(|| ParseError::from(ParseErrorKind::MissingToken("weight".to_string())))
ParseError::from(ParseErrorKind::MissingToken("weight".to_string()))
})
.and_then(|s| u16::from_str(s).map_err(Into::into))?; .and_then(|s| u16::from_str(s).map_err(Into::into))?;
let port: u16 = tokens let port: u16 = tokens
.next() .next()
.ok_or_else(|| { .ok_or_else(|| ParseError::from(ParseErrorKind::MissingToken("port".to_string())))
ParseError::from(ParseErrorKind::MissingToken("port".to_string()))
})
.and_then(|s| u16::from_str(s).map_err(Into::into))?; .and_then(|s| u16::from_str(s).map_err(Into::into))?;
let target: Name = tokens let target: Name = tokens
.next() .next()
.ok_or_else(|| { .ok_or_else(|| ParseError::from(ParseErrorKind::MissingToken("target".to_string())))
ParseError::from(ParseErrorKind::MissingToken("target".to_string()))
})
.and_then(|s| Name::parse(s, origin).map_err(ParseError::from))?; .and_then(|s| Name::parse(s, origin).map_err(ParseError::from))?;
Ok(SRV::new(priority, weight, port, target)) Ok(SRV::new(priority, weight, port, target))

View File

@@ -8,7 +8,6 @@
//! SSHFP records for SSH public key fingerprints //! SSHFP records for SSH public key fingerprints
use data_encoding::{Encoding, Specification}; use data_encoding::{Encoding, Specification};
use crate::error::*; use crate::error::*;
use crate::rr::rdata::SSHFP; use crate::rr::rdata::SSHFP;
@@ -132,8 +131,7 @@ fn test_parsing() {
ECDSA, ECDSA,
SHA1, SHA1,
&[ &[
198, 70, 7, 162, 140, 83, 0, 254, 193, 24, 11, 110, 65, 123, 146, 41, 67, 207, 252, 198, 70, 7, 162, 140, 83, 0, 254, 193, 24, 11, 110, 65, 123, 146, 41, 67, 207, 252, 221,
221,
], ],
); );
test_parsing( test_parsing(

View File

@@ -6,13 +6,12 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be // http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms. // copied, modified, or distributed except according to those terms.
//! tlsa records for storing TLS authentication records //! tlsa records for storing TLS authentication records
use data_encoding::{Encoding, Specification}; use data_encoding::{Encoding, Specification};
use crate::error::*; use crate::error::*;
use crate::rr::rdata::TLSA;
use crate::rr::rdata::tlsa::CertUsage; use crate::rr::rdata::tlsa::CertUsage;
use crate::rr::rdata::TLSA;
// TODO: dedup with sshfp // TODO: dedup with sshfp
lazy_static! { lazy_static! {
@@ -56,16 +55,18 @@ fn to_u8(data: &str) -> ParseResult<u8> {
pub fn parse<'i, I: Iterator<Item = &'i str>>(tokens: I) -> ParseResult<TLSA> { pub fn parse<'i, I: Iterator<Item = &'i str>>(tokens: I) -> ParseResult<TLSA> {
let mut iter = tokens; let mut iter = tokens;
let token: &str = iter.next().ok_or_else(|| { let token: &str = iter
ParseError::from(ParseErrorKind::Message("TLSA usage field missing")) .next()
})?; .ok_or_else(|| ParseError::from(ParseErrorKind::Message("TLSA usage field missing")))?;
let usage = CertUsage::from(to_u8(token)?); let usage = CertUsage::from(to_u8(token)?);
let token = iter.next() let token = iter
.next()
.ok_or_else(|| ParseErrorKind::Message("TLSA selector field missing"))?; .ok_or_else(|| ParseErrorKind::Message("TLSA selector field missing"))?;
let selector = to_u8(token)?.into(); let selector = to_u8(token)?.into();
let token = iter.next() let token = iter
.next()
.ok_or_else(|| ParseErrorKind::Message("TLSA matching field missing"))?; .ok_or_else(|| ParseErrorKind::Message("TLSA matching field missing"))?;
let matching = to_u8(token)?.into(); let matching = to_u8(token)?.into();
@@ -90,29 +91,29 @@ mod tests {
#[test] #[test]
fn test_parsing() { fn test_parsing() {
assert!( assert!(parse(
parse( vec![
vec![ "0",
"0", "0",
"0", "1",
"1", "d2abde240d7cd3ee6b4b28c54df034b9",
"d2abde240d7cd3ee6b4b28c54df034b9", "7983a1d16e8a410e4561cb106618e971",
"7983a1d16e8a410e4561cb106618e971", ]
].into_iter() .into_iter()
).is_ok() )
); .is_ok());
assert!( assert!(parse(
parse( vec![
vec![ "1",
"1", "1",
"1", "2",
"2", "92003ba34942dc74152e2f2c408d29ec",
"92003ba34942dc74152e2f2c408d29ec", "a5a520e7f2e06bb944f4dca346baf63c",
"a5a520e7f2e06bb944f4dca346baf63c", "1b177615d466f6c4b71c216a50292bd5",
"1b177615d466f6c4b71c216a50292bd5", "8c9ebdd2f74e38fe51ffd48c43326cbc",
"8c9ebdd2f74e38fe51ffd48c43326cbc", ]
].into_iter() .into_iter()
).is_ok() )
); .is_ok());
} }
} }

View File

@@ -19,6 +19,6 @@
mod tcp_client_connection; mod tcp_client_connection;
use proto::tcp; use proto::tcp;
pub use self::tcp_client_connection::TcpClientConnection;
pub use self::tcp::TcpClientStream; pub use self::tcp::TcpClientStream;
pub use self::tcp::TcpStream; pub use self::tcp::TcpStream;
pub use self::tcp_client_connection::TcpClientConnection;

View File

@@ -11,9 +11,9 @@ use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio_net::tcp::TcpStream;
use proto::tcp::{TcpClientConnect, TcpClientStream}; use proto::tcp::{TcpClientConnect, TcpClientStream};
use proto::xfer::{DnsMultiplexer, DnsMultiplexerConnect, DnsRequestSender}; use proto::xfer::{DnsMultiplexer, DnsMultiplexerConnect, DnsRequestSender};
use tokio_net::tcp::TcpStream;
use crate::client::ClientConnection; use crate::client::ClientConnection;
use crate::error::*; use crate::error::*;
@@ -62,7 +62,8 @@ impl TcpClientConnection {
impl ClientConnection for TcpClientConnection { impl ClientConnection for TcpClientConnection {
type Sender = DnsMultiplexer<TcpClientStream<TcpStream>, Signer>; type Sender = DnsMultiplexer<TcpClientStream<TcpStream>, Signer>;
type Response = <Self::Sender as DnsRequestSender>::DnsResponseFuture; type Response = <Self::Sender as DnsRequestSender>::DnsResponseFuture;
type SenderFuture = DnsMultiplexerConnect<TcpClientConnect<TcpStream>, TcpClientStream<TcpStream>, Signer>; type SenderFuture =
DnsMultiplexerConnect<TcpClientConnect<TcpStream>, TcpClientStream<TcpStream>, Signer>;
fn new_stream(&self, signer: Option<Arc<Signer>>) -> Self::SenderFuture { fn new_stream(&self, signer: Option<Arc<Signer>>) -> Self::SenderFuture {
let (tcp_client_stream, handle) = let (tcp_client_stream, handle) =

View File

@@ -19,6 +19,6 @@
mod udp_client_connection; mod udp_client_connection;
use proto::udp; use proto::udp;
pub use self::udp_client_connection::UdpClientConnection;
pub use self::udp::UdpClientStream; pub use self::udp::UdpClientStream;
pub use self::udp::UdpStream; pub use self::udp::UdpStream;
pub use self::udp_client_connection::UdpClientConnection;

View File

@@ -5,24 +5,24 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be // http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms. // copied, modified, or distributed except according to those terms.
use std::ops::DerefMut;
use std::fmt::{self, Display}; use std::fmt::{self, Display};
use std::io;
use std::mem; use std::mem;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::ops::DerefMut;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use std::task::Context; use std::task::Context;
use std::io;
use bytes::Bytes; use bytes::Bytes;
use futures::{future, Future, FutureExt, Poll, Stream, TryFutureExt}; use futures::{future, Future, FutureExt, Poll, Stream, TryFutureExt};
use h2::client::{Connection, SendRequest};
use h2; use h2;
use h2::client::{Connection, SendRequest};
use http::{self, header}; use http::{self, header};
use rustls::{Certificate, ClientConfig}; use rustls::{Certificate, ClientConfig};
use tokio_executor; use tokio_executor;
use tokio_net::tcp::TcpStream as TokioTcpStream;
use tokio_rustls::{client::TlsStream as TokioTlsClientStream, Connect, TlsConnector}; use tokio_rustls::{client::TlsStream as TokioTlsClientStream, Connect, TlsConnector};
use tokio_net::tcp::{TcpStream as TokioTcpStream};
use typed_headers::{ContentLength, HeaderMapExt}; use typed_headers::{ContentLength, HeaderMapExt};
use webpki::DNSNameRef; use webpki::DNSNameRef;
@@ -53,11 +53,12 @@ impl Display for HttpsClientStream {
} }
impl HttpsClientStream { impl HttpsClientStream {
async fn inner_send(h2: SendRequest<Bytes>, async fn inner_send(
h2: SendRequest<Bytes>,
message: SerialMessage, message: SerialMessage,
name_server_name: Arc<String>, name_server_name: Arc<String>,
name_server: SocketAddr) -> Result<DnsResponse, ProtoError> { name_server: SocketAddr,
) -> Result<DnsResponse, ProtoError> {
let mut h2 = match h2.ready().await { let mut h2 = match h2.ready().await {
Ok(h2) => h2, Ok(h2) => h2,
Err(err) => { Err(err) => {
@@ -71,23 +72,23 @@ impl HttpsClientStream {
let bytes = Bytes::from(message.bytes()); let bytes = Bytes::from(message.bytes());
let request = crate::request::new(&name_server_name, bytes.len()); let request = crate::request::new(&name_server_name, bytes.len());
let request = request.map_err(|err| ProtoError::from(format!("bad http request: {}", err)))?; let request =
request.map_err(|err| ProtoError::from(format!("bad http request: {}", err)))?;
debug!("request: {:#?}", request); debug!("request: {:#?}", request);
// Send the request // Send the request
let (response_future, mut send_stream) = let (response_future, mut send_stream) = h2
h2.send_request(request, false).map_err(|err| { .send_request(request, false)
ProtoError::from(format!("h2 send_request error: {}", err)) .map_err(|err| ProtoError::from(format!("h2 send_request error: {}", err)))?;
})?;
send_stream send_stream
.send_data(bytes, true) .send_data(bytes, true)
.map_err(|e| ProtoError::from(format!("h2 send_data error: {}", e)))?; .map_err(|e| ProtoError::from(format!("h2 send_data error: {}", e)))?;
let mut response_stream = response_future.await.map_err(|err| ProtoError::from( let mut response_stream = response_future
format!("received a stream error: {}", err) .await
))?; .map_err(|err| ProtoError::from(format!("received a stream error: {}", err)))?;
debug!("got response: {:#?}", response_stream); debug!("got response: {:#?}", response_stream);
@@ -101,10 +102,12 @@ impl HttpsClientStream {
// TODO: what is a good max here? // TODO: what is a good max here?
// max(512) says make sure it is at least 512 bytes, and min 4096 says it is at most 4k // max(512) says make sure it is at least 512 bytes, and min 4096 says it is at most 4k
// just a little protection from malicious actors. // just a little protection from malicious actors.
let mut response_bytes = Bytes::with_capacity(content_length.unwrap_or(512).max(512).min(4096)); let mut response_bytes =
Bytes::with_capacity(content_length.unwrap_or(512).max(512).min(4096));
while let Some(partial_bytes) = response_stream.body_mut().data().await { while let Some(partial_bytes) = response_stream.body_mut().data().await {
let partial_bytes = partial_bytes.map_err(|e| ProtoError::from(format!("bad http request: {}", e)))?; let partial_bytes =
partial_bytes.map_err(|e| ProtoError::from(format!("bad http request: {}", e)))?;
debug!("got bytes: {}", partial_bytes.len()); debug!("got bytes: {}", partial_bytes.len());
response_bytes.extend(partial_bytes); response_bytes.extend(partial_bytes);
@@ -136,7 +139,8 @@ impl HttpsClientStream {
// TODO: make explicit error type // TODO: make explicit error type
return Err(ProtoError::from(format!( return Err(ProtoError::from(format!(
"http unsuccessful code: {}, message: {}", "http unsuccessful code: {}, message: {}",
response_stream.status(), error_string response_stream.status(),
error_string
))); )));
} else { } else {
// verify content type // verify content type
@@ -148,12 +152,10 @@ impl HttpsClientStream {
.map(|h| { .map(|h| {
h.to_str().map_err(|err| { h.to_str().map_err(|err| {
// TODO: make explicit error type // TODO: make explicit error type
ProtoError::from(format!( ProtoError::from(format!("ContentType header not a string: {}", err))
"ContentType header not a string: {}",
err
))
}) })
}).unwrap_or(Ok(crate::MIME_APPLICATION_DNS))?; })
.unwrap_or(Ok(crate::MIME_APPLICATION_DNS))?;
if content_type != crate::MIME_APPLICATION_DNS { if content_type != crate::MIME_APPLICATION_DNS {
return Err(ProtoError::from(format!( return Err(ProtoError::from(format!(
@@ -174,7 +176,7 @@ impl HttpsClientStream {
impl DnsRequestSender for HttpsClientStream { impl DnsRequestSender for HttpsClientStream {
type DnsResponseFuture = HttpsClientResponse; type DnsResponseFuture = HttpsClientResponse;
/// This indicates that the HTTP message was successfully sent, and we now have the response.RecvStream /// This indicates that the HTTP message was successfully sent, and we now have the response.RecvStream
/// ///
/// If the request fails, this will return the error, and it should be assumed that the Stream portion of /// If the request fails, this will return the error, and it should be assumed that the Stream portion of
/// this will have no date. /// this will have no date.
@@ -231,13 +233,16 @@ impl DnsRequestSender for HttpsClientStream {
let bytes = match message.to_vec() { let bytes = match message.to_vec() {
Ok(bytes) => bytes, Ok(bytes) => bytes,
Err(err) => { Err(err) => return HttpsClientResponse(Box::pin(future::err(err.into()))),
return HttpsClientResponse(Box::pin(future::err(err.into())))
}
}; };
let message = SerialMessage::new(bytes, self.name_server); let message = SerialMessage::new(bytes, self.name_server);
HttpsClientResponse(Box::pin(Self::inner_send(self.h2.clone(), message, Arc::clone(&self.name_server_name), self.name_server))) HttpsClientResponse(Box::pin(Self::inner_send(
self.h2.clone(),
message,
Arc::clone(&self.name_server_name),
self.name_server,
)))
// HttpsSerialResponse(HttpsSerialResponseInner::StartSend { // HttpsSerialResponse(HttpsSerialResponseInner::StartSend {
// h2: self.h2.clone(), // h2: self.h2.clone(),
@@ -272,7 +277,10 @@ impl Stream for HttpsClientStream {
match self.h2.poll_ready(cx) { match self.h2.poll_ready(cx) {
Poll::Ready(Ok(r)) => Poll::Ready(Some(Ok(r))), Poll::Ready(Ok(r)) => Poll::Ready(Some(Ok(r))),
Poll::Pending => Poll::Pending, Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(ProtoError::from(format!("h2 stream errored: {}", e))))), Poll::Ready(Err(e)) => Poll::Ready(Some(Err(ProtoError::from(format!(
"h2 stream errored: {}",
e
))))),
} }
} }
} }
@@ -369,7 +377,19 @@ enum HttpsClientConnectState {
name_server: SocketAddr, name_server: SocketAddr,
}, },
H2Handshake { H2Handshake {
handshake: Pin<Box<dyn Future<Output = Result<(SendRequest<Bytes>, Connection<TokioTlsClientStream<TokioTcpStream>, Bytes>), h2::Error>> + Send>>, handshake: Pin<
Box<
dyn Future<
Output = Result<
(
SendRequest<Bytes>,
Connection<TokioTlsClientStream<TokioTcpStream>, Bytes>,
),
h2::Error,
>,
> + Send,
>,
>,
name_server_name: Arc<String>, name_server_name: Arc<String>,
name_server: SocketAddr, name_server: SocketAddr,
}, },
@@ -383,7 +403,10 @@ impl Future for HttpsClientConnectState {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
loop { loop {
let next = match *self { let next = match *self {
HttpsClientConnectState::ConnectTcp { name_server, ref mut tls } => { HttpsClientConnectState::ConnectTcp {
name_server,
ref mut tls,
} => {
debug!("tcp connecting to: {}", name_server); debug!("tcp connecting to: {}", name_server);
let connect = Box::pin(TokioTcpStream::connect(name_server)); let connect = Box::pin(TokioTcpStream::connect(name_server));
HttpsClientConnectState::TcpConnecting { HttpsClientConnectState::TcpConnecting {
@@ -443,16 +466,16 @@ impl Future for HttpsClientConnectState {
name_server, name_server,
ref mut handshake, ref mut handshake,
} => { } => {
let (send_request, connection) = ready!( let (send_request, connection) = ready!(handshake
handshake .poll_unpin(cx)
.poll_unpin(cx) .map_err(|e| ProtoError::from(format!("h2 handshake error: {}", e))))?;
.map_err(|e| ProtoError::from(format!("h2 handshake error: {}", e)))
)?;
// TODO: hand this back for others to run rather than spawning here? // TODO: hand this back for others to run rather than spawning here?
debug!("h2 connection established to: {}", name_server); debug!("h2 connection established to: {}", name_server);
tokio_executor::spawn( tokio_executor::spawn(
connection.map_err(|e| warn!("h2 connection failed: {}", e)).map(|_: Result<(),()>| ()), connection
.map_err(|e| warn!("h2 connection failed: {}", e))
.map(|_: Result<(), ()>| ()),
); );
HttpsClientConnectState::Connected(Some(HttpsClientStream { HttpsClientConnectState::Connected(Some(HttpsClientStream {
@@ -476,27 +499,29 @@ impl Future for HttpsClientConnectState {
} }
/// A future that resolves to /// A future that resolves to
pub struct HttpsClientResponse(Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>); pub struct HttpsClientResponse(
Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>,
);
// impl HttpsClientResponse { // impl HttpsClientResponse {
// /// creates a new future for the request // /// creates a new future for the request
// /// // ///
// /// # Arguments // /// # Arguments
// /// // ///
// /// * `request` - Serialized message being sent // /// * `request` - Serialized message being sent
// /// * `message_id` - Id of the message that was encoded in the serial message // /// * `message_id` - Id of the message that was encoded in the serial message
// fn new(request: SerialMessage, message_id: u16, timeout: Duration) -> Self { // fn new(request: SerialMessage, message_id: u16, timeout: Duration) -> Self {
// UdpResponse(Box::pin(Timeout::new( // UdpResponse(Box::pin(Timeout::new(
// SingleUseUdpSocket::send_serial_message::<S>(request, message_id), // SingleUseUdpSocket::send_serial_message::<S>(request, message_id),
// timeout, // timeout,
// ))) // )))
// } // }
// /// ad already completed future // /// ad already completed future
// fn complete<F: Future<Output = Result<DnsResponse, ProtoError>> + Send + 'static>(f: F) -> Self { // fn complete<F: Future<Output = Result<DnsResponse, ProtoError>> + Send + 'static>(f: F) -> Self {
// // TODO: this constructure isn't really necessary // // TODO: this constructure isn't really necessary
// UdpResponse(Box::pin(Timeout::new(f, Duration::from_secs(5)))) // UdpResponse(Box::pin(Timeout::new(f, Duration::from_secs(5))))
// } // }
// } // }
impl Future for HttpsClientResponse { impl Future for HttpsClientResponse {

View File

@@ -9,8 +9,8 @@
use std::borrow::Borrow; use std::borrow::Borrow;
use std::fmt::Debug; use std::fmt::Debug;
use std::sync::Arc;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use std::task::Context; use std::task::Context;
use bytes::Bytes; use bytes::Bytes;
@@ -25,7 +25,10 @@ use crate::HttpsError;
/// ///
/// To allow downstream clients to do something interesting with the lifetime of the bytes, this doesn't /// To allow downstream clients to do something interesting with the lifetime of the bytes, this doesn't
/// perform a conversion to a Message, only collects all the bytes. /// perform a conversion to a Message, only collects all the bytes.
pub async fn message_from<R>(this_server_name: Arc<String>, request: Request<R>) -> Result<Bytes, HttpsError> pub async fn message_from<R>(
this_server_name: Arc<String>,
request: Request<R>,
) -> Result<Bytes, HttpsError>
where where
R: Stream<Item = Result<Bytes, h2::Error>> + 'static + Send + Debug + Unpin, R: Stream<Item = Result<Bytes, h2::Error>> + 'static + Send + Debug + Unpin,
{ {
@@ -57,11 +60,13 @@ where
} }
/// Deserialize the message from a POST message /// Deserialize the message from a POST message
pub(crate) async fn message_from_post<R>(mut request_stream: R, length: Option<usize>) -> Result<Bytes, HttpsError> pub(crate) async fn message_from_post<R>(
mut request_stream: R,
length: Option<usize>,
) -> Result<Bytes, HttpsError>
where where
R: Stream<Item = Result<Bytes, h2::Error>> + 'static + Send + Debug + Unpin, R: Stream<Item = Result<Bytes, h2::Error>> + 'static + Send + Debug + Unpin,
{ {
let mut bytes = Bytes::with_capacity(length.unwrap_or(0).min(512).max(4096)); let mut bytes = Bytes::with_capacity(length.unwrap_or(0).min(512).max(4096));
loop { loop {
@@ -166,7 +171,6 @@ mod tests {
// FIXME: generic stream impl is issue... // FIXME: generic stream impl is issue...
let from_post = message_from(Arc::new("ns.example.com".to_string()), request); let from_post = message_from(Arc::new("ns.example.com".to_string()), request);
let bytes = match block_on(from_post) { let bytes = match block_on(from_post) {
Ok(bytes) => bytes, Ok(bytes) => bytes,

View File

@@ -19,8 +19,8 @@ extern crate log;
extern crate failure; extern crate failure;
extern crate rustls; extern crate rustls;
extern crate tokio_executor; extern crate tokio_executor;
extern crate tokio_rustls;
extern crate tokio_net; extern crate tokio_net;
extern crate tokio_rustls;
extern crate trust_dns_proto; extern crate trust_dns_proto;
extern crate trust_dns_rustls; extern crate trust_dns_rustls;
extern crate typed_headers; extern crate typed_headers;
@@ -45,6 +45,6 @@ pub use self::error::{Error as HttpsError, Result as HttpsResult};
//pub use self::https_client_connection::{HttpsClientConnection, HttpsClientConnectionBuilder}; //pub use self::https_client_connection::{HttpsClientConnection, HttpsClientConnectionBuilder};
pub use self::https_client_stream::{ pub use self::https_client_stream::{
HttpsClientConnect, HttpsClientStream, HttpsClientStreamBuilder, HttpsClientResponse, HttpsClientConnect, HttpsClientResponse, HttpsClientStream, HttpsClientStreamBuilder,
}; };
//pub use self::https_stream::{HttpsStream, HttpsStreamBuilder}; //pub use self::https_stream::{HttpsStream, HttpsStreamBuilder};

View File

@@ -83,7 +83,12 @@ pub fn verify<T>(name_server: &str, request: &Request<T>) -> HttpsResult<()> {
// validate path // validate path
if uri.path() != crate::DNS_QUERY_PATH { if uri.path() != crate::DNS_QUERY_PATH {
return Err(format!("bad path: {}, expected: {}", uri.path(), crate::DNS_QUERY_PATH).into()); return Err(format!(
"bad path: {}, expected: {}",
uri.path(),
crate::DNS_QUERY_PATH
)
.into());
} }
// we only accept HTTPS // we only accept HTTPS
@@ -114,10 +119,9 @@ pub fn verify<T>(name_server: &str, request: &Request<T>) -> HttpsResult<()> {
let accept = accept.ok_or_else(|| "Accept is unspecified")?; let accept = accept.ok_or_else(|| "Accept is unspecified")?;
// TODO: switch to mime::APPLICATION_DNS when that stabilizes // TODO: switch to mime::APPLICATION_DNS when that stabilizes
if !accept if !accept.iter().any(|q| {
.iter() (q.item.type_() == crate::MIME_APPLICATION && q.item.subtype() == crate::MIME_DNS_BINARY)
.any(|q| (q.item.type_() == crate::MIME_APPLICATION && q.item.subtype() == crate::MIME_DNS_BINARY)) }) {
{
return Err("does not accept content type".into()); return Err("does not accept content type".into());
} }

View File

@@ -79,7 +79,8 @@ fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) {
} }
panic!("timeout"); panic!("timeout");
}).unwrap(); })
.unwrap();
let server_path = env::var("TDNS_SERVER_SRC_ROOT").unwrap_or_else(|_| "../server".to_owned()); let server_path = env::var("TDNS_SERVER_SRC_ROOT").unwrap_or_else(|_| "../server".to_owned());
println!("using server src path: {}", server_path); println!("using server src path: {}", server_path);
@@ -166,7 +167,8 @@ fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) {
// println!("wrote bytes iter: {}", i); // println!("wrote bytes iter: {}", i);
std::thread::yield_now(); std::thread::yield_now();
} }
}).unwrap(); })
.unwrap();
// let the server go first // let the server go first
std::thread::yield_now(); std::thread::yield_now();
@@ -199,11 +201,13 @@ fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) {
sender sender
.unbounded_send(SerialMessage::new(TEST_BYTES.to_vec(), server_addr)) .unbounded_send(SerialMessage::new(TEST_BYTES.to_vec(), server_addr))
.expect("send failed"); .expect("send failed");
let (buffer, stream_tmp) = io_loop let (buffer, stream_tmp) = io_loop.block_on(stream.into_future());
.block_on(stream.into_future());
stream = stream_tmp; stream = stream_tmp;
let message = buffer.expect("no buffer received"); let message = buffer.expect("no buffer received");
assert_eq!(message.expect("message destructure failed").bytes(), TEST_BYTES); assert_eq!(
message.expect("message destructure failed").bytes(),
TEST_BYTES
);
} }
succeeded.store(true, std::sync::atomic::Ordering::Relaxed); succeeded.store(true, std::sync::atomic::Ordering::Relaxed);

View File

@@ -139,12 +139,13 @@ impl TlsStreamBuilder {
let ca_chain = self.ca_chain.clone(); let ca_chain = self.ca_chain.clone();
let identity = self.identity; let identity = self.identity;
let tcp_stream: Result<TokioTcpStream, _> = TokioTcpStream::connect(&name_server)/*.map_err(|e| { let tcp_stream: Result<TokioTcpStream, _> = TokioTcpStream::connect(&name_server) /*.map_err(|e| {
io::Error::new( io::Error::new(
io::ErrorKind::ConnectionRefused, io::ErrorKind::ConnectionRefused,
format!("tls error: {}", e), format!("tls error: {}", e),
) )
})*/.await; })*/
.await;
// TODO: for some reason the above wouldn't accept a ? // TODO: for some reason the above wouldn't accept a ?
let tcp_stream = match tcp_stream { let tcp_stream = match tcp_stream {
@@ -154,7 +155,8 @@ impl TlsStreamBuilder {
// This set of futures collapses the next tcp socket into a stream which can be used for // This set of futures collapses the next tcp socket into a stream which can be used for
// sending and receiving tcp packets. // sending and receiving tcp packets.
let tls_connector = tls_stream::tls_new(ca_chain, identity).map(TokioTlsConnector::from) let tls_connector = tls_stream::tls_new(ca_chain, identity)
.map(TokioTlsConnector::from)
.map_err(|e| { .map_err(|e| {
io::Error::new( io::Error::new(
io::ErrorKind::ConnectionRefused, io::ErrorKind::ConnectionRefused,
@@ -162,13 +164,20 @@ impl TlsStreamBuilder {
) )
})?; })?;
let tls_connected = tls_connector.connect(&dns_name, tcp_stream).map_err(|e| { let tls_connected = tls_connector
io::Error::new( .connect(&dns_name, tcp_stream)
io::ErrorKind::ConnectionRefused, .map_err(|e| {
format!("tls error: {}", e), io::Error::new(
) io::ErrorKind::ConnectionRefused,
}).await?; format!("tls error: {}", e),
)
})
.await?;
Ok(TcpStream::from_stream_with_receiver(tls_connected, name_server, outbound_messages)) Ok(TcpStream::from_stream_with_receiver(
tls_connected,
name_server,
outbound_messages,
))
} }
} }

View File

@@ -18,8 +18,8 @@
extern crate futures; extern crate futures;
extern crate openssl; extern crate openssl;
extern crate tokio_openssl;
extern crate tokio_net; extern crate tokio_net;
extern crate tokio_openssl;
extern crate trust_dns_proto; extern crate trust_dns_proto;
mod tls_client_stream; mod tls_client_stream;

View File

@@ -14,8 +14,8 @@ use futures::{Future, TryFutureExt};
#[cfg(feature = "mtls")] #[cfg(feature = "mtls")]
use openssl::pkcs12::Pkcs12; use openssl::pkcs12::Pkcs12;
use openssl::x509::X509; use openssl::x509::X509;
use tokio_openssl::SslStream as TokioTlsStream;
use tokio_net::tcp::TcpStream as TokioTcpStream; use tokio_net::tcp::TcpStream as TokioTcpStream;
use tokio_openssl::SslStream as TokioTlsStream;
use trust_dns_proto::error::ProtoError; use trust_dns_proto::error::ProtoError;
use trust_dns_proto::tcp::TcpClientStream; use trust_dns_proto::tcp::TcpClientStream;

View File

@@ -246,9 +246,11 @@ impl TlsStreamBuilder {
// This set of futures collapses the next tcp socket into a stream which can be used for // This set of futures collapses the next tcp socket into a stream which can be used for
// sending and receiving tcp packets. // sending and receiving tcp packets.
let stream = Box::pin(connect_tls(tls_config, dns_name, name_server).map_ok(move |s| { let stream = Box::pin(
TcpStream::from_stream_with_receiver(s, name_server, outbound_messages) connect_tls(tls_config, dns_name, name_server).map_ok(move |s| {
})); TcpStream::from_stream_with_receiver(s, name_server, outbound_messages)
}),
);
(stream, message_sender) (stream, message_sender)
} }

View File

@@ -82,7 +82,8 @@ fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) {
} }
panic!("timeout"); panic!("timeout");
}).unwrap(); })
.unwrap();
let (root_pkey, root_name, root_cert) = root_ca(); let (root_pkey, root_name, root_cert) = root_ca();
let root_cert_der = root_cert.to_der().unwrap(); let root_cert_der = root_cert.to_der().unwrap();
@@ -187,7 +188,8 @@ fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) {
// println!("wrote bytes iter: {}", i); // println!("wrote bytes iter: {}", i);
std::thread::yield_now(); std::thread::yield_now();
} }
}).unwrap(); })
.unwrap();
// let the server go first // let the server go first
std::thread::yield_now(); std::thread::yield_now();
@@ -219,10 +221,11 @@ fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) {
sender sender
.unbounded_send(SerialMessage::new(TEST_BYTES.to_vec(), server_addr)) .unbounded_send(SerialMessage::new(TEST_BYTES.to_vec(), server_addr))
.expect("send failed"); .expect("send failed");
let (buffer, stream_tmp) = io_loop let (buffer, stream_tmp) = io_loop.block_on(stream.into_future());
.block_on(stream.into_future());
stream = stream_tmp; stream = stream_tmp;
let message = buffer.expect("no buffer received").expect("error receiving bytes"); let message = buffer
.expect("no buffer received")
.expect("error receiving bytes");
assert_eq!(message.bytes(), TEST_BYTES); assert_eq!(message.bytes(), TEST_BYTES);
} }

View File

@@ -1,15 +1,14 @@
#![feature(test)] #![feature(test)]
extern crate trust_dns_proto;
extern crate test; extern crate test;
extern crate trust_dns_proto;
use trust_dns_proto::op::{MessageType, Header, ResponseCode, OpCode, Message}; use trust_dns_proto::op::{Header, Message, MessageType, OpCode, ResponseCode};
use trust_dns_proto::rr::Record; use trust_dns_proto::rr::Record;
use trust_dns_proto::serialize::binary::{BinDecoder, BinEncodable, BinEncoder, BinDecodable}; use trust_dns_proto::serialize::binary::{BinDecodable, BinDecoder, BinEncodable, BinEncoder};
use test::Bencher; use test::Bencher;
#[bench] #[bench]
fn bench_emit_header(b: &mut Bencher) { fn bench_emit_header(b: &mut Bencher) {
let header = Header::new(); let header = Header::new();
@@ -36,13 +35,10 @@ fn bench_parse_header_no_reservation(b: &mut Bencher) {
}) })
} }
#[bench] #[bench]
fn bench_parse_header(b: &mut Bencher) { fn bench_parse_header(b: &mut Bencher) {
let byte_vec = vec![ let byte_vec = vec![
0x01, 0x10, 0xAA, 0x83, 0x01, 0x10, 0xAA, 0x83, 0x88, 0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11,
0x88, 0x77, 0x66, 0x55,
0x44, 0x33, 0x22, 0x11
]; ];
b.iter(|| { b.iter(|| {
let mut decoder = BinDecoder::new(&byte_vec); let mut decoder = BinDecoder::new(&byte_vec);

View File

@@ -24,8 +24,8 @@ use ring::error::Unspecified;
use failure::{Backtrace, Context, Fail}; use failure::{Backtrace, Context, Fail};
use tokio_executor::SpawnError; use tokio_executor::SpawnError;
use tokio_timer::Error as TimerError;
use tokio_timer::timeout::Elapsed; use tokio_timer::timeout::Elapsed;
use tokio_timer::Error as TimerError;
/// An alias for results returned by functions of this crate /// An alias for results returned by functions of this crate
pub type ProtoResult<T> = ::std::result::Result<T, ProtoError>; pub type ProtoResult<T> = ::std::result::Result<T, ProtoError>;

View File

@@ -38,21 +38,21 @@ extern crate socket2;
extern crate tokio; extern crate tokio;
extern crate tokio_executor; extern crate tokio_executor;
extern crate tokio_io; extern crate tokio_io;
extern crate tokio_sync;
#[cfg(feature = "tokio-compat")] #[cfg(feature = "tokio-compat")]
extern crate tokio_net; extern crate tokio_net;
extern crate tokio_sync;
extern crate tokio_timer; extern crate tokio_timer;
extern crate url; extern crate url;
macro_rules! try_ready_stream { macro_rules! try_ready_stream {
($e:expr) => ({ ($e:expr) => {{
match $e { match $e {
Poll::Ready(Some(Ok(t))) => t, Poll::Ready(Some(Ok(t))) => t,
Poll::Ready(None) => return Poll::Ready(None), Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => return Poll::Pending, Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(From::from(e)))), Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(From::from(e)))),
} }
}) }};
} }
pub mod error; pub mod error;

View File

@@ -7,16 +7,16 @@
use std::fmt::{self, Display}; use std::fmt::{self, Display};
use std::net::{Ipv4Addr, SocketAddr}; use std::net::{Ipv4Addr, SocketAddr};
use std::task::Context;
use std::pin::Pin; use std::pin::Pin;
use std::task::Context;
use futures::{Future, FutureExt, Poll, Stream, TryFutureExt};
use futures::stream::{StreamExt, TryStreamExt}; use futures::stream::{StreamExt, TryStreamExt};
use futures::{Future, FutureExt, Poll, Stream, TryFutureExt};
use crate::error::ProtoError; use crate::error::ProtoError;
use crate::xfer::{DnsClientStream, SerialMessage};
use crate::multicast::mdns_stream::{MDNS_IPV4, MDNS_IPV6}; use crate::multicast::mdns_stream::{MDNS_IPV4, MDNS_IPV6};
use crate::multicast::{MdnsQueryType, MdnsStream}; use crate::multicast::{MdnsQueryType, MdnsStream};
use crate::xfer::{DnsClientStream, SerialMessage};
use crate::{BufDnsStreamHandle, DnsStreamHandle}; use crate::{BufDnsStreamHandle, DnsStreamHandle};
/// A UDP client stream of DNS binary packets /// A UDP client stream of DNS binary packets
@@ -64,8 +64,8 @@ impl MdnsClientStream {
MdnsStream::new(mdns_addr, mdns_query_type, packet_ttl, ipv4_if, ipv6_if); MdnsStream::new(mdns_addr, mdns_query_type, packet_ttl, ipv4_if, ipv6_if);
let stream_future = stream_future let stream_future = stream_future
.map_ok(move |mdns_stream| MdnsClientStream { mdns_stream }) .map_ok(move |mdns_stream| MdnsClientStream { mdns_stream })
.map_err(ProtoError::from); .map_err(ProtoError::from);
let new_future = Box::new(stream_future); let new_future = Box::new(stream_future);
let new_future = MdnsClientConnect(new_future); let new_future = MdnsClientConnect(new_future);
@@ -106,7 +106,9 @@ impl Stream for MdnsClientStream {
} }
/// A future that resolves to an MdnsClientStream /// A future that resolves to an MdnsClientStream
pub struct MdnsClientConnect(Box<dyn Future<Output = Result<MdnsClientStream, ProtoError>> + Send + Unpin>); pub struct MdnsClientConnect(
Box<dyn Future<Output = Result<MdnsClientStream, ProtoError>> + Send + Unpin>,
);
impl Future for MdnsClientConnect { impl Future for MdnsClientConnect {
type Output = Result<MdnsClientStream, ProtoError>; type Output = Result<MdnsClientStream, ProtoError>;

View File

@@ -12,12 +12,12 @@ use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::task::Context; use std::task::Context;
use futures::ready;
use futures::future;
use futures::stream::{Stream, StreamExt};
use futures::channel::mpsc::unbounded; use futures::channel::mpsc::unbounded;
use futures::{Future, FutureExt, Poll, TryFutureExt}; use futures::future;
use futures::lock::Mutex; use futures::lock::Mutex;
use futures::ready;
use futures::stream::{Stream, StreamExt};
use futures::{Future, FutureExt, Poll, TryFutureExt};
use rand; use rand;
use rand::distributions::{uniform::Uniform, Distribution}; use rand::distributions::{uniform::Uniform, Distribution};
use socket2::{self, Socket}; use socket2::{self, Socket};
@@ -144,21 +144,20 @@ impl MdnsStream {
Box::new( Box::new(
next_socket next_socket
.map(move |socket| { .map(move |socket| match socket {
match socket { Ok(Some(socket)) => Ok(Some(UdpSocket::from_std(socket, &handle)?)),
Ok(Some(socket)) => Ok(Some(UdpSocket::from_std(socket, &handle)?)), Ok(None) => Ok(None),
Ok(None) => Ok(None), Err(err) => Err(err),
Err(err) => Err(err),
}
}) })
.map_ok(move |socket: Option<_>| { .map_ok(move |socket: Option<_>| {
let datagram: Option<_> = let datagram: Option<_> =
socket.map(|socket| UdpStream::from_parts(socket, outbound_messages)); socket.map(|socket| UdpStream::from_parts(socket, outbound_messages));
let multicast: Option<_> = let multicast: Option<_> = multicast_socket.map(|multicast_socket| {
multicast_socket.map(|multicast_socket| { Arc::new(Mutex::new(
Arc::new(Mutex::new(UdpSocket::from_std(multicast_socket, &handle_clone) UdpSocket::from_std(multicast_socket, &handle_clone)
.expect("bad handle?"))) .expect("bad handle?"),
}); ))
});
MdnsStream { MdnsStream {
multicast_addr, multicast_addr,
@@ -166,7 +165,7 @@ impl MdnsStream {
multicast, multicast,
rcving_mcast: None, rcving_mcast: None,
} }
}) }),
) )
}; };
@@ -284,7 +283,6 @@ impl Stream for MdnsStream {
} }
} }
loop { loop {
let msg = if let Some(ref mut receiving) = self.rcving_mcast { let msg = if let Some(ref mut receiving) = self.rcving_mcast {
// TODO: should we drop this packet if it's not from the same src as dest? // TODO: should we drop this packet if it's not from the same src as dest?
@@ -294,7 +292,7 @@ impl Stream for MdnsStream {
} else { } else {
None None
}; };
self.rcving_mcast = None; self.rcving_mcast = None;
if let Some(msg) = msg { if let Some(msg) = msg {
@@ -309,7 +307,7 @@ impl Stream for MdnsStream {
let mut buf = [0u8; 2048]; let mut buf = [0u8; 2048];
let mut socket = socket.lock().await; let mut socket = socket.lock().await;
let (len, src) = socket.recv_from(&mut buf).await?; let (len, src) = socket.recv_from(&mut buf).await?;
Ok(SerialMessage::new( Ok(SerialMessage::new(
buf.iter().take(len).cloned().collect(), buf.iter().take(len).cloned().collect(),
src, src,
@@ -492,10 +490,11 @@ pub mod tests {
.name("test_one_shot_mdns:server".to_string()) .name("test_one_shot_mdns:server".to_string())
.spawn(move || { .spawn(move || {
let mut server_loop = Runtime::new().unwrap(); let mut server_loop = Runtime::new().unwrap();
let mut timeout = let mut timeout = future::lazy(|_| {
future::lazy(|_| { tokio_timer::delay(Instant::now() + Duration::from_millis(100))
tokio_timer::delay(Instant::now() + Duration::from_millis(100)) })
}).flatten().boxed(); .flatten()
.boxed();
// TTLs are 0 so that multicast test packets never leave the test host... // TTLs are 0 so that multicast test packets never leave the test host...
// FIXME: this is hardcoded to index 5 for ipv6, which isn't going to be correct in most cases... // FIXME: this is hardcoded to index 5 for ipv6, which isn't going to be correct in most cases...
@@ -518,16 +517,21 @@ pub mod tests {
return; return;
} }
// wait for some bytes... // wait for some bytes...
match server_loop match server_loop.block_on(
.block_on(future::lazy(|_| future::select(server_stream, timeout)).flatten()) future::lazy(|_| future::select(server_stream, timeout)).flatten(),
{ ) {
Either::Left((buffer_and_addr_stream_tmp, timeout_tmp)) => { Either::Left((buffer_and_addr_stream_tmp, timeout_tmp)) => {
let (buffer_and_addr, stream_tmp): (Option<Result<SerialMessage, io::Error>>, MdnsStream) = buffer_and_addr_stream_tmp; let (buffer_and_addr, stream_tmp): (
Option<Result<SerialMessage, io::Error>>,
MdnsStream,
) = buffer_and_addr_stream_tmp;
server_stream = stream_tmp.into_future(); server_stream = stream_tmp.into_future();
timeout = timeout_tmp; timeout = timeout_tmp;
let (buffer, addr) = let (buffer, addr) = buffer_and_addr
buffer_and_addr.expect("no msg received").expect("error receiving msg").unwrap(); .expect("no msg received")
.expect("error receiving msg")
.unwrap();
assert_eq!(&buffer, test_bytes); assert_eq!(&buffer, test_bytes);
//println!("server got data! {}", addr); //println!("server got data! {}", addr);
@@ -541,13 +545,16 @@ pub mod tests {
server_stream = buffer_and_addr_stream_tmp; server_stream = buffer_and_addr_stream_tmp;
timeout = future::lazy(|_| { timeout = future::lazy(|_| {
tokio_timer::delay(Instant::now() + Duration::from_millis(100)) tokio_timer::delay(Instant::now() + Duration::from_millis(100))
}).flatten().boxed(); })
.flatten()
.boxed();
} }
} }
// let the server turn for a bit... send the message // let the server turn for a bit... send the message
server_loop server_loop.block_on(tokio_timer::delay(
.block_on(tokio_timer::delay(Instant::now() + Duration::from_millis(100))); Instant::now() + Duration::from_millis(100),
));
} }
}) })
.unwrap(); .unwrap();
@@ -560,9 +567,9 @@ pub mod tests {
MdnsStream::new(mdns_addr, MdnsQueryType::OneShot, Some(1), None, Some(5)); MdnsStream::new(mdns_addr, MdnsQueryType::OneShot, Some(1), None, Some(5));
let mut stream = io_loop.block_on(stream).ok().unwrap().into_future(); let mut stream = io_loop.block_on(stream).ok().unwrap().into_future();
let mut timeout = let mut timeout =
future::lazy(|_| { future::lazy(|_| tokio_timer::delay(Instant::now() + Duration::from_millis(100)))
tokio_timer::delay(Instant::now() + Duration::from_millis(100)) .flatten()
}).flatten().boxed(); .boxed();
let mut successes = 0; let mut successes = 0;
for _ in 0..send_recv_times { for _ in 0..send_recv_times {
@@ -580,7 +587,10 @@ pub mod tests {
stream = stream_tmp.into_future(); stream = stream_tmp.into_future();
timeout = timeout_tmp; timeout = timeout_tmp;
let (buffer, _addr) = buffer_and_addr.expect("no msg received").expect("error receiving msg").unwrap(); let (buffer, _addr) = buffer_and_addr
.expect("no msg received")
.expect("error receiving msg")
.unwrap();
println!("client got data!"); println!("client got data!");
assert_eq!(&buffer, test_bytes); assert_eq!(&buffer, test_bytes);
@@ -590,7 +600,9 @@ pub mod tests {
stream = buffer_and_addr_stream_tmp; stream = buffer_and_addr_stream_tmp;
timeout = future::lazy(|_| { timeout = future::lazy(|_| {
tokio_timer::delay(Instant::now() + Duration::from_millis(100)) tokio_timer::delay(Instant::now() + Duration::from_millis(100))
}).flatten().boxed(); })
.flatten()
.boxed();
} }
} }
} }
@@ -636,10 +648,11 @@ pub mod tests {
.name("test_one_shot_mdns:server".to_string()) .name("test_one_shot_mdns:server".to_string())
.spawn(move || { .spawn(move || {
let mut io_loop = Runtime::new().unwrap(); let mut io_loop = Runtime::new().unwrap();
let mut timeout = let mut timeout = future::lazy(|_| {
future::lazy(|_| { tokio_timer::delay(Instant::now() + Duration::from_millis(100))
tokio_timer::delay(Instant::now() + Duration::from_millis(100)) })
}).flatten().boxed(); .flatten()
.boxed();
// TTLs are 0 so that multicast test packets never leave the test host... // TTLs are 0 so that multicast test packets never leave the test host...
// FIXME: this is hardcoded to index 5 for ipv6, which isn't going to be correct in most cases... // FIXME: this is hardcoded to index 5 for ipv6, which isn't going to be correct in most cases...
@@ -654,9 +667,9 @@ pub mod tests {
for _ in 0..=send_recv_times { for _ in 0..=send_recv_times {
// wait for some bytes... // wait for some bytes...
match io_loop match io_loop.block_on(
.block_on(future::lazy(|_| future::select(server_stream, timeout)).flatten()) future::lazy(|_| future::select(server_stream, timeout)).flatten(),
{ ) {
Either::Left((_buffer_and_addr_stream_tmp, _timeout_tmp)) => { Either::Left((_buffer_and_addr_stream_tmp, _timeout_tmp)) => {
// let (buffer_and_addr, stream_tmp) = buffer_and_addr_stream_tmp; // let (buffer_and_addr, stream_tmp) = buffer_and_addr_stream_tmp;
@@ -675,13 +688,16 @@ pub mod tests {
server_stream = buffer_and_addr_stream_tmp; server_stream = buffer_and_addr_stream_tmp;
timeout = future::lazy(|_| { timeout = future::lazy(|_| {
tokio_timer::delay(Instant::now() + Duration::from_millis(100)) tokio_timer::delay(Instant::now() + Duration::from_millis(100))
}).flatten().boxed(); })
.flatten()
.boxed();
} }
} }
// let the server turn for a bit... send the message // let the server turn for a bit... send the message
io_loop io_loop.block_on(tokio_timer::delay(
.block_on(tokio_timer::delay(Instant::now() + Duration::from_millis(100))); Instant::now() + Duration::from_millis(100),
));
} }
}) })
.unwrap(); .unwrap();
@@ -693,9 +709,9 @@ pub mod tests {
MdnsStream::new(mdns_addr, MdnsQueryType::OneShot, Some(1), None, Some(5)); MdnsStream::new(mdns_addr, MdnsQueryType::OneShot, Some(1), None, Some(5));
let mut stream = io_loop.block_on(stream).ok().unwrap().into_future(); let mut stream = io_loop.block_on(stream).ok().unwrap().into_future();
let mut timeout = let mut timeout =
future::lazy(|_| { future::lazy(|_| tokio_timer::delay(Instant::now() + Duration::from_millis(100)))
tokio_timer::delay(Instant::now() + Duration::from_millis(100)) .flatten()
}).flatten().boxed(); .boxed();
for _ in 0..send_recv_times { for _ in 0..send_recv_times {
// test once // test once
@@ -706,7 +722,8 @@ pub mod tests {
println!("client sending data!"); println!("client sending data!");
// TODO: this lazy is probably unnecessary? // TODO: this lazy is probably unnecessary?
let run_result = io_loop.block_on(future::lazy(|_| future::select(stream, timeout)).flatten()); let run_result =
io_loop.block_on(future::lazy(|_| future::select(stream, timeout)).flatten());
if server_got_packet.load(std::sync::atomic::Ordering::Relaxed) { if server_got_packet.load(std::sync::atomic::Ordering::Relaxed) {
return; return;
@@ -722,7 +739,9 @@ pub mod tests {
stream = buffer_and_addr_stream_tmp; stream = buffer_and_addr_stream_tmp;
timeout = future::lazy(|_| { timeout = future::lazy(|_| {
tokio_timer::delay(Instant::now() + Duration::from_millis(100)) tokio_timer::delay(Instant::now() + Duration::from_millis(100))
}).flatten().boxed(); })
.flatten()
.boxed();
} }
} }
} }

View File

@@ -447,10 +447,14 @@ impl<'r> BinDecodable<'r> for Header {
// TODO: We should pass these restrictions on, they can't be trusted, but that would seriously complicate the Header type.. // TODO: We should pass these restrictions on, they can't be trusted, but that would seriously complicate the Header type..
// TODO: perhaps the read methods for BinDecodable should return Restrict? // TODO: perhaps the read methods for BinDecodable should return Restrict?
let query_count = decoder.read_u16()?.unverified(/*this must be verified when reading queries*/); let query_count =
let answer_count = decoder.read_u16()?.unverified(/*this must be evaluated when reading records*/); decoder.read_u16()?.unverified(/*this must be verified when reading queries*/);
let name_server_count = decoder.read_u16()?.unverified(/*this must be evaluated when reading records*/); let answer_count =
let additional_count = decoder.read_u16()?.unverified(/*this must be evaluated when reading records*/); decoder.read_u16()?.unverified(/*this must be evaluated when reading records*/);
let name_server_count =
decoder.read_u16()?.unverified(/*this must be evaluated when reading records*/);
let additional_count =
decoder.read_u16()?.unverified(/*this must be evaluated when reading records*/);
// TODO: question, should this use the builder pattern instead? might be cleaner code, but // TODO: question, should this use the builder pattern instead? might be cleaner code, but
// this guarantees that the Header is fully instantiated with all values... // this guarantees that the Header is fully instantiated with all values...

View File

@@ -95,7 +95,9 @@ impl BinEncodable for DNSClass {
impl<'r> BinDecodable<'r> for DNSClass { impl<'r> BinDecodable<'r> for DNSClass {
fn read(decoder: &mut BinDecoder) -> ProtoResult<Self> { fn read(decoder: &mut BinDecoder) -> ProtoResult<Self> {
Self::from_u16(decoder.read_u16()?.unverified(/*DNSClass is verified as safe in processing this*/)) Self::from_u16(
decoder.read_u16()?.unverified(/*DNSClass is verified as safe in processing this*/),
)
} }
} }

View File

@@ -190,7 +190,7 @@ impl From<Algorithm> for u8 {
Algorithm::ECDSAP256SHA256 => 13, Algorithm::ECDSAP256SHA256 => 13,
Algorithm::ECDSAP384SHA384 => 14, Algorithm::ECDSAP384SHA384 => 14,
Algorithm::ED25519 => 15, Algorithm::ED25519 => 15,
Algorithm::Unknown(v) => v Algorithm::Unknown(v) => v,
} }
} }
} }
@@ -212,10 +212,7 @@ fn test_into() {
Algorithm::ECDSAP384SHA384, Algorithm::ECDSAP384SHA384,
Algorithm::ED25519, Algorithm::ED25519,
] { ] {
assert_eq!( assert_eq!(*algorithm, Algorithm::from_u8(Into::<u8>::into(*algorithm)))
*algorithm,
Algorithm::from_u8(Into::<u8>::into(*algorithm))
)
} }
} }

View File

@@ -5,8 +5,8 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be // http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms. // copied, modified, or distributed except according to those terms.
use crate::error::*;
use super::Algorithm; use super::Algorithm;
use crate::error::*;
pub struct ECPublicKey { pub struct ECPublicKey {
buf: [u8; MAX_LEN], buf: [u8; MAX_LEN],

View File

@@ -26,8 +26,8 @@ pub mod rdata;
#[cfg(any(feature = "openssl", feature = "ring"))] #[cfg(any(feature = "openssl", feature = "ring"))]
mod rsa_public_key; mod rsa_public_key;
mod supported_algorithm; mod supported_algorithm;
mod trust_anchor;
pub mod tbs; pub mod tbs;
mod trust_anchor;
mod verifier; mod verifier;
pub use self::algorithm::Algorithm; pub use self::algorithm::Algorithm;

View File

@@ -403,7 +403,8 @@ impl<'k> PublicKey for Rsa<'k> {
n: self.pkey.n(), n: self.pkey.n(),
e: self.pkey.e(), e: self.pkey.e(),
}; };
public_key.verify(alg, message, signature) public_key
.verify(alg, message, signature)
.map_err(Into::into) .map_err(Into::into)
} }
} }
@@ -529,5 +530,4 @@ mod tests {
test_case(&[0x80, 0x00, 0x80], &[0x00, 0x80, 0x00, 0x80]); test_case(&[0x80, 0x00, 0x80], &[0x00, 0x80, 0x00, 0x80]);
test_case(&[0xff, 0x00, 0x80], &[0x00, 0xff, 0x00, 0x80]); test_case(&[0xff, 0x00, 0x80], &[0x00, 0xff, 0x00, 0x80]);
} }
} }

View File

@@ -345,7 +345,8 @@ pub fn read(decoder: &mut BinDecoder, rdata_length: Restrict<u16>) -> ProtoResul
// protocol is defined to only be '3' right now // protocol is defined to only be '3' right now
*protocol == 3 *protocol == 3
}).map_err(|protocol| ProtoError::from(ProtoErrorKind::DnsKeyProtocolNot3(protocol)))?; })
.map_err(|protocol| ProtoError::from(ProtoErrorKind::DnsKeyProtocolNot3(protocol)))?;
let algorithm: Algorithm = Algorithm::read(decoder)?; let algorithm: Algorithm = Algorithm::read(decoder)?;
@@ -412,16 +413,18 @@ mod tests {
println!("bytes: {:?}", bytes); println!("bytes: {:?}", bytes);
let mut decoder: BinDecoder = BinDecoder::new(bytes); let mut decoder: BinDecoder = BinDecoder::new(bytes);
let restrict = Restrict::new(bytes.len() as u16); let read_rdata = read(&mut decoder, Restrict::new(bytes.len() as u16));
let read_rdata = read(&mut decoder, restrict).expect("Decoding error");
assert_eq!(rdata, read_rdata);
assert!( assert!(
rdata read_rdata.is_ok(),
.to_digest( format!("error decoding: {:?}", read_rdata.unwrap_err())
&Name::parse("www.example.com.", None).unwrap(),
DigestType::SHA256
).is_ok()
); );
assert_eq!(rdata, read_rdata.unwrap());
assert!(rdata
.to_digest(
&Name::parse("www.example.com.", None).unwrap(),
DigestType::SHA256
)
.is_ok());
} }
#[test] #[test]

View File

@@ -785,7 +785,8 @@ pub fn read(decoder: &mut BinDecoder, rdata_length: Restrict<u16>) -> ProtoResul
// Bits 4-5 are reserved and must be zero. // Bits 4-5 are reserved and must be zero.
// Bits 8-11 are reserved and must be zero. // Bits 8-11 are reserved and must be zero.
flags & 0b0010_1100_1111_0000 == 0 flags & 0b0010_1100_1111_0000 == 0
}).map_err(|_| ProtoError::from("flag 2, 4-5, and 8-11 are reserved, must be zero"))?; })
.map_err(|_| ProtoError::from("flag 2, 4-5, and 8-11 are reserved, must be zero"))?;
let key_trust = KeyTrust::from(flags); let key_trust = KeyTrust::from(flags);
let extended_flags: bool = flags & 0b0001_0000_0000_0000 != 0; let extended_flags: bool = flags & 0b0001_0000_0000_0000 != 0;

View File

@@ -262,7 +262,8 @@ pub fn read(decoder: &mut BinDecoder, rdata_length: Restrict<u16>) -> ProtoResul
let salt_len = salt_len let salt_len = salt_len
.verify_unwrap(|salt_len| { .verify_unwrap(|salt_len| {
*salt_len <= salt_len_max.unverified(/*safe in comparison usage*/) *salt_len <= salt_len_max.unverified(/*safe in comparison usage*/)
}).map_err(|_| ProtoError::from("salt_len exceeds buffer length"))?; })
.map_err(|_| ProtoError::from("salt_len exceeds buffer length"))?;
let salt: Vec<u8> = let salt: Vec<u8> =
decoder.read_vec(salt_len)?.unverified(/*salt is any valid array of bytes*/); decoder.read_vec(salt_len)?.unverified(/*salt is any valid array of bytes*/);
@@ -275,7 +276,8 @@ pub fn read(decoder: &mut BinDecoder, rdata_length: Restrict<u16>) -> ProtoResul
let hash_len = hash_len let hash_len = hash_len
.verify_unwrap(|hash_len| { .verify_unwrap(|hash_len| {
*hash_len <= hash_len_max.unverified(/*safe in comparison usage*/) *hash_len <= hash_len_max.unverified(/*safe in comparison usage*/)
}).map_err(|_| ProtoError::from("hash_len exceeds buffer length"))?; })
.map_err(|_| ProtoError::from("hash_len exceeds buffer length"))?;
let next_hashed_owner_name: Vec<u8> = let next_hashed_owner_name: Vec<u8> =
decoder.read_vec(hash_len)?.unverified(/*will fail in usage if invalid*/); decoder.read_vec(hash_len)?.unverified(/*will fail in usage if invalid*/);

View File

@@ -87,7 +87,7 @@ impl SupportedAlgorithms {
/// Set the specified algorithm as supported /// Set the specified algorithm as supported
pub fn set(&mut self, algorithm: Algorithm) { pub fn set(&mut self, algorithm: Algorithm) {
if let Some(bit_pos) = Self::pos(algorithm) { if let Some(bit_pos) = Self::pos(algorithm) {
self.bit_map |= bit_pos; self.bit_map |= bit_pos;
} }
} }
@@ -96,7 +96,7 @@ impl SupportedAlgorithms {
pub fn has(self, algorithm: Algorithm) -> bool { pub fn has(self, algorithm: Algorithm) -> bool {
if let Some(bit_pos) = Self::pos(algorithm) { if let Some(bit_pos) = Self::pos(algorithm) {
(bit_pos & self.bit_map) == bit_pos (bit_pos & self.bit_map) == bit_pos
} else { } else {
false false
} }
} }
@@ -140,7 +140,7 @@ impl<'a> From<&'a [u8]> for SupportedAlgorithms {
let mut supported = SupportedAlgorithms::new(); let mut supported = SupportedAlgorithms::new();
for a in values.iter().map(|i| Algorithm::from_u8(*i)) { for a in values.iter().map(|i| Algorithm::from_u8(*i)) {
match a { match a {
Algorithm::Unknown(v) => warn!("unrecognized algorithm: {}", v), Algorithm::Unknown(v) => warn!("unrecognized algorithm: {}", v),
a => supported.set(a), a => supported.set(a),
} }

View File

@@ -31,19 +31,18 @@ pub fn message_tbs<M: BinEncodable>(message: &M, pre_sig0: &SIG) -> ProtoResult<
{ {
let mut encoder: BinEncoder = BinEncoder::with_mode(&mut buf, EncodeMode::Normal); let mut encoder: BinEncoder = BinEncoder::with_mode(&mut buf, EncodeMode::Normal);
assert!( assert!(sig::emit_pre_sig(
sig::emit_pre_sig( &mut encoder,
&mut encoder, pre_sig0.type_covered(),
pre_sig0.type_covered(), pre_sig0.algorithm(),
pre_sig0.algorithm(), pre_sig0.num_labels(),
pre_sig0.num_labels(), pre_sig0.original_ttl(),
pre_sig0.original_ttl(), pre_sig0.sig_expiration(),
pre_sig0.sig_expiration(), pre_sig0.sig_inception(),
pre_sig0.sig_inception(), pre_sig0.key_tag(),
pre_sig0.key_tag(), pre_sig0.signer_name(),
pre_sig0.signer_name(), )
).is_ok() .is_ok());
);
// need a separate encoder here, as the encoding references absolute positions // need a separate encoder here, as the encoding references absolute positions
// inside the buffer. If the buffer already contains the sig0 RDATA, offsets // inside the buffer. If the buffer already contains the sig0 RDATA, offsets
// are wrong and the signature won't match. // are wrong and the signature won't match.
@@ -124,30 +123,28 @@ pub fn rrset_tbs(
// RRSIG_RDATA is the wire format of the RRSIG RDATA fields // RRSIG_RDATA is the wire format of the RRSIG RDATA fields
// with the Signature field excluded and the Signer's Name // with the Signature field excluded and the Signer's Name
// in canonical form. // in canonical form.
assert!( assert!(sig::emit_pre_sig(
sig::emit_pre_sig( &mut encoder,
&mut encoder, type_covered,
type_covered, algorithm,
algorithm, name.num_labels(),
name.num_labels(), original_ttl,
original_ttl, sig_expiration,
sig_expiration, sig_inception,
sig_inception, key_tag,
key_tag, signer_name,
signer_name, )
).is_ok() .is_ok());
);
// construct the rrset signing data // construct the rrset signing data
for record in rrset { for record in rrset {
// RR(i) = name | type | class | OrigTTL | RDATA length | RDATA // RR(i) = name | type | class | OrigTTL | RDATA length | RDATA
// //
// name is calculated according to the function in the RFC 4035 // name is calculated according to the function in the RFC 4035
assert!( assert!(name
name.to_lowercase() .to_lowercase()
.emit_as_canonical(&mut encoder, true) .emit_as_canonical(&mut encoder, true)
.is_ok() .is_ok());
);
// //
// type is the RRset type and all RRs in the class // type is the RRset type and all RRs in the class
assert!(type_covered.emit(&mut encoder).is_ok()); assert!(type_covered.emit(&mut encoder).is_ok());

View File

@@ -20,9 +20,9 @@ use std::str::FromStr;
use crate::error::*; use crate::error::*;
use crate::rr::domain::label::{CaseInsensitive, CaseSensitive, IntoLabel, Label, LabelCmp}; use crate::rr::domain::label::{CaseInsensitive, CaseSensitive, IntoLabel, Label, LabelCmp};
use crate::rr::domain::usage::LOCALHOST as LOCALHOST_usage; use crate::rr::domain::usage::LOCALHOST as LOCALHOST_usage;
use crate::serialize::binary::*;
#[cfg(feature = "serde-config")] #[cfg(feature = "serde-config")]
use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
use crate::serialize::binary::*;
/// Them should be through references. As a workaround the Strings are all Rc as well as the array /// Them should be through references. As a workaround the Strings are all Rc as well as the array
#[derive(Clone, Default, Debug, Eq)] #[derive(Clone, Default, Debug, Eq)]

View File

@@ -13,8 +13,7 @@ use std::ops::Deref;
use crate::rr::domain::Name; use crate::rr::domain::Name;
lazy_static! {
lazy_static!{
/// Default Name usage, everything is normal... /// Default Name usage, everything is normal...
pub static ref DEFAULT: ZoneUsage = ZoneUsage::default(); pub static ref DEFAULT: ZoneUsage = ZoneUsage::default();
} }
@@ -68,7 +67,7 @@ lazy_static! {
pub static ref LOCAL: ZoneUsage = ZoneUsage::local(Name::from_ascii("local.").unwrap()); pub static ref LOCAL: ZoneUsage = ZoneUsage::local(Name::from_ascii("local.").unwrap());
// RFC 6762 Multicast DNS February 2013 // RFC 6762 Multicast DNS February 2013
// Any DNS query for a name ending with "254.169.in-addr.arpa." MUST // Any DNS query for a name ending with "254.169.in-addr.arpa." MUST
// be sent to the mDNS IPv4 link-local multicast address 224.0.0.251 // be sent to the mDNS IPv4 link-local multicast address 224.0.0.251
// or the mDNS IPv6 multicast address FF02::FB. Since names under // or the mDNS IPv6 multicast address FF02::FB. Since names under
@@ -111,9 +110,6 @@ lazy_static! {
pub static ref INVALID: ZoneUsage = ZoneUsage::invalid(Name::from_ascii("invalid.").unwrap()); pub static ref INVALID: ZoneUsage = ZoneUsage::invalid(Name::from_ascii("invalid.").unwrap());
} }
/// Users: /// Users:
/// ///
/// Are human users expected to recognize these names as special and /// Are human users expected to recognize these names as special and
@@ -126,7 +122,7 @@ pub enum UserUsage {
/// be aware that these names are likely to yield different results /// be aware that these names are likely to yield different results
/// on different networks. /// on different networks.
Normal, Normal,
/// Users are free to use localhost names as they would any other /// Users are free to use localhost names as they would any other
/// domain names. Users may assume that IPv4 and IPv6 address /// domain names. Users may assume that IPv4 and IPv6 address
/// queries for localhost names will always resolve to the respective /// queries for localhost names will always resolve to the respective
@@ -134,7 +130,7 @@ pub enum UserUsage {
Loopback, Loopback,
/// Multi-cast link-local usage /// Multi-cast link-local usage
LinkLocal, LinkLocal,
/// Users are free to use "invalid" names as they would any other /// Users are free to use "invalid" names as they would any other
/// domain names. Users MAY assume that queries for "invalid" names /// domain names. Users MAY assume that queries for "invalid" names
@@ -208,7 +204,7 @@ pub enum ResolverUsage {
Loopback, Loopback,
/// Link local, generally for mDNS /// Link local, generally for mDNS
/// ///
/// Any DNS query for a name ending with ".local." MUST be sent to the /// Any DNS query for a name ending with ".local." MUST be sent to the
/// mDNS IPv4 link-local multicast address 224.0.0.251 (or its IPv6 /// mDNS IPv4 link-local multicast address 224.0.0.251 (or its IPv6
/// equivalent FF02::FB). The design rationale for using a fixed /// equivalent FF02::FB). The design rationale for using a fixed
@@ -235,7 +231,7 @@ pub enum ResolverUsage {
/// their implementations recognize these names as special and treat /// their implementations recognize these names as special and treat
/// them differently? If so, how? /// them differently? If so, how?
#[derive(Clone, Copy, PartialEq, Eq)] #[derive(Clone, Copy, PartialEq, Eq)]
pub enum CacheUsage{ pub enum CacheUsage {
/// Caching DNS servers SHOULD recognize these names as special and /// Caching DNS servers SHOULD recognize these names as special and
/// SHOULD NOT, by default, attempt to look up NS records for them, /// SHOULD NOT, by default, attempt to look up NS records for them,
/// or otherwise query authoritative DNS servers in an attempt to /// or otherwise query authoritative DNS servers in an attempt to
@@ -279,7 +275,7 @@ pub enum CacheUsage{
/// make their implementations recognize these names as special and /// make their implementations recognize these names as special and
/// treat them differently? If so, how? /// treat them differently? If so, how?
#[derive(Clone, Copy, PartialEq, Eq)] #[derive(Clone, Copy, PartialEq, Eq)]
pub enum AuthUsage{ pub enum AuthUsage {
/// Authoritative DNS servers SHOULD recognize these names as special /// Authoritative DNS servers SHOULD recognize these names as special
/// and SHOULD, by default, generate immediate negative responses for /// and SHOULD, by default, generate immediate negative responses for
/// all such queries, unless explicitly configured by the /// all such queries, unless explicitly configured by the
@@ -348,7 +344,7 @@ pub enum OpUsage {
/// name is reserved for use in documentation and cannot be /// name is reserved for use in documentation and cannot be
/// registered!) /// registered!)
#[derive(Clone, Copy, PartialEq, Eq)] #[derive(Clone, Copy, PartialEq, Eq)]
pub enum RegistryUsage{ pub enum RegistryUsage {
/// Stanard checks apply /// Stanard checks apply
Normal, Normal,
@@ -396,51 +392,132 @@ pub struct ZoneUsage {
impl ZoneUsage { impl ZoneUsage {
/// Constructs a new ZoneUsage with the associated values /// Constructs a new ZoneUsage with the associated values
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn new(name: Name, user: UserUsage, app: AppUsage, resolver: ResolverUsage, cache: CacheUsage, auth: AuthUsage, op: OpUsage, registry: RegistryUsage) -> Self { pub fn new(
ZoneUsage {name, user, app, resolver, cache, auth, op, registry} name: Name,
user: UserUsage,
app: AppUsage,
resolver: ResolverUsage,
cache: CacheUsage,
auth: AuthUsage,
op: OpUsage,
registry: RegistryUsage,
) -> Self {
ZoneUsage {
name,
user,
app,
resolver,
cache,
auth,
op,
registry,
}
} }
/// Constructs a new Default, with all no restrictions /// Constructs a new Default, with all no restrictions
pub fn default() -> Self { pub fn default() -> Self {
Self::new(Name::root(), UserUsage::Normal, AppUsage::Normal, ResolverUsage::Normal, CacheUsage::Normal, AuthUsage::Normal, OpUsage::Normal, RegistryUsage::Normal) Self::new(
Name::root(),
UserUsage::Normal,
AppUsage::Normal,
ResolverUsage::Normal,
CacheUsage::Normal,
AuthUsage::Normal,
OpUsage::Normal,
RegistryUsage::Normal,
)
} }
/// Restrictions for reverse zones /// Restrictions for reverse zones
pub fn reverse(name: Name) -> Self { pub fn reverse(name: Name) -> Self {
Self::new(name, UserUsage::Normal, AppUsage::Normal, ResolverUsage::Normal, CacheUsage::NonRecursive, AuthUsage::Local, OpUsage::Normal, RegistryUsage::Reserved) Self::new(
name,
UserUsage::Normal,
AppUsage::Normal,
ResolverUsage::Normal,
CacheUsage::NonRecursive,
AuthUsage::Local,
OpUsage::Normal,
RegistryUsage::Reserved,
)
} }
/// Restrictions for the .test. zone /// Restrictions for the .test. zone
pub fn test(name: Name) -> Self { pub fn test(name: Name) -> Self {
Self::new(name, UserUsage::Normal, AppUsage::Normal, ResolverUsage::Normal, CacheUsage::NonRecursive, AuthUsage::Local, OpUsage::Normal, RegistryUsage::Reserved) Self::new(
name,
UserUsage::Normal,
AppUsage::Normal,
ResolverUsage::Normal,
CacheUsage::NonRecursive,
AuthUsage::Local,
OpUsage::Normal,
RegistryUsage::Reserved,
)
} }
/// Restrictions for the .localhost. zone /// Restrictions for the .localhost. zone
pub fn localhost(name: Name) -> Self { pub fn localhost(name: Name) -> Self {
Self::new(name, UserUsage::Loopback, AppUsage::Loopback, ResolverUsage::Loopback, CacheUsage::Loopback, AuthUsage::Loopback, OpUsage::Loopback, RegistryUsage::Reserved) Self::new(
name,
UserUsage::Loopback,
AppUsage::Loopback,
ResolverUsage::Loopback,
CacheUsage::Loopback,
AuthUsage::Loopback,
OpUsage::Loopback,
RegistryUsage::Reserved,
)
} }
/// Restrictions for the .local. zone /// Restrictions for the .local. zone
pub fn local(name: Name) -> Self { pub fn local(name: Name) -> Self {
Self::new(name, UserUsage::LinkLocal, AppUsage::LinkLocal, ResolverUsage::LinkLocal, CacheUsage::Normal, AuthUsage::Local, OpUsage::Normal, RegistryUsage::Reserved) Self::new(
name,
UserUsage::LinkLocal,
AppUsage::LinkLocal,
ResolverUsage::LinkLocal,
CacheUsage::Normal,
AuthUsage::Local,
OpUsage::Normal,
RegistryUsage::Reserved,
)
} }
/// Restrictions for the .invalid. zone /// Restrictions for the .invalid. zone
pub fn invalid(name: Name) -> Self { pub fn invalid(name: Name) -> Self {
Self::new(name, UserUsage::NxDomain, AppUsage::NxDomain, ResolverUsage::NxDomain, CacheUsage::NxDomain, AuthUsage::NxDomain, OpUsage::NxDomain, RegistryUsage::Reserved) Self::new(
name,
UserUsage::NxDomain,
AppUsage::NxDomain,
ResolverUsage::NxDomain,
CacheUsage::NxDomain,
AuthUsage::NxDomain,
OpUsage::NxDomain,
RegistryUsage::Reserved,
)
} }
/// Restrictions for the .example. zone /// Restrictions for the .example. zone
pub fn example(name: Name) -> Self { pub fn example(name: Name) -> Self {
Self::new(name, UserUsage::Normal, AppUsage::Normal, ResolverUsage::Normal, CacheUsage::Normal, AuthUsage::Normal, OpUsage::Normal, RegistryUsage::Reserved) Self::new(
name,
UserUsage::Normal,
AppUsage::Normal,
ResolverUsage::Normal,
CacheUsage::Normal,
AuthUsage::Normal,
OpUsage::Normal,
RegistryUsage::Reserved,
)
} }
/// A reference to this zone name /// A reference to this zone name
pub fn name(&self) -> &Name { pub fn name(&self) -> &Name {
&self.name &self.name
} }
/// Returns the UserUsage of this zone /// Returns the UserUsage of this zone
pub fn user(&self) -> UserUsage { pub fn user(&self) -> UserUsage {
self.user self.user

View File

@@ -34,8 +34,8 @@
use std::net::Ipv6Addr; use std::net::Ipv6Addr;
use crate::serialize::binary::*;
use crate::error::*; use crate::error::*;
use crate::serialize::binary::*;
/// Read the RData from the given Decoder /// Read the RData from the given Decoder
#[allow(clippy::many_single_char_names)] #[allow(clippy::many_single_char_names)]

View File

@@ -39,9 +39,9 @@
//! the description of name server logic in [RFC-1034] for details. //! the description of name server logic in [RFC-1034] for details.
//! ``` //! ```
use crate::serialize::binary::*;
use crate::error::*; use crate::error::*;
use crate::rr::domain::Name; use crate::rr::domain::Name;
use crate::serialize::binary::*;
/// Read the RData from the given Decoder /// Read the RData from the given Decoder
pub fn read(decoder: &mut BinDecoder) -> ProtoResult<Name> { pub fn read(decoder: &mut BinDecoder) -> ProtoResult<Name> {

View File

@@ -434,7 +434,8 @@ mod tests {
println!("bytes: {:?}", bytes); println!("bytes: {:?}", bytes);
let mut decoder: BinDecoder = BinDecoder::new(bytes); let mut decoder: BinDecoder = BinDecoder::new(bytes);
let read_rdata = read(&mut decoder, Restrict::new(bytes.len() as u16)).expect("failed to read back"); let read_rdata =
read(&mut decoder, Restrict::new(bytes.len() as u16)).expect("failed to read back");
assert_eq!(rdata, read_rdata); assert_eq!(rdata, read_rdata);
} }

View File

@@ -78,9 +78,11 @@ pub fn read(decoder: &mut BinDecoder, rdata_length: Restrict<u16>) -> ProtoResul
let mut strings = Vec::with_capacity(1); let mut strings = Vec::with_capacity(1);
// no unsafe usage of rdata length after this point // no unsafe usage of rdata length after this point
let rdata_length = rdata_length.map(|u| u as usize).unverified(/*used as a higher bound, safely*/); let rdata_length =
rdata_length.map(|u| u as usize).unverified(/*used as a higher bound, safely*/);
while data_len - decoder.len() < rdata_length { while data_len - decoder.len() < rdata_length {
let string = decoder.read_character_data()?.unverified(/*any data should be validate in TXT usage*/); let string =
decoder.read_character_data()?.unverified(/*any data should be validate in TXT usage*/);
strings.push(string.to_vec().into_boxed_slice()); strings.push(string.to_vec().into_boxed_slice());
} }
Ok(TXT { Ok(TXT {

View File

@@ -192,7 +192,9 @@ impl<'a> BinDecoder<'a> {
/// ///
/// Return the u16 from the buffer /// Return the u16 from the buffer
pub fn read_u16(&mut self) -> ProtoResult<Restrict<u16>> { pub fn read_u16(&mut self) -> ProtoResult<Restrict<u16>> {
Ok(self.read_slice(2)?.map(|s| u16::from_be_bytes([s[0], s[1]]))) Ok(self
.read_slice(2)?
.map(|s| u16::from_be_bytes([s[0], s[1]])))
} }
/// Reads the next four bytes into i32. /// Reads the next four bytes into i32.
@@ -204,7 +206,9 @@ impl<'a> BinDecoder<'a> {
/// ///
/// Return the i32 from the buffer /// Return the i32 from the buffer
pub fn read_i32(&mut self) -> ProtoResult<Restrict<i32>> { pub fn read_i32(&mut self) -> ProtoResult<Restrict<i32>> {
Ok(self.read_slice(4)?.map(|s| i32::from_be_bytes([s[0], s[1], s[2], s[3]]))) Ok(self
.read_slice(4)?
.map(|s| i32::from_be_bytes([s[0], s[1], s[2], s[3]])))
} }
/// Reads the next four bytes into u32. /// Reads the next four bytes into u32.
@@ -216,7 +220,9 @@ impl<'a> BinDecoder<'a> {
/// ///
/// Return the u32 from the buffer /// Return the u32 from the buffer
pub fn read_u32(&mut self) -> ProtoResult<Restrict<u32>> { pub fn read_u32(&mut self) -> ProtoResult<Restrict<u32>> {
Ok(self.read_slice(4)?.map(|s| u32::from_be_bytes([s[0], s[1], s[2], s[3]]))) Ok(self
.read_slice(4)?
.map(|s| u32::from_be_bytes([s[0], s[1], s[2], s[3]])))
} }
} }

View File

@@ -164,8 +164,8 @@ impl RestrictedMath for Restrict<u16> {
} }
} }
impl<R, A> RestrictedMath for Result<R, A> impl<R, A> RestrictedMath for Result<R, A>
where where
R: RestrictedMath, R: RestrictedMath,
A: 'static + Sized + Copy, A: 'static + Sized + Copy,
{ {
@@ -178,7 +178,7 @@ where
Err(_) => Err(arg), Err(_) => Err(arg),
} }
} }
fn checked_sub(&self, arg: Self::Arg) -> Result<Restrict<Self::Value>, Self::Arg> { fn checked_sub(&self, arg: Self::Arg) -> Result<Restrict<Self::Value>, Self::Arg> {
match *self { match *self {
Ok(ref r) => r.checked_sub(arg), Ok(ref r) => r.checked_sub(arg),
@@ -200,22 +200,40 @@ mod tests {
#[test] #[test]
fn test_checked_add() { fn test_checked_add() {
assert_eq!(Restrict(1_usize).checked_add(2_usize).unwrap().unverified(), 3_usize); assert_eq!(
assert_eq!(Restrict(1_u16).checked_add(2_u16).unwrap().unverified(), 3_u16); Restrict(1_usize).checked_add(2_usize).unwrap().unverified(),
3_usize
);
assert_eq!(
Restrict(1_u16).checked_add(2_u16).unwrap().unverified(),
3_u16
);
assert_eq!(Restrict(1_u8).checked_add(2_u8).unwrap().unverified(), 3_u8); assert_eq!(Restrict(1_u8).checked_add(2_u8).unwrap().unverified(), 3_u8);
} }
#[test] #[test]
fn test_checked_sub() { fn test_checked_sub() {
assert_eq!(Restrict(2_usize).checked_sub(1_usize).unwrap().unverified(), 1_usize); assert_eq!(
assert_eq!(Restrict(2_u16).checked_sub(1_u16).unwrap().unverified(), 1_u16); Restrict(2_usize).checked_sub(1_usize).unwrap().unverified(),
1_usize
);
assert_eq!(
Restrict(2_u16).checked_sub(1_u16).unwrap().unverified(),
1_u16
);
assert_eq!(Restrict(2_u8).checked_sub(1_u8).unwrap().unverified(), 1_u8); assert_eq!(Restrict(2_u8).checked_sub(1_u8).unwrap().unverified(), 1_u8);
} }
#[test] #[test]
fn test_checked_mul() { fn test_checked_mul() {
assert_eq!(Restrict(1_usize).checked_mul(2_usize).unwrap().unverified(), 2_usize); assert_eq!(
assert_eq!(Restrict(1_u16).checked_mul(2_u16).unwrap().unverified(), 2_u16); Restrict(1_usize).checked_mul(2_usize).unwrap().unverified(),
2_usize
);
assert_eq!(
Restrict(1_u16).checked_mul(2_u16).unwrap().unverified(),
2_u16
);
assert_eq!(Restrict(1_u8).checked_mul(2_u8).unwrap().unverified(), 2_u8); assert_eq!(Restrict(1_u8).checked_mul(2_u8).unwrap().unverified(), 2_u8);
} }
} }

View File

@@ -8,12 +8,12 @@
use std::fmt::{self, Display}; use std::fmt::{self, Display};
use std::io; use std::io;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::time::Duration;
use std::pin::Pin; use std::pin::Pin;
use std::task::Context; use std::task::Context;
use std::time::Duration;
use async_trait::async_trait; use async_trait::async_trait;
use futures::{Future, Poll, Stream, TryFutureExt, StreamExt}; use futures::{Future, Poll, Stream, StreamExt, TryFutureExt};
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use crate::error::ProtoError; use crate::error::ProtoError;
@@ -40,7 +40,10 @@ impl<S: Connect + 'static + Send> TcpClientStream<S> {
#[allow(clippy::new_ret_no_self)] #[allow(clippy::new_ret_no_self)]
pub fn new( pub fn new(
name_server: SocketAddr, name_server: SocketAddr,
) -> (TcpClientConnect<S::Transport>, Box<dyn DnsStreamHandle + Send>) { ) -> (
TcpClientConnect<S::Transport>,
Box<dyn DnsStreamHandle + Send>,
) {
Self::with_timeout(name_server, Duration::from_secs(5)) Self::with_timeout(name_server, Duration::from_secs(5))
} }
@@ -53,7 +56,10 @@ impl<S: Connect + 'static + Send> TcpClientStream<S> {
pub fn with_timeout( pub fn with_timeout(
name_server: SocketAddr, name_server: SocketAddr,
timeout: Duration, timeout: Duration,
) -> (TcpClientConnect<S::Transport>, Box<dyn DnsStreamHandle + Send>) { ) -> (
TcpClientConnect<S::Transport>,
Box<dyn DnsStreamHandle + Send>,
) {
let (stream_future, sender) = TcpStream::<S>::with_timeout(name_server, timeout); let (stream_future, sender) = TcpStream::<S>::with_timeout(name_server, timeout);
let new_future = Box::pin( let new_future = Box::pin(
@@ -106,14 +112,16 @@ impl<S: AsyncRead + AsyncWrite + Send + Unpin> Stream for TcpClientStream<S> {
// TODO: create unboxed future for the TCP Stream // TODO: create unboxed future for the TCP Stream
/// A future that resolves to an TcpClientStream /// A future that resolves to an TcpClientStream
pub struct TcpClientConnect<S>(Pin<Box<dyn Future<Output = Result<TcpClientStream<S>, ProtoError>> + Send + 'static>>); pub struct TcpClientConnect<S>(
Pin<Box<dyn Future<Output = Result<TcpClientStream<S>, ProtoError>> + Send + 'static>>,
);
impl<S> Future for TcpClientConnect<S> { impl<S> Future for TcpClientConnect<S> {
type Output = Result<TcpClientStream<S>, ProtoError>; type Output = Result<TcpClientStream<S>, ProtoError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.0.as_mut().poll(cx) self.0.as_mut().poll(cx)
} }
} }
#[cfg(feature = "tokio-compat")] #[cfg(feature = "tokio-compat")]
@@ -123,7 +131,7 @@ use tokio_net::tcp::TcpStream as TokioTcpStream;
#[async_trait] #[async_trait]
impl Connect for TokioTcpStream { impl Connect for TokioTcpStream {
type Transport = TokioTcpStream; type Transport = TokioTcpStream;
async fn connect(addr: &SocketAddr) -> io::Result<Self::Transport> { async fn connect(addr: &SocketAddr) -> io::Result<Self::Transport> {
TokioTcpStream::connect(addr).await TokioTcpStream::connect(addr).await
} }
@@ -240,10 +248,11 @@ fn tcp_client_stream_test(server_addr: IpAddr) {
sender sender
.send(SerialMessage::new(TEST_BYTES.to_vec(), server_addr)) .send(SerialMessage::new(TEST_BYTES.to_vec(), server_addr))
.expect("send failed"); .expect("send failed");
let (buffer, stream_tmp) = io_loop let (buffer, stream_tmp) = io_loop.block_on(stream.into_future());
.block_on(stream.into_future());
stream = stream_tmp; stream = stream_tmp;
let buffer = buffer.expect("no buffer received").expect("error recieving buffer"); let buffer = buffer
.expect("no buffer received")
.expect("error recieving buffer");
assert_eq!(buffer.bytes(), TEST_BYTES); assert_eq!(buffer.bytes(), TEST_BYTES);
} }

View File

@@ -306,7 +306,7 @@ impl<S: tokio_io::AsyncRead + tokio_io::AsyncWrite + Unpin> Stream for TcpStream
Poll::Ready(Some(message)) => { Poll::Ready(Some(message)) => {
// if there is no peer, this connection should die... // if there is no peer, this connection should die...
let (buffer, dst) = message.unwrap(); let (buffer, dst) = message.unwrap();
// This is an error if the destination is not our peer (this is TCP after all) // This is an error if the destination is not our peer (this is TCP after all)
// This will kill the connection... // This will kill the connection...
if peer != dst { if peer != dst {

View File

@@ -9,10 +9,10 @@ use std::borrow::Borrow;
use std::fmt::{self, Display}; use std::fmt::{self, Display};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use std::task::Context; use std::task::Context;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use futures::{Future, Poll, Stream}; use futures::{Future, Poll, Stream};
use tokio_timer::timeout::{Elapsed, Timeout}; use tokio_timer::timeout::{Elapsed, Timeout};
@@ -183,7 +183,9 @@ impl<S: Send, MF: MessageFinalizer> Stream for UdpClientStream<S, MF> {
} }
/// A future that resolves to /// A future that resolves to
pub struct UdpResponse(Pin<Box<dyn Future<Output = Result<Result<DnsResponse, ProtoError>, Elapsed>> + Send>>); pub struct UdpResponse(
Pin<Box<dyn Future<Output = Result<Result<DnsResponse, ProtoError>, Elapsed>> + Send>>,
);
impl UdpResponse { impl UdpResponse {
/// creates a new future for the request /// creates a new future for the request
@@ -192,7 +194,11 @@ impl UdpResponse {
/// ///
/// * `request` - Serialized message being sent /// * `request` - Serialized message being sent
/// * `message_id` - Id of the message that was encoded in the serial message /// * `message_id` - Id of the message that was encoded in the serial message
fn new<S: UdpSocket + Send + Unpin + 'static>(request: SerialMessage, message_id: u16, timeout: Duration) -> Self { fn new<S: UdpSocket + Send + Unpin + 'static>(
request: SerialMessage,
message_id: u16,
timeout: Duration,
) -> Self {
UdpResponse(Box::pin(Timeout::new( UdpResponse(Box::pin(Timeout::new(
SingleUseUdpSocket::send_serial_message::<S>(request, message_id), SingleUseUdpSocket::send_serial_message::<S>(request, message_id),
timeout, timeout,
@@ -200,7 +206,9 @@ impl UdpResponse {
} }
/// ad already completed future /// ad already completed future
fn complete<F: Future<Output = Result<DnsResponse, ProtoError>> + Send + 'static>(f: F) -> Self { fn complete<F: Future<Output = Result<DnsResponse, ProtoError>> + Send + 'static>(
f: F,
) -> Self {
// TODO: this constructure isn't really necessary // TODO: this constructure isn't really necessary
UdpResponse(Box::pin(Timeout::new(f, Duration::from_secs(5)))) UdpResponse(Box::pin(Timeout::new(f, Duration::from_secs(5))))
} }
@@ -210,7 +218,11 @@ impl Future for UdpResponse {
type Output = Result<DnsResponse, ProtoError>; type Output = Result<DnsResponse, ProtoError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.0.as_mut().poll(cx).map_err(ProtoError::from).map(|r| r.and_then(|r| r)) self.0
.as_mut()
.poll(cx)
.map_err(ProtoError::from)
.map(|r| r.and_then(|r| r))
} }
} }
@@ -247,7 +259,10 @@ impl<S: Send + Unpin, MF: MessageFinalizer> Future for UdpClientConnect<S, MF> {
struct SingleUseUdpSocket; struct SingleUseUdpSocket;
impl SingleUseUdpSocket { impl SingleUseUdpSocket {
async fn send_serial_message<S: UdpSocket + Send>(msg: SerialMessage, msg_id: u16) -> Result<DnsResponse, ProtoError> { async fn send_serial_message<S: UdpSocket + Send>(
msg: SerialMessage,
msg_id: u16,
) -> Result<DnsResponse, ProtoError> {
let name_server = msg.addr(); let name_server = msg.addr();
let mut socket: S = NextRandomUdpSocket::new(&name_server).await?; let mut socket: S = NextRandomUdpSocket::new(&name_server).await?;
let bytes = msg.bytes(); let bytes = msg.bytes();
@@ -255,7 +270,11 @@ impl SingleUseUdpSocket {
let len_sent: usize = socket.send_to(bytes, addr).await?; let len_sent: usize = socket.send_to(bytes, addr).await?;
if bytes.len() != len_sent { if bytes.len() != len_sent {
return Err(ProtoError::from(format!("Not all bytes of message sent, {} of {}", len_sent, bytes.len()))) return Err(ProtoError::from(format!(
"Not all bytes of message sent, {} of {}",
len_sent,
bytes.len()
)));
} }
// TODO: limit the max number of attempted messages? this relies on a timeout to die... // TODO: limit the max number of attempted messages? this relies on a timeout to die...
@@ -286,7 +305,7 @@ impl SingleUseUdpSocket {
Ok(message) => { Ok(message) => {
if msg_id == message.id() { if msg_id == message.id() {
debug!("received message id: {}", message.id()); debug!("received message id: {}", message.id());
return Ok(DnsResponse::from(message)) return Ok(DnsResponse::from(message));
} else { } else {
// on wrong id, attempted poison? // on wrong id, attempted poison?
warn!( warn!(
@@ -321,144 +340,144 @@ impl SingleUseUdpSocket {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
#[cfg(not(target_os = "linux"))] #[cfg(not(target_os = "linux"))]
use std::net::Ipv6Addr; use std::net::Ipv6Addr;
use std::net::{IpAddr, Ipv4Addr}; use std::net::{IpAddr, Ipv4Addr};
use tokio_net::udp; use tokio_net::udp;
use crate::op::Message; use super::*;
use super::*; use crate::op::Message;
#[test]
#[test] fn test_udp_client_stream_ipv4() {
fn test_udp_client_stream_ipv4() { udp_client_stream_test(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)))
udp_client_stream_test(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)))
}
#[test]
#[cfg(not(target_os = "linux"))] // ignored until Travis-CI fixes IPv6
fn test_udp_client_stream_ipv6() {
udp_client_stream_test(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)))
}
fn udp_client_stream_test(server_addr: IpAddr) {
use crate::op::Query;
use crate::rr::rdata::NULL;
use crate::rr::{Name, RData, Record, RecordType};
use std::str::FromStr;
use tokio::runtime::current_thread::Runtime;
// use env_logger;
// env_logger::try_init().ok();
let succeeded = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let succeeded_clone = succeeded.clone();
std::thread::Builder::new()
.name("thread_killer".to_string())
.spawn(move || {
let succeeded = succeeded_clone.clone();
for _ in 0..15 {
std::thread::sleep(std::time::Duration::from_secs(1));
if succeeded.load(std::sync::atomic::Ordering::Relaxed) {
return;
}
}
panic!("timeout");
})
.unwrap();
let server = std::net::UdpSocket::bind(SocketAddr::new(server_addr, 0)).unwrap();
server
.set_read_timeout(Some(std::time::Duration::from_secs(5)))
.unwrap(); // should receive something within 5 seconds...
server
.set_write_timeout(Some(std::time::Duration::from_secs(5)))
.unwrap(); // should receive something within 5 seconds...
let server_addr = server.local_addr().unwrap();
let mut query = Message::new();
let test_name = Name::from_str("dead.beef").unwrap();
query.add_query(Query::query(test_name.clone(), RecordType::NULL));
let test_bytes: &'static [u8; 8] = b"DEADBEEF";
let send_recv_times = 4;
let test_name_server = test_name.clone();
// an in and out server
let server_handle = std::thread::Builder::new()
.name("test_udp_client_stream_ipv4:server".to_string())
.spawn(move || {
let mut buffer = [0_u8; 512];
for i in 0..send_recv_times {
// wait for some bytes...
debug!("server receiving request {}", i);
let (len, addr) = server.recv_from(&mut buffer).expect("receive failed");
debug!("server received request {} from: {}", i, addr);
let request = Message::from_vec(&buffer[0..len]).expect("failed parse of request");
assert_eq!(*request.queries()[0].name(), test_name_server.clone());
assert_eq!(request.queries()[0].query_type(), RecordType::NULL);
let mut message = Message::new();
message.set_id(request.id());
message.add_queries(request.queries().to_vec());
message.add_answer(Record::from_rdata(
test_name_server.clone(),
0,
RData::NULL(NULL::with(test_bytes.to_vec())),
));
// bounce them right back...
let bytes = message.to_vec().unwrap();
debug!("server sending response {} to: {}", i, addr);
assert_eq!(
server.send_to(&bytes, addr).expect("send failed"),
bytes.len()
);
debug!("server sent response {}", i);
std::thread::yield_now();
}
})
.unwrap();
// setup the client, which is going to run on the testing thread...
let mut io_loop = Runtime::new().unwrap();
// the tests should run within 5 seconds... right?
// TODO: add timeout here, so that test never hangs...
// let timeout = Timeout::new(Duration::from_secs(5));
let stream = UdpClientStream::with_timeout(server_addr, Duration::from_millis(500));
let mut stream: UdpClientStream<udp::UdpSocket> = io_loop.block_on(stream).ok().unwrap();
let mut worked_once = false;
for i in 0..send_recv_times {
// test once
let response_future =
stream.send_message(DnsRequest::new(query.clone(), Default::default()));
println!("client sending request {}", i);
let response = match io_loop.block_on(response_future) {
Ok(response) => response,
Err(err) => {
println!("failed to get message: {}", err);
continue;
}
};
println!("client got response {}", i);
let response = Message::from(response);
if let RData::NULL(null) = response.answers()[0].rdata() {
assert_eq!(null.anything().expect("no bytes in NULL"), test_bytes);
} else {
panic!("not a NULL response");
}
worked_once = true;
} }
succeeded.store(true, std::sync::atomic::Ordering::Relaxed); #[test]
server_handle.join().expect("server thread failed"); #[cfg(not(target_os = "linux"))] // ignored until Travis-CI fixes IPv6
fn test_udp_client_stream_ipv6() {
udp_client_stream_test(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)))
}
assert!(worked_once); fn udp_client_stream_test(server_addr: IpAddr) {
use crate::op::Query;
use crate::rr::rdata::NULL;
use crate::rr::{Name, RData, Record, RecordType};
use std::str::FromStr;
use tokio::runtime::current_thread::Runtime;
// use env_logger;
// env_logger::try_init().ok();
let succeeded = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let succeeded_clone = succeeded.clone();
std::thread::Builder::new()
.name("thread_killer".to_string())
.spawn(move || {
let succeeded = succeeded_clone.clone();
for _ in 0..15 {
std::thread::sleep(std::time::Duration::from_secs(1));
if succeeded.load(std::sync::atomic::Ordering::Relaxed) {
return;
}
}
panic!("timeout");
})
.unwrap();
let server = std::net::UdpSocket::bind(SocketAddr::new(server_addr, 0)).unwrap();
server
.set_read_timeout(Some(std::time::Duration::from_secs(5)))
.unwrap(); // should receive something within 5 seconds...
server
.set_write_timeout(Some(std::time::Duration::from_secs(5)))
.unwrap(); // should receive something within 5 seconds...
let server_addr = server.local_addr().unwrap();
let mut query = Message::new();
let test_name = Name::from_str("dead.beef").unwrap();
query.add_query(Query::query(test_name.clone(), RecordType::NULL));
let test_bytes: &'static [u8; 8] = b"DEADBEEF";
let send_recv_times = 4;
let test_name_server = test_name.clone();
// an in and out server
let server_handle = std::thread::Builder::new()
.name("test_udp_client_stream_ipv4:server".to_string())
.spawn(move || {
let mut buffer = [0_u8; 512];
for i in 0..send_recv_times {
// wait for some bytes...
debug!("server receiving request {}", i);
let (len, addr) = server.recv_from(&mut buffer).expect("receive failed");
debug!("server received request {} from: {}", i, addr);
let request =
Message::from_vec(&buffer[0..len]).expect("failed parse of request");
assert_eq!(*request.queries()[0].name(), test_name_server.clone());
assert_eq!(request.queries()[0].query_type(), RecordType::NULL);
let mut message = Message::new();
message.set_id(request.id());
message.add_queries(request.queries().to_vec());
message.add_answer(Record::from_rdata(
test_name_server.clone(),
0,
RData::NULL(NULL::with(test_bytes.to_vec())),
));
// bounce them right back...
let bytes = message.to_vec().unwrap();
debug!("server sending response {} to: {}", i, addr);
assert_eq!(
server.send_to(&bytes, addr).expect("send failed"),
bytes.len()
);
debug!("server sent response {}", i);
std::thread::yield_now();
}
})
.unwrap();
// setup the client, which is going to run on the testing thread...
let mut io_loop = Runtime::new().unwrap();
// the tests should run within 5 seconds... right?
// TODO: add timeout here, so that test never hangs...
// let timeout = Timeout::new(Duration::from_secs(5));
let stream = UdpClientStream::with_timeout(server_addr, Duration::from_millis(500));
let mut stream: UdpClientStream<udp::UdpSocket> = io_loop.block_on(stream).ok().unwrap();
let mut worked_once = false;
for i in 0..send_recv_times {
// test once
let response_future =
stream.send_message(DnsRequest::new(query.clone(), Default::default()));
println!("client sending request {}", i);
let response = match io_loop.block_on(response_future) {
Ok(response) => response,
Err(err) => {
println!("failed to get message: {}", err);
continue;
}
};
println!("client got response {}", i);
let response = Message::from(response);
if let RData::NULL(null) = response.answers()[0].rdata() {
assert_eq!(null.anything().expect("no bytes in NULL"), test_bytes);
} else {
panic!("not a NULL response");
}
worked_once = true;
}
succeeded.store(true, std::sync::atomic::Ordering::Relaxed);
server_handle.join().expect("server thread failed");
assert!(worked_once);
}
} }
}

View File

@@ -127,12 +127,20 @@ impl<S: UdpSocket + Send + 'static> UdpStream<S> {
} }
impl<S: Send> UdpStream<S> { impl<S: Send> UdpStream<S> {
fn pollable_split(&mut self) -> ( fn pollable_split(
&mut Arc<Mutex<S>>, &mut self,
) -> (
&mut Arc<Mutex<S>>,
&mut Option<Pin<Box<dyn Future<Output = io::Result<usize>> + Send>>>, &mut Option<Pin<Box<dyn Future<Output = io::Result<usize>> + Send>>>,
&mut Peekable<Fuse<UnboundedReceiver<SerialMessage>>>, &mut Peekable<Fuse<UnboundedReceiver<SerialMessage>>>,
&mut Option<Pin<Box<dyn Future<Output = io::Result<SerialMessage>> + Send>>>) { &mut Option<Pin<Box<dyn Future<Output = io::Result<SerialMessage>> + Send>>>,
(&mut self.socket, &mut self.sending, &mut self.outbound_messages, &mut self.receiving) ) {
(
&mut self.socket,
&mut self.sending,
&mut self.outbound_messages,
&mut self.receiving,
)
} }
} }
@@ -154,8 +162,7 @@ impl<S: UdpSocket + Send + 'static> Stream for UdpStream<S> {
*sending = None; *sending = None;
// first try to send // first try to send
match outbound_messages.as_mut().poll_next(cx) match outbound_messages.as_mut().poll_next(cx) {
{
Poll::Ready(Some(message)) => { Poll::Ready(Some(message)) => {
let socket = Arc::clone(socket); let socket = Arc::clone(socket);
let sending_fut = async { let sending_fut = async {
@@ -188,7 +195,7 @@ impl<S: UdpSocket + Send + 'static> Stream for UdpStream<S> {
} else { } else {
None None
}; };
*receiving = None; *receiving = None;
if let Some(msg) = msg { if let Some(msg) = msg {
@@ -202,7 +209,7 @@ impl<S: UdpSocket + Send + 'static> Stream for UdpStream<S> {
let mut buf = [0u8; 2048]; let mut buf = [0u8; 2048];
let mut socket = socket.lock().await; let mut socket = socket.lock().await;
let (len, src) = socket.recv_from(&mut buf).await?; let (len, src) = socket.recv_from(&mut buf).await?;
Ok(SerialMessage::new( Ok(SerialMessage::new(
buf.iter().take(len).cloned().collect(), buf.iter().take(len).cloned().collect(),
src, src,
@@ -260,7 +267,9 @@ impl<S: UdpSocket> Future for NextRandomUdpSocket<S> {
debug!("created socket successfully"); debug!("created socket successfully");
return Poll::Ready(Ok(socket)); return Poll::Ready(Ok(socket));
} }
Poll::Ready(Err(err)) => debug!("unable to bind port, attempt: {}: {}", attempt, err), Poll::Ready(Err(err)) => {
debug!("unable to bind port, attempt: {}: {}", attempt, err)
}
Poll::Pending => debug!("unable to bind port, attempt: {}", attempt), Poll::Pending => debug!("unable to bind port, attempt: {}", attempt),
} }
} }
@@ -310,7 +319,7 @@ impl UdpSocket for udp::UdpSocket {
async fn bind(addr: &SocketAddr) -> io::Result<Self> { async fn bind(addr: &SocketAddr) -> io::Result<Self> {
udp::UdpSocket::bind(addr).await udp::UdpSocket::bind(addr).await
} }
async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.recv_from(buf).await self.recv_from(buf).await
} }
@@ -385,9 +394,11 @@ fn udp_stream_test(server_addr: IpAddr) {
std::net::SocketAddr::V6(_) => "[::1]:0", std::net::SocketAddr::V6(_) => "[::1]:0",
}; };
let socket = let socket = io_loop
io_loop.block_on(udp::UdpSocket::bind(&client_addr.to_socket_addrs().unwrap().next().unwrap())) .block_on(udp::UdpSocket::bind(
.expect("could not create socket"); // some random address... &client_addr.to_socket_addrs().unwrap().next().unwrap(),
))
.expect("could not create socket"); // some random address...
let (mut stream, sender) = UdpStream::<udp::UdpSocket>::with_bound(socket); let (mut stream, sender) = UdpStream::<udp::UdpSocket>::with_bound(socket);
//let mut stream: UdpStream = io_loop.block_on(stream).ok().unwrap(); //let mut stream: UdpStream = io_loop.block_on(stream).ok().unwrap();
@@ -398,7 +409,9 @@ fn udp_stream_test(server_addr: IpAddr) {
.unwrap(); .unwrap();
let (buffer_and_addr, stream_tmp) = io_loop.block_on(stream.into_future()); let (buffer_and_addr, stream_tmp) = io_loop.block_on(stream.into_future());
stream = stream_tmp; stream = stream_tmp;
let message = buffer_and_addr.expect("no buffer received").expect("error receiving buffer"); let message = buffer_and_addr
.expect("no buffer received")
.expect("error receiving buffer");
assert_eq!(message.bytes(), test_bytes); assert_eq!(message.bytes(), test_bytes);
assert_eq!(message.addr(), server_addr); assert_eq!(message.addr(), server_addr);
} }

View File

@@ -10,12 +10,14 @@
use std::pin::Pin; use std::pin::Pin;
use std::task::Context; use std::task::Context;
use futures::channel::mpsc::{unbounded, UnboundedReceiver};
use futures::stream::{Peekable, Stream, StreamExt}; use futures::stream::{Peekable, Stream, StreamExt};
use futures::{Future, FutureExt, Poll}; use futures::{Future, FutureExt, Poll};
use futures::channel::mpsc::{unbounded, UnboundedReceiver};
use crate::error::*; use crate::error::*;
use crate::xfer::{DnsRequest, DnsRequestSender, DnsRequestStreamHandle, DnsResponse, OneshotDnsRequest}; use crate::xfer::{
DnsRequest, DnsRequestSender, DnsRequestStreamHandle, DnsResponse, OneshotDnsRequest,
};
/// This is a generic Exchange implemented over multiplexed DNS connection providers. /// This is a generic Exchange implemented over multiplexed DNS connection providers.
/// ///
@@ -76,7 +78,12 @@ where
) )
} }
fn pollable_split(&mut self) -> (&mut S, &mut Peekable<UnboundedReceiver<OneshotDnsRequest<R>>>) { fn pollable_split(
&mut self,
) -> (
&mut S,
&mut Peekable<UnboundedReceiver<OneshotDnsRequest<R>>>,
) {
(&mut self.io_stream, &mut self.outbound_messages) (&mut self.io_stream, &mut self.outbound_messages)
} }
} }
@@ -120,8 +127,7 @@ where
} }
// then see if there is more to send // then see if there is more to send
match outbound_messages.as_mut().poll_next(cx) match outbound_messages.as_mut().poll_next(cx) {
{
// already handled above, here to make sure the poll() pops the next message // already handled above, here to make sure the poll() pops the next message
Poll::Ready(Some(dns_request)) => { Poll::Ready(Some(dns_request)) => {
// if there is no peer, this connection should die... // if there is no peer, this connection should die...
@@ -134,7 +140,7 @@ where
Err(_) => { Err(_) => {
warn!("failed to associate send_message response to the sender"); warn!("failed to associate send_message response to the sender");
return Poll::Ready(Err( return Poll::Ready(Err(
"failed to associate send_message response to the sender".into() "failed to associate send_message response to the sender".into(),
)); ));
} }
} }

View File

@@ -8,9 +8,9 @@
//! `DnsHandle` types perform conversions of the raw DNS messages before sending the messages on the specified streams. //! `DnsHandle` types perform conversions of the raw DNS messages before sending the messages on the specified streams.
use std::pin::Pin; use std::pin::Pin;
use futures::future::{Future, FutureExt, TryFutureExt};
use futures::channel::mpsc::UnboundedSender; use futures::channel::mpsc::UnboundedSender;
use futures::channel::oneshot; use futures::channel::oneshot;
use futures::future::{Future, FutureExt, TryFutureExt};
use rand; use rand;
use crate::error::*; use crate::error::*;
@@ -69,10 +69,7 @@ impl BasicDnsHandle {
impl DnsHandle for BasicDnsHandle { impl DnsHandle for BasicDnsHandle {
type Response = Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send + Unpin>>; type Response = Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send + Unpin>>;
fn send<R: Into<DnsRequest>>( fn send<R: Into<DnsRequest>>(&mut self, request: R) -> Self::Response {
&mut self,
request: R,
) -> Self::Response {
let request = request.into(); let request = request.into();
let (complete, receiver) = oneshot::channel(); let (complete, receiver) = oneshot::channel();
let message_sender: &mut _ = &mut self.message_sender; let message_sender: &mut _ = &mut self.message_sender;
@@ -91,10 +88,11 @@ impl DnsHandle for BasicDnsHandle {
}; };
// convert the oneshot into a Box of a Future message and error. // convert the oneshot into a Box of a Future message and error.
Box::pin(receiver Box::pin(
.map_err(|c| ProtoError::from(ProtoErrorKind::Canceled(c))) receiver
.map(|r| r.and_then(|r| r))) .map_err(|c| ProtoError::from(ProtoErrorKind::Canceled(c)))
.map(|r| r.and_then(|r| r)),
)
} }
} }

View File

@@ -16,8 +16,8 @@ use std::sync::Arc;
use std::task::Context; use std::task::Context;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use futures::stream::{Stream, StreamExt};
use futures::channel::oneshot; use futures::channel::oneshot;
use futures::stream::{Stream, StreamExt};
use futures::{ready, Future, FutureExt, Poll}; use futures::{ready, Future, FutureExt, Poll};
use rand; use rand;
use rand::distributions::{Distribution, Standard}; use rand::distributions::{Distribution, Standard};
@@ -318,7 +318,8 @@ where
Poll::Pending => { Poll::Pending => {
return DnsMultiplexerSerialResponseInner::Err(Some(ProtoError::from( return DnsMultiplexerSerialResponseInner::Err(Some(ProtoError::from(
"id space exhausted, consider filing an issue", "id space exhausted, consider filing an issue",
))).into() )))
.into()
} }
}; };
@@ -505,10 +506,10 @@ impl Future for DnsMultiplexerSerialResponseInner {
// The inner type of the completion might have been an error // The inner type of the completion might have been an error
// we need to unwrap that, and translate to be the Future's error // we need to unwrap that, and translate to be the Future's error
DnsMultiplexerSerialResponseInner::Completion(ref mut complete) => { DnsMultiplexerSerialResponseInner::Completion(ref mut complete) => {
complete.poll_unpin(cx).map(|r| r complete.poll_unpin(cx).map(|r| {
.map_err(|_| ProtoError::from("the completion was canceled")) r.map_err(|_| ProtoError::from("the completion was canceled"))
.and_then(|r| r) .and_then(|r| r)
) })
} }
DnsMultiplexerSerialResponseInner::Err(ref mut err) => { DnsMultiplexerSerialResponseInner::Err(ref mut err) => {
Poll::Ready(Err(err.take().expect("cannot poll after complete"))) Poll::Ready(Err(err.take().expect("cannot poll after complete")))

View File

@@ -9,9 +9,9 @@ use std::net::SocketAddr;
use std::pin::Pin; use std::pin::Pin;
use std::task::Context; use std::task::Context;
use futures::{ready, Future, Poll, Stream}; use futures::channel::mpsc::{TrySendError, UnboundedSender};
use futures::channel::mpsc::{UnboundedSender, TrySendError};
use futures::channel::oneshot::{self, Receiver, Sender}; use futures::channel::oneshot::{self, Receiver, Sender};
use futures::{ready, Future, Poll, Stream};
use crate::error::*; use crate::error::*;
use crate::op::Message; use crate::op::Message;
@@ -156,7 +156,10 @@ pub trait DnsRequestSender:
Stream<Item = Result<(), ProtoError>> + 'static + Display + Send + Unpin Stream<Item = Result<(), ProtoError>> + 'static + Display + Send + Unpin
{ {
/// A future that resolves to a response serial message /// A future that resolves to a response serial message
type DnsResponseFuture: Future<Output = Result<DnsResponse, ProtoError>> + 'static + Send + Unpin; type DnsResponseFuture: Future<Output = Result<DnsResponse, ProtoError>>
+ 'static
+ Send
+ Unpin;
/// Send a message, and return a future of the response /// Send a message, and return a future of the response
/// ///
@@ -212,7 +215,9 @@ macro_rules! try_oneshot {
match $expr { match $expr {
Result::Ok(val) => val, Result::Ok(val) => val,
Result::Err(err) => return OneshotDnsResponseReceiver::Err(Some(ProtoError::from(err))), Result::Err(err) => {
return OneshotDnsResponseReceiver::Err(Some(ProtoError::from(err)))
}
} }
}}; }};
($expr:expr,) => { ($expr:expr,) => {
@@ -318,7 +323,7 @@ where
} }
OneshotDnsResponseReceiver::Received(ref mut future) => { OneshotDnsResponseReceiver::Received(ref mut future) => {
let future = Pin::new(future); let future = Pin::new(future);
return future.poll(cx) return future.poll(cx);
} }
OneshotDnsResponseReceiver::Err(ref mut err) => { OneshotDnsResponseReceiver::Err(ref mut err) => {
return Poll::Ready(Err(err return Poll::Ready(Err(err

View File

@@ -97,9 +97,9 @@ impl<H: DnsHandle + Unpin> Future for RetrySendFuture<H> {
mod test { mod test {
use super::*; use super::*;
use crate::error::*; use crate::error::*;
use futures::future::*;
use futures::executor::block_on;
use crate::op::*; use crate::op::*;
use futures::executor::block_on;
use futures::future::*;
use std::cell::Cell; use std::cell::Cell;
use DnsHandle; use DnsHandle;

View File

@@ -14,8 +14,8 @@ use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::task::Context; use std::task::Context;
use futures::{Future, FutureExt, Poll, TryFutureExt};
use futures::future::{self, SelectAll}; use futures::future::{self, SelectAll};
use futures::{Future, FutureExt, Poll, TryFutureExt};
use crate::error::*; use crate::error::*;
use crate::op::{OpCode, Query}; use crate::op::{OpCode, Query};
@@ -25,8 +25,8 @@ use crate::rr::dnssec::Verifier;
use crate::rr::dnssec::{Algorithm, SupportedAlgorithms, TrustAnchor}; use crate::rr::dnssec::{Algorithm, SupportedAlgorithms, TrustAnchor};
use crate::rr::rdata::opt::EdnsOption; use crate::rr::rdata::opt::EdnsOption;
use crate::rr::{DNSClass, Name, RData, Record, RecordType}; use crate::rr::{DNSClass, Name, RData, Record, RecordType};
use crate::xfer::{DnsRequest, DnsRequestOptions, DnsResponse};
use crate::xfer::dns_handle::DnsHandle; use crate::xfer::dns_handle::DnsHandle;
use crate::xfer::{DnsRequest, DnsRequestOptions, DnsResponse};
#[derive(Debug)] #[derive(Debug)]
struct Rrset { struct Rrset {
@@ -112,7 +112,9 @@ impl<H: DnsHandle + Unpin> DnsHandle for SecureDnsHandle<H> {
// backstop, this might need to be configurable at some point // backstop, this might need to be configurable at some point
if self.request_depth > 20 { if self.request_depth > 20 {
return Box::pin(future::err(ProtoError::from("exceeded max validation depth"))); return Box::pin(future::err(ProtoError::from(
"exceeded max validation depth",
)));
} }
// dnssec only matters on queries. // dnssec only matters on queries.
@@ -157,8 +159,8 @@ impl<H: DnsHandle + Unpin> DnsHandle for SecureDnsHandle<H> {
.first() .first()
.map_or(DNSClass::IN, Query::query_class); .map_or(DNSClass::IN, Query::query_class);
return return Box::pin(
Box::pin(self.handle self.handle
.send(request) .send(request)
.and_then(move |message_response| { .and_then(move |message_response| {
// group the record sets by name and type // group the record sets by name and type
@@ -201,7 +203,8 @@ impl<H: DnsHandle + Unpin> DnsHandle for SecureDnsHandle<H> {
} }
future::ok(verified_message) future::ok(verified_message)
})); }),
);
} }
Box::pin(self.handle.send(request)) Box::pin(self.handle.send(request))
@@ -253,7 +256,8 @@ fn verify_rrsets<H: DnsHandle + Unpin>(
return future::err(ProtoError::from(ProtoErrorKind::Message( return future::err(ProtoError::from(ProtoErrorKind::Message(
"no results to verify", "no results to verify",
))).boxed(); )))
.boxed();
} }
// collect all the rrsets to verify // collect all the rrsets to verify
@@ -314,7 +318,8 @@ fn verify_rrsets<H: DnsHandle + Unpin>(
message_result: Some(message_result), message_result: Some(message_result),
rrsets: rrsets_to_verify, rrsets: rrsets_to_verify,
verified_rrsets: HashSet::new(), verified_rrsets: HashSet::new(),
}.boxed() }
.boxed()
} }
fn is_dnssec(rr: &Record, dnssec_type: DNSSECRecordType) -> bool { fn is_dnssec(rr: &Record, dnssec_type: DNSSECRecordType) -> bool {
@@ -326,7 +331,9 @@ impl Future for VerifyRrsetsFuture {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
if self.message_result.is_none() { if self.message_result.is_none() {
return Poll::Ready(Err(ProtoError::from(ProtoErrorKind::Message("message is none")))); return Poll::Ready(Err(ProtoError::from(ProtoErrorKind::Message(
"message is none",
))));
} }
// loop through all the rrset evaluations, filter all the rrsets in the Message // loop through all the rrset evaluations, filter all the rrsets in the Message
@@ -358,10 +365,14 @@ impl Future for VerifyRrsetsFuture {
if !remaining.is_empty() { if !remaining.is_empty() {
// continue the evaluation // continue the evaluation
drop(mem::replace(&mut self.as_mut().rrsets, future::select_all(remaining))); drop(mem::replace(
&mut self.as_mut().rrsets,
future::select_all(remaining),
));
} else { } else {
// validated not none above... // validated not none above...
let mut message_result = mem::replace(&mut self.as_mut().message_result, None).unwrap(); let mut message_result =
mem::replace(&mut self.as_mut().message_result, None).unwrap();
// take all the rrsets from the Message, filter down each set to the validated rrsets // take all the rrsets from the Message, filter down each set to the validated rrsets
// TODO: does the section in the message matter here? // TODO: does the section in the message matter here?
@@ -442,8 +453,7 @@ where
match rrset.record_type { match rrset.record_type {
RecordType::DNSSEC(DNSSECRecordType::DNSKEY) => verify_dnskey_rrset(handle, rrset), RecordType::DNSSEC(DNSSECRecordType::DNSKEY) => verify_dnskey_rrset(handle, rrset),
_ => future::ok(rrset).boxed(), _ => future::ok(rrset).boxed(),
} })
)
.map_err(|e| { .map_err(|e| {
debug!("rrset failed validation: {}", e); debug!("rrset failed validation: {}", e);
e e
@@ -461,7 +471,7 @@ fn verify_dnskey_rrset<H>(
rrset: Rrset, rrset: Rrset,
) -> Pin<Box<dyn Future<Output = Result<Rrset, ProtoError>> + Send>> ) -> Pin<Box<dyn Future<Output = Result<Rrset, ProtoError>> + Send>>
where where
H: DnsHandle + Unpin H: DnsHandle + Unpin,
{ {
debug!( debug!(
"dnskey validation {}, record_type: {:?}", "dnskey validation {}, record_type: {:?}",
@@ -659,40 +669,40 @@ where
// after this function. This function is only responsible for validating the signature // after this function. This function is only responsible for validating the signature
// the DNSKey validation should come after, see verify_rrset(). // the DNSKey validation should come after, see verify_rrset().
return future::ready( return future::ready(
rrsigs rrsigs
.into_iter() .into_iter()
// this filter is technically unnecessary, can probably remove it... // this filter is technically unnecessary, can probably remove it...
.filter(|rrsig| is_dnssec(rrsig, DNSSECRecordType::RRSIG)) .filter(|rrsig| is_dnssec(rrsig, DNSSECRecordType::RRSIG))
.map(|rrsig| { .map(|rrsig| {
if let RData::DNSSEC(DNSSECRData::SIG(sig)) = rrsig.unwrap_rdata() { if let RData::DNSSEC(DNSSECRData::SIG(sig)) = rrsig.unwrap_rdata() {
// setting up the context explicitly. // setting up the context explicitly.
sig sig
} else { } else {
panic!("expected a SIG here"); panic!("expected a SIG here");
} }
}) })
.filter_map(|sig| { .filter_map(|sig| {
let rrset = Arc::clone(&rrset); let rrset = Arc::clone(&rrset);
if rrset.records.iter().any(|r| { if rrset.records.iter().any(|r| {
if let RData::DNSSEC(DNSSECRData::DNSKEY(ref dnskey)) = *r.rdata() { if let RData::DNSSEC(DNSSECRData::DNSKEY(ref dnskey)) = *r.rdata() {
verify_rrset_with_dnskey(dnskey, &sig, &rrset).is_ok() verify_rrset_with_dnskey(dnskey, &sig, &rrset).is_ok()
} else {
panic!("expected a DNSKEY here: {:?}", r.rdata());
}
}) {
Some(rrset)
} else { } else {
None panic!("expected a DNSKEY here: {:?}", r.rdata());
} }
}) }) {
.next() Some(rrset)
.ok_or_else(|| { } else {
ProtoError::from(ProtoErrorKind::Message("self-signed dnskey is invalid")) None
}), }
) })
.map_ok(move |rrset| Arc::try_unwrap(rrset).expect("unable to unwrap Arc")) .next()
.boxed(); .ok_or_else(|| {
ProtoError::from(ProtoErrorKind::Message("self-signed dnskey is invalid"))
}),
)
.map_ok(move |rrset| Arc::try_unwrap(rrset).expect("unable to unwrap Arc"))
.boxed();
} }
// we can validate with any of the rrsigs... // we can validate with any of the rrsigs...
@@ -748,7 +758,8 @@ where
return future::err(ProtoError::from(ProtoErrorKind::RrsigsNotPresent { return future::err(ProtoError::from(ProtoErrorKind::RrsigsNotPresent {
name: rrset.name.clone(), name: rrset.name.clone(),
record_type: rrset.record_type, record_type: rrset.record_type,
})).boxed(); }))
.boxed();
} }
// as long as any of the verifications is good, then the RRSET is valid. // as long as any of the verifications is good, then the RRSET is valid.

View File

@@ -88,7 +88,10 @@ lazy_static! {
/// ///
/// This looks up the `host` (a `&str` or `String` is good), and combines that with the provided port /// This looks up the `host` (a `&str` or `String` is good), and combines that with the provided port
/// this mimics the lookup functions of `std::net`. /// this mimics the lookup functions of `std::net`.
pub fn resolve<N: IntoName + TryParseIp>(host: N, port: u16) -> impl Future<Output = io::Result<Vec<SocketAddr>>> { pub fn resolve<N: IntoName + TryParseIp>(
host: N,
port: u16,
) -> impl Future<Output = io::Result<Vec<SocketAddr>>> {
// Now we use the global resolver to perform a lookup_ip. // Now we use the global resolver to perform a lookup_ip.
let resolve_future = GLOBAL_DNS_RESOLVER.lookup_ip(host).map(move |result| { let resolve_future = GLOBAL_DNS_RESOLVER.lookup_ip(host).map(move |result| {
// map the result into what we want... // map the result into what we want...
@@ -99,7 +102,8 @@ pub fn resolve<N: IntoName + TryParseIp>(host: N, port: u16) -> impl Future<Outp
io::ErrorKind::AddrNotAvailable, io::ErrorKind::AddrNotAvailable,
format!("dns resolution error: {}", err), format!("dns resolution error: {}", err),
) )
}).map(move |lookup_ip| { })
.map(move |lookup_ip| {
// we take all the IPs returned, and then send back the set of IPs // we take all the IPs returned, and then send back the set of IPs
lookup_ip lookup_ip
.iter() .iter()
@@ -129,7 +133,8 @@ fn main() {
}); });
(name, join) (name, join)
}).collect::<Vec<_>>(); })
.collect::<Vec<_>>();
// print the resolved IPs // print the resolved IPs
for (name, join) in threads { for (name, join) in threads {

View File

@@ -49,7 +49,8 @@ fn main() {
// Go through the list of resolution operations and wait for them to complete. // Go through the list of resolution operations and wait for them to complete.
for (name, lookup) in futures.drain(..) { for (name, lookup) in futures.drain(..) {
let ips = runtime.block_on(lookup) let ips = runtime
.block_on(lookup)
.expect("Failed completing lookup future") .expect("Failed completing lookup future")
.iter() .iter()
.collect::<Vec<_>>(); .collect::<Vec<_>>();

View File

@@ -5,17 +5,17 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be // http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms. // copied, modified, or distributed except according to those terms.
use std::sync::Arc;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use std::task::Context; use std::task::Context;
use futures::{future, ready, channel::mpsc, Future, FutureExt, Poll, StreamExt};
use futures::lock::Mutex; use futures::lock::Mutex;
use futures::{channel::mpsc, future, ready, Future, FutureExt, Poll, StreamExt};
#[cfg(feature = "dnssec")] #[cfg(feature = "dnssec")]
use proto::SecureDnsHandle; use proto::SecureDnsHandle;
use proto::{ use proto::{
error::ProtoResult, error::ProtoResult,
rr::{Name, RData, RecordType, Record}, rr::{Name, RData, Record, RecordType},
xfer::{DnsRequestOptions, RetryDnsHandle}, xfer::{DnsRequestOptions, RetryDnsHandle},
}; };
@@ -82,7 +82,8 @@ pub(super) fn task(
hosts, hosts,
request_rx, request_rx,
} }
}).flatten() })
.flatten()
} }
type ClientCache = CachingClient<LookupEither<ConnectionHandle, StandardConnection>>; type ClientCache = CachingClient<LookupEither<ConnectionHandle, StandardConnection>>;

View File

@@ -13,10 +13,11 @@ use std::sync::Arc;
use std::task::Context; use std::task::Context;
use futures::{ use futures::{
self, future, self,
channel::{mpsc, oneshot}, channel::{mpsc, oneshot},
Future, FutureExt, Poll, TryFutureExt, future,
lock::Mutex, lock::Mutex,
Future, FutureExt, Poll, TryFutureExt,
}; };
use proto::error::ProtoResult; use proto::error::ProtoResult;
use proto::rr::domain::TryParseIp; use proto::rr::domain::TryParseIp;
@@ -498,8 +499,7 @@ mod tests {
thread::spawn(move || { thread::spawn(move || {
let mut background_runtime = Runtime::new().unwrap(); let mut background_runtime = Runtime::new().unwrap();
background_runtime background_runtime.block_on(bg);
.block_on(bg);
}); });
let response = io_loop let response = io_loop

View File

@@ -497,7 +497,11 @@ impl Default for LookupIpStrategy {
/// Configuration for the Resolver /// Configuration for the Resolver
#[derive(Debug, Clone, Copy, Eq, PartialEq)] #[derive(Debug, Clone, Copy, Eq, PartialEq)]
#[cfg_attr(feature = "serde-config", derive(Serialize, Deserialize), serde(default))] #[cfg_attr(
feature = "serde-config",
derive(Serialize, Deserialize),
serde(default)
)]
#[allow(dead_code)] // TODO: remove after all params are supported #[allow(dead_code)] // TODO: remove after all params are supported
pub struct ResolverOpts { pub struct ResolverOpts {
/// Sets the number of dots that must appear (unless it's a final dot representing the root) /// Sets the number of dots that must appear (unless it's a final dot representing the root)

View File

@@ -10,8 +10,8 @@
use std::borrow::Cow; use std::borrow::Cow;
use std::collections::HashMap; use std::collections::HashMap;
use std::task::{Context, Poll};
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll};
use futures::Future; use futures::Future;
@@ -21,7 +21,9 @@ use proto::xfer::DnsRequestOptions;
use crate::async_resolver::{AsyncResolver, BackgroundLookup}; use crate::async_resolver::{AsyncResolver, BackgroundLookup};
use crate::error::*; use crate::error::*;
use crate::lookup::{ReverseLookup, ReverseLookupFuture, ReverseLookupIter, TxtLookup, TxtLookupFuture}; use crate::lookup::{
ReverseLookup, ReverseLookupFuture, ReverseLookupIter, TxtLookup, TxtLookupFuture,
};
/// An extension for the Resolver to perform DNS Service Discovery /// An extension for the Resolver to perform DNS Service Discovery
pub trait DnsSdHandle { pub trait DnsSdHandle {
@@ -143,7 +145,6 @@ mod tests {
use crate::config::*; use crate::config::*;
use super::*; use super::*;
#[test] #[test]

View File

@@ -8,9 +8,9 @@
//! Error types for the crate //! Error types for the crate
use failure::{Backtrace, Context, Fail}; use failure::{Backtrace, Context, Fail};
use std::{fmt, io, sync, time::Instant};
use proto::error::{ProtoError, ProtoErrorKind}; use proto::error::{ProtoError, ProtoErrorKind};
use proto::op::Query; use proto::op::Query;
use std::{fmt, io, sync, time::Instant};
/// An alias for results returned by functions of this crate /// An alias for results returned by functions of this crate
pub type ResolveResult<T> = ::std::result::Result<T, ResolveError>; pub type ResolveResult<T> = ::std::result::Result<T, ResolveError>;
@@ -33,7 +33,7 @@ pub enum ResolveErrorKind {
query: Query, query: Query,
/// A deadline after which the `NXDOMAIN` response is no longer /// A deadline after which the `NXDOMAIN` response is no longer
/// valid, and the nameserver should be queried again. /// valid, and the nameserver should be queried again.
valid_until: Option<Instant> valid_until: Option<Instant>,
}, },
// foreign // foreign
@@ -56,7 +56,10 @@ impl Clone for ResolveErrorKind {
match *self { match *self {
Message(msg) => Message(msg), Message(msg) => Message(msg),
Msg(ref msg) => Msg(msg.clone()), Msg(ref msg) => Msg(msg.clone()),
NoRecordsFound { ref query, valid_until } => NoRecordsFound { NoRecordsFound {
ref query,
valid_until,
} => NoRecordsFound {
query: query.clone(), query: query.clone(),
valid_until, valid_until,
}, },

View File

@@ -10,14 +10,23 @@ use futures::Future;
use proto::error::ProtoError; use proto::error::ProtoError;
use proto::xfer::{BufDnsRequestStreamHandle, DnsExchange}; use proto::xfer::{BufDnsRequestStreamHandle, DnsExchange};
use trust_dns_https::{HttpsClientStream, HttpsClientStreamBuilder, HttpsClientResponse}; use trust_dns_https::{HttpsClientResponse, HttpsClientStream, HttpsClientStreamBuilder};
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
pub(crate) fn new_https_stream( pub(crate) fn new_https_stream(
socket_addr: SocketAddr, socket_addr: SocketAddr,
dns_name: String, dns_name: String,
) -> ( ) -> (
Pin<Box<dyn Future<Output = Result<DnsExchange<HttpsClientStream, HttpsClientResponse>, ProtoError>> + Send>>, Pin<
Box<
dyn Future<
Output = Result<
DnsExchange<HttpsClientStream, HttpsClientResponse>,
ProtoError,
>,
> + Send,
>,
>,
BufDnsRequestStreamHandle<HttpsClientResponse>, BufDnsRequestStreamHandle<HttpsClientResponse>,
) { ) {
// using the mozilla default root store // using the mozilla default root store

View File

@@ -9,12 +9,12 @@
use std::cmp::min; use std::cmp::min;
use std::net::{Ipv4Addr, Ipv6Addr}; use std::net::{Ipv4Addr, Ipv6Addr};
use std::pin::Pin;
use std::slice::Iter; use std::slice::Iter;
use std::sync::Arc; use std::sync::Arc;
use std::task::Context;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::vec::IntoIter; use std::vec::IntoIter;
use std::pin::Pin;
use std::task::Context;
use futures::{future, Future, FutureExt, Poll}; use futures::{future, Future, FutureExt, Poll};
@@ -31,7 +31,9 @@ use crate::dns_lru::MAX_TTL;
use crate::error::*; use crate::error::*;
use crate::lookup_ip::LookupIpIter; use crate::lookup_ip::LookupIpIter;
use crate::lookup_state::CachingClient; use crate::lookup_state::CachingClient;
use crate::name_server::{ConnectionHandle, ConnectionProvider, NameServerPool, StandardConnection}; use crate::name_server::{
ConnectionHandle, ConnectionProvider, NameServerPool, StandardConnection,
};
/// Result of a DNS query when querying for any record type supported by the Trust-DNS Proto library. /// Result of a DNS query when querying for any record type supported by the Trust-DNS Proto library.
/// ///
@@ -242,9 +244,9 @@ impl<C: DnsHandle + 'static> LookupFuture<C> {
}); });
let query: Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>> = match name { let query: Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>> = match name {
Ok(name) => { Ok(name) => client_cache
client_cache.lookup(Query::query(name, record_type), options.clone()).boxed() .lookup(Query::query(name, record_type), options.clone())
} .boxed(),
Err(err) => future::err(err).boxed(), Err(err) => future::err(err).boxed(),
}; };
@@ -537,8 +539,8 @@ pub mod tests {
use std::str::FromStr; use std::str::FromStr;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use futures::{future, Future};
use futures::executor::block_on; use futures::executor::block_on;
use futures::{future, Future};
use proto::error::{ProtoErrorKind, ProtoResult}; use proto::error::{ProtoErrorKind, ProtoResult};
use proto::op::Message; use proto::op::Message;
@@ -556,9 +558,7 @@ pub mod tests {
type Response = Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>; type Response = Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>;
fn send<R: Into<DnsRequest>>(&mut self, _: R) -> Self::Response { fn send<R: Into<DnsRequest>>(&mut self, _: R) -> Self::Response {
future::ready( future::ready(self.messages.lock().unwrap().pop().unwrap_or_else(empty)).boxed()
self.messages.lock().unwrap().pop().unwrap_or_else(empty),
).boxed()
} }
} }

View File

@@ -10,10 +10,10 @@
//! At it's heart LookupIp uses Lookup for performing all lookups. It is unlike other standard lookups in that there are customizations around A and AAAA resolutions. //! At it's heart LookupIp uses Lookup for performing all lookups. It is unlike other standard lookups in that there are customizations around A and AAAA resolutions.
use std::net::IpAddr; use std::net::IpAddr;
use std::sync::Arc;
use std::time::Instant;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use std::task::Context; use std::task::Context;
use std::time::Instant;
use failure::Fail; use failure::Fail;
use futures::{future, future::Either, Future, FutureExt, Poll}; use futures::{future, future::Either, Future, FutureExt, Poll};
@@ -216,9 +216,7 @@ where
names: vec![], names: vec![],
strategy: LookupIpStrategy::default(), strategy: LookupIpStrategy::default(),
options: DnsRequestOptions::default(), options: DnsRequestOptions::default(),
query: future::err( query: future::err(ResolveErrorKind::Msg(format!("{}", error)).into()).boxed(),
ResolveErrorKind::Msg(format!("{}", error)).into(),
).boxed(),
hosts: None, hosts: None,
finally_ip_addr: None, finally_ip_addr: None,
} }
@@ -303,42 +301,46 @@ fn ipv4_and_ipv6<C: DnsHandle + 'static>(
client.clone(), client.clone(),
options.clone(), options.clone(),
hosts.clone(), hosts.clone(),
), ),
hosts_lookup( hosts_lookup(Query::query(name, RecordType::AAAA), client, options, hosts),
Query::query(name, RecordType::AAAA),
client,
options,
hosts,
)
) )
.then(|sel_res| { .then(|sel_res| {
let (ips, remaining_query) = match sel_res { let (ips, remaining_query) = match sel_res {
Either::Left(ips_and_remaining) => ips_and_remaining, Either::Left(ips_and_remaining) => ips_and_remaining,
Either::Right(ips_and_remaining) => ips_and_remaining, Either::Right(ips_and_remaining) => ips_and_remaining,
}; };
// Some ips returned, get the other record result, or else just return record // Some ips returned, get the other record result, or else just return record
// One failed, just return the other // One failed, just return the other
match ips { match ips {
Ok(ips) => remaining_query.then(move |remaining_ips| match remaining_ips { Ok(ips) => remaining_query
.then(move |remaining_ips| match remaining_ips {
// join AAAA and A results // join AAAA and A results
Ok(rem_ips) => { Ok(rem_ips) => {
// TODO: create a LookupIp enum with the ability to chain these together // TODO: create a LookupIp enum with the ability to chain these together
let ips = ips.append(rem_ips); let ips = ips.append(rem_ips);
future::ok(ips) future::ok(ips)
}, }
// One failed, just return the other // One failed, just return the other
Err(e) => { Err(e) => {
debug!("one of ipv4 or ipv6 lookup failed in ipv4_and_ipv6 strategy: {}", e); debug!(
"one of ipv4 or ipv6 lookup failed in ipv4_and_ipv6 strategy: {}",
e
);
future::ok(ips) future::ok(ips)
}, }
}).boxed(), })
Err(e) => { .boxed(),
debug!("one of ipv4 or ipv6 lookup failed in ipv4_and_ipv6 strategy: {}", e); Err(e) => {
remaining_query.boxed() debug!(
} "one of ipv4 or ipv6 lookup failed in ipv4_and_ipv6 strategy: {}",
e
);
remaining_query.boxed()
} }
}).boxed() }
})
.boxed()
} }
/// queries only for AAAA and on no results queries for A /// queries only for AAAA and on no results queries for A
@@ -386,34 +388,37 @@ fn rt_then_swap<C: DnsHandle + 'static>(
) -> Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>> { ) -> Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>> {
let or_client = client.clone(); let or_client = client.clone();
hosts_lookup( hosts_lookup(
Query::query(name.clone(), first_type), Query::query(name.clone(), first_type),
client, client,
options.clone(), options.clone(),
hosts.clone(), hosts.clone(),
) )
.then(move |res| { .then(move |res| {
match res { match res {
Ok(ips) => { Ok(ips) => {
if ips.is_empty() { if ips.is_empty() {
// no ips returns, NXDomain or Otherwise, doesn't matter // no ips returns, NXDomain or Otherwise, doesn't matter
hosts_lookup( hosts_lookup(
Query::query(name.clone(), second_type), Query::query(name.clone(), second_type),
or_client, or_client,
options, options,
hosts, hosts,
).boxed() )
} else { .boxed()
future::ok(ips).boxed() } else {
} future::ok(ips).boxed()
} }
Err(_) => hosts_lookup(
Query::query(name.clone(), second_type),
or_client,
options,
hosts,
).boxed(),
} }
}).boxed() Err(_) => hosts_lookup(
Query::query(name.clone(), second_type),
or_client,
options,
hosts,
)
.boxed(),
}
})
.boxed()
} }
#[cfg(test)] #[cfg(test)]
@@ -421,8 +426,8 @@ pub mod tests {
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use futures::{future, Future};
use futures::executor::block_on; use futures::executor::block_on;
use futures::{future, Future};
use proto::error::{ProtoError, ProtoResult}; use proto::error::{ProtoError, ProtoResult};
use proto::op::Message; use proto::op::Message;
@@ -437,7 +442,8 @@ pub mod tests {
} }
impl DnsHandle for MockDnsHandle { impl DnsHandle for MockDnsHandle {
type Response = Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send + Unpin>>; type Response =
Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send + Unpin>>;
fn send<R: Into<DnsRequest>>(&mut self, _: R) -> Self::Response { fn send<R: Into<DnsRequest>>(&mut self, _: R) -> Self::Response {
Box::pin(future::ready( Box::pin(future::ready(

View File

@@ -39,13 +39,13 @@ lazy_static! {
} }
struct DepthTracker { struct DepthTracker {
query_depth: Arc<AtomicU8> query_depth: Arc<AtomicU8>,
} }
impl DepthTracker { impl DepthTracker {
fn track(query_depth: Arc<AtomicU8>) -> Self { fn track(query_depth: Arc<AtomicU8>) -> Self {
query_depth.fetch_add(1, Ordering::Release); query_depth.fetch_add(1, Ordering::Release);
Self{ query_depth } Self { query_depth }
} }
} }

View File

@@ -6,10 +6,10 @@
// copied, modified, or distributed except according to those terms. // copied, modified, or distributed except according to those terms.
use std::net::SocketAddr; use std::net::SocketAddr;
use std::task::Context;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use std::pin::Pin; use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::Context;
use std::time::Duration;
use futures::{Future, FutureExt, Poll, TryFutureExt}; use futures::{Future, FutureExt, Poll, TryFutureExt};
use tokio_executor::{DefaultExecutor, Executor}; use tokio_executor::{DefaultExecutor, Executor};
@@ -133,9 +133,12 @@ impl ConnectionHandleConnect {
let stream = UdpClientStream::<TokioUdpSocket>::with_timeout(socket_addr, timeout); let stream = UdpClientStream::<TokioUdpSocket>::with_timeout(socket_addr, timeout);
let (stream, handle) = DnsExchange::connect(stream); let (stream, handle) = DnsExchange::connect(stream);
let stream = stream.and_then(|stream| stream).map_err(|e| { let stream = stream
debug!("udp connection shutting down: {}", e); .and_then(|stream| stream)
}).map(|_| ()); .map_err(|e| {
debug!("udp connection shutting down: {}", e);
})
.map(|_| ());
let handle = BufDnsRequestStreamHandle::new(handle); let handle = BufDnsRequestStreamHandle::new(handle);
DefaultExecutor::current().spawn(stream.boxed())?; DefaultExecutor::current().spawn(stream.boxed())?;
@@ -156,9 +159,12 @@ impl ConnectionHandleConnect {
); );
let (stream, handle) = DnsExchange::connect(dns_conn); let (stream, handle) = DnsExchange::connect(dns_conn);
let stream = stream.and_then(|stream| stream).map_err(|e| { let stream = stream
debug!("tcp connection shutting down: {}", e); .and_then(|stream| stream)
}).map(|_|()); .map_err(|e| {
debug!("tcp connection shutting down: {}", e);
})
.map(|_| ());
let handle = BufDnsRequestStreamHandle::new(handle); let handle = BufDnsRequestStreamHandle::new(handle);
DefaultExecutor::current().spawn(stream.boxed())?; DefaultExecutor::current().spawn(stream.boxed())?;
@@ -179,9 +185,12 @@ impl ConnectionHandleConnect {
); );
let (stream, handle) = DnsExchange::connect(dns_conn); let (stream, handle) = DnsExchange::connect(dns_conn);
let stream = stream.and_then(|stream| stream).map_err(|e| { let stream = stream
debug!("tls connection shutting down: {}", e); .and_then(|stream| stream)
}).map(|_| ()); .map_err(|e| {
debug!("tls connection shutting down: {}", e);
})
.map(|_| ());
let handle = BufDnsRequestStreamHandle::new(handle); let handle = BufDnsRequestStreamHandle::new(handle);
DefaultExecutor::current().spawn(Box::pin(stream))?; DefaultExecutor::current().spawn(Box::pin(stream))?;
@@ -196,9 +205,12 @@ impl ConnectionHandleConnect {
} => { } => {
let (stream, handle) = crate::https::new_https_stream(socket_addr, tls_dns_name); let (stream, handle) = crate::https::new_https_stream(socket_addr, tls_dns_name);
let stream = stream.and_then(|stream| stream).map_err(|e| { let stream = stream
debug!("https connection shutting down: {}", e); .and_then(|stream| stream)
}).map(|_| ()); .map_err(|e| {
debug!("https connection shutting down: {}", e);
})
.map(|_| ());
DefaultExecutor::current().spawn(Box::pin(stream))?; DefaultExecutor::current().spawn(Box::pin(stream))?;
Ok(ConnectionHandleConnected::Https(handle)) Ok(ConnectionHandleConnected::Https(handle))
@@ -219,9 +231,12 @@ impl ConnectionHandleConnect {
); );
let (stream, handle) = DnsExchange::connect(dns_conn); let (stream, handle) = DnsExchange::connect(dns_conn);
let stream = stream.and_then(|stream| stream).map_err(|e| { let stream = stream
debug!("mdns connection shutting down: {}", e); .and_then(|stream| stream)
}).map(|_| ()); .map_err(|e| {
debug!("mdns connection shutting down: {}", e);
})
.map(|_| ());
let handle = BufDnsRequestStreamHandle::new(handle); let handle = BufDnsRequestStreamHandle::new(handle);
DefaultExecutor::current().spawn(Box::pin(stream))?; DefaultExecutor::current().spawn(Box::pin(stream))?;
@@ -243,7 +258,10 @@ enum ConnectionHandleConnected {
impl DnsHandle for ConnectionHandleConnected { impl DnsHandle for ConnectionHandleConnected {
type Response = ConnectionHandleResponseInner; type Response = ConnectionHandleResponseInner;
fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&mut self, request: R) -> ConnectionHandleResponseInner { fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(
&mut self,
request: R,
) -> ConnectionHandleResponseInner {
match self { match self {
ConnectionHandleConnected::Udp(ref mut conn) => { ConnectionHandleConnected::Udp(ref mut conn) => {
ConnectionHandleResponseInner::Udp(conn.send(request)) ConnectionHandleResponseInner::Udp(conn.send(request))
@@ -266,7 +284,10 @@ enum ConnectionHandleInner {
} }
impl ConnectionHandleInner { impl ConnectionHandleInner {
fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&mut self, request: R) -> ConnectionHandleResponseInner { fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(
&mut self,
request: R,
) -> ConnectionHandleResponseInner {
loop { loop {
let connected: Result<ConnectionHandleConnected, proto::error::ProtoError> = match self let connected: Result<ConnectionHandleConnected, proto::error::ProtoError> = match self
{ {
@@ -337,7 +358,9 @@ impl Future for ConnectionHandleResponseInner {
#[cfg(feature = "dns-over-https")] #[cfg(feature = "dns-over-https")]
Https(ref mut https) => return https.poll_unpin(cx), Https(ref mut https) => return https.poll_unpin(cx),
ProtoError(ref mut e) => { ProtoError(ref mut e) => {
return Poll::Ready(Err(e.take().expect("futures cannot be polled once complete"))); return Poll::Ready(Err(e
.take()
.expect("futures cannot be polled once complete")));
} }
}; };

View File

@@ -6,17 +6,17 @@
// copied, modified, or distributed except according to those terms. // copied, modified, or distributed except according to those terms.
mod connection_provider; mod connection_provider;
#[allow(clippy::module_inception)]
mod name_server;
mod name_server_pool; mod name_server_pool;
mod name_server_state; mod name_server_state;
mod name_server_stats; mod name_server_stats;
#[allow(clippy::module_inception)]
mod name_server;
pub use self::connection_provider::ConnectionProvider;
pub(crate) use self::connection_provider::{ConnectionHandle, StandardConnection};
#[cfg(feature = "mdns")]
pub(crate) use self::name_server::mdns_nameserver;
pub use self::name_server::NameServer;
pub use self::name_server_pool::NameServerPool;
use self::name_server_state::NameServerState; use self::name_server_state::NameServerState;
use self::name_server_stats::NameServerStats; use self::name_server_stats::NameServerStats;
pub use self::name_server_pool::NameServerPool;
pub use self::connection_provider::ConnectionProvider;
pub(crate) use self::connection_provider::{StandardConnection, ConnectionHandle};
pub use self::name_server::NameServer;
#[cfg(feature = "mdns")]
pub(crate) use self::name_server::mdns_nameserver;

View File

@@ -7,9 +7,9 @@
use std::cmp::Ordering; use std::cmp::Ordering;
use std::fmt::{self, Debug, Formatter}; use std::fmt::{self, Debug, Formatter};
use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use std::pin::Pin;
use futures::{future, Future, TryFutureExt}; use futures::{future, Future, TryFutureExt};
@@ -136,7 +136,8 @@ where
}; };
// Because a Poisoned lock error could have occurred, make sure to create a new Mutex... // Because a Poisoned lock error could have occurred, make sure to create a new Mutex...
Box::pin(client Box::pin(
client
.send(request) .send(request)
.and_then(move |response| { .and_then(move |response| {
// first we'll evaluate if the message succeeded // first we'll evaluate if the message succeeded
@@ -176,9 +177,8 @@ where
// These are connection failures, not lookup failures, that is handled in the resolver layer // These are connection failures, not lookup failures, that is handled in the resolver layer
future::err(error) future::err(error)
}) }),
) )
} }
} }
@@ -194,10 +194,7 @@ impl<C: DnsHandle, P: ConnectionProvider<ConnHandle = C>> Ord for NameServer<C,
// this will prefer established connections, we should try other connections after // this will prefer established connections, we should try other connections after
// some number to make sure that all are used. This is more important for when // some number to make sure that all are used. This is more important for when
// latency is started to be used. // latency is started to be used.
match self match self.state.cmp(&other.state) {
.state
.cmp(&other.state)
{
Ordering::Equal => (), Ordering::Equal => (),
o => { o => {
return o; return o;
@@ -266,10 +263,7 @@ mod tests {
}; };
let mut io_loop = Runtime::new().unwrap(); let mut io_loop = Runtime::new().unwrap();
let name_server = future::lazy(|_| { let name_server = future::lazy(|_| {
NameServer::<_, StandardConnection>::new( NameServer::<_, StandardConnection>::new(config, ResolverOpts::default())
config,
ResolverOpts::default(),
)
}); });
let name = Name::parse("www.example.com.", None).unwrap(); let name = Name::parse("www.example.com.", None).unwrap();

View File

@@ -166,19 +166,21 @@ where
let request = mdns.take_request(); let request = mdns.take_request();
// First try the UDP connections // First try the UDP connections
Box::pin(Self::try_send(opts, datagram_conns, request) Box::pin(
.and_then(move |response| { Self::try_send(opts, datagram_conns, request)
// handling promotion from datagram to stream base on truncation in message .and_then(move |response| {
if ResponseCode::NoError == response.response_code() && response.truncated() { // handling promotion from datagram to stream base on truncation in message
// TCP connections should not truncate if ResponseCode::NoError == response.response_code() && response.truncated() {
future::Either::Left(Self::try_send(opts, stream_conns1, tcp_message1)) // TCP connections should not truncate
} else { future::Either::Left(Self::try_send(opts, stream_conns1, tcp_message1))
// Return the result from the UDP connection } else {
future::Either::Right(future::ok(response)) // Return the result from the UDP connection
} future::Either::Right(future::ok(response))
}) }
// if UDP fails, try TCP })
.or_else(move |_| Self::try_send(opts, stream_conns2, tcp_message2))) // if UDP fails, try TCP
.or_else(move |_| Self::try_send(opts, stream_conns2, tcp_message2)),
)
} }
} }
@@ -345,7 +347,9 @@ impl Local {
/// # Panics /// # Panics
/// ///
/// Panics if this is in fact a Local::NotMdns /// Panics if this is in fact a Local::NotMdns
fn take_future(self) -> Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send + Unpin>> { fn take_future(
self,
) -> Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send + Unpin>> {
match self { match self {
Local::ResolveFuture(future) => future, Local::ResolveFuture(future) => future,
_ => panic!("non Local queries have no future, see take_message()"), _ => panic!("non Local queries have no future, see take_message()"),

View File

@@ -42,10 +42,10 @@ impl NameServerStateInner {
impl NameServerState { impl NameServerState {
/// Set at the new Init state /// Set at the new Init state
/// ///
/// If send_dns is some, this will be sent on the first request when it is established /// If send_dns is some, this will be sent on the first request when it is established
pub fn init(send_edns: Option<Edns>) -> Self { pub fn init(send_edns: Option<Edns>) -> Self {
NameServerState(RwLock::new(NameServerStateInner::Init{ send_edns })) NameServerState(RwLock::new(NameServerStateInner::Init { send_edns }))
} }
/// Transition to the Established state /// Transition to the Established state
@@ -54,15 +54,15 @@ impl NameServerState {
/// the remote's support. /// the remote's support.
pub fn establish(&self, remote_edns: Option<Edns>) { pub fn establish(&self, remote_edns: Option<Edns>) {
let mut state = self.0.write().expect("poisoned lock"); let mut state = self.0.write().expect("poisoned lock");
*state = NameServerStateInner::Established{ remote_edns }; *state = NameServerStateInner::Established { remote_edns };
} }
/// transition to the Failed state /// transition to the Failed state
/// ///
/// when is the time of the failure /// when is the time of the failure
pub fn fail(&self, when: Instant) { pub fn fail(&self, when: Instant) {
let mut state = self.0.write().expect("poisoned lock"); let mut state = self.0.write().expect("poisoned lock");
*state = NameServerStateInner::Failed{ when }; *state = NameServerStateInner::Failed { when };
} }
/// True if this is in the Failed state /// True if this is in the Failed state
@@ -127,7 +127,8 @@ impl PartialOrd for NameServerState {
impl PartialEq for NameServerState { impl PartialEq for NameServerState {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.0.read().expect("self poisoned").to_usize() == other.0.read().expect("self poisoned").to_usize() self.0.read().expect("self poisoned").to_usize()
== other.0.read().expect("self poisoned").to_usize()
} }
} }
@@ -142,9 +143,13 @@ mod tests {
fn test_state_cmp() { fn test_state_cmp() {
let init = NameServerState::init(None); let init = NameServerState::init(None);
let established = NameServerState(RwLock::new(NameServerStateInner::Established { remote_edns: None })); let established = NameServerState(RwLock::new(NameServerStateInner::Established {
remote_edns: None,
}));
let failed = NameServerState(RwLock::new(NameServerStateInner::Failed { when: Instant::now() })); let failed = NameServerState(RwLock::new(NameServerStateInner::Failed {
when: Instant::now(),
}));
assert_eq!(init.cmp(&init), Ordering::Equal); assert_eq!(init.cmp(&init), Ordering::Equal);
assert_eq!(init.cmp(&established), Ordering::Less); assert_eq!(init.cmp(&established), Ordering::Less);

View File

@@ -31,13 +31,18 @@ impl NameServerStats {
pub fn next_success(&self) { pub fn next_success(&self) {
self.successes.fetch_add(1, atomic::Ordering::Release); self.successes.fetch_add(1, atomic::Ordering::Release);
} }
pub fn next_failure(&self) { pub fn next_failure(&self) {
self.failures.fetch_add(1, atomic::Ordering::Release); self.failures.fetch_add(1, atomic::Ordering::Release);
} }
fn noload_eq(self_successes: usize, other_successes: usize, self_failures: usize, other_failures: usize) -> bool { fn noload_eq(
self_successes: usize,
other_successes: usize,
self_failures: usize,
other_failures: usize,
) -> bool {
self_successes == other_successes && self_failures == other_failures self_successes == other_successes && self_failures == other_failures
} }
} }
@@ -51,13 +56,17 @@ impl PartialEq for NameServerStats {
let other_failures = other.failures.load(atomic::Ordering::Acquire); let other_failures = other.failures.load(atomic::Ordering::Acquire);
// if they are literally equal, just return // if they are literally equal, just return
Self::noload_eq(self_successes, other_successes, self_failures, other_failures) Self::noload_eq(
self_successes,
other_successes,
self_failures,
other_failures,
)
} }
} }
impl Eq for NameServerStats {} impl Eq for NameServerStats {}
impl Ord for NameServerStats { impl Ord for NameServerStats {
/// Custom implementation of Ord for NameServer which incorporates the performance of the connection into it's ranking /// Custom implementation of Ord for NameServer which incorporates the performance of the connection into it's ranking
fn cmp(&self, other: &Self) -> Ordering { fn cmp(&self, other: &Self) -> Ordering {
@@ -68,7 +77,12 @@ impl Ord for NameServerStats {
let other_failures = other.failures.load(atomic::Ordering::Acquire); let other_failures = other.failures.load(atomic::Ordering::Acquire);
// if they are literally equal, just return // if they are literally equal, just return
if Self::noload_eq(self_successes, other_successes, self_failures, other_failures) { if Self::noload_eq(
self_successes,
other_successes,
self_failures,
other_failures,
) {
return Ordering::Equal; return Ordering::Equal;
} }
@@ -107,9 +121,9 @@ mod tests {
#[test] #[test]
fn test_state_cmp() { fn test_state_cmp() {
let nil = NameServerStats::new(0,0); let nil = NameServerStats::new(0, 0);
let successes = NameServerStats::new(1,0); let successes = NameServerStats::new(1, 0);
let failures = NameServerStats::new(0,1); let failures = NameServerStats::new(0, 1);
assert_eq!(nil.cmp(&nil), Ordering::Equal); assert_eq!(nil.cmp(&nil), Ordering::Equal);
assert_eq!(nil.cmp(&successes), Ordering::Greater); assert_eq!(nil.cmp(&successes), Ordering::Greater);

View File

@@ -126,7 +126,7 @@ mod tests {
fn tests_dir() -> String { fn tests_dir() -> String {
let server_path = env::var("TDNS_SERVER_SRC_ROOT").unwrap_or_else(|_| ".".to_owned()); let server_path = env::var("TDNS_SERVER_SRC_ROOT").unwrap_or_else(|_| ".".to_owned());
format!{"{}/../resolver/tests", server_path} format! {"{}/../resolver/tests", server_path}
} }
#[test] #[test]

View File

@@ -13,9 +13,9 @@ use std::pin::Pin;
use futures::Future; use futures::Future;
use trust_dns_native_tls::{TlsClientStream, TlsClientStreamBuilder};
use proto::error::ProtoError; use proto::error::ProtoError;
use proto::BufDnsStreamHandle; use proto::BufDnsStreamHandle;
use trust_dns_native_tls::{TlsClientStream, TlsClientStreamBuilder};
pub(crate) fn new_tls_stream( pub(crate) fn new_tls_stream(
socket_addr: SocketAddr, socket_addr: SocketAddr,

View File

@@ -13,9 +13,9 @@ use std::pin::Pin;
use futures::Future; use futures::Future;
use trust_dns_openssl::{TlsClientStream, TlsClientStreamBuilder};
use proto::error::ProtoError; use proto::error::ProtoError;
use proto::BufDnsStreamHandle; use proto::BufDnsStreamHandle;
use trust_dns_openssl::{TlsClientStream, TlsClientStreamBuilder};
pub(crate) fn new_tls_stream( pub(crate) fn new_tls_stream(
socket_addr: SocketAddr, socket_addr: SocketAddr,

View File

@@ -12,8 +12,8 @@ extern crate rustls;
extern crate webpki_roots; extern crate webpki_roots;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use self::rustls::{ClientConfig, ProtocolVersion, RootCertStore}; use self::rustls::{ClientConfig, ProtocolVersion, RootCertStore};
use futures::Future; use futures::Future;

View File

@@ -71,5 +71,4 @@ mod tests {
fn test_quad9_tls() { fn test_quad9_tls() {
tls_test(ResolverConfig::quad9_tls()) tls_test(ResolverConfig::quad9_tls())
} }
} }

View File

@@ -20,10 +20,10 @@ extern crate futures;
extern crate rustls; extern crate rustls;
#[cfg(test)] #[cfg(test)]
extern crate tokio; extern crate tokio;
extern crate trust_dns_proto;
extern crate tokio_io; extern crate tokio_io;
extern crate tokio_rustls;
extern crate tokio_net; extern crate tokio_net;
extern crate tokio_rustls;
extern crate trust_dns_proto;
extern crate webpki; extern crate webpki;
#[macro_use] #[macro_use]
extern crate log; extern crate log;

View File

@@ -84,7 +84,8 @@ fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) {
} }
panic!("timeout"); panic!("timeout");
}).unwrap(); })
.unwrap();
let server_path = env::var("TDNS_SERVER_SRC_ROOT").unwrap_or_else(|_| "../server".to_owned()); let server_path = env::var("TDNS_SERVER_SRC_ROOT").unwrap_or_else(|_| "../server".to_owned());
println!("using server src path: {}", server_path); println!("using server src path: {}", server_path);
@@ -186,7 +187,8 @@ fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) {
// println!("wrote bytes iter: {}", i); // println!("wrote bytes iter: {}", i);
std::thread::yield_now(); std::thread::yield_now();
} }
}).unwrap(); })
.unwrap();
// let the server go first // let the server go first
std::thread::yield_now(); std::thread::yield_now();
@@ -201,7 +203,10 @@ fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) {
let trust_chain = Certificate(root_cert_der); let trust_chain = Certificate(root_cert_der);
let mut config = ClientConfig::new(); let mut config = ClientConfig::new();
config.root_store.add(&trust_chain).expect("bad certificate!"); config
.root_store
.add(&trust_chain)
.expect("bad certificate!");
// barrier.wait(); // barrier.wait();
// fix MTLS // fix MTLS
@@ -219,10 +224,11 @@ fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) {
sender sender
.unbounded_send(SerialMessage::new(TEST_BYTES.to_vec(), server_addr)) .unbounded_send(SerialMessage::new(TEST_BYTES.to_vec(), server_addr))
.expect("send failed"); .expect("send failed");
let (buffer, stream_tmp) = io_loop let (buffer, stream_tmp) = io_loop.block_on(stream.into_future());
.block_on(stream.into_future());
stream = stream_tmp; stream = stream_tmp;
let message = buffer.expect("no buffer received").expect("error receiving buffer"); let message = buffer
.expect("no buffer received")
.expect("error receiving buffer");
assert_eq!(message.bytes(), TEST_BYTES); assert_eq!(message.bytes(), TEST_BYTES);
} }

View File

@@ -6,8 +6,8 @@
// copied, modified, or distributed except according to those terms. // copied, modified, or distributed except according to those terms.
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use futures::{Future, TryFutureExt}; use futures::{Future, TryFutureExt};
use rustls::ClientConfig; use rustls::ClientConfig;

View File

@@ -7,15 +7,15 @@
use std::io; use std::io;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use futures::channel::mpsc::{unbounded, UnboundedReceiver}; use futures::channel::mpsc::{unbounded, UnboundedReceiver};
use futures::{Future, TryFutureExt}; use futures::{Future, TryFutureExt};
use rustls::ClientConfig; use rustls::ClientConfig;
use tokio_io; use tokio_io;
use tokio_rustls::TlsConnector;
use tokio_net::tcp::TcpStream as TokioTcpStream; use tokio_net::tcp::TcpStream as TokioTcpStream;
use tokio_rustls::TlsConnector;
use webpki::{DNSName, DNSNameRef}; use webpki::{DNSName, DNSNameRef};
use trust_dns_proto::tcp::TcpStream; use trust_dns_proto::tcp::TcpStream;
@@ -29,7 +29,10 @@ pub type TlsStream<S> = TcpStream<S>;
/// Initializes a TlsStream with an existing tokio_tls::TlsStream. /// Initializes a TlsStream with an existing tokio_tls::TlsStream.
/// ///
/// This is intended for use with a TlsListener and Incoming connections /// This is intended for use with a TlsListener and Incoming connections
pub fn tls_from_stream<S: tokio_io::AsyncRead + tokio_io::AsyncWrite>(stream: S, peer_addr: SocketAddr) -> (TlsStream<S>, BufStreamHandle) { pub fn tls_from_stream<S: tokio_io::AsyncRead + tokio_io::AsyncWrite>(
stream: S,
peer_addr: SocketAddr,
) -> (TlsStream<S>, BufStreamHandle) {
let (message_sender, outbound_messages) = unbounded(); let (message_sender, outbound_messages) = unbounded();
let message_sender = BufStreamHandle::new(message_sender); let message_sender = BufStreamHandle::new(message_sender);
@@ -75,26 +78,44 @@ pub fn tls_connect(
let message_sender = BufStreamHandle::new(message_sender); let message_sender = BufStreamHandle::new(message_sender);
let tls_connector = TlsConnector::from(client_config); let tls_connector = TlsConnector::from(client_config);
// This set of futures collapses the next tcp socket into a stream which can be used for // This set of futures collapses the next tcp socket into a stream which can be used for
// sending and receiving tcp packets. // sending and receiving tcp packets.
let stream = Box::pin(connect_tls(tls_connector, name_server, dns_name, outbound_messages)); let stream = Box::pin(connect_tls(
tls_connector,
name_server,
dns_name,
outbound_messages,
));
(stream, message_sender) (stream, message_sender)
} }
async fn connect_tls(tls_connector: TlsConnector, name_server: SocketAddr, dns_name: String, outbound_messages: UnboundedReceiver<SerialMessage>) -> io::Result<TcpStream<TokioTlsClientStream>> { async fn connect_tls(
tls_connector: TlsConnector,
name_server: SocketAddr,
dns_name: String,
outbound_messages: UnboundedReceiver<SerialMessage>,
) -> io::Result<TcpStream<TokioTlsClientStream>> {
let tcp = TokioTcpStream::connect(&name_server).await?; let tcp = TokioTcpStream::connect(&name_server).await?;
let dns_name = DNSNameRef::try_from_ascii_str(&dns_name).map(DNSName::from) let dns_name = DNSNameRef::try_from_ascii_str(&dns_name)
.map(DNSName::from)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "bad dns_name"))?; .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "bad dns_name"))?;
let s = tls_connector.connect(dns_name.as_ref(), tcp).map_err(|e| { let s = tls_connector
io::Error::new( .connect(dns_name.as_ref(), tcp)
io::ErrorKind::ConnectionRefused, .map_err(|e| {
io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("tls error: {}", e), format!("tls error: {}", e),
) )
}).await?; })
.await?;
Ok(TcpStream::from_stream_with_receiver(s, name_server, outbound_messages)) Ok(TcpStream::from_stream_with_receiver(
} s,
name_server,
outbound_messages,
))
}

View File

@@ -15,7 +15,6 @@ use crate::authority::LookupObject;
use crate::proto::rr::dnssec::SupportedAlgorithms; use crate::proto::rr::dnssec::SupportedAlgorithms;
use crate::proto::rr::{Record, RecordSet, RecordType, RrsetRecords}; use crate::proto::rr::{Record, RecordSet, RecordType, RrsetRecords};
/// The result of a lookup on an Authority /// The result of a lookup on an Authority
/// ///
/// # Lifetimes /// # Lifetimes

View File

@@ -77,7 +77,11 @@ pub trait Authority: Send {
) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>>; ) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>>;
/// Get the NS, NameServer, record for the zone /// Get the NS, NameServer, record for the zone
fn ns(&self, is_secure: bool, supported_algorithms: SupportedAlgorithms) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>> { fn ns(
&self,
is_secure: bool,
supported_algorithms: SupportedAlgorithms,
) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>> {
self.lookup( self.lookup(
self.origin(), self.origin(),
RecordType::NS, RecordType::NS,

View File

@@ -20,17 +20,19 @@ use std::borrow::Borrow;
use std::collections::HashMap; use std::collections::HashMap;
use std::io; use std::io;
use std::pin::Pin; use std::pin::Pin;
use std::task::Context;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::task::Context;
use futures::{Future, FutureExt, Poll, TryFutureExt, ready}; use futures::{ready, Future, FutureExt, Poll, TryFutureExt};
use trust_dns::op::{Edns, Header, LowerQuery, MessageType, OpCode, ResponseCode}; use trust_dns::op::{Edns, Header, LowerQuery, MessageType, OpCode, ResponseCode};
use trust_dns::rr::dnssec::{Algorithm, SupportedAlgorithms}; use trust_dns::rr::dnssec::{Algorithm, SupportedAlgorithms};
use trust_dns::rr::rdata::opt::{EdnsCode, EdnsOption}; use trust_dns::rr::rdata::opt::{EdnsCode, EdnsOption};
use trust_dns::rr::{LowerName, RecordType}; use trust_dns::rr::{LowerName, RecordType};
use crate::authority::{AuthLookup, MessageRequest, MessageResponse, MessageResponseBuilder, ZoneType}; use crate::authority::{
AuthLookup, MessageRequest, MessageResponse, MessageResponseBuilder, ZoneType,
};
use crate::authority::{AuthorityObject, BoxedLookupFuture, LookupError, LookupObject}; use crate::authority::{AuthorityObject, BoxedLookupFuture, LookupError, LookupObject};
use crate::server::{Request, RequestHandler, ResponseHandler}; use crate::server::{Request, RequestHandler, ResponseHandler};
@@ -173,7 +175,8 @@ pub enum HandleRequest {
impl HandleRequest { impl HandleRequest {
fn lookup<R: ResponseHandler + Unpin>(lookup_future: LookupFuture<R>) -> Self { fn lookup<R: ResponseHandler + Unpin>(lookup_future: LookupFuture<R>) -> Self {
let lookup = Box::pin(lookup_future) as Pin<Box<dyn Future<Output = Result<(), ()>> + Send>>; let lookup =
Box::pin(lookup_future) as Pin<Box<dyn Future<Output = Result<(), ()>> + Send>>;
HandleRequest::LookupFuture(lookup) HandleRequest::LookupFuture(lookup)
} }
@@ -184,7 +187,7 @@ impl HandleRequest {
impl Future for HandleRequest { impl Future for HandleRequest {
// TODO: return () // TODO: return ()
type Output = Result<(),()>; type Output = Result<(), ()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match *self { match *self {
@@ -478,7 +481,7 @@ impl<R: ResponseHandler + Unpin> LookupFuture<R> {
impl<R: ResponseHandler + Unpin> Future for LookupFuture<R> { impl<R: ResponseHandler + Unpin> Future for LookupFuture<R> {
// TODO: return () // TODO: return ()
type Output = Result<(),()>; type Output = Result<(), ()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
loop { loop {
@@ -646,14 +649,18 @@ impl<R: ResponseHandler> AuthorityLookup<R> {
} }
impl<R: ResponseHandler> AuthorityLookup<R> { impl<R: ResponseHandler> AuthorityLookup<R> {
fn split(&mut self) -> (&mut ResponseParams<R>, fn split(
&RequestParams, &mut self,
&Arc<RwLock<Box<dyn AuthorityObject>>>, ) -> (
&mut AuthOrResolve) { &mut ResponseParams<R>,
&RequestParams,
&Arc<RwLock<Box<dyn AuthorityObject>>>,
&mut AuthOrResolve,
) {
( (
self.response_params self.response_params
.as_mut() .as_mut()
.expect("bad state, response_params should not be none here"), .expect("bad state, response_params should not be none here"),
&self.request_params, &self.request_params,
&self.authority, &self.authority,
&mut self.state, &mut self.state,
@@ -663,17 +670,12 @@ impl<R: ResponseHandler> AuthorityLookup<R> {
impl<R: ResponseHandler> Future for AuthorityLookup<R> { impl<R: ResponseHandler> Future for AuthorityLookup<R> {
// TODO: return () // TODO: return ()
type Output = Result<(),()>; type Output = Result<(), ()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let (response_params, request_params, authority, state) = self.split(); let (response_params, request_params, authority, state) = self.split();
let sections = ready!(state.poll( let sections = ready!(state.poll(cx, request_params, response_params, authority))?;
cx,
request_params,
response_params,
authority
))?;
let records = sections.answers; let records = sections.answers;
let soa = sections.soa; let soa = sections.soa;

Some files were not shown because too many files have changed in this diff Show More