Move to async-trait for server to simplify function calls

This commit is contained in:
Benjamin Fry
2021-09-05 11:00:16 -07:00
parent 6dde9938d9
commit b8ad0d68ca
27 changed files with 631 additions and 669 deletions

1
Cargo.lock generated
View File

@@ -1676,6 +1676,7 @@ dependencies = [
name = "trust-dns-integration" name = "trust-dns-integration"
version = "0.21.0-alpha.2" version = "0.21.0-alpha.2"
dependencies = [ dependencies = [
"async-trait",
"chrono", "chrono",
"env_logger", "env_logger",
"futures", "futures",

View File

@@ -40,46 +40,48 @@ extern crate clap;
#[macro_use] #[macro_use]
extern crate log; extern crate log;
#[cfg(feature = "dnssec")] use std::{
use std::future::Future; net::{IpAddr, Ipv4Addr, SocketAddr, ToSocketAddrs},
use std::net::{IpAddr, Ipv4Addr, SocketAddr, ToSocketAddrs}; path::{Path, PathBuf},
use std::path::{Path, PathBuf}; sync::Arc,
use std::sync::{Arc, RwLock}; };
use clap::{Arg, ArgMatches}; use clap::{Arg, ArgMatches};
use tokio::net::TcpListener; use futures::lock::Mutex;
use tokio::net::UdpSocket; use tokio::{
use tokio::runtime::{self, Runtime}; net::{TcpListener, UdpSocket},
runtime::{self, Runtime},
};
use trust_dns_client::rr::Name; use trust_dns_client::rr::Name;
use trust_dns_server::authority::{AuthorityObject, Catalog, ZoneType};
#[cfg(feature = "dns-over-tls")] #[cfg(feature = "dns-over-tls")]
use trust_dns_server::config::dnssec::{self, TlsCertConfig}; use trust_dns_server::config::dnssec::{self, TlsCertConfig};
use trust_dns_server::config::{Config, ZoneConfig};
use trust_dns_server::logger;
use trust_dns_server::server::ServerFuture;
use trust_dns_server::store::file::{FileAuthority, FileConfig};
#[cfg(feature = "resolver")] #[cfg(feature = "resolver")]
use trust_dns_server::store::forwarder::ForwardAuthority; use trust_dns_server::store::forwarder::ForwardAuthority;
#[cfg(feature = "sqlite")] #[cfg(feature = "sqlite")]
use trust_dns_server::store::sqlite::{SqliteAuthority, SqliteConfig}; use trust_dns_server::store::sqlite::{SqliteAuthority, SqliteConfig};
use trust_dns_server::store::StoreConfig; use trust_dns_server::{
#[cfg(feature = "dnssec")] authority::{AuthorityObject, Catalog, ZoneType},
use { config::{Config, ZoneConfig},
trust_dns_client::rr::rdata::key::KeyUsage, trust_dns_server::authority::DnssecAuthority, store::{
trust_dns_server::authority::LookupError, file::{FileAuthority, FileConfig},
StoreConfig,
},
}; };
use trust_dns_server::{logger, server::ServerFuture};
#[cfg(feature = "dnssec")] #[cfg(feature = "dnssec")]
fn load_keys<A, L, LF>( use {trust_dns_client::rr::rdata::key::KeyUsage, trust_dns_server::authority::DnssecAuthority};
#[cfg(feature = "dnssec")]
fn load_keys<A, L>(
authority: &mut A, authority: &mut A,
zone_name: Name, zone_name: Name,
zone_config: &ZoneConfig, zone_config: &ZoneConfig,
) -> Result<(), String> ) -> Result<(), String>
where where
A: DnssecAuthority<Lookup = L, LookupFuture = LF>, A: DnssecAuthority<Lookup = L>,
L: Send + Sized + 'static, L: Send + Sized + 'static,
LF: Future<Output = Result<L, LookupError>> + Send,
{ {
if zone_config.is_dnssec_enabled() { if zone_config.is_dnssec_enabled() {
for key_config in zone_config.get_keys() { for key_config in zone_config.get_keys() {
@@ -168,7 +170,7 @@ fn load_zone(
// load any keys for the Zone, if it is a dynamic update zone, then keys are required // load any keys for the Zone, if it is a dynamic update zone, then keys are required
load_keys(&mut authority, zone_name_for_signer, zone_config)?; load_keys(&mut authority, zone_name_for_signer, zone_config)?;
Box::new(Arc::new(RwLock::new(authority))) Box::new(Arc::new(Mutex::new(authority)))
} }
Some(StoreConfig::File(ref config)) => { Some(StoreConfig::File(ref config)) => {
if zone_path.is_some() { if zone_path.is_some() {
@@ -185,14 +187,14 @@ fn load_zone(
// load any keys for the Zone, if it is a dynamic update zone, then keys are required // load any keys for the Zone, if it is a dynamic update zone, then keys are required
load_keys(&mut authority, zone_name_for_signer, zone_config)?; load_keys(&mut authority, zone_name_for_signer, zone_config)?;
Box::new(Arc::new(RwLock::new(authority))) Box::new(Arc::new(Mutex::new(authority)))
} }
#[cfg(feature = "resolver")] #[cfg(feature = "resolver")]
Some(StoreConfig::Forward(ref config)) => { Some(StoreConfig::Forward(ref config)) => {
let forwarder = ForwardAuthority::try_from_config(zone_name, zone_type, config); let forwarder = ForwardAuthority::try_from_config(zone_name, zone_type, config);
let authority = runtime.block_on(forwarder)?; let authority = runtime.block_on(forwarder)?;
Box::new(Arc::new(RwLock::new(authority))) Box::new(Arc::new(Mutex::new(authority)))
} }
#[cfg(feature = "sqlite")] #[cfg(feature = "sqlite")]
None if zone_config.is_update_allowed() => { None if zone_config.is_update_allowed() => {
@@ -223,7 +225,7 @@ fn load_zone(
// load any keys for the Zone, if it is a dynamic update zone, then keys are required // load any keys for the Zone, if it is a dynamic update zone, then keys are required
load_keys(&mut authority, zone_name_for_signer, zone_config)?; load_keys(&mut authority, zone_name_for_signer, zone_config)?;
Box::new(Arc::new(RwLock::new(authority))) Box::new(Arc::new(Mutex::new(authority)))
} }
None => { None => {
let config = FileConfig { let config = FileConfig {
@@ -240,7 +242,7 @@ fn load_zone(
// load any keys for the Zone, if it is a dynamic update zone, then keys are required // load any keys for the Zone, if it is a dynamic update zone, then keys are required
load_keys(&mut authority, zone_name_for_signer, zone_config)?; load_keys(&mut authority, zone_name_for_signer, zone_config)?;
Box::new(Arc::new(RwLock::new(authority))) Box::new(Arc::new(Mutex::new(authority)))
} }
Some(_) => { Some(_) => {
panic!("unrecognized authority type, check enabled features"); panic!("unrecognized authority type, check enabled features");

View File

@@ -65,7 +65,7 @@ path = "src/lib.rs"
[dependencies] [dependencies]
async-std = "1.6" async-std = "1.6"
async-trait = "0.1.36" async-trait = "0.1.42"
futures-io = { version = "0.3.5", default-features = false, features = ["std"] } futures-io = { version = "0.3.5", default-features = false, features = ["std"] }
futures-util = { version = "0.3.5", default-features = false, features = ["std"] } futures-util = { version = "0.3.5", default-features = false, features = ["std"] }
pin-utils = "0.1.0" pin-utils = "0.1.0"

View File

@@ -68,7 +68,7 @@ name = "trust_dns_proto"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
async-trait = "0.1.36" async-trait = "0.1.42"
backtrace = { version = "0.3.50", optional = true } backtrace = { version = "0.3.50", optional = true }
bytes = { version = "1", optional = true } bytes = { version = "1", optional = true }
cfg-if = "1" cfg-if = "1"

View File

@@ -6,21 +6,23 @@
// copied, modified, or distributed except according to those terms. // copied, modified, or distributed except according to those terms.
//! All authority related types //! All authority related types
use std::future::Future;
use std::pin::Pin;
use cfg_if::cfg_if; use cfg_if::cfg_if;
use crate::authority::{LookupError, MessageRequest, UpdateResult, ZoneType};
use crate::client::op::LowerQuery;
use crate::client::rr::{LowerName, RecordSet, RecordType};
#[cfg(feature = "dnssec")] #[cfg(feature = "dnssec")]
use crate::client::{ use crate::client::{
proto::rr::dnssec::rdata::key::KEY, proto::rr::dnssec::rdata::key::KEY,
rr::dnssec::{DnsSecResult, SigSigner, SupportedAlgorithms}, rr::dnssec::{DnsSecResult, SigSigner, SupportedAlgorithms},
rr::Name, rr::Name,
}; };
use crate::proto::rr::RrsetRecords; use crate::{
authority::{LookupError, MessageRequest, UpdateResult, ZoneType},
client::{
op::LowerQuery,
rr::{LowerName, RecordSet, RecordType},
},
proto::rr::RrsetRecords,
};
/// LookupOptions that specify different options from the client to include or exclude various records in the response. /// LookupOptions that specify different options from the client to include or exclude various records in the response.
/// ///
@@ -95,11 +97,10 @@ impl LookupOptions {
} }
/// Authority implementations can be used with a `Catalog` /// Authority implementations can be used with a `Catalog`
pub trait Authority: Send { #[async_trait::async_trait]
pub trait Authority: Send + Sync {
/// Result of a lookup /// Result of a lookup
type Lookup: Send + Sized + 'static; type Lookup: Send + Sync + Sized + 'static;
/// The future type that will resolve to a Lookup
type LookupFuture: Future<Output = Result<Self::Lookup, LookupError>> + Send;
/// What type is this zone /// What type is this zone
fn zone_type(&self) -> ZoneType; fn zone_type(&self) -> ZoneType;
@@ -108,7 +109,7 @@ pub trait Authority: Send {
fn is_axfr_allowed(&self) -> bool; fn is_axfr_allowed(&self) -> bool;
/// Perform a dynamic update of a zone /// Perform a dynamic update of a zone
fn update(&mut self, update: &MessageRequest) -> UpdateResult<bool>; async fn update(&mut self, update: &MessageRequest) -> UpdateResult<bool>;
/// Get the origin of this zone, i.e. example.com is the origin for www.example.com /// Get the origin of this zone, i.e. example.com is the origin for www.example.com
fn origin(&self) -> &LowerName; fn origin(&self) -> &LowerName;
@@ -127,12 +128,12 @@ pub trait Authority: Send {
/// # Return value /// # Return value
/// ///
/// None if there are no matching records, otherwise a `Vec` containing the found records. /// None if there are no matching records, otherwise a `Vec` containing the found records.
fn lookup( async fn lookup(
&self, &self,
name: &LowerName, name: &LowerName,
rtype: RecordType, rtype: RecordType,
lookup_options: LookupOptions, lookup_options: LookupOptions,
) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>>; ) -> Result<Self::Lookup, LookupError>;
/// Using the specified query, perform a lookup against this zone. /// Using the specified query, perform a lookup against this zone.
/// ///
@@ -145,18 +146,16 @@ pub trait Authority: Send {
/// ///
/// Returns a vectory containing the results of the query, it will be empty if not found. If /// Returns a vectory containing the results of the query, it will be empty if not found. If
/// `is_secure` is true, in the case of no records found then NSEC records will be returned. /// `is_secure` is true, in the case of no records found then NSEC records will be returned.
fn search( async fn search(
&self, &self,
query: &LowerQuery, query: &LowerQuery,
lookup_options: LookupOptions, lookup_options: LookupOptions,
) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>>; ) -> Result<Self::Lookup, LookupError>;
/// Get the NS, NameServer, record for the zone /// Get the NS, NameServer, record for the zone
fn ns( async fn ns(&self, lookup_options: LookupOptions) -> Result<Self::Lookup, LookupError> {
&self,
lookup_options: LookupOptions,
) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>> {
self.lookup(self.origin(), RecordType::NS, lookup_options) self.lookup(self.origin(), RecordType::NS, lookup_options)
.await
} }
/// Return the NSEC records based on the given name /// Return the NSEC records based on the given name
@@ -166,27 +165,26 @@ pub trait Authority: Send {
/// * `name` - given this name (i.e. the lookup name), return the NSEC record that is less than /// * `name` - given this name (i.e. the lookup name), return the NSEC record that is less than
/// this /// this
/// * `is_secure` - if true then it will return RRSIG records as well /// * `is_secure` - if true then it will return RRSIG records as well
fn get_nsec_records( async fn get_nsec_records(
&self, &self,
name: &LowerName, name: &LowerName,
lookup_options: LookupOptions, lookup_options: LookupOptions,
) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>>; ) -> Result<Self::Lookup, LookupError>;
/// Returns the SOA of the authority. /// Returns the SOA of the authority.
/// ///
/// *Note*: This will only return the SOA, if this is fulfilling a request, a standard lookup /// *Note*: This will only return the SOA, if this is fulfilling a request, a standard lookup
/// should be used, see `soa_secure()`, which will optionally return RRSIGs. /// should be used, see `soa_secure()`, which will optionally return RRSIGs.
fn soa(&self) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>> { async fn soa(&self) -> Result<Self::Lookup, LookupError> {
// SOA should be origin|SOA // SOA should be origin|SOA
self.lookup(self.origin(), RecordType::SOA, LookupOptions::default()) self.lookup(self.origin(), RecordType::SOA, LookupOptions::default())
.await
} }
/// Returns the SOA record for the zone /// Returns the SOA record for the zone
fn soa_secure( async fn soa_secure(&self, lookup_options: LookupOptions) -> Result<Self::Lookup, LookupError> {
&self,
lookup_options: LookupOptions,
) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>> {
self.lookup(self.origin(), RecordType::SOA, lookup_options) self.lookup(self.origin(), RecordType::SOA, lookup_options)
.await
} }
} }

View File

@@ -1,4 +1,4 @@
// Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com> // Copyright 2015-2021 Benjamin Fry <benjaminfry@me.com>
// //
// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or // Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or // http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
@@ -7,36 +7,36 @@
//! All authority related types //! All authority related types
use std::future::Future; use std::sync::Arc;
use std::pin::Pin;
use std::sync::{Arc, RwLock};
use std::task::{Context, Poll};
use futures_util::{future, TryFutureExt}; use futures_util::lock::Mutex;
use log::debug; use log::debug;
use crate::authority::{ use crate::{
Authority, LookupError, LookupOptions, MessageRequest, UpdateResult, ZoneType, authority::{Authority, LookupError, LookupOptions, MessageRequest, UpdateResult, ZoneType},
client::{
op::LowerQuery,
rr::{LowerName, Record, RecordType},
},
}; };
use crate::client::op::LowerQuery;
use crate::client::rr::{LowerName, Record, RecordType};
/// An Object safe Authority /// An Object safe Authority
#[async_trait::async_trait]
pub trait AuthorityObject: Send + Sync { pub trait AuthorityObject: Send + Sync {
/// Clone the object /// Clone the object
fn box_clone(&self) -> Box<dyn AuthorityObject>; fn box_clone(&self) -> Box<dyn AuthorityObject>;
/// What type is this zone /// What type is this zone
fn zone_type(&self) -> ZoneType; async fn zone_type(&self) -> ZoneType;
/// Return true if AXFR is allowed /// Return true if AXFR is allowed
fn is_axfr_allowed(&self) -> bool; async fn is_axfr_allowed(&self) -> bool;
/// Perform a dynamic update of a zone /// Perform a dynamic update of a zone
fn update(&self, update: &MessageRequest) -> UpdateResult<bool>; async fn update(&self, update: &MessageRequest) -> UpdateResult<bool>;
/// Get the origin of this zone, i.e. example.com is the origin for www.example.com /// Get the origin of this zone, i.e. example.com is the origin for www.example.com
fn origin(&self) -> LowerName; async fn origin(&self) -> LowerName;
/// Looks up all Resource Records matching the giving `Name` and `RecordType`. /// Looks up all Resource Records matching the giving `Name` and `RecordType`.
/// ///
@@ -52,12 +52,12 @@ pub trait AuthorityObject: Send + Sync {
/// # Return value /// # Return value
/// ///
/// None if there are no matching records, otherwise a `Vec` containing the found records. /// None if there are no matching records, otherwise a `Vec` containing the found records.
fn lookup( async fn lookup(
&self, &self,
name: &LowerName, name: &LowerName,
rtype: RecordType, rtype: RecordType,
lookup_options: LookupOptions, lookup_options: LookupOptions,
) -> BoxedLookupFuture; ) -> Result<Box<dyn LookupObject>, LookupError>;
/// Using the specified query, perform a lookup against this zone. /// Using the specified query, perform a lookup against this zone.
/// ///
@@ -70,11 +70,19 @@ pub trait AuthorityObject: Send + Sync {
/// ///
/// Returns a vectory containing the results of the query, it will be empty if not found. If /// Returns a vectory containing the results of the query, it will be empty if not found. If
/// `is_secure` is true, in the case of no records found then NSEC records will be returned. /// `is_secure` is true, in the case of no records found then NSEC records will be returned.
fn search(&self, query: &LowerQuery, lookup_options: LookupOptions) -> BoxedLookupFuture; async fn search(
&self,
query: &LowerQuery,
lookup_options: LookupOptions,
) -> Result<Box<dyn LookupObject>, LookupError>;
/// Get the NS, NameServer, record for the zone /// Get the NS, NameServer, record for the zone
fn ns(&self, lookup_options: LookupOptions) -> BoxedLookupFuture { async fn ns(
self.lookup(&self.origin(), RecordType::NS, lookup_options) &self,
lookup_options: LookupOptions,
) -> Result<Box<dyn LookupObject>, LookupError> {
self.lookup(&self.origin().await, RecordType::NS, lookup_options)
.await
} }
/// Return the NSEC records based on the given name /// Return the NSEC records based on the given name
@@ -84,55 +92,64 @@ pub trait AuthorityObject: Send + Sync {
/// * `name` - given this name (i.e. the lookup name), return the NSEC record that is less than /// * `name` - given this name (i.e. the lookup name), return the NSEC record that is less than
/// this /// this
/// * `is_secure` - if true then it will return RRSIG records as well /// * `is_secure` - if true then it will return RRSIG records as well
fn get_nsec_records( async fn get_nsec_records(
&self, &self,
name: &LowerName, name: &LowerName,
lookup_options: LookupOptions, lookup_options: LookupOptions,
) -> BoxedLookupFuture; ) -> Result<Box<dyn LookupObject>, LookupError>;
/// Returns the SOA of the authority. /// Returns the SOA of the authority.
/// ///
/// *Note*: This will only return the SOA, if this is fulfilling a request, a standard lookup /// *Note*: This will only return the SOA, if this is fulfilling a request, a standard lookup
/// should be used, see `soa_secure()`, which will optionally return RRSIGs. /// should be used, see `soa_secure()`, which will optionally return RRSIGs.
fn soa(&self) -> BoxedLookupFuture { async fn soa(&self) -> Result<Box<dyn LookupObject>, LookupError> {
// SOA should be origin|SOA // SOA should be origin|SOA
self.lookup(&self.origin(), RecordType::SOA, LookupOptions::default()) self.lookup(
&self.origin().await,
RecordType::SOA,
LookupOptions::default(),
)
.await
} }
/// Returns the SOA record for the zone /// Returns the SOA record for the zone
fn soa_secure(&self, lookup_options: LookupOptions) -> BoxedLookupFuture { async fn soa_secure(
self.lookup(&self.origin(), RecordType::SOA, lookup_options) &self,
lookup_options: LookupOptions,
) -> Result<Box<dyn LookupObject>, LookupError> {
self.lookup(&self.origin().await, RecordType::SOA, lookup_options)
.await
} }
} }
impl<A, L> AuthorityObject for Arc<RwLock<A>> #[async_trait::async_trait]
impl<A, L> AuthorityObject for Arc<Mutex<A>>
where where
A: Authority<Lookup = L> + Send + Sync + 'static, A: Authority<Lookup = L> + Send + Sync + 'static,
A::LookupFuture: Send + 'static, L: LookupObject + Send + Sync + 'static,
L: LookupObject + Send + 'static,
{ {
fn box_clone(&self) -> Box<dyn AuthorityObject> { fn box_clone(&self) -> Box<dyn AuthorityObject> {
Box::new(self.clone()) Box::new(self.clone())
} }
/// What type is this zone /// What type is this zone
fn zone_type(&self) -> ZoneType { async fn zone_type(&self) -> ZoneType {
Authority::zone_type(&*self.read().expect("poisoned")) Authority::zone_type(&*self.lock().await)
} }
/// Return true if AXFR is allowed /// Return true if AXFR is allowed
fn is_axfr_allowed(&self) -> bool { async fn is_axfr_allowed(&self) -> bool {
Authority::is_axfr_allowed(&*self.read().expect("poisoned")) Authority::is_axfr_allowed(&*self.lock().await)
} }
/// Perform a dynamic update of a zone /// Perform a dynamic update of a zone
fn update(&self, update: &MessageRequest) -> UpdateResult<bool> { async fn update(&self, update: &MessageRequest) -> UpdateResult<bool> {
Authority::update(&mut *self.write().expect("poisoned"), update) Authority::update(&mut *self.lock().await, update).await
} }
/// Get the origin of this zone, i.e. example.com is the origin for www.example.com /// Get the origin of this zone, i.e. example.com is the origin for www.example.com
fn origin(&self) -> LowerName { async fn origin(&self) -> LowerName {
Authority::origin(&*self.read().expect("poisoned")).clone() Authority::origin(&*self.lock().await).clone()
} }
/// Looks up all Resource Records matching the giving `Name` and `RecordType`. /// Looks up all Resource Records matching the giving `Name` and `RecordType`.
@@ -149,15 +166,15 @@ where
/// # Return value /// # Return value
/// ///
/// None if there are no matching records, otherwise a `Vec` containing the found records. /// None if there are no matching records, otherwise a `Vec` containing the found records.
fn lookup( async fn lookup(
&self, &self,
name: &LowerName, name: &LowerName,
rtype: RecordType, rtype: RecordType,
lookup_options: LookupOptions, lookup_options: LookupOptions,
) -> BoxedLookupFuture { ) -> Result<Box<dyn LookupObject>, LookupError> {
let this = self.read().expect("poisoned"); let this = self.lock().await;
let lookup = Authority::lookup(&*this, name, rtype, lookup_options); let lookup = Authority::lookup(&*this, name, rtype, lookup_options).await;
BoxedLookupFuture::from(lookup.map_ok(|l| Box::new(l) as Box<dyn LookupObject>)) lookup.map(|l| Box::new(l) as Box<dyn LookupObject>)
} }
/// Using the specified query, perform a lookup against this zone. /// Using the specified query, perform a lookup against this zone.
@@ -171,11 +188,15 @@ where
/// ///
/// Returns a vectory containing the results of the query, it will be empty if not found. If /// Returns a vectory containing the results of the query, it will be empty if not found. If
/// `is_secure` is true, in the case of no records found then NSEC records will be returned. /// `is_secure` is true, in the case of no records found then NSEC records will be returned.
fn search(&self, query: &LowerQuery, lookup_options: LookupOptions) -> BoxedLookupFuture { async fn search(
let this = self.read().expect("poisoned"); &self,
query: &LowerQuery,
lookup_options: LookupOptions,
) -> Result<Box<dyn LookupObject>, LookupError> {
let this = self.lock().await;
debug!("performing {} on {}", query, this.origin()); debug!("performing {} on {}", query, this.origin());
let lookup = Authority::search(&*this, query, lookup_options); let lookup = Authority::search(&*this, query, lookup_options).await;
BoxedLookupFuture::from(lookup.map_ok(|l| Box::new(l) as Box<dyn LookupObject>)) lookup.map(|l| Box::new(l) as Box<dyn LookupObject>)
} }
/// Return the NSEC records based on the given name /// Return the NSEC records based on the given name
@@ -185,14 +206,13 @@ where
/// * `name` - given this name (i.e. the lookup name), return the NSEC record that is less than /// * `name` - given this name (i.e. the lookup name), return the NSEC record that is less than
/// this /// this
/// * `is_secure` - if true then it will return RRSIG records as well /// * `is_secure` - if true then it will return RRSIG records as well
fn get_nsec_records( async fn get_nsec_records(
&self, &self,
name: &LowerName, name: &LowerName,
lookup_options: LookupOptions, lookup_options: LookupOptions,
) -> BoxedLookupFuture { ) -> Result<Box<dyn LookupObject>, LookupError> {
let lookup = let lookup = Authority::get_nsec_records(&*self.lock().await, name, lookup_options).await;
Authority::get_nsec_records(&*self.read().expect("poisoned"), name, lookup_options); lookup.map(|l| Box::new(l) as Box<dyn LookupObject>)
BoxedLookupFuture::from(lookup.map_ok(|l| Box::new(l) as Box<dyn LookupObject>))
} }
} }
@@ -227,34 +247,3 @@ impl LookupObject for EmptyLookup {
None None
} }
} }
/// A boxed lookup future
#[allow(clippy::type_complexity)]
pub struct BoxedLookupFuture(
Pin<Box<dyn Future<Output = Result<Box<dyn LookupObject>, LookupError>> + Send>>,
);
impl BoxedLookupFuture {
/// Performs a conversion (boxes) into the future
pub fn from<T>(future: T) -> Self
where
T: Future<Output = Result<Box<dyn LookupObject>, LookupError>> + Send + Sized + 'static,
{
BoxedLookupFuture(Box::pin(future))
}
/// Creates an empty (i.e. no records) lookup future
pub fn empty() -> Self {
BoxedLookupFuture(Box::pin(future::ok(
Box::new(EmptyLookup) as Box<dyn LookupObject>
)))
}
}
impl Future for BoxedLookupFuture {
type Output = Result<Box<dyn LookupObject>, LookupError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.as_mut().poll(cx)
}
}

View File

@@ -8,29 +8,27 @@
// TODO, I've implemented this as a separate entity from the cache, but I wonder if the cache // TODO, I've implemented this as a separate entity from the cache, but I wonder if the cache
// should be the only "front-end" for lookups, where if that misses, then we go to the catalog // should be the only "front-end" for lookups, where if that misses, then we go to the catalog
// then, if requested, do a recursive lookup... i.e. the catalog would only point to files. // then, if requested, do a recursive lookup... i.e. the catalog would only point to files.
use std::borrow::Borrow; use std::{borrow::Borrow, collections::HashMap, future::Future, io};
use std::collections::HashMap;
use std::future::Future;
use std::io;
use std::pin::Pin;
use cfg_if::cfg_if; use cfg_if::cfg_if;
use log::{debug, error, info, trace, warn}; use log::{debug, error, info, trace, warn};
use crate::authority::{
AuthLookup, MessageRequest, MessageResponse, MessageResponseBuilder, ZoneType,
};
use crate::authority::{
AuthorityObject, BoxedLookupFuture, EmptyLookup, LookupError, LookupObject, LookupOptions,
};
use crate::client::op::{Edns, Header, LowerQuery, MessageType, OpCode, ResponseCode};
#[cfg(feature = "dnssec")] #[cfg(feature = "dnssec")]
use crate::client::rr::{ use crate::client::rr::{
dnssec::{Algorithm, SupportedAlgorithms}, dnssec::{Algorithm, SupportedAlgorithms},
rdata::opt::{EdnsCode, EdnsOption}, rdata::opt::{EdnsCode, EdnsOption},
}; };
use crate::client::rr::{LowerName, RecordType}; use crate::{
use crate::server::{Request, RequestHandler, ResponseHandler}; authority::{
AuthLookup, AuthorityObject, EmptyLookup, LookupError, LookupObject, LookupOptions,
MessageRequest, MessageResponse, MessageResponseBuilder, ZoneType,
},
client::{
op::{Edns, Header, LowerQuery, MessageType, OpCode, ResponseCode},
rr::{LowerName, RecordType},
},
server::{Request, RequestHandler, ResponseHandler},
};
/// Set of authorities, zones, available to this server. /// Set of authorities, zones, available to this server.
#[derive(Default)] #[derive(Default)]
@@ -39,7 +37,7 @@ pub struct Catalog {
} }
#[allow(unused_mut, unused_variables)] #[allow(unused_mut, unused_variables)]
fn send_response<R: ResponseHandler>( async fn send_response<R: ResponseHandler>(
response_edns: Option<Edns>, response_edns: Option<Edns>,
mut response: MessageResponse<'_, '_>, mut response: MessageResponse<'_, '_>,
mut response_handle: R, mut response_handle: R,
@@ -63,23 +61,18 @@ fn send_response<R: ResponseHandler>(
response.set_edns(resp_edns); response.set_edns(resp_edns);
} }
response_handle.send_response(response) response_handle.send_response(response).await
} }
#[async_trait::async_trait]
impl RequestHandler for Catalog { impl RequestHandler for Catalog {
type ResponseFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
/// Determines what needs to happen given the type of request, i.e. Query or Update. /// Determines what needs to happen given the type of request, i.e. Query or Update.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `request` - the requested action to perform. /// * `request` - the requested action to perform.
/// * `response_handle` - sink for the response message to be sent /// * `response_handle` - sink for the response message to be sent
fn handle_request<R: ResponseHandler>( async fn handle_request<R: ResponseHandler>(&self, request: Request, mut response_handle: R) {
&self,
request: Request,
mut response_handle: R,
) -> Self::ResponseFuture {
let request_message = request.message; let request_message = request.message;
trace!("request: {:?}", request_message); trace!("request: {:?}", request_message);
@@ -109,12 +102,13 @@ impl RequestHandler for Catalog {
response.edns(resp_edns); response.edns(resp_edns);
// TODO: should ResponseHandle consume self? // TODO: should ResponseHandle consume self?
let result = let result = response_handle
response_handle.send_response(response.build_no_records(response_header)); .send_response(response.build_no_records(response_header))
.await;
if let Err(e) = result { if let Err(e) = result {
error!("request error: {}", e); error!("request error: {}", e);
} }
return Box::pin(async {}); return;
} }
response_edns = Some(resp_edns); response_edns = Some(resp_edns);
@@ -128,19 +122,24 @@ impl RequestHandler for Catalog {
MessageType::Query => match request_message.op_code() { MessageType::Query => match request_message.op_code() {
OpCode::Query => { OpCode::Query => {
debug!("query received: {}", request_message.id()); debug!("query received: {}", request_message.id());
return Box::pin(self.lookup(request_message, response_edns, response_handle)); self.lookup(request_message, response_edns, response_handle)
.await;
Ok(())
} }
OpCode::Update => { OpCode::Update => {
debug!("update received: {}", request_message.id()); debug!("update received: {}", request_message.id());
// TODO: this should be a future
self.update(&request_message, response_edns, response_handle) self.update(&request_message, response_edns, response_handle)
.await
} }
c => { c => {
warn!("unimplemented op_code: {:?}", c); warn!("unimplemented op_code: {:?}", c);
let response = MessageResponseBuilder::new(Some(request_message.raw_queries())); let response = MessageResponseBuilder::new(Some(request_message.raw_queries()));
response_handle.send_response( response_handle
response.error_msg(request_message.header(), ResponseCode::NotImp), .send_response(
) response.error_msg(request_message.header(), ResponseCode::NotImp),
)
.await
} }
}, },
MessageType::Response => { MessageType::Response => {
@@ -149,16 +148,17 @@ impl RequestHandler for Catalog {
request_message.id() request_message.id()
); );
let response = MessageResponseBuilder::new(Some(request_message.raw_queries())); let response = MessageResponseBuilder::new(Some(request_message.raw_queries()));
response_handle.send_response( response_handle
response.error_msg(request_message.header(), ResponseCode::FormErr), .send_response(
) response.error_msg(request_message.header(), ResponseCode::FormErr),
)
.await
} }
}; };
if let Err(e) = result { if let Err(e) = result {
error!("request failed: {}", e); error!("request failed: {}", e);
} }
Box::pin(async {})
} }
} }
@@ -234,7 +234,7 @@ impl Catalog {
/// ///
/// * `request` - an update message /// * `request` - an update message
/// * `response_handle` - sink for the response message to be sent /// * `response_handle` - sink for the response message to be sent
pub fn update<R: ResponseHandler + 'static>( pub async fn update<R: ResponseHandler + 'static>(
&self, &self,
update: &MessageRequest, update: &MessageRequest,
response_edns: Option<Edns>, response_edns: Option<Edns>,
@@ -270,23 +270,14 @@ impl Catalog {
let response_code = match result { let response_code = match result {
Ok(authority) => { Ok(authority) => {
// Ask for Master/Slave terms to be replaced
#[allow(deprecated)] #[allow(deprecated)]
match authority.zone_type() { match authority.zone_type().await {
ZoneType::Slave | ZoneType::Master => {
warn!("Consider replacing the usage of master/slave with primary/secondary, see Juneteenth.");
}
_ => (),
}
#[allow(deprecated)]
match authority.zone_type() {
ZoneType::Secondary | ZoneType::Slave => { ZoneType::Secondary | ZoneType::Slave => {
error!("secondary forwarding for update not yet implemented"); error!("secondary forwarding for update not yet implemented");
ResponseCode::NotImp ResponseCode::NotImp
} }
ZoneType::Primary | ZoneType::Master => { ZoneType::Primary | ZoneType::Master => {
let update_result = authority.update(update); let update_result = authority.update(update).await;
match update_result { match update_result {
// successful update // successful update
Ok(..) => ResponseCode::NoError, Ok(..) => ResponseCode::NoError,
@@ -311,6 +302,7 @@ impl Catalog {
response.build_no_records(response_header), response.build_no_records(response_header),
response_handle, response_handle,
) )
.await
} }
/// Checks whether the `Catalog` contains DNS records for `name` /// Checks whether the `Catalog` contains DNS records for `name`
@@ -332,12 +324,12 @@ impl Catalog {
/// ///
/// * `request` - the query message. /// * `request` - the query message.
/// * `response_handle` - sink for the response message to be sent /// * `response_handle` - sink for the response message to be sent
pub fn lookup<R: ResponseHandler>( pub async fn lookup<R: ResponseHandler>(
&self, &self,
request: MessageRequest, request: MessageRequest,
response_edns: Option<Edns>, response_edns: Option<Edns>,
response_handle: R, response_handle: R,
) -> impl Future<Output = ()> + 'static { ) {
// find matching authorities for the request // find matching authorities for the request
let queries_and_authorities = request let queries_and_authorities = request
.queries() .queries()
@@ -360,6 +352,7 @@ impl Catalog {
response.error_msg(request.header(), ResponseCode::Refused), response.error_msg(request.header(), ResponseCode::Refused),
response_handle.clone(), response_handle.clone(),
) )
.await
.map_err(|e| error!("failed to send response: {}", e)) .map_err(|e| error!("failed to send response: {}", e))
.ok(); .ok();
} }
@@ -370,6 +363,7 @@ impl Catalog {
response_edns, response_edns,
response_handle, response_handle,
) )
.await
} }
/// Recursively searches the catalog for a matching authority /// Recursively searches the catalog for a matching authority
@@ -403,7 +397,7 @@ async fn lookup<R: ResponseHandler + Unpin>(
info!( info!(
"request: {} found authority: {}", "request: {} found authority: {}",
request.id(), request.id(),
authority.origin() authority.origin().await
); );
let (response_header, sections) = build_response( let (response_header, sections) = build_response(
@@ -423,7 +417,7 @@ async fn lookup<R: ResponseHandler + Unpin>(
sections.additionals.iter(), sections.additionals.iter(),
); );
let result = send_response(response_edns.clone(), response, response_handle.clone()); let result = send_response(response_edns.clone(), response, response_handle.clone()).await;
if let Err(e) = result { if let Err(e) = result {
error!("error sending response: {}", e); error!("error sending response: {}", e);
} }
@@ -472,13 +466,13 @@ async fn build_response(
} }
let mut response_header = Header::response_from_request(request_header); let mut response_header = Header::response_from_request(request_header);
response_header.set_authoritative(authority.zone_type().is_authoritative()); response_header.set_authoritative(authority.zone_type().await.is_authoritative());
debug!("performing {} on {}", query, authority.origin()); debug!("performing {} on {}", query, authority.origin().await);
let future = authority.search(query, lookup_options); let future = authority.search(query, lookup_options);
#[allow(deprecated)] #[allow(deprecated)]
let sections = match authority.zone_type() { let sections = match authority.zone_type().await {
ZoneType::Primary | ZoneType::Secondary | ZoneType::Master | ZoneType::Slave => { ZoneType::Primary | ZoneType::Secondary | ZoneType::Master | ZoneType::Slave => {
send_authoritative_response( send_authoritative_response(
future, future,
@@ -499,7 +493,7 @@ async fn build_response(
} }
async fn send_authoritative_response( async fn send_authoritative_response(
future: BoxedLookupFuture, future: impl Future<Output = Result<Box<dyn LookupObject>, LookupError>>,
authority: &dyn AuthorityObject, authority: &dyn AuthorityObject,
response_header: &mut Header, response_header: &mut Header,
lookup_options: LookupOptions, lookup_options: LookupOptions,
@@ -603,7 +597,7 @@ async fn send_authoritative_response(
} }
async fn send_forwarded_response( async fn send_forwarded_response(
future: BoxedLookupFuture, future: impl Future<Output = Result<Box<dyn LookupObject>, LookupError>>,
request_header: &Header, request_header: &Header,
response_header: &mut Header, response_header: &mut Header,
) -> LookupSections { ) -> LookupSections {

View File

@@ -26,7 +26,7 @@ pub use self::auth_lookup::{
AnyRecords, AuthLookup, AuthLookupIter, LookupRecords, LookupRecordsIter, AnyRecords, AuthLookup, AuthLookupIter, LookupRecords, LookupRecordsIter,
}; };
pub use self::authority::{Authority, LookupOptions}; pub use self::authority::{Authority, LookupOptions};
pub use self::authority_object::{AuthorityObject, BoxedLookupFuture, EmptyLookup, LookupObject}; pub use self::authority_object::{AuthorityObject, EmptyLookup, LookupObject};
pub use self::catalog::Catalog; pub use self::catalog::Catalog;
pub use self::error::{LookupError, LookupResult}; pub use self::error::{LookupError, LookupResult};
pub use self::message_request::{MessageRequest, Queries, UpdateRequest}; pub use self::message_request::{MessageRequest, Queries, UpdateRequest};

View File

@@ -5,21 +5,19 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be // http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms. // copied, modified, or distributed except according to those terms.
use std::io; use std::{io, net::SocketAddr, sync::Arc};
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures_util::lock::Mutex;
use h2::server; use h2::server;
use log::{debug, warn}; use log::{debug, warn};
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
use crate::authority::{MessageRequest, MessageResponse}; use crate::{
use crate::proto::https::https_server; authority::{MessageRequest, MessageResponse},
use crate::proto::serialize::binary::BinDecodable; proto::{https::https_server, serialize::binary::BinDecodable},
use crate::server::request_handler::RequestHandler; server::{request_handler::RequestHandler, response_handler::ResponseHandler, server_future},
use crate::server::response_handler::ResponseHandler; };
use crate::server::server_future;
pub(crate) async fn h2_handler<T, I>( pub(crate) async fn h2_handler<T, I>(
handler: Arc<Mutex<T>>, handler: Arc<Mutex<T>>,
@@ -90,8 +88,9 @@ async fn handle_request<T>(
#[derive(Clone)] #[derive(Clone)]
struct HttpsResponseHandle(Arc<Mutex<::h2::server::SendResponse<Bytes>>>); struct HttpsResponseHandle(Arc<Mutex<::h2::server::SendResponse<Bytes>>>);
#[async_trait::async_trait]
impl ResponseHandler for HttpsResponseHandle { impl ResponseHandler for HttpsResponseHandle {
fn send_response(&mut self, response: MessageResponse<'_, '_>) -> io::Result<()> { async fn send_response(&mut self, response: MessageResponse<'_, '_>) -> io::Result<()> {
use crate::proto::https::response; use crate::proto::https::response;
use crate::proto::https::HttpsError; use crate::proto::https::HttpsError;
use crate::proto::serialize::binary::BinEncoder; use crate::proto::serialize::binary::BinEncoder;
@@ -109,7 +108,7 @@ impl ResponseHandler for HttpsResponseHandle {
let mut stream = self let mut stream = self
.0 .0
.lock() .lock()
.expect("https poisoned") .await
.send_response(response, false) .send_response(response, false)
.map_err(HttpsError::from)?; .map_err(HttpsError::from)?;
stream.send_data(bytes, true).map_err(HttpsError::from)?; stream.send_data(bytes, true).map_err(HttpsError::from)?;

View File

@@ -1,4 +1,4 @@
// Copyright 2015-2017 Benjamin Fry <benjaminfry@me.com> // Copyright 2015-2021 Benjamin Fry <benjaminfry@me.com>
// //
// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or // Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or // http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
@@ -7,11 +7,9 @@
//! Request Handler for incoming requests //! Request Handler for incoming requests
use std::future::Future;
use std::net::SocketAddr; use std::net::SocketAddr;
use crate::authority::MessageRequest; use crate::{authority::MessageRequest, server::ResponseHandler};
use crate::server::ResponseHandler;
/// An incoming request to the DNS catalog /// An incoming request to the DNS catalog
pub struct Request { pub struct Request {
@@ -22,19 +20,13 @@ pub struct Request {
} }
/// Trait for handling incoming requests, and providing a message response. /// Trait for handling incoming requests, and providing a message response.
pub trait RequestHandler: Send + Unpin + 'static { #[async_trait::async_trait]
/// A future for execution of the request pub trait RequestHandler: Send + Sync + Unpin + 'static {
type ResponseFuture: Future<Output = ()> + Send + Unpin + 'static;
/// Determines what needs to happen given the type of request, i.e. Query or Update. /// Determines what needs to happen given the type of request, i.e. Query or Update.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `request` - the requested action to perform. /// * `request` - the requested action to perform.
/// * `response_handle` - handle to which a return message should be sent /// * `response_handle` - handle to which a return message should be sent
fn handle_request<R: ResponseHandler>( async fn handle_request<R: ResponseHandler>(&self, request: Request, response_handle: R);
&self,
request: Request,
response_handle: R,
) -> Self::ResponseFuture;
} }

View File

@@ -16,14 +16,15 @@ use crate::proto::xfer::SerialMessage;
use crate::proto::{BufDnsStreamHandle, DnsStreamHandle}; use crate::proto::{BufDnsStreamHandle, DnsStreamHandle};
/// A handler for send a response to a client /// A handler for send a response to a client
pub trait ResponseHandler: Clone + Send + Unpin + 'static { #[async_trait::async_trait]
pub trait ResponseHandler: Clone + Send + Sync + Unpin + 'static {
// TODO: add associated error type // TODO: add associated error type
//type Error; //type Error;
/// Serializes and sends a message to to the wrapped handle /// Serializes and sends a message to to the wrapped handle
/// ///
/// self is consumed as only one message should ever be sent in response to a Request /// self is consumed as only one message should ever be sent in response to a Request
fn send_response(&mut self, response: MessageResponse<'_, '_>) -> io::Result<()>; async fn send_response(&mut self, response: MessageResponse<'_, '_>) -> io::Result<()>;
} }
/// A handler for wrapping a BufStreamHandle, which will properly serialize the message and add the /// A handler for wrapping a BufStreamHandle, which will properly serialize the message and add the
@@ -41,11 +42,12 @@ impl ResponseHandle {
} }
} }
#[async_trait::async_trait]
impl ResponseHandler for ResponseHandle { impl ResponseHandler for ResponseHandle {
/// Serializes and sends a message to to the wrapped handle /// Serializes and sends a message to to the wrapped handle
/// ///
/// self is consumed as only one message should ever be sent in response to a Request /// self is consumed as only one message should ever be sent in response to a Request
fn send_response(&mut self, response: MessageResponse<'_, '_>) -> io::Result<()> { async fn send_response(&mut self, response: MessageResponse<'_, '_>) -> io::Result<()> {
info!( info!(
"response: {} response_code: {}", "response: {} response_code: {}",
response.header().id(), response.header().id(),

View File

@@ -1,23 +1,16 @@
// Copyright 2015-2017 Benjamin Fry <benjaminfry@me.com> // Copyright 2015-2021 Benjamin Fry <benjaminfry@me.com>
// //
// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or // Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or // http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be // http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms. // copied, modified, or distributed except according to those terms.
use std::future::Future; use std::{io, net::SocketAddr, sync::Arc, time::Duration};
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::Duration;
use futures_util::{future, FutureExt, StreamExt}; use futures_util::{future, lock::Mutex, StreamExt};
use log::{debug, info, warn}; use log::{debug, info, warn};
#[cfg(feature = "dns-over-rustls")] #[cfg(feature = "dns-over-rustls")]
use rustls::{Certificate, PrivateKey}; use rustls::{Certificate, PrivateKey};
use tokio::net; use tokio::{net, task::JoinHandle};
use tokio::task::JoinHandle;
use crate::authority::MessageRequest; use crate::authority::MessageRequest;
use crate::proto::error::ProtoError; use crate::proto::error::ProtoError;
@@ -78,7 +71,7 @@ impl<T: RequestHandler> ServerFuture<T> {
let stream_handle = stream_handle.with_remote_addr(src_addr); let stream_handle = stream_handle.with_remote_addr(src_addr);
tokio::spawn(async move { tokio::spawn(async move {
self::handle_raw_request(message, handler, stream_handle).await; self::handle_raw_request(message, handler, stream_handle).await
}); });
} }
@@ -557,11 +550,11 @@ impl<T: RequestHandler> ServerFuture<T> {
} }
} }
pub(crate) fn handle_raw_request<T: RequestHandler>( pub(crate) async fn handle_raw_request<T: RequestHandler>(
message: SerialMessage, message: SerialMessage,
request_handler: Arc<Mutex<T>>, request_handler: Arc<Mutex<T>>,
response_handler: BufDnsStreamHandle, response_handler: BufDnsStreamHandle,
) -> HandleRawRequest<T::ResponseFuture> { ) {
let src_addr = message.addr(); let src_addr = message.addr();
let response_handler = ResponseHandle::new(message.addr(), response_handler); let response_handler = ResponseHandle::new(message.addr(), response_handler);
@@ -572,20 +565,19 @@ pub(crate) fn handle_raw_request<T: RequestHandler>(
let mut decoder = BinDecoder::new(message.bytes()); let mut decoder = BinDecoder::new(message.bytes());
match MessageRequest::read(&mut decoder) { match MessageRequest::read(&mut decoder) {
Ok(message) => { Ok(message) => {
let handle_request = self::handle_request(message, src_addr, request_handler, response_handler).await
self::handle_request(message, src_addr, request_handler, response_handler);
HandleRawRequest::HandleRequest(handle_request)
} }
Err(e) => HandleRawRequest::Result(e.into()), // FIXME: return the error and properly log it in handle_request?
Err(e) => warn!("failed to handle message: {}", e),
} }
} }
pub(crate) fn handle_request<R: ResponseHandler, T: RequestHandler>( pub(crate) async fn handle_request<R: ResponseHandler, T: RequestHandler>(
message: MessageRequest, message: MessageRequest,
src_addr: SocketAddr, src_addr: SocketAddr,
request_handler: Arc<Mutex<T>>, request_handler: Arc<Mutex<T>>,
response_handler: R, response_handler: R,
) -> T::ResponseFuture { ) {
let request = Request { let request = Request {
message, message,
src: src_addr, src: src_addr,
@@ -607,26 +599,7 @@ pub(crate) fn handle_request<R: ResponseHandler, T: RequestHandler>(
request_handler request_handler
.lock() .lock()
.expect("poisoned lock") .await
.handle_request(request, response_handler) .handle_request(request, response_handler)
} .await
#[must_use = "futures do nothing unless polled"]
pub(crate) enum HandleRawRequest<F: Future<Output = ()>> {
HandleRequest(F),
Result(io::Error),
}
impl<F: Future<Output = ()> + Unpin> Future for HandleRawRequest<F> {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match *self {
HandleRawRequest::HandleRequest(ref mut f) => f.poll_unpin(cx),
HandleRawRequest::Result(ref res) => {
warn!("failed to handle message: {}", res);
Poll::Ready(())
}
}
}
} }

View File

@@ -1,4 +1,4 @@
// Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com> // Copyright 2015-2021 Benjamin Fry <benjaminfry@me.com>
// //
// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or // Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or // http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
@@ -7,31 +7,33 @@
//! All authority related types //! All authority related types
use std::collections::BTreeMap; use std::{
use std::fs::File; collections::BTreeMap,
use std::future::Future; fs::File,
use std::io::{BufRead, BufReader}; io::{BufRead, BufReader},
use std::ops::{Deref, DerefMut}; ops::{Deref, DerefMut},
use std::path::{Path, PathBuf}; path::{Path, PathBuf},
use std::pin::Pin; };
use log::{debug, info}; use log::{debug, info};
#[cfg(feature = "dnssec")] #[cfg(feature = "dnssec")]
use crate::authority::DnssecAuthority; use crate::{
use crate::authority::{ authority::DnssecAuthority,
Authority, LookupError, LookupOptions, MessageRequest, UpdateResult, ZoneType, client::{
proto::rr::dnssec::rdata::key::KEY,
rr::dnssec::{DnsSecResult, SigSigner},
},
}; };
use crate::client::op::LowerQuery; use crate::{
use crate::client::rr::{LowerName, Name, RecordSet, RecordType, RrKey}; authority::{Authority, LookupError, LookupOptions, MessageRequest, UpdateResult, ZoneType},
use crate::client::serialize::txt::{Lexer, Parser, Token}; client::{
#[cfg(feature = "dnssec")] op::LowerQuery,
use crate::client::{ rr::{LowerName, Name, RecordSet, RecordType, RrKey},
proto::rr::dnssec::rdata::key::KEY, serialize::txt::{Lexer, Parser, Token},
rr::dnssec::{DnsSecResult, SigSigner}, },
store::{file::FileConfig, in_memory::InMemoryAuthority},
}; };
use crate::store::file::FileConfig;
use crate::store::in_memory::InMemoryAuthority;
/// FileAuthority is responsible for storing the resource records for a particular zone. /// FileAuthority is responsible for storing the resource records for a particular zone.
/// ///
@@ -232,9 +234,9 @@ impl DerefMut for FileAuthority {
} }
} }
#[async_trait::async_trait]
impl Authority for FileAuthority { impl Authority for FileAuthority {
type Lookup = <InMemoryAuthority as Authority>::Lookup; type Lookup = <InMemoryAuthority as Authority>::Lookup;
type LookupFuture = <InMemoryAuthority as Authority>::LookupFuture;
/// What type is this zone /// What type is this zone
fn zone_type(&self) -> ZoneType { fn zone_type(&self) -> ZoneType {
@@ -247,7 +249,7 @@ impl Authority for FileAuthority {
} }
/// Perform a dynamic update of a zone /// Perform a dynamic update of a zone
fn update(&mut self, _update: &MessageRequest) -> UpdateResult<bool> { async fn update(&mut self, _update: &MessageRequest) -> UpdateResult<bool> {
use crate::proto::op::ResponseCode; use crate::proto::op::ResponseCode;
Err(ResponseCode::NotImp) Err(ResponseCode::NotImp)
} }
@@ -271,13 +273,13 @@ impl Authority for FileAuthority {
/// # Return value /// # Return value
/// ///
/// None if there are no matching records, otherwise a `Vec` containing the found records. /// None if there are no matching records, otherwise a `Vec` containing the found records.
fn lookup( async fn lookup(
&self, &self,
name: &LowerName, name: &LowerName,
rtype: RecordType, rtype: RecordType,
lookup_options: LookupOptions, lookup_options: LookupOptions,
) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>> { ) -> Result<Self::Lookup, LookupError> {
Box::pin(self.0.lookup(name, rtype, lookup_options)) self.0.lookup(name, rtype, lookup_options).await
} }
/// Using the specified query, perform a lookup against this zone. /// Using the specified query, perform a lookup against this zone.
@@ -291,20 +293,17 @@ impl Authority for FileAuthority {
/// ///
/// Returns a vectory containing the results of the query, it will be empty if not found. If /// Returns a vectory containing the results of the query, it will be empty if not found. If
/// `is_secure` is true, in the case of no records found then NSEC records will be returned. /// `is_secure` is true, in the case of no records found then NSEC records will be returned.
fn search( async fn search(
&self, &self,
query: &LowerQuery, query: &LowerQuery,
lookup_options: LookupOptions, lookup_options: LookupOptions,
) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>> { ) -> Result<Self::Lookup, LookupError> {
Box::pin(self.0.search(query, lookup_options)) self.0.search(query, lookup_options).await
} }
/// Get the NS, NameServer, record for the zone /// Get the NS, NameServer, record for the zone
fn ns( async fn ns(&self, lookup_options: LookupOptions) -> Result<Self::Lookup, LookupError> {
&self, self.0.ns(lookup_options).await
lookup_options: LookupOptions,
) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>> {
self.0.ns(lookup_options)
} }
/// Return the NSEC records based on the given name /// Return the NSEC records based on the given name
@@ -314,28 +313,25 @@ impl Authority for FileAuthority {
/// * `name` - given this name (i.e. the lookup name), return the NSEC record that is less than /// * `name` - given this name (i.e. the lookup name), return the NSEC record that is less than
/// this /// this
/// * `is_secure` - if true then it will return RRSIG records as well /// * `is_secure` - if true then it will return RRSIG records as well
fn get_nsec_records( async fn get_nsec_records(
&self, &self,
name: &LowerName, name: &LowerName,
lookup_options: LookupOptions, lookup_options: LookupOptions,
) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>> { ) -> Result<Self::Lookup, LookupError> {
self.0.get_nsec_records(name, lookup_options) self.0.get_nsec_records(name, lookup_options).await
} }
/// Returns the SOA of the authority. /// Returns the SOA of the authority.
/// ///
/// *Note*: This will only return the SOA, if this is fulfilling a request, a standard lookup /// *Note*: This will only return the SOA, if this is fulfilling a request, a standard lookup
/// should be used, see `soa_secure()`, which will optionally return RRSIGs. /// should be used, see `soa_secure()`, which will optionally return RRSIGs.
fn soa(&self) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>> { async fn soa(&self) -> Result<Self::Lookup, LookupError> {
self.0.soa() self.0.soa().await
} }
/// Returns the SOA record for the zone /// Returns the SOA record for the zone
fn soa_secure( async fn soa_secure(&self, lookup_options: LookupOptions) -> Result<Self::Lookup, LookupError> {
&self, self.0.soa_secure(lookup_options).await
lookup_options: LookupOptions,
) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>> {
self.0.soa_secure(lookup_options)
} }
} }

View File

@@ -1,30 +1,27 @@
// Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com> // Copyright 2015-2021 Benjamin Fry <benjaminfry@me.com>
// //
// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or // Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or // http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be // http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms. // copied, modified, or distributed except according to those terms.
use std::future::Future;
use std::io; use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use futures_util::{future, FutureExt};
use log::info; use log::info;
use crate::client::op::LowerQuery; use crate::{
use crate::client::op::ResponseCode; authority::{
use crate::client::rr::{LowerName, Name, Record, RecordType}; Authority, LookupError, LookupObject, LookupOptions, MessageRequest, UpdateResult, ZoneType,
use crate::resolver::config::ResolverConfig; },
use crate::resolver::error::ResolveError; client::{
use crate::resolver::lookup::Lookup as ResolverLookup; op::{LowerQuery, ResponseCode},
use crate::resolver::{TokioAsyncResolver, TokioHandle}; rr::{LowerName, Name, Record, RecordType},
},
use crate::authority::{ resolver::{
Authority, LookupError, LookupObject, LookupOptions, MessageRequest, UpdateResult, ZoneType, config::ResolverConfig, lookup::Lookup as ResolverLookup, TokioAsyncResolver, TokioHandle,
},
store::forwarder::ForwardConfig,
}; };
use crate::store::forwarder::ForwardConfig;
/// An authority that will forward resolutions to upstream resolvers. /// An authority that will forward resolutions to upstream resolvers.
/// ///
@@ -73,9 +70,9 @@ impl ForwardAuthority {
} }
} }
#[async_trait::async_trait]
impl Authority for ForwardAuthority { impl Authority for ForwardAuthority {
type Lookup = ForwardLookup; type Lookup = ForwardLookup;
type LookupFuture = Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>>;
/// Always Forward /// Always Forward
fn zone_type(&self) -> ZoneType { fn zone_type(&self) -> ZoneType {
@@ -87,7 +84,7 @@ impl Authority for ForwardAuthority {
false false
} }
fn update(&mut self, _update: &MessageRequest) -> UpdateResult<bool> { async fn update(&mut self, _update: &MessageRequest) -> UpdateResult<bool> {
Err(ResponseCode::NotImp) Err(ResponseCode::NotImp)
} }
@@ -101,41 +98,40 @@ impl Authority for ForwardAuthority {
} }
/// Forwards a lookup given the resolver configuration for this Forwarded zone /// Forwards a lookup given the resolver configuration for this Forwarded zone
fn lookup( async fn lookup(
&self, &self,
name: &LowerName, name: &LowerName,
rtype: RecordType, rtype: RecordType,
_lookup_options: LookupOptions, _lookup_options: LookupOptions,
) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>> { ) -> Result<Self::Lookup, LookupError> {
// TODO: make this an error? // TODO: make this an error?
assert!(self.origin.zone_of(name)); assert!(self.origin.zone_of(name));
info!("forwarding lookup: {} {}", name, rtype); info!("forwarding lookup: {} {}", name, rtype);
let name: LowerName = name.clone(); let name: LowerName = name.clone();
Box::pin(ForwardLookupFuture(self.resolver.lookup( let resolve = self.resolver.lookup(name, rtype, Default::default()).await;
name,
rtype, resolve.map(ForwardLookup).map_err(LookupError::from)
Default::default(),
)))
} }
fn search( async fn search(
&self, &self,
query: &LowerQuery, query: &LowerQuery,
lookup_options: LookupOptions, lookup_options: LookupOptions,
) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>> { ) -> Result<Self::Lookup, LookupError> {
Box::pin(self.lookup(query.name(), query.query_type(), lookup_options)) self.lookup(query.name(), query.query_type(), lookup_options)
.await
} }
fn get_nsec_records( async fn get_nsec_records(
&self, &self,
_name: &LowerName, _name: &LowerName,
_lookup_options: LookupOptions, _lookup_options: LookupOptions,
) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>> { ) -> Result<Self::Lookup, LookupError> {
Box::pin(future::err(LookupError::from(io::Error::new( Err(LookupError::from(io::Error::new(
io::ErrorKind::Other, io::ErrorKind::Other,
"Getting NSEC records is unimplemented for the forwarder", "Getting NSEC records is unimplemented for the forwarder",
)))) )))
} }
} }
@@ -154,21 +150,3 @@ impl LookupObject for ForwardLookup {
None None
} }
} }
pub(crate) struct ForwardLookupFuture<
F: Future<Output = Result<ResolverLookup, ResolveError>> + Send + Unpin + 'static,
>(F);
impl<F: Future<Output = Result<ResolverLookup, ResolveError>> + Send + Unpin> Future
for ForwardLookupFuture<F>
{
type Output = Result<ForwardLookup, LookupError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.0.poll_unpin(cx) {
Poll::Ready(Ok(f)) => Poll::Ready(Ok(ForwardLookup(f))),
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
}
}
}

View File

@@ -1,4 +1,4 @@
// Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com> // Copyright 2015-2021 Benjamin Fry <benjaminfry@me.com>
// //
// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or // Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or // http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
@@ -7,11 +7,7 @@
//! All authority related types //! All authority related types
use std::borrow::Borrow; use std::{borrow::Borrow, collections::BTreeMap, sync::Arc};
use std::collections::BTreeMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use cfg_if::cfg_if; use cfg_if::cfg_if;
use futures_util::future::{self, TryFutureExt}; use futures_util::future::{self, TryFutureExt};
@@ -709,9 +705,9 @@ fn maybe_next_name(
} }
} }
#[async_trait::async_trait]
impl Authority for InMemoryAuthority { impl Authority for InMemoryAuthority {
type Lookup = AuthLookup; type Lookup = AuthLookup;
type LookupFuture = future::Ready<Result<Self::Lookup, LookupError>>;
/// What type is this zone /// What type is this zone
fn zone_type(&self) -> ZoneType { fn zone_type(&self) -> ZoneType {
@@ -780,7 +776,7 @@ impl Authority for InMemoryAuthority {
/// ///
/// true if any of additions, updates or deletes were made to the zone, false otherwise. Err is /// true if any of additions, updates or deletes were made to the zone, false otherwise. Err is
/// returned in the case of bad data, etc. /// returned in the case of bad data, etc.
fn update(&mut self, _update: &MessageRequest) -> UpdateResult<bool> { async fn update(&mut self, _update: &MessageRequest) -> UpdateResult<bool> {
Err(ResponseCode::NotImp) Err(ResponseCode::NotImp)
} }
@@ -803,12 +799,12 @@ impl Authority for InMemoryAuthority {
/// # Return value /// # Return value
/// ///
/// None if there are no matching records, otherwise a `Vec` containing the found records. /// None if there are no matching records, otherwise a `Vec` containing the found records.
fn lookup( async fn lookup(
&self, &self,
name: &LowerName, name: &LowerName,
query_type: RecordType, query_type: RecordType,
lookup_options: LookupOptions, lookup_options: LookupOptions,
) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>> { ) -> Result<Self::Lookup, LookupError> {
// Collect the records from each rr_set // Collect the records from each rr_set
let (result, additionals): (LookupResult<LookupRecords>, Option<LookupRecords>) = let (result, additionals): (LookupResult<LookupRecords>, Option<LookupRecords>) =
match query_type { match query_type {
@@ -944,30 +940,28 @@ impl Authority for InMemoryAuthority {
.keys() .keys()
.any(|key| key.name() == name || name.zone_of(key.name())) .any(|key| key.name() == name || name.zone_of(key.name()))
{ {
return Box::pin(future::err(LookupError::NameExists)); return Err(LookupError::NameExists);
} else { } else {
let code = if self.origin().zone_of(name) { let code = if self.origin().zone_of(name) {
ResponseCode::NXDomain ResponseCode::NXDomain
} else { } else {
ResponseCode::Refused ResponseCode::Refused
}; };
return Box::pin(future::err(LookupError::from(code))); return Err(LookupError::from(code));
} }
} }
Err(e) => return Box::pin(future::err(e)), Err(e) => return Err(e),
o => o, o => o,
}; };
Box::pin(future::ready( result.map(|answers| AuthLookup::answers(answers, additionals))
result.map(|answers| AuthLookup::answers(answers, additionals)),
))
} }
fn search( async fn search(
&self, &self,
query: &LowerQuery, query: &LowerQuery,
lookup_options: LookupOptions, lookup_options: LookupOptions,
) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>> { ) -> Result<Self::Lookup, LookupError> {
debug!("searching InMemoryAuthority for: {}", query); debug!("searching InMemoryAuthority for: {}", query);
let lookup_name = query.name(); let lookup_name = query.name();
@@ -978,20 +972,23 @@ impl Authority for InMemoryAuthority {
if RecordType::AXFR == record_type { if RecordType::AXFR == record_type {
// TODO: support more advanced AXFR options // TODO: support more advanced AXFR options
if !self.is_axfr_allowed() { if !self.is_axfr_allowed() {
return Box::pin(future::err(LookupError::from(ResponseCode::Refused))); return Err(LookupError::from(ResponseCode::Refused));
} }
#[allow(deprecated)] #[allow(deprecated)]
match self.zone_type() { match self.zone_type() {
ZoneType::Primary | ZoneType::Secondary | ZoneType::Master | ZoneType::Slave => (), ZoneType::Primary | ZoneType::Secondary | ZoneType::Master | ZoneType::Slave => (),
// TODO: Forward? // TODO: Forward?
_ => return Box::pin(future::err(LookupError::from(ResponseCode::NXDomain))), _ => return Err(LookupError::from(ResponseCode::NXDomain)),
} }
} }
// perform the actual lookup // perform the actual lookup
match record_type { match record_type {
RecordType::SOA => Box::pin(self.lookup(self.origin(), record_type, lookup_options)), RecordType::SOA => {
self.lookup(self.origin(), record_type, lookup_options)
.await
}
RecordType::AXFR => { RecordType::AXFR => {
// TODO: shouldn't these SOA's be secure? at least the first, perhaps not the last? // TODO: shouldn't these SOA's be secure? at least the first, perhaps not the last?
let lookup = future::try_join3( let lookup = future::try_join3(
@@ -1009,10 +1006,10 @@ impl Authority for InMemoryAuthority {
}, },
}); });
Box::pin(lookup) lookup.await
} }
// A standard Lookup path // A standard Lookup path
_ => Box::pin(self.lookup(lookup_name, record_type, lookup_options)), _ => self.lookup(lookup_name, record_type, lookup_options).await,
} }
} }
@@ -1024,11 +1021,11 @@ impl Authority for InMemoryAuthority {
/// this /// this
/// * `is_secure` - if true then it will return RRSIG records as well /// * `is_secure` - if true then it will return RRSIG records as well
#[cfg(feature = "dnssec")] #[cfg(feature = "dnssec")]
fn get_nsec_records( async fn get_nsec_records(
&self, &self,
name: &LowerName, name: &LowerName,
lookup_options: LookupOptions, lookup_options: LookupOptions,
) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>> { ) -> Result<Self::Lookup, LookupError> {
fn is_nsec_rrset(rr_set: &RecordSet) -> bool { fn is_nsec_rrset(rr_set: &RecordSet) -> bool {
rr_set.record_type() == RecordType::NSEC rr_set.record_type() == RecordType::NSEC
} }
@@ -1041,7 +1038,7 @@ impl Authority for InMemoryAuthority {
.map(|rr_set| LookupRecords::new(lookup_options, rr_set.clone())); .map(|rr_set| LookupRecords::new(lookup_options, rr_set.clone()));
if let Some(no_data) = no_data { if let Some(no_data) = no_data {
return Box::pin(future::ready(Ok(no_data.into()))); return Ok(no_data.into());
} }
let get_closest_nsec = |name: &LowerName| -> Option<Arc<RecordSet>> { let get_closest_nsec = |name: &LowerName| -> Option<Arc<RecordSet>> {
@@ -1099,20 +1096,16 @@ impl Authority for InMemoryAuthority {
(None, None) => vec![], (None, None) => vec![],
}; };
Box::pin(future::ready(Ok(LookupRecords::many( Ok(LookupRecords::many(lookup_options, proofs).into())
lookup_options,
proofs,
)
.into())))
} }
#[cfg(not(feature = "dnssec"))] #[cfg(not(feature = "dnssec"))]
fn get_nsec_records( async fn get_nsec_records(
&self, &self,
_name: &LowerName, _name: &LowerName,
_lookup_options: LookupOptions, _lookup_options: LookupOptions,
) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>> { ) -> Result<Self::Lookup, LookupError> {
Box::pin(future::ok(AuthLookup::default())) Ok(AuthLookup::default())
} }
} }

View File

@@ -1,4 +1,4 @@
// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com> // Copyright 2015-2021 Benjamin Fry <benjaminfry@me.com>
// //
// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or // Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or // http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
@@ -7,24 +7,30 @@
//! All authority related types //! All authority related types
use std::future::Future; use std::{
use std::ops::{Deref, DerefMut}; ops::{Deref, DerefMut},
use std::path::{Path, PathBuf}; path::{Path, PathBuf},
use std::pin::Pin; sync::Arc,
use std::sync::Arc; };
use log::{error, info, warn}; use log::{error, info, warn};
use crate::authority::{ use crate::{
Authority, LookupError, LookupOptions, MessageRequest, UpdateResult, ZoneType, authority::{Authority, LookupError, LookupOptions, MessageRequest, UpdateResult, ZoneType},
client::{
op::LowerQuery,
rr::{LowerName, RrKey},
},
error::{PersistenceErrorKind, PersistenceResult},
proto::{
op::ResponseCode,
rr::{DNSClass, Name, RData, Record, RecordSet, RecordType},
},
store::{
in_memory::InMemoryAuthority,
sqlite::{Journal, SqliteConfig},
},
}; };
use crate::client::op::LowerQuery;
use crate::client::rr::{LowerName, RrKey};
use crate::error::{PersistenceErrorKind, PersistenceResult};
use crate::proto::op::ResponseCode;
use crate::proto::rr::{DNSClass, Name, RData, Record, RecordSet, RecordType};
use crate::store::in_memory::InMemoryAuthority;
use crate::store::sqlite::{Journal, SqliteConfig};
#[cfg(feature = "dnssec")] #[cfg(feature = "dnssec")]
use crate::{ use crate::{
authority::{DnssecAuthority, UpdateRequest}, authority::{DnssecAuthority, UpdateRequest},
@@ -258,9 +264,7 @@ impl SqliteAuthority {
/// NONE rrset empty RRset does not exist /// NONE rrset empty RRset does not exist
/// zone rrset rr RRset exists (value dependent) /// zone rrset rr RRset exists (value dependent)
/// ``` /// ```
pub fn verify_prerequisites(&self, pre_requisites: &[Record]) -> UpdateResult<()> { pub async fn verify_prerequisites(&self, pre_requisites: &[Record]) -> UpdateResult<()> {
use futures_executor::block_on;
// 3.2.5 - Pseudocode for Prerequisite Section Processing // 3.2.5 - Pseudocode for Prerequisite Section Processing
// //
// for rr in prerequisites // for rr in prerequisites
@@ -313,14 +317,15 @@ impl SqliteAuthority {
match require.rr_type() { match require.rr_type() {
// ANY ANY empty Name is in use // ANY ANY empty Name is in use
RecordType::ANY => { RecordType::ANY => {
/*TODO: this works because the future here is always complete*/ if self
if block_on(self.lookup( .lookup(
&required_name, &required_name,
RecordType::ANY, RecordType::ANY,
LookupOptions::default(), LookupOptions::default(),
)) )
.unwrap_or_default() .await
.was_empty() .unwrap_or_default()
.was_empty()
{ {
return Err(ResponseCode::NXDomain); return Err(ResponseCode::NXDomain);
} else { } else {
@@ -329,14 +334,11 @@ impl SqliteAuthority {
} }
// ANY rrset empty RRset exists (value independent) // ANY rrset empty RRset exists (value independent)
rrset => { rrset => {
/*TODO: this works because the future here is always complete*/ if self
if block_on(self.lookup( .lookup(&required_name, rrset, LookupOptions::default())
&required_name, .await
rrset, .unwrap_or_default()
LookupOptions::default(), .was_empty()
))
.unwrap_or_default()
.was_empty()
{ {
return Err(ResponseCode::NXRRSet); return Err(ResponseCode::NXRRSet);
} else { } else {
@@ -353,14 +355,15 @@ impl SqliteAuthority {
match require.rr_type() { match require.rr_type() {
// NONE ANY empty Name is not in use // NONE ANY empty Name is not in use
RecordType::ANY => { RecordType::ANY => {
/*TODO: this works because the future here is always complete*/ if !self
if !block_on(self.lookup( .lookup(
&required_name, &required_name,
RecordType::ANY, RecordType::ANY,
LookupOptions::default(), LookupOptions::default(),
)) )
.unwrap_or_default() .await
.was_empty() .unwrap_or_default()
.was_empty()
{ {
return Err(ResponseCode::YXDomain); return Err(ResponseCode::YXDomain);
} else { } else {
@@ -369,14 +372,11 @@ impl SqliteAuthority {
} }
// NONE rrset empty RRset does not exist // NONE rrset empty RRset does not exist
rrset => { rrset => {
/*TODO: this works because the future here is always complete*/ if !self
if !block_on(self.lookup( .lookup(&required_name, rrset, LookupOptions::default())
&required_name, .await
rrset, .unwrap_or_default()
LookupOptions::default(), .was_empty()
))
.unwrap_or_default()
.was_empty()
{ {
return Err(ResponseCode::YXRRSet); return Err(ResponseCode::YXRRSet);
} else { } else {
@@ -391,15 +391,12 @@ impl SqliteAuthority {
class if class == self.class() => class if class == self.class() =>
// zone rrset rr RRset exists (value dependent) // zone rrset rr RRset exists (value dependent)
{ {
/*TODO: this works because the future here is always complete*/ if !self
if !block_on(self.lookup( .lookup(&required_name, require.rr_type(), LookupOptions::default())
&required_name, .await
require.rr_type(), .unwrap_or_default()
LookupOptions::default(), .iter()
)) .any(|rr| rr == require)
.unwrap_or_default()
.iter()
.any(|rr| rr == require)
{ {
return Err(ResponseCode::NXRRSet); return Err(ResponseCode::NXRRSet);
} else { } else {
@@ -440,8 +437,7 @@ impl SqliteAuthority {
#[cfg(feature = "dnssec")] #[cfg(feature = "dnssec")]
#[cfg_attr(docsrs, doc(cfg(feature = "dnssec")))] #[cfg_attr(docsrs, doc(cfg(feature = "dnssec")))]
#[allow(clippy::blocks_in_if_conditions)] #[allow(clippy::blocks_in_if_conditions)]
pub fn authorize(&self, update_message: &MessageRequest) -> UpdateResult<()> { pub async fn authorize(&self, update_message: &MessageRequest) -> UpdateResult<()> {
use futures_executor::block_on;
use log::debug; use log::debug;
use crate::client::rr::rdata::DNSSECRData; use crate::client::rr::rdata::DNSSECRData;
@@ -468,51 +464,56 @@ impl SqliteAuthority {
// verify sig0, currently the only authorization that is accepted. // verify sig0, currently the only authorization that is accepted.
let sig0s: &[Record] = update_message.sig0(); let sig0s: &[Record] = update_message.sig0();
debug!("authorizing with: {:?}", sig0s); debug!("authorizing with: {:?}", sig0s);
if !sig0s.is_empty() if !sig0s.is_empty() {
&& sig0s let mut found_key = false;
.iter() for sig in sig0s.iter().filter_map(|sig0| {
.filter_map(|sig0| { if let RData::DNSSEC(DNSSECRData::SIG(ref sig)) = *sig0.rdata() {
if let RData::DNSSEC(DNSSECRData::SIG(ref sig)) = *sig0.rdata() { Some(sig)
Some(sig) } else {
} else { None
None }
} }) {
}) let name = LowerName::from(sig.signer_name());
.any(|sig| { let keys = self
let name = LowerName::from(sig.signer_name()); .lookup(&name, RecordType::KEY, LookupOptions::default())
// TODO: updates should be async as well. .await;
let keys =
block_on(self.lookup(&name, RecordType::KEY, LookupOptions::default()));
let keys = match keys { let keys = match keys {
Ok(keys) => keys, Ok(keys) => keys,
Err(_) => return false, Err(_) => continue, // error trying to lookup a key by that name, try the next one.
}; };
debug!("found keys {:?}", keys); debug!("found keys {:?}", keys);
// TODO: check key usage flags and restrictions // TODO: check key usage flags and restrictions
keys.iter() found_key = keys
.filter_map(|rr_set| { .iter()
if let RData::DNSSEC(DNSSECRData::KEY(ref key)) = *rr_set.rdata() { .filter_map(|rr_set| {
Some(key) if let RData::DNSSEC(DNSSECRData::KEY(ref key)) = *rr_set.rdata() {
} else { Some(key)
None } else {
} None
}) }
.any(|key| { })
key.verify_message(update_message, sig.sig(), sig) .any(|key| {
.map(|_| { key.verify_message(update_message, sig.sig(), sig)
info!("verified sig: {:?} with key: {:?}", sig, key); .map(|_| {
true info!("verified sig: {:?} with key: {:?}", sig, key);
}) true
.unwrap_or_else(|_| { })
debug!("did not verify sig: {:?} with key: {:?}", sig, key); .unwrap_or_else(|_| {
false debug!("did not verify sig: {:?} with key: {:?}", sig, key);
}) false
}) })
}) });
{
return Ok(()); if found_key {
break; // stop searching for matching keys, we found one
}
}
if found_key {
return Ok(());
}
} else { } else {
warn!( warn!(
"no sig0 matched registered records: id {}", "no sig0 matched registered records: id {}",
@@ -825,9 +826,9 @@ impl DerefMut for SqliteAuthority {
} }
} }
#[async_trait::async_trait]
impl Authority for SqliteAuthority { impl Authority for SqliteAuthority {
type Lookup = <InMemoryAuthority as Authority>::Lookup; type Lookup = <InMemoryAuthority as Authority>::Lookup;
type LookupFuture = <InMemoryAuthority as Authority>::LookupFuture;
/// What type is this zone /// What type is this zone
fn zone_type(&self) -> ZoneType { fn zone_type(&self) -> ZoneType {
@@ -897,10 +898,10 @@ impl Authority for SqliteAuthority {
/// true if any of additions, updates or deletes were made to the zone, false otherwise. Err is /// true if any of additions, updates or deletes were made to the zone, false otherwise. Err is
/// returned in the case of bad data, etc. /// returned in the case of bad data, etc.
#[cfg(feature = "dnssec")] #[cfg(feature = "dnssec")]
fn update(&mut self, update: &MessageRequest) -> UpdateResult<bool> { async fn update(&mut self, update: &MessageRequest) -> UpdateResult<bool> {
// the spec says to authorize after prereqs, seems better to auth first. // the spec says to authorize after prereqs, seems better to auth first.
self.authorize(update)?; self.authorize(update).await?;
self.verify_prerequisites(update.prerequisites())?; self.verify_prerequisites(update.prerequisites()).await?;
self.pre_scan(update.updates())?; self.pre_scan(update.updates())?;
self.update_records(update.updates(), true) self.update_records(update.updates(), true)
@@ -908,7 +909,7 @@ impl Authority for SqliteAuthority {
/// Always fail when DNSSEC is disabled. /// Always fail when DNSSEC is disabled.
#[cfg(not(feature = "dnssec"))] #[cfg(not(feature = "dnssec"))]
fn update(&mut self, _update: &MessageRequest) -> UpdateResult<bool> { async fn update(&mut self, _update: &MessageRequest) -> UpdateResult<bool> {
Err(ResponseCode::NotImp) Err(ResponseCode::NotImp)
} }
@@ -931,21 +932,21 @@ impl Authority for SqliteAuthority {
/// # Return value /// # Return value
/// ///
/// None if there are no matching records, otherwise a `Vec` containing the found records. /// None if there are no matching records, otherwise a `Vec` containing the found records.
fn lookup( async fn lookup(
&self, &self,
name: &LowerName, name: &LowerName,
rtype: RecordType, rtype: RecordType,
lookup_options: LookupOptions, lookup_options: LookupOptions,
) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>> { ) -> Result<Self::Lookup, LookupError> {
self.in_memory.lookup(name, rtype, lookup_options) self.in_memory.lookup(name, rtype, lookup_options).await
} }
fn search( async fn search(
&self, &self,
query: &LowerQuery, query: &LowerQuery,
lookup_options: LookupOptions, lookup_options: LookupOptions,
) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>> { ) -> Result<Self::Lookup, LookupError> {
self.in_memory.search(query, lookup_options) self.in_memory.search(query, lookup_options).await
} }
/// Return the NSEC records based on the given name /// Return the NSEC records based on the given name
@@ -955,12 +956,12 @@ impl Authority for SqliteAuthority {
/// * `name` - given this name (i.e. the lookup name), return the NSEC record that is less than /// * `name` - given this name (i.e. the lookup name), return the NSEC record that is less than
/// this /// this
/// * `is_secure` - if true then it will return RRSIG records as well /// * `is_secure` - if true then it will return RRSIG records as well
fn get_nsec_records( async fn get_nsec_records(
&self, &self,
name: &LowerName, name: &LowerName,
lookup_options: LookupOptions, lookup_options: LookupOptions,
) -> Pin<Box<dyn Future<Output = Result<Self::Lookup, LookupError>> + Send>> { ) -> Result<Self::Lookup, LookupError> {
self.in_memory.get_nsec_records(name, lookup_options) self.in_memory.get_nsec_records(name, lookup_options).await
} }
} }

View File

@@ -391,7 +391,7 @@ pub fn test_update_errors<A: Authority<Lookup = AuthLookup>>(mut authority: A) {
let update = MessageRequest::from_bytes(&bytes).unwrap(); let update = MessageRequest::from_bytes(&bytes).unwrap();
// this is expected to fail, i.e. updates are not allowed // this is expected to fail, i.e. updates are not allowed
assert!(authority.update(&update).is_err()); assert!(block_on(authority.update(&update)).is_err());
} }
pub fn test_dots_in_name<A: Authority<Lookup = AuthLookup>>(authority: A) { pub fn test_dots_in_name<A: Authority<Lookup = AuthLookup>>(authority: A) {

View File

@@ -1,16 +1,19 @@
#![cfg(feature = "dnssec")] #![cfg(feature = "dnssec")]
use std::future::Future; use std::{
use std::net::{Ipv4Addr, Ipv6Addr}; future::Future,
use std::str::FromStr; net::{Ipv4Addr, Ipv6Addr},
str::FromStr,
};
use futures_executor::block_on; use futures_executor::block_on;
use trust_dns_client::op::update_message; use trust_dns_client::{
use trust_dns_client::op::{Message, Query, ResponseCode}; op::{update_message, Message, Query, ResponseCode},
use trust_dns_client::proto::rr::{DNSClass, Name, RData, Record, RecordSet, RecordType}; proto::rr::{DNSClass, Name, RData, Record, RecordSet, RecordType},
use trust_dns_client::rr::dnssec::{Algorithm, SigSigner, SupportedAlgorithms, Verifier}; rr::dnssec::{Algorithm, SigSigner, SupportedAlgorithms, Verifier},
use trust_dns_client::serialize::binary::{BinDecodable, BinEncodable, BinSerializable}; serialize::binary::{BinDecodable, BinEncodable, BinSerializable},
};
use trust_dns_server::authority::{ use trust_dns_server::authority::{
AuthLookup, Authority, DnssecAuthority, LookupError, LookupOptions, MessageRequest, AuthLookup, Authority, DnssecAuthority, LookupError, LookupOptions, MessageRequest,
UpdateResult, UpdateResult,
@@ -25,7 +28,7 @@ fn update_authority<A: Authority<Lookup = AuthLookup>>(
let message = message.to_bytes().unwrap(); let message = message.to_bytes().unwrap();
let request = MessageRequest::from_bytes(&message).unwrap(); let request = MessageRequest::from_bytes(&message).unwrap();
authority.update(&request) block_on(authority.update(&request))
} }
pub fn test_create<A: Authority<Lookup = AuthLookup>>(mut authority: A, keys: &[SigSigner]) { pub fn test_create<A: Authority<Lookup = AuthLookup>>(mut authority: A, keys: &[SigSigner]) {

View File

@@ -69,6 +69,7 @@ dns-over-tls = []
sqlite = ["trust-dns-server/sqlite"] sqlite = ["trust-dns-server/sqlite"]
[dependencies] [dependencies]
async-trait = "0.1.42"
chrono = "0.4" chrono = "0.4"
env_logger = "0.9" env_logger = "0.9"
lazy_static = "1.0" lazy_static = "1.0"

View File

@@ -96,8 +96,9 @@ impl TestResponseHandler {
} }
} }
#[async_trait::async_trait]
impl ResponseHandler for TestResponseHandler { impl ResponseHandler for TestResponseHandler {
fn send_response(&mut self, response: MessageResponse) -> io::Result<()> { async fn send_response(&mut self, response: MessageResponse<'_, '_>) -> io::Result<()> {
let buf = &mut self.buf.lock().unwrap(); let buf = &mut self.buf.lock().unwrap();
buf.clear(); buf.clear();
let mut encoder = BinEncoder::new(buf); let mut encoder = BinEncoder::new(buf);

View File

@@ -1,9 +1,10 @@
use std::net::*; use std::net::*;
use std::str::FromStr; use std::str::FromStr;
use std::sync::{Arc, RwLock}; use std::sync::Arc;
use futures::executor::block_on; use futures::executor::block_on;
use futures::lock::Mutex;
use trust_dns_client::op::*; use trust_dns_client::op::*;
use trust_dns_client::rr::rdata::*; use trust_dns_client::rr::rdata::*;
use trust_dns_client::rr::*; use trust_dns_client::rr::*;
@@ -119,8 +120,8 @@ fn test_catalog_lookup() {
let test_origin = test.origin().clone(); let test_origin = test.origin().clone();
let mut catalog: Catalog = Catalog::new(); let mut catalog: Catalog = Catalog::new();
catalog.upsert(origin.clone(), Box::new(Arc::new(RwLock::new(example)))); catalog.upsert(origin.clone(), Box::new(Arc::new(Mutex::new(example))));
catalog.upsert(test_origin.clone(), Box::new(Arc::new(RwLock::new(test)))); catalog.upsert(test_origin.clone(), Box::new(Arc::new(Mutex::new(test))));
let mut question: Message = Message::new(); let mut question: Message = Message::new();
@@ -190,8 +191,8 @@ fn test_catalog_lookup_soa() {
let test_origin = test.origin().clone(); let test_origin = test.origin().clone();
let mut catalog: Catalog = Catalog::new(); let mut catalog: Catalog = Catalog::new();
catalog.upsert(origin.clone(), Box::new(Arc::new(RwLock::new(example)))); catalog.upsert(origin.clone(), Box::new(Arc::new(Mutex::new(example))));
catalog.upsert(test_origin, Box::new(Arc::new(RwLock::new(test)))); catalog.upsert(test_origin, Box::new(Arc::new(Mutex::new(test))));
let mut question: Message = Message::new(); let mut question: Message = Message::new();
@@ -254,7 +255,7 @@ fn test_catalog_nx_soa() {
let origin = example.origin().clone(); let origin = example.origin().clone();
let mut catalog: Catalog = Catalog::new(); let mut catalog: Catalog = Catalog::new();
catalog.upsert(origin, Box::new(Arc::new(RwLock::new(example)))); catalog.upsert(origin, Box::new(Arc::new(Mutex::new(example))));
let mut question: Message = Message::new(); let mut question: Message = Message::new();
@@ -299,7 +300,7 @@ fn test_non_authoritive_nx_refused() {
let origin = example.origin().clone(); let origin = example.origin().clone();
let mut catalog: Catalog = Catalog::new(); let mut catalog: Catalog = Catalog::new();
catalog.upsert(origin, Box::new(Arc::new(RwLock::new(example)))); catalog.upsert(origin, Box::new(Arc::new(Mutex::new(example))));
let mut question: Message = Message::new(); let mut question: Message = Message::new();
@@ -350,7 +351,7 @@ fn test_axfr() {
.clone(); .clone();
let mut catalog: Catalog = Catalog::new(); let mut catalog: Catalog = Catalog::new();
catalog.upsert(origin.clone(), Box::new(Arc::new(RwLock::new(test)))); catalog.upsert(origin.clone(), Box::new(Arc::new(Mutex::new(test))));
let mut query: Query = Query::new(); let mut query: Query = Query::new();
query.set_name(origin.clone().into()); query.set_name(origin.clone().into());
@@ -467,7 +468,7 @@ fn test_axfr_refused() {
let origin = test.origin().clone(); let origin = test.origin().clone();
let mut catalog: Catalog = Catalog::new(); let mut catalog: Catalog = Catalog::new();
catalog.upsert(origin.clone(), Box::new(Arc::new(RwLock::new(test)))); catalog.upsert(origin.clone(), Box::new(Arc::new(Mutex::new(test))));
let mut query: Query = Query::new(); let mut query: Query = Query::new();
query.set_name(origin.into()); query.set_name(origin.into());
@@ -503,7 +504,7 @@ fn test_cname_additionals() {
let origin = example.origin().clone(); let origin = example.origin().clone();
let mut catalog: Catalog = Catalog::new(); let mut catalog: Catalog = Catalog::new();
catalog.upsert(origin, Box::new(Arc::new(RwLock::new(example)))); catalog.upsert(origin, Box::new(Arc::new(Mutex::new(example))));
let mut question: Message = Message::new(); let mut question: Message = Message::new();
@@ -547,7 +548,7 @@ fn test_multiple_cname_additionals() {
let origin = example.origin().clone(); let origin = example.origin().clone();
let mut catalog: Catalog = Catalog::new(); let mut catalog: Catalog = Catalog::new();
catalog.upsert(origin, Box::new(Arc::new(RwLock::new(example)))); catalog.upsert(origin, Box::new(Arc::new(Mutex::new(example))));
let mut question: Message = Message::new(); let mut question: Message = Message::new();

View File

@@ -1,41 +1,43 @@
use std::net::*; use std::{
use std::str::FromStr; net::*,
use std::sync::{Arc, Mutex, RwLock}; str::FromStr,
sync::{Arc, Mutex as StdMutex},
};
#[cfg(feature = "dnssec")] #[cfg(feature = "dnssec")]
use chrono::Duration; use chrono::Duration;
use futures::{Future, FutureExt, TryFutureExt}; use futures::{lock::Mutex, Future, FutureExt, TryFutureExt};
use tokio::net::TcpStream as TokioTcpStream; use tokio::{
use tokio::net::UdpSocket as TokioUdpSocket; net::{TcpStream as TokioTcpStream, UdpSocket as TokioUdpSocket},
use tokio::runtime::Runtime; runtime::Runtime,
};
#[cfg(all(feature = "dnssec", feature = "sqlite"))] #[cfg(all(feature = "dnssec", feature = "sqlite"))]
use trust_dns_client::client::Signer; use trust_dns_client::client::Signer;
use trust_dns_client::op::{Message, MessageType, OpCode, Query, ResponseCode};
#[cfg(feature = "dnssec")] #[cfg(feature = "dnssec")]
use trust_dns_client::rr::dnssec::SigSigner; use trust_dns_client::rr::{dnssec::SigSigner, Record};
#[cfg(feature = "dnssec")]
use trust_dns_client::rr::Record;
use trust_dns_client::rr::{DNSClass, Name, RData, RecordSet, RecordType};
use trust_dns_client::tcp::TcpClientStream;
use trust_dns_client::udp::UdpClientStream;
use trust_dns_client::{ use trust_dns_client::{
client::{AsyncClient, ClientHandle}, client::{AsyncClient, ClientHandle},
rr::rdata::opt::EdnsOption, error::ClientErrorKind,
op::{Edns, Message, MessageType, OpCode, Query, ResponseCode},
rr::{
rdata::opt::{EdnsCode, EdnsOption},
DNSClass, Name, RData, RecordSet, RecordType,
},
tcp::TcpClientStream,
udp::UdpClientStream,
}; };
use trust_dns_client::{error::ClientErrorKind, op::Edns, rr::rdata::opt::EdnsCode};
use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use trust_dns_proto::xfer::FirstAnswer;
#[cfg(feature = "dnssec")] #[cfg(feature = "dnssec")]
use trust_dns_proto::xfer::{DnsExchangeBackground, DnsMultiplexer}; use trust_dns_proto::xfer::{DnsExchangeBackground, DnsMultiplexer};
use trust_dns_proto::DnsHandle;
#[cfg(all(feature = "dnssec", feature = "sqlite"))] #[cfg(all(feature = "dnssec", feature = "sqlite"))]
use trust_dns_proto::TokioTime; use trust_dns_proto::TokioTime;
use trust_dns_proto::{iocompat::AsyncIoTokioAsStd, xfer::FirstAnswer, DnsHandle};
use trust_dns_server::authority::{Authority, Catalog}; use trust_dns_server::authority::{Authority, Catalog};
use trust_dns_integration::authority::create_example; use trust_dns_integration::{
use trust_dns_integration::{NeverReturnsClientStream, TestClientStream}; authority::create_example, NeverReturnsClientStream, TestClientStream,
};
#[test] #[test]
fn test_query_nonet() { fn test_query_nonet() {
@@ -45,11 +47,11 @@ fn test_query_nonet() {
let mut catalog = Catalog::new(); let mut catalog = Catalog::new();
catalog.upsert( catalog.upsert(
authority.origin().clone(), authority.origin().clone(),
Box::new(Arc::new(RwLock::new(authority))), Box::new(Arc::new(Mutex::new(authority))),
); );
let io_loop = Runtime::new().unwrap(); let io_loop = Runtime::new().unwrap();
let (stream, sender) = TestClientStream::new(Arc::new(Mutex::new(catalog))); let (stream, sender) = TestClientStream::new(Arc::new(StdMutex::new(catalog)));
let client = AsyncClient::new(stream, sender, None); let client = AsyncClient::new(stream, sender, None);
let (mut client, bg) = io_loop.block_on(client).expect("client failed to connect"); let (mut client, bg) = io_loop.block_on(client).expect("client failed to connect");
trust_dns_proto::spawn_bg(&io_loop, bg); trust_dns_proto::spawn_bg(&io_loop, bg);
@@ -251,11 +253,11 @@ fn test_notify() {
let mut catalog = Catalog::new(); let mut catalog = Catalog::new();
catalog.upsert( catalog.upsert(
authority.origin().clone(), authority.origin().clone(),
Box::new(Arc::new(RwLock::new(authority))), Box::new(Arc::new(Mutex::new(authority))),
); );
let io_loop = Runtime::new().unwrap(); let io_loop = Runtime::new().unwrap();
let (stream, sender) = TestClientStream::new(Arc::new(Mutex::new(catalog))); let (stream, sender) = TestClientStream::new(Arc::new(StdMutex::new(catalog)));
let client = AsyncClient::new(stream, sender, None); let client = AsyncClient::new(stream, sender, None);
let (mut client, bg) = io_loop.block_on(client).expect("client failed to connect"); let (mut client, bg) = io_loop.block_on(client).expect("client failed to connect");
trust_dns_proto::spawn_bg(&io_loop, bg); trust_dns_proto::spawn_bg(&io_loop, bg);
@@ -316,11 +318,11 @@ async fn create_sig0_ready_client() -> (
let mut catalog = Catalog::new(); let mut catalog = Catalog::new();
catalog.upsert( catalog.upsert(
authority.origin().clone(), authority.origin().clone(),
Box::new(Arc::new(RwLock::new(authority))), Box::new(Arc::new(Mutex::new(authority))),
); );
let signer = Arc::new(signer.into()); let signer = Arc::new(signer.into());
let (stream, sender) = TestClientStream::new(Arc::new(Mutex::new(catalog))); let (stream, sender) = TestClientStream::new(Arc::new(StdMutex::new(catalog)));
let client = AsyncClient::new(stream, sender, Some(signer)) let client = AsyncClient::new(stream, sender, Some(signer))
.await .await
.expect("failed to get new AsyncClient"); .expect("failed to get new AsyncClient");

View File

@@ -2,12 +2,13 @@ use std::net::*;
use std::pin::Pin; use std::pin::Pin;
#[cfg(feature = "dnssec")] #[cfg(feature = "dnssec")]
use std::str::FromStr; use std::str::FromStr;
use std::sync::{Arc, Mutex, RwLock}; use std::sync::{Arc, Mutex as StdMutex};
#[cfg(feature = "dnssec")] #[cfg(feature = "dnssec")]
use chrono::Duration; use chrono::Duration;
use futures::Future; use futures::Future;
use futures::lock::Mutex;
use trust_dns_client::client::Signer; use trust_dns_client::client::Signer;
#[cfg(feature = "dnssec")] #[cfg(feature = "dnssec")]
use trust_dns_client::client::SyncDnssecClient; use trust_dns_client::client::SyncDnssecClient;
@@ -30,13 +31,13 @@ use trust_dns_proto::xfer::{DnsMultiplexer, DnsMultiplexerConnect};
use trust_dns_server::authority::{Authority, Catalog}; use trust_dns_server::authority::{Authority, Catalog};
pub struct TestClientConnection { pub struct TestClientConnection {
catalog: Arc<Mutex<Catalog>>, catalog: Arc<StdMutex<Catalog>>,
} }
impl TestClientConnection { impl TestClientConnection {
pub fn new(catalog: Catalog) -> TestClientConnection { pub fn new(catalog: Catalog) -> TestClientConnection {
TestClientConnection { TestClientConnection {
catalog: Arc::new(Mutex::new(catalog)), catalog: Arc::new(StdMutex::new(catalog)),
} }
} }
} }
@@ -64,7 +65,7 @@ fn test_query_nonet() {
let mut catalog = Catalog::new(); let mut catalog = Catalog::new();
catalog.upsert( catalog.upsert(
authority.origin().clone(), authority.origin().clone(),
Box::new(Arc::new(RwLock::new(authority))), Box::new(Arc::new(Mutex::new(authority))),
); );
let client = SyncClient::new(TestClientConnection::new(catalog)); let client = SyncClient::new(TestClientConnection::new(catalog));
@@ -475,7 +476,7 @@ fn create_sig0_ready_client(mut catalog: Catalog) -> (SyncClient<TestClientConne
catalog.upsert( catalog.upsert(
authority.origin().clone(), authority.origin().clone(),
Box::new(Arc::new(RwLock::new(authority))), Box::new(Arc::new(Mutex::new(authority))),
); );
let client = SyncClient::with_signer(TestClientConnection::new(catalog), signer); let client = SyncClient::with_signer(TestClientConnection::new(catalog), signer);

View File

@@ -2,8 +2,9 @@
use std::net::*; use std::net::*;
use std::str::FromStr; use std::str::FromStr;
use std::sync::{Arc, Mutex, RwLock}; use std::sync::{Arc, Mutex as StdMutex};
use futures::lock::Mutex;
use tokio::net::TcpStream as TokioTcpStream; use tokio::net::TcpStream as TokioTcpStream;
use tokio::net::UdpSocket as TokioUdpSocket; use tokio::net::UdpSocket as TokioUdpSocket;
use tokio::runtime::Runtime; use tokio::runtime::Runtime;
@@ -228,11 +229,11 @@ where
let mut catalog = Catalog::new(); let mut catalog = Catalog::new();
catalog.upsert( catalog.upsert(
authority.origin().clone(), authority.origin().clone(),
Box::new(Arc::new(RwLock::new(authority))), Box::new(Arc::new(Mutex::new(authority))),
); );
let io_loop = Runtime::new().unwrap(); let io_loop = Runtime::new().unwrap();
let (stream, sender) = TestClientStream::new(Arc::new(Mutex::new(catalog))); let (stream, sender) = TestClientStream::new(Arc::new(StdMutex::new(catalog)));
let client = AsyncClient::new(stream, sender, None); let client = AsyncClient::new(stream, sender, None);
let (client, bg) = io_loop let (client, bg) = io_loop

View File

@@ -1,25 +1,32 @@
use std::net::*; use std::{
use std::str::FromStr; net::*,
use std::sync::{Arc, Mutex, RwLock}; str::FromStr,
sync::{Arc, Mutex as StdMutex},
};
use futures::lock::Mutex;
use tokio::runtime::Runtime; use tokio::runtime::Runtime;
use trust_dns_proto::op::{NoopMessageFinalizer, Query}; use trust_dns_proto::{
use trust_dns_proto::rr::{DNSClass, Name, RData, Record, RecordType}; op::{NoopMessageFinalizer, Query},
use trust_dns_proto::xfer::{DnsExchange, DnsMultiplexer}; rr::{DNSClass, Name, RData, Record, RecordType},
use trust_dns_proto::TokioTime; xfer::{DnsExchange, DnsMultiplexer},
use trust_dns_resolver::caching_client::CachingClient; TokioTime,
use trust_dns_resolver::config::LookupIpStrategy; };
use trust_dns_resolver::error::ResolveError; use trust_dns_resolver::{
use trust_dns_resolver::lookup::{Lookup, LookupFuture}; caching_client::CachingClient,
use trust_dns_resolver::lookup_ip::LookupIpFuture; config::LookupIpStrategy,
use trust_dns_resolver::Hosts; error::ResolveError,
use trust_dns_server::authority::{Authority, Catalog}; lookup::{Lookup, LookupFuture},
use trust_dns_server::store::in_memory::InMemoryAuthority; lookup_ip::LookupIpFuture,
Hosts,
};
use trust_dns_server::{
authority::{Authority, Catalog},
store::in_memory::InMemoryAuthority,
};
use trust_dns_integration::authority::create_example; use trust_dns_integration::{authority::create_example, mock_client::*, TestClientStream};
use trust_dns_integration::mock_client::*;
use trust_dns_integration::TestClientStream;
#[test] #[test]
fn test_lookup() { fn test_lookup() {
@@ -27,11 +34,11 @@ fn test_lookup() {
let mut catalog = Catalog::new(); let mut catalog = Catalog::new();
catalog.upsert( catalog.upsert(
authority.origin().clone(), authority.origin().clone(),
Box::new(Arc::new(RwLock::new(authority))), Box::new(Arc::new(Mutex::new(authority))),
); );
let io_loop = Runtime::new().unwrap(); let io_loop = Runtime::new().unwrap();
let (stream, sender) = TestClientStream::new(Arc::new(Mutex::new(catalog))); let (stream, sender) = TestClientStream::new(Arc::new(StdMutex::new(catalog)));
let dns_conn = DnsMultiplexer::new(stream, sender, NoopMessageFinalizer::new()); let dns_conn = DnsMultiplexer::new(stream, sender, NoopMessageFinalizer::new());
let client = DnsExchange::connect::<_, _, TokioTime>(dns_conn); let client = DnsExchange::connect::<_, _, TokioTime>(dns_conn);
@@ -58,11 +65,11 @@ fn test_lookup_hosts() {
let mut catalog = Catalog::new(); let mut catalog = Catalog::new();
catalog.upsert( catalog.upsert(
authority.origin().clone(), authority.origin().clone(),
Box::new(Arc::new(RwLock::new(authority))), Box::new(Arc::new(Mutex::new(authority))),
); );
let io_loop = Runtime::new().unwrap(); let io_loop = Runtime::new().unwrap();
let (stream, sender) = TestClientStream::new(Arc::new(Mutex::new(catalog))); let (stream, sender) = TestClientStream::new(Arc::new(StdMutex::new(catalog)));
let dns_conn = DnsMultiplexer::new(stream, sender, NoopMessageFinalizer::new()); let dns_conn = DnsMultiplexer::new(stream, sender, NoopMessageFinalizer::new());
let client = DnsExchange::connect::<_, _, TokioTime>(dns_conn); let client = DnsExchange::connect::<_, _, TokioTime>(dns_conn);
@@ -119,11 +126,11 @@ fn test_lookup_ipv4_like() {
let mut catalog = Catalog::new(); let mut catalog = Catalog::new();
catalog.upsert( catalog.upsert(
authority.origin().clone(), authority.origin().clone(),
Box::new(Arc::new(RwLock::new(authority))), Box::new(Arc::new(Mutex::new(authority))),
); );
let io_loop = Runtime::new().unwrap(); let io_loop = Runtime::new().unwrap();
let (stream, sender) = TestClientStream::new(Arc::new(Mutex::new(catalog))); let (stream, sender) = TestClientStream::new(Arc::new(StdMutex::new(catalog)));
let dns_conn = DnsMultiplexer::new(stream, sender, NoopMessageFinalizer::new()); let dns_conn = DnsMultiplexer::new(stream, sender, NoopMessageFinalizer::new());
let client = DnsExchange::connect::<_, _, TokioTime>(dns_conn); let client = DnsExchange::connect::<_, _, TokioTime>(dns_conn);
@@ -152,11 +159,11 @@ fn test_lookup_ipv4_like_fall_through() {
let mut catalog = Catalog::new(); let mut catalog = Catalog::new();
catalog.upsert( catalog.upsert(
authority.origin().clone(), authority.origin().clone(),
Box::new(Arc::new(RwLock::new(authority))), Box::new(Arc::new(Mutex::new(authority))),
); );
let io_loop = Runtime::new().unwrap(); let io_loop = Runtime::new().unwrap();
let (stream, sender) = TestClientStream::new(Arc::new(Mutex::new(catalog))); let (stream, sender) = TestClientStream::new(Arc::new(StdMutex::new(catalog)));
let dns_conn = DnsMultiplexer::new(stream, sender, NoopMessageFinalizer::new()); let dns_conn = DnsMultiplexer::new(stream, sender, NoopMessageFinalizer::new());
let client = DnsExchange::connect::<_, _, TokioTime>(dns_conn); let client = DnsExchange::connect::<_, _, TokioTime>(dns_conn);

View File

@@ -1,10 +1,11 @@
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::str::FromStr; use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, RwLock}; use std::sync::Arc;
use std::thread; use std::thread;
use std::time::Duration; use std::time::Duration;
use futures::lock::Mutex;
use futures::{future, Future, FutureExt}; use futures::{future, Future, FutureExt};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
@@ -264,7 +265,7 @@ fn new_catalog() -> Catalog {
let origin = example.origin().clone(); let origin = example.origin().clone();
let mut catalog: Catalog = Catalog::new(); let mut catalog: Catalog = Catalog::new();
catalog.upsert(origin, Box::new(Arc::new(RwLock::new(example)))); catalog.upsert(origin, Box::new(Arc::new(Mutex::new(example))));
catalog catalog
} }

View File

@@ -186,7 +186,10 @@ fn test_authorize() {
let bytes = message.to_bytes().unwrap(); let bytes = message.to_bytes().unwrap();
let message = MessageRequest::from_bytes(&bytes).unwrap(); let message = MessageRequest::from_bytes(&bytes).unwrap();
assert_eq!(authority.authorize(&message), Err(ResponseCode::Refused)); assert_eq!(
block_on(authority.authorize(&message)),
Err(ResponseCode::Refused)
);
// TODO: this will nee to be more complex as additional policies are added // TODO: this will nee to be more complex as additional policies are added
// authority.set_allow_update(true); // authority.set_allow_update(true);
@@ -203,151 +206,174 @@ fn test_prerequisites() {
// first check the initial negatives, ttl = 0, and the zone is the same // first check the initial negatives, ttl = 0, and the zone is the same
assert_eq!( assert_eq!(
authority.verify_prerequisites(&[Record::new() block_on(
.set_name(not_in_zone.clone()) authority.verify_prerequisites(&[Record::new()
.set_ttl(86400) .set_name(not_in_zone.clone())
.set_rr_type(RecordType::A) .set_ttl(86400)
.set_dns_class(DNSClass::IN) .set_rr_type(RecordType::A)
.set_rdata(RData::NULL(NULL::new())) .set_dns_class(DNSClass::IN)
.clone()],), .set_rdata(RData::NULL(NULL::new()))
.clone()],)
),
Err(ResponseCode::FormErr) Err(ResponseCode::FormErr)
); );
assert_eq!( assert_eq!(
authority.verify_prerequisites(&[Record::new() block_on(
.set_name(not_zone) authority.verify_prerequisites(&[Record::new()
.set_ttl(0) .set_name(not_zone)
.set_rr_type(RecordType::A) .set_ttl(0)
.set_dns_class(DNSClass::IN) .set_rr_type(RecordType::A)
.set_rdata(RData::NULL(NULL::new())) .set_dns_class(DNSClass::IN)
.clone()],), .set_rdata(RData::NULL(NULL::new()))
.clone()],)
),
Err(ResponseCode::NotZone) Err(ResponseCode::NotZone)
); );
// * ANY ANY empty Name is in use // * ANY ANY empty Name is in use
assert!(authority assert!(block_on(
.verify_prerequisites(&[Record::new() authority.verify_prerequisites(&[Record::new()
.set_name(authority.origin().clone().into()) .set_name(authority.origin().clone().into())
.set_ttl(0) .set_ttl(0)
.set_dns_class(DNSClass::ANY) .set_dns_class(DNSClass::ANY)
.set_rr_type(RecordType::ANY) .set_rr_type(RecordType::ANY)
.set_rdata(RData::NULL(NULL::new())) .set_rdata(RData::NULL(NULL::new()))
.clone()]) .clone()])
.is_ok()); )
.is_ok());
assert_eq!( assert_eq!(
authority.verify_prerequisites(&[Record::new() block_on(
.set_name(not_in_zone.clone()) authority.verify_prerequisites(&[Record::new()
.set_ttl(0) .set_name(not_in_zone.clone())
.set_dns_class(DNSClass::ANY) .set_ttl(0)
.set_rr_type(RecordType::ANY) .set_dns_class(DNSClass::ANY)
.set_rdata(RData::NULL(NULL::new())) .set_rr_type(RecordType::ANY)
.clone()],), .set_rdata(RData::NULL(NULL::new()))
.clone()],)
),
Err(ResponseCode::NXDomain) Err(ResponseCode::NXDomain)
); );
// * ANY rrset empty RRset exists (value independent) // * ANY rrset empty RRset exists (value independent)
assert!(authority assert!(block_on(
.verify_prerequisites(&[Record::new() authority.verify_prerequisites(&[Record::new()
.set_name(authority.origin().clone().into()) .set_name(authority.origin().clone().into())
.set_ttl(0) .set_ttl(0)
.set_dns_class(DNSClass::ANY) .set_dns_class(DNSClass::ANY)
.set_rr_type(RecordType::A) .set_rr_type(RecordType::A)
.set_rdata(RData::NULL(NULL::new())) .set_rdata(RData::NULL(NULL::new()))
.clone()]) .clone()])
.is_ok()); )
.is_ok());
assert_eq!( assert_eq!(
authority.verify_prerequisites(&[Record::new() block_on(
.set_name(not_in_zone.clone()) authority.verify_prerequisites(&[Record::new()
.set_ttl(0) .set_name(not_in_zone.clone())
.set_dns_class(DNSClass::ANY) .set_ttl(0)
.set_rr_type(RecordType::A) .set_dns_class(DNSClass::ANY)
.set_rdata(RData::NULL(NULL::new())) .set_rr_type(RecordType::A)
.clone()],), .set_rdata(RData::NULL(NULL::new()))
.clone()],)
),
Err(ResponseCode::NXRRSet) Err(ResponseCode::NXRRSet)
); );
// * NONE ANY empty Name is not in use // * NONE ANY empty Name is not in use
assert!(authority assert!(block_on(
.verify_prerequisites(&[Record::new() authority.verify_prerequisites(&[Record::new()
.set_name(not_in_zone.clone()) .set_name(not_in_zone.clone())
.set_ttl(0) .set_ttl(0)
.set_dns_class(DNSClass::NONE) .set_dns_class(DNSClass::NONE)
.set_rr_type(RecordType::ANY) .set_rr_type(RecordType::ANY)
.set_rdata(RData::NULL(NULL::new())) .set_rdata(RData::NULL(NULL::new()))
.clone()]) .clone()])
.is_ok()); )
.is_ok());
assert_eq!( assert_eq!(
authority.verify_prerequisites(&[Record::new() block_on(
.set_name(authority.origin().clone().into()) authority.verify_prerequisites(&[Record::new()
.set_ttl(0) .set_name(authority.origin().clone().into())
.set_dns_class(DNSClass::NONE) .set_ttl(0)
.set_rr_type(RecordType::ANY) .set_dns_class(DNSClass::NONE)
.set_rdata(RData::NULL(NULL::new())) .set_rr_type(RecordType::ANY)
.clone()],), .set_rdata(RData::NULL(NULL::new()))
.clone()],)
),
Err(ResponseCode::YXDomain) Err(ResponseCode::YXDomain)
); );
// * NONE rrset empty RRset does not exist // * NONE rrset empty RRset does not exist
assert!(authority assert!(block_on(
.verify_prerequisites(&[Record::new() authority.verify_prerequisites(&[Record::new()
.set_name(not_in_zone.clone()) .set_name(not_in_zone.clone())
.set_ttl(0) .set_ttl(0)
.set_dns_class(DNSClass::NONE) .set_dns_class(DNSClass::NONE)
.set_rr_type(RecordType::A) .set_rr_type(RecordType::A)
.set_rdata(RData::NULL(NULL::new())) .set_rdata(RData::NULL(NULL::new()))
.clone()]) .clone()])
.is_ok()); )
.is_ok());
assert_eq!( assert_eq!(
authority.verify_prerequisites(&[Record::new() block_on(
.set_name(authority.origin().clone().into()) authority.verify_prerequisites(&[Record::new()
.set_ttl(0) .set_name(authority.origin().clone().into())
.set_dns_class(DNSClass::NONE) .set_ttl(0)
.set_rr_type(RecordType::A) .set_dns_class(DNSClass::NONE)
.set_rdata(RData::NULL(NULL::new())) .set_rr_type(RecordType::A)
.clone()],), .set_rdata(RData::NULL(NULL::new()))
.clone()],)
),
Err(ResponseCode::YXRRSet) Err(ResponseCode::YXRRSet)
); );
// * zone rrset rr RRset exists (value dependent) // * zone rrset rr RRset exists (value dependent)
assert!(authority assert!(block_on(
.verify_prerequisites(&[Record::new() authority.verify_prerequisites(&[Record::new()
.set_name(authority.origin().clone().into()) .set_name(authority.origin().clone().into())
.set_ttl(0) .set_ttl(0)
.set_dns_class(DNSClass::IN) .set_dns_class(DNSClass::IN)
.set_rr_type(RecordType::A) .set_rr_type(RecordType::A)
.set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 34))) .set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 34)))
.clone()]) .clone()])
.is_ok()); )
.is_ok());
// wrong class // wrong class
assert_eq!( assert_eq!(
authority.verify_prerequisites(&[Record::new() block_on(
.set_name(authority.origin().clone().into()) authority.verify_prerequisites(&[Record::new()
.set_ttl(0) .set_name(authority.origin().clone().into())
.set_dns_class(DNSClass::CH) .set_ttl(0)
.set_rr_type(RecordType::A) .set_dns_class(DNSClass::CH)
.set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 34))) .set_rr_type(RecordType::A)
.clone()],), .set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 34)))
.clone()],)
),
Err(ResponseCode::FormErr) Err(ResponseCode::FormErr)
); );
// wrong Name // wrong Name
assert_eq!( assert_eq!(
authority.verify_prerequisites(&[Record::new() block_on(
.set_name(not_in_zone) authority.verify_prerequisites(&[Record::new()
.set_ttl(0) .set_name(not_in_zone)
.set_dns_class(DNSClass::IN) .set_ttl(0)
.set_rr_type(RecordType::A) .set_dns_class(DNSClass::IN)
.set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 24))) .set_rr_type(RecordType::A)
.clone()],), .set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 24)))
.clone()],)
),
Err(ResponseCode::NXRRSet) Err(ResponseCode::NXRRSet)
); );
// wrong IP // wrong IP
assert_eq!( assert_eq!(
authority.verify_prerequisites(&[Record::new() block_on(
.set_name(authority.origin().clone().into()) authority.verify_prerequisites(&[Record::new()
.set_ttl(0) .set_name(authority.origin().clone().into())
.set_dns_class(DNSClass::IN) .set_ttl(0)
.set_rr_type(RecordType::A) .set_dns_class(DNSClass::IN)
.set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 24))) .set_rr_type(RecordType::A)
.clone()],), .set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 24)))
.clone()],)
),
Err(ResponseCode::NXRRSet) Err(ResponseCode::NXRRSet)
); );
} }