diff --git a/CHANGELOG.md b/CHANGELOG.md index a72178d7..3af49acd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,10 @@ All notes should be prepended with the location of the change, e.g. `(proto)` or ### Changed +- (server) `ResponseHandler` trait is now `async_trait`, requires all impls to be annotated with `#[async_trait]` #1550 +- (server) `Authority` impls required to be internally modifiable and `Send + Sync` #1550 +- (server) Most `Authority` methods changes to `async fn` rather than returning custom `Future` impls #1550 +- (server) `Authority` trait is now `async_trait`, requires all impls to be annotated with `#[async_trait]` #1550 - (proto) Header now stores ResponseCode instead of just u8 #1537 - (client) improved async client example documentation (@ErwanDL) #1539 - (resolver) on `REFUSED` (and other negative) response(s), fall back to other nameservers (@peterthejohnston) #1513 #1526 diff --git a/Cargo.lock b/Cargo.lock index 5bc0dbca..72d791a2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,7 +1,5 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 - [[package]] name = "addr2line" version = "0.16.0" diff --git a/bin/src/named.rs b/bin/src/named.rs index 9a8fe4c1..7dac5b19 100644 --- a/bin/src/named.rs +++ b/bin/src/named.rs @@ -47,10 +47,9 @@ use std::{ }; use clap::{Arg, ArgMatches}; -use futures::lock::Mutex; use tokio::{ net::{TcpListener, UdpSocket}, - runtime::{self, Runtime}, + runtime, }; use trust_dns_client::rr::Name; @@ -74,14 +73,14 @@ use trust_dns_server::{logger, server::ServerFuture}; use {trust_dns_client::rr::rdata::key::KeyUsage, trust_dns_server::authority::DnssecAuthority}; #[cfg(feature = "dnssec")] -fn load_keys( +async fn load_keys( authority: &mut A, zone_name: Name, zone_config: &ZoneConfig, ) -> Result<(), String> where A: DnssecAuthority, - L: Send + Sized + 'static, + L: Send + Sync + Sized + 'static, { if zone_config.is_dnssec_enabled() { for key_config in zone_config.get_keys() { @@ -97,6 +96,7 @@ where })?; authority .add_zone_signing_key(zone_signer) + .await .expect("failed to add zone signing key to authority"); } if key_config.is_zone_update_auth() { @@ -110,19 +110,20 @@ where .expect("failed to get sig0 key"); authority .add_update_auth_key(zone_name.clone(), public_key) + .await .expect("failed to add update auth key to authority"); } } info!("signing zone: {}", zone_config.get_zone().unwrap()); - authority.secure_zone().expect("failed to sign zone"); + authority.secure_zone().await.expect("failed to sign zone"); } Ok(()) } #[cfg(not(feature = "dnssec"))] #[allow(clippy::unnecessary_wraps)] -fn load_keys( +async fn load_keys( _authority: &mut T, _zone_name: Name, _zone_config: &ZoneConfig, @@ -132,10 +133,9 @@ fn load_keys( #[cfg_attr(not(feature = "dnssec"), allow(unused_mut, unused))] #[warn(clippy::wildcard_enum_match_arm)] // make sure all cases are handled despite of non_exhaustive -fn load_zone( +async fn load_zone( zone_dir: &Path, zone_config: &ZoneConfig, - runtime: &mut Runtime, ) -> Result, String> { debug!("loading zone with config: {:#?}", zone_config); @@ -166,11 +166,12 @@ fn load_zone( is_dnssec_enabled, Some(zone_dir), config, - )?; + ) + .await?; // load any keys for the Zone, if it is a dynamic update zone, then keys are required - load_keys(&mut authority, zone_name_for_signer, zone_config)?; - Box::new(Arc::new(Mutex::new(authority))) + load_keys(&mut authority, zone_name_for_signer, zone_config).await?; + Box::new(Arc::new(authority)) as Box } Some(StoreConfig::File(ref config)) => { if zone_path.is_some() { @@ -186,15 +187,15 @@ fn load_zone( )?; // load any keys for the Zone, if it is a dynamic update zone, then keys are required - load_keys(&mut authority, zone_name_for_signer, zone_config)?; - Box::new(Arc::new(Mutex::new(authority))) + load_keys(&mut authority, zone_name_for_signer, zone_config).await?; + Box::new(Arc::new(authority)) as Box } #[cfg(feature = "resolver")] Some(StoreConfig::Forward(ref config)) => { let forwarder = ForwardAuthority::try_from_config(zone_name, zone_type, config); - let authority = runtime.block_on(forwarder)?; + let authority = forwarder.await?; - Box::new(Arc::new(Mutex::new(authority))) + Box::new(Arc::new(authority)) as Box } #[cfg(feature = "sqlite")] None if zone_config.is_update_allowed() => { @@ -221,11 +222,12 @@ fn load_zone( is_dnssec_enabled, Some(zone_dir), &config, - )?; + ) + .await?; // load any keys for the Zone, if it is a dynamic update zone, then keys are required - load_keys(&mut authority, zone_name_for_signer, zone_config)?; - Box::new(Arc::new(Mutex::new(authority))) + load_keys(&mut authority, zone_name_for_signer, zone_config).await?; + Box::new(Arc::new(authority)) as Box } None => { let config = FileConfig { @@ -241,8 +243,8 @@ fn load_zone( )?; // load any keys for the Zone, if it is a dynamic update zone, then keys are required - load_keys(&mut authority, zone_name_for_signer, zone_config)?; - Box::new(Arc::new(Mutex::new(authority))) + load_keys(&mut authority, zone_name_for_signer, zone_config).await?; + Box::new(Arc::new(authority)) as Box } Some(_) => { panic!("unrecognized authority type, check enabled features"); @@ -303,6 +305,7 @@ impl<'a> From> for Args { /// Main method for running the named server. /// /// `Note`: Tries to avoid panics, in favor of always starting. +#[allow(unused_mut)] fn main() { let args = app_from_crate!() .arg( @@ -397,7 +400,7 @@ fn main() { .get_zone() .unwrap_or_else(|_| panic!("bad zone name in {:?}", config_path)); - match load_zone(&zone_dir, zone, &mut runtime) { + match runtime.block_on(load_zone(&zone_dir, zone)) { Ok(authority) => catalog.upsert(zone_name.into(), authority), Err(error) => panic!("could not load zone {}: {}", zone_name, error), } @@ -527,7 +530,7 @@ fn config_tls( tls_cert_config: &TlsCertConfig, zone_dir: &Path, listen_addrs: &[IpAddr], - runtime: &mut Runtime, + runtime: &mut runtime::Runtime, ) { use futures::TryFutureExt; @@ -580,7 +583,7 @@ fn config_https( tls_cert_config: &TlsCertConfig, zone_dir: &Path, listen_addrs: &[IpAddr], - runtime: &mut Runtime, + runtime: &mut runtime::Runtime, ) { use futures::TryFutureExt; diff --git a/crates/server/Cargo.toml b/crates/server/Cargo.toml index 032f2a6b..56a9b7f0 100644 --- a/crates/server/Cargo.toml +++ b/crates/server/Cargo.toml @@ -66,6 +66,8 @@ tls = ["dns-over-openssl"] # WARNING: there is a bug in the mutual tls auth code at the moment see issue #100 # mtls = ["trust-dns-client/mtls"] +testing = [] + [lib] name = "trust_dns_server" path = "src/lib.rs" @@ -95,6 +97,9 @@ trust-dns-client= { version = "0.21.0-alpha.2", path = "../client" } trust-dns-proto = { version = "0.21.0-alpha.2", path = "../proto" } trust-dns-resolver = { version = "0.21.0-alpha.2", path = "../resolver", features = ["serde-config"], optional = true } +[dev-dependencies] +tokio = { version="1.0", features = ["macros", "rt"] } + [package.metadata.docs.rs] all-features = true default-target = "x86_64-unknown-linux-gnu" diff --git a/crates/server/src/authority/authority.rs b/crates/server/src/authority/authority.rs index a5fd1a54..5e4f7354 100644 --- a/crates/server/src/authority/authority.rs +++ b/crates/server/src/authority/authority.rs @@ -109,7 +109,7 @@ pub trait Authority: Send + Sync { fn is_axfr_allowed(&self) -> bool; /// Perform a dynamic update of a zone - async fn update(&mut self, update: &MessageRequest) -> UpdateResult; + async fn update(&self, update: &MessageRequest) -> UpdateResult; /// Get the origin of this zone, i.e. example.com is the origin for www.example.com fn origin(&self) -> &LowerName; @@ -191,13 +191,14 @@ pub trait Authority: Send + Sync { /// Extension to Authority to allow for DNSSEC features #[cfg(feature = "dnssec")] #[cfg_attr(docsrs, doc(cfg(feature = "dnssec")))] +#[async_trait::async_trait] pub trait DnssecAuthority: Authority { /// Add a (Sig0) key that is authorized to perform updates against this authority - fn add_update_auth_key(&mut self, name: Name, key: KEY) -> DnsSecResult<()>; + async fn add_update_auth_key(&self, name: Name, key: KEY) -> DnsSecResult<()>; /// Add Signer - fn add_zone_signing_key(&mut self, signer: SigSigner) -> DnsSecResult<()>; + async fn add_zone_signing_key(&self, signer: SigSigner) -> DnsSecResult<()>; /// Sign the zone for DNSSEC - fn secure_zone(&mut self) -> DnsSecResult<()>; + async fn secure_zone(&self) -> DnsSecResult<()>; } diff --git a/crates/server/src/authority/authority_object.rs b/crates/server/src/authority/authority_object.rs index 21fdd6a9..cbb6243e 100644 --- a/crates/server/src/authority/authority_object.rs +++ b/crates/server/src/authority/authority_object.rs @@ -9,7 +9,6 @@ use std::sync::Arc; -use futures_util::lock::Mutex; use log::debug; use crate::{ @@ -27,16 +26,16 @@ pub trait AuthorityObject: Send + Sync { fn box_clone(&self) -> Box; /// What type is this zone - async fn zone_type(&self) -> ZoneType; + fn zone_type(&self) -> ZoneType; /// Return true if AXFR is allowed - async fn is_axfr_allowed(&self) -> bool; + fn is_axfr_allowed(&self) -> bool; /// Perform a dynamic update of a zone async fn update(&self, update: &MessageRequest) -> UpdateResult; /// Get the origin of this zone, i.e. example.com is the origin for www.example.com - async fn origin(&self) -> LowerName; + fn origin(&self) -> &LowerName; /// Looks up all Resource Records matching the giving `Name` and `RecordType`. /// @@ -81,7 +80,7 @@ pub trait AuthorityObject: Send + Sync { &self, lookup_options: LookupOptions, ) -> Result, LookupError> { - self.lookup(&self.origin().await, RecordType::NS, lookup_options) + self.lookup(self.origin(), RecordType::NS, lookup_options) .await } @@ -104,12 +103,8 @@ pub trait AuthorityObject: Send + Sync { /// should be used, see `soa_secure()`, which will optionally return RRSIGs. async fn soa(&self) -> Result, LookupError> { // SOA should be origin|SOA - self.lookup( - &self.origin().await, - RecordType::SOA, - LookupOptions::default(), - ) - .await + self.lookup(self.origin(), RecordType::SOA, LookupOptions::default()) + .await } /// Returns the SOA record for the zone @@ -117,13 +112,13 @@ pub trait AuthorityObject: Send + Sync { &self, lookup_options: LookupOptions, ) -> Result, LookupError> { - self.lookup(&self.origin().await, RecordType::SOA, lookup_options) + self.lookup(self.origin(), RecordType::SOA, lookup_options) .await } } #[async_trait::async_trait] -impl AuthorityObject for Arc> +impl AuthorityObject for Arc where A: Authority + Send + Sync + 'static, L: LookupObject + Send + Sync + 'static, @@ -133,23 +128,23 @@ where } /// What type is this zone - async fn zone_type(&self) -> ZoneType { - Authority::zone_type(&*self.lock().await) + fn zone_type(&self) -> ZoneType { + Authority::zone_type(self.as_ref()) } /// Return true if AXFR is allowed - async fn is_axfr_allowed(&self) -> bool { - Authority::is_axfr_allowed(&*self.lock().await) + fn is_axfr_allowed(&self) -> bool { + Authority::is_axfr_allowed(self.as_ref()) } /// Perform a dynamic update of a zone async fn update(&self, update: &MessageRequest) -> UpdateResult { - Authority::update(&mut *self.lock().await, update).await + Authority::update(self.as_ref(), update).await } /// Get the origin of this zone, i.e. example.com is the origin for www.example.com - async fn origin(&self) -> LowerName { - Authority::origin(&*self.lock().await).clone() + fn origin(&self) -> &LowerName { + Authority::origin(self.as_ref()) } /// Looks up all Resource Records matching the giving `Name` and `RecordType`. @@ -172,7 +167,7 @@ where rtype: RecordType, lookup_options: LookupOptions, ) -> Result, LookupError> { - let this = self.lock().await; + let this = self.as_ref(); let lookup = Authority::lookup(&*this, name, rtype, lookup_options).await; lookup.map(|l| Box::new(l) as Box) } @@ -193,7 +188,7 @@ where query: &LowerQuery, lookup_options: LookupOptions, ) -> Result, LookupError> { - let this = self.lock().await; + let this = self.as_ref(); debug!("performing {} on {}", query, this.origin()); let lookup = Authority::search(&*this, query, lookup_options).await; lookup.map(|l| Box::new(l) as Box) @@ -211,7 +206,7 @@ where name: &LowerName, lookup_options: LookupOptions, ) -> Result, LookupError> { - let lookup = Authority::get_nsec_records(&*self.lock().await, name, lookup_options).await; + let lookup = Authority::get_nsec_records(self.as_ref(), name, lookup_options).await; lookup.map(|l| Box::new(l) as Box) } } diff --git a/crates/server/src/authority/catalog.rs b/crates/server/src/authority/catalog.rs index 84fd9d76..2339211d 100644 --- a/crates/server/src/authority/catalog.rs +++ b/crates/server/src/authority/catalog.rs @@ -271,7 +271,7 @@ impl Catalog { let response_code = match result { Ok(authority) => { #[allow(deprecated)] - match authority.zone_type().await { + match authority.zone_type() { ZoneType::Secondary | ZoneType::Slave => { error!("secondary forwarding for update not yet implemented"); ResponseCode::NotImp @@ -397,7 +397,7 @@ async fn lookup( info!( "request: {} found authority: {}", request.id(), - authority.origin().await + authority.origin() ); let (response_header, sections) = build_response( @@ -466,13 +466,13 @@ async fn build_response( } let mut response_header = Header::response_from_request(request_header); - response_header.set_authoritative(authority.zone_type().await.is_authoritative()); + response_header.set_authoritative(authority.zone_type().is_authoritative()); - debug!("performing {} on {}", query, authority.origin().await); + debug!("performing {} on {}", query, authority.origin()); let future = authority.search(query, lookup_options); #[allow(deprecated)] - let sections = match authority.zone_type().await { + let sections = match authority.zone_type() { ZoneType::Primary | ZoneType::Secondary | ZoneType::Master | ZoneType::Slave => { send_authoritative_response( future, diff --git a/crates/server/src/server/server_future.rs b/crates/server/src/server/server_future.rs index 2af8f139..02fca13e 100644 --- a/crates/server/src/server/server_future.rs +++ b/crates/server/src/server/server_future.rs @@ -207,6 +207,7 @@ impl ServerFuture { ) -> io::Result<()> { use crate::proto::openssl::{tls_server, TlsStream}; use openssl::ssl::Ssl; + use std::pin::Pin; use tokio_openssl::SslStream as TokioSslStream; let ((cert, chain), key) = certificate_and_key; diff --git a/crates/server/src/store/file/authority.rs b/crates/server/src/store/file/authority.rs index 48536d80..18ae316b 100644 --- a/crates/server/src/store/file/authority.rs +++ b/crates/server/src/store/file/authority.rs @@ -249,7 +249,7 @@ impl Authority for FileAuthority { } /// Perform a dynamic update of a zone - async fn update(&mut self, _update: &MessageRequest) -> UpdateResult { + async fn update(&self, _update: &MessageRequest) -> UpdateResult { use crate::proto::op::ResponseCode; Err(ResponseCode::NotImp) } @@ -337,20 +337,21 @@ impl Authority for FileAuthority { #[cfg(feature = "dnssec")] #[cfg_attr(docsrs, doc(cfg(feature = "dnssec")))] +#[async_trait::async_trait] impl DnssecAuthority for FileAuthority { /// Add a (Sig0) key that is authorized to perform updates against this authority - fn add_update_auth_key(&mut self, name: Name, key: KEY) -> DnsSecResult<()> { - self.0.add_update_auth_key(name, key) + async fn add_update_auth_key(&self, name: Name, key: KEY) -> DnsSecResult<()> { + self.0.add_update_auth_key(name, key).await } /// Add Signer - fn add_zone_signing_key(&mut self, signer: SigSigner) -> DnsSecResult<()> { - self.0.add_zone_signing_key(signer) + async fn add_zone_signing_key(&self, signer: SigSigner) -> DnsSecResult<()> { + self.0.add_zone_signing_key(signer).await } /// Sign the zone for DNSSEC - fn secure_zone(&mut self) -> DnsSecResult<()> { - DnssecAuthority::secure_zone(&mut self.0) + async fn secure_zone(&self) -> DnsSecResult<()> { + DnssecAuthority::secure_zone(&self.0).await } } diff --git a/crates/server/src/store/forwarder/authority.rs b/crates/server/src/store/forwarder/authority.rs index fe48f40c..2bb26b07 100644 --- a/crates/server/src/store/forwarder/authority.rs +++ b/crates/server/src/store/forwarder/authority.rs @@ -84,7 +84,7 @@ impl Authority for ForwardAuthority { false } - async fn update(&mut self, _update: &MessageRequest) -> UpdateResult { + async fn update(&self, _update: &MessageRequest) -> UpdateResult { Err(ResponseCode::NotImp) } diff --git a/crates/server/src/store/in_memory/authority.rs b/crates/server/src/store/in_memory/authority.rs index e132752f..d50a6e14 100644 --- a/crates/server/src/store/in_memory/authority.rs +++ b/crates/server/src/store/in_memory/authority.rs @@ -7,11 +7,19 @@ //! All authority related types -use std::{borrow::Borrow, collections::BTreeMap, sync::Arc}; +use std::{ + borrow::Borrow, + collections::BTreeMap, + ops::{Deref, DerefMut}, + sync::Arc, +}; use cfg_if::cfg_if; -use futures_util::future::{self, TryFutureExt}; -use log::{debug, error}; +use futures_util::{ + future::{self, TryFutureExt}, + lock::{Mutex, MutexGuard}, +}; +use log::{debug, error, warn}; #[cfg(feature = "dnssec")] use crate::{ @@ -42,16 +50,9 @@ use crate::{ pub struct InMemoryAuthority { origin: LowerName, class: DNSClass, - records: BTreeMap>, zone_type: ZoneType, allow_axfr: bool, - // Private key mapped to the Record of the DNSKey - // TODO: these private_keys should be stored securely. Ideally, we have keys only stored per - // server instance, but that requires requesting updates from the parent zone, which may or - // may not support dynamic updates to register the new key... Trust-DNS will provide support - // for this, in some form, perhaps alternate root zones... - #[cfg(feature = "dnssec")] - secure_keys: Vec, + inner: Mutex, } impl InMemoryAuthority { @@ -77,6 +78,7 @@ impl InMemoryAuthority { allow_axfr: bool, ) -> Result { let mut this = Self::empty(origin.clone(), zone_type, allow_axfr); + let inner = this.inner.get_mut(); // SOA must be present let serial = records @@ -95,7 +97,7 @@ impl InMemoryAuthority { let rr_type = rrset.record_type(); for record in rrset.records_without_rrsigs() { - if !this.upsert(record.clone(), serial) { + if !inner.upsert(record.clone(), serial, this.class) { return Err(format!( "Failed to insert {} {} to zone: {}", name, rr_type, origin @@ -116,47 +118,231 @@ impl InMemoryAuthority { Self { origin: LowerName::new(&origin), class: DNSClass::IN, - records: BTreeMap::new(), zone_type, allow_axfr, - #[cfg(feature = "dnssec")] - secure_keys: Vec::new(), + inner: Mutex::new(InnerInMemory::default()), } } - /// Clears all records (including SOA, etc) - pub fn clear(&mut self) { - self.records.clear() - } - - /// Get the DNSClass of the zone + /// The DNSClass of this zone pub fn class(&self) -> DNSClass { self.class } - /// Enables AXFRs of all the zones records + /// Allow AXFR's (zone transfers) + #[cfg(any(test, feature = "testing"))] + #[cfg_attr(docsrs, doc(cfg(feature = "testing")))] pub fn set_allow_axfr(&mut self, allow_axfr: bool) { self.allow_axfr = allow_axfr; } + /// Clears all records (including SOA, etc) + pub fn clear(&mut self) { + self.inner.get_mut().records.clear() + } + /// Retrieve the Signer, which contains the private keys, for this zone - #[cfg(feature = "dnssec")] - pub fn secure_keys(&self) -> &[SigSigner] { - &self.secure_keys + #[cfg(all(feature = "dnssec", feature = "testing"))] + pub async fn secure_keys(&self) -> impl Deref + '_ { + MutexGuard::map(self.inner.lock().await, |i| i.secure_keys.as_mut_slice()) } /// Get all the records - pub fn records(&self) -> &BTreeMap> { - &self.records + pub async fn records(&self) -> impl Deref>> + '_ { + MutexGuard::map(self.inner.lock().await, |i| &mut i.records) } /// Get a mutable reference to the records - pub fn records_mut(&mut self) -> &mut BTreeMap> { - &mut self.records + pub async fn records_mut( + &self, + ) -> impl DerefMut>> + '_ { + MutexGuard::map(self.inner.lock().await, |i| &mut i.records) } - fn inner_soa(&self) -> Option<&SOA> { - let rr_key = RrKey::new(self.origin.clone(), RecordType::SOA); + /// Get a mutable reference to the records + pub fn records_get_mut(&mut self) -> &mut BTreeMap> { + &mut self.inner.get_mut().records + } + + /// Returns the minimum ttl (as used in the SOA record) + pub async fn minimum_ttl(&self) -> u32 { + self.inner.lock().await.minimum_ttl(self.origin()) + } + + /// get the current serial number for the zone. + pub async fn serial(&self) -> u32 { + self.inner.lock().await.serial(self.origin()) + } + + #[cfg(any(feature = "dnssec", feature = "sqlite"))] + #[allow(unused)] + pub(crate) async fn increment_soa_serial(&self) -> u32 { + self.inner + .lock() + .await + .increment_soa_serial(self.origin(), self.class) + } + + /// Inserts or updates a `Record` depending on it's existence in the authority. + /// + /// Guarantees that SOA, CNAME only has one record, will implicitly update if they already exist. + /// + /// # Arguments + /// + /// * `record` - The `Record` to be inserted or updated. + /// * `serial` - Current serial number to be recorded against updates. + /// + /// # Return value + /// + /// true if the value was inserted, false otherwise + pub async fn upsert(&self, record: Record, serial: u32) -> bool { + self.inner.lock().await.upsert(record, serial, self.class) + } + + /// Non-async version of upsert when behind a mutable reference. + pub fn upsert_mut(&mut self, record: Record, serial: u32) -> bool { + self.inner.get_mut().upsert(record, serial, self.class) + } + + /// Add a (Sig0) key that is authorized to perform updates against this authority + #[cfg(feature = "dnssec")] + fn inner_add_update_auth_key( + inner: &mut InnerInMemory, + + name: Name, + key: KEY, + origin: &LowerName, + dns_class: DNSClass, + ) -> DnsSecResult<()> { + let rdata = RData::DNSSEC(DNSSECRData::KEY(key)); + // TODO: what TTL? + let record = Record::from_rdata(name, 86400, rdata); + + let serial = inner.serial(origin); + if inner.upsert(record, serial, dns_class) { + Ok(()) + } else { + Err("failed to add auth key".into()) + } + } + + /// Non-async method of add_update_auth_key when behind a mutable reference + #[cfg(feature = "dnssec")] + #[cfg_attr(docsrs, doc(cfg(feature = "dnssec")))] + pub fn add_update_auth_key_mut(&mut self, name: Name, key: KEY) -> DnsSecResult<()> { + let Self { + ref origin, + ref mut inner, + class, + .. + } = self; + + Self::inner_add_update_auth_key(inner.get_mut(), name, key, origin, *class) + } + + /// By adding a secure key, this will implicitly enable dnssec for the zone. + /// + /// # Arguments + /// + /// * `signer` - Signer with associated private key + #[cfg(feature = "dnssec")] + fn inner_add_zone_signing_key( + inner: &mut InnerInMemory, + signer: SigSigner, + origin: &LowerName, + dns_class: DNSClass, + ) -> DnsSecResult<()> { + // also add the key to the zone + let zone_ttl = inner.minimum_ttl(origin); + let dnskey = signer.key().to_dnskey(signer.algorithm())?; + let dnskey = Record::from_rdata( + origin.clone().into(), + zone_ttl, + RData::DNSSEC(DNSSECRData::DNSKEY(dnskey)), + ); + + // TODO: also generate the CDS and CDNSKEY + let serial = inner.serial(origin); + inner.upsert(dnskey, serial, dns_class); + inner.secure_keys.push(signer); + Ok(()) + } + + /// Non-async method of add_zone_signing_key when behind a mutable reference + #[cfg(feature = "dnssec")] + #[cfg_attr(docsrs, doc(cfg(feature = "dnssec")))] + pub fn add_zone_signing_key_mut(&mut self, signer: SigSigner) -> DnsSecResult<()> { + let Self { + ref origin, + ref mut inner, + class, + .. + } = self; + + Self::inner_add_zone_signing_key(inner.get_mut(), signer, origin, *class) + } + + /// (Re)generates the nsec records, increments the serial number and signs the zone + #[cfg(feature = "dnssec")] + #[cfg_attr(docsrs, doc(cfg(feature = "dnssec")))] + pub fn secure_zone_mut(&mut self) -> DnsSecResult<()> { + let Self { + ref origin, + ref mut inner, + .. + } = self; + inner.get_mut().secure_zone_mut(origin, self.class) + } + + /// (Re)generates the nsec records, increments the serial number and signs the zone + #[cfg(not(feature = "dnssec"))] + #[cfg_attr(docsrs, doc(cfg(feature = "dnssec")))] + pub fn secure_zone_mut(&mut self) -> Result<(), &str> { + Err("DNSSEC was not enabled during compilation.") + } +} + +struct InnerInMemory { + records: BTreeMap>, + // Private key mapped to the Record of the DNSKey + // TODO: these private_keys should be stored securely. Ideally, we have keys only stored per + // server instance, but that requires requesting updates from the parent zone, which may or + // may not support dynamic updates to register the new key... Trust-DNS will provide support + // for this, in some form, perhaps alternate root zones... + #[cfg(feature = "dnssec")] + secure_keys: Vec, +} + +impl Default for InnerInMemory { + fn default() -> Self { + Self { + records: BTreeMap::new(), + #[cfg(feature = "dnssec")] + secure_keys: Vec::new(), + } + } +} + +impl InnerInMemory { + /// Retrieve the Signer, which contains the private keys, for this zone + #[cfg(feature = "dnssec")] + fn secure_keys(&self) -> &[SigSigner] { + &self.secure_keys + } + + // /// Get all the records + // fn records(&self) -> &BTreeMap> { + // &self.records + // } + + // /// Get a mutable reference to the records + // fn records_mut(&mut self) -> &mut BTreeMap> { + // &mut self.records + // } + + fn inner_soa(&self, origin: &LowerName) -> Option<&SOA> { + // FIXME: can't there be an RrKeyRef? + let rr_key = RrKey::new(origin.clone(), RecordType::SOA); self.records .get(&rr_key) @@ -165,13 +351,13 @@ impl InMemoryAuthority { } /// Returns the minimum ttl (as used in the SOA record) - pub fn minimum_ttl(&self) -> u32 { - let soa = self.inner_soa(); + fn minimum_ttl(&self, origin: &LowerName) -> u32 { + let soa = self.inner_soa(origin); let soa = match soa { Some(soa) => soa, None => { - error!("could not lookup SOA for authority: {}", self.origin); + error!("could not lookup SOA for authority: {}", origin); return 0; } }; @@ -180,13 +366,13 @@ impl InMemoryAuthority { } /// get the current serial number for the zone. - pub fn serial(&self) -> u32 { - let soa = self.inner_soa(); + fn serial(&self, origin: &LowerName) -> u32 { + let soa = self.inner_soa(origin); let soa = match soa { Some(soa) => soa, None => { - error!("could not lookup SOA for authority: {}", self.origin); + error!("could not lookup SOA for authority: {}", origin); return 0; } }; @@ -330,9 +516,9 @@ impl InMemoryAuthority { } #[cfg(any(feature = "dnssec", feature = "sqlite"))] - pub(crate) fn increment_soa_serial(&mut self) -> u32 { + fn increment_soa_serial(&mut self, origin: &LowerName, dns_class: DNSClass) -> u32 { // we'll remove the SOA and then replace it - let rr_key = RrKey::new(self.origin.clone(), RecordType::SOA); + let rr_key = RrKey::new(origin.clone(), RecordType::SOA); let record = self .records .remove(&rr_key) @@ -342,7 +528,7 @@ impl InMemoryAuthority { let mut record = if let Some(record) = record { record } else { - error!("could not lookup SOA for authority: {}", self.origin); + error!("could not lookup SOA for authority: {}", origin); return 0; }; @@ -353,7 +539,7 @@ impl InMemoryAuthority { panic!("This was not an SOA record"); // valid panic, never should happen }; - self.upsert(record, serial); + self.upsert(record, serial, dns_class); serial } @@ -369,8 +555,15 @@ impl InMemoryAuthority { /// # Return value /// /// true if the value was inserted, false otherwise - pub fn upsert(&mut self, record: Record, serial: u32) -> bool { - assert_eq!(self.class, record.dns_class()); + fn upsert(&mut self, record: Record, serial: u32, dns_class: DNSClass) -> bool { + if dns_class != record.dns_class() { + warn!( + "mismatched dns_class on record insert, zone: {} record: {}", + dns_class, + record.dns_class() + ); + return false; + } #[cfg(feature = "dnssec")] fn is_nsec(upsert_type: RecordType, occupied_type: RecordType) -> bool { @@ -442,35 +635,29 @@ impl InMemoryAuthority { /// (Re)generates the nsec records, increments the serial number and signs the zone #[cfg(feature = "dnssec")] #[cfg_attr(docsrs, doc(cfg(feature = "dnssec")))] - pub fn secure_zone(&mut self) -> DnsSecResult<()> { + fn secure_zone_mut(&mut self, origin: &LowerName, dns_class: DNSClass) -> DnsSecResult<()> { // TODO: only call nsec_zone after adds/deletes // needs to be called before incrementing the soa serial, to make sure IXFR works properly - self.nsec_zone(); + self.nsec_zone(origin, dns_class); // need to resign any records at the current serial number and bump the number. // first bump the serial number on the SOA, so that it is resigned with the new serial. - self.increment_soa_serial(); + self.increment_soa_serial(origin, dns_class); // TODO: should we auto sign here? or maybe up a level... - self.sign_zone() - } - - /// (Re)generates the nsec records, increments the serial number and signs the zone - #[cfg(not(feature = "dnssec"))] - pub fn secure_zone(&mut self) -> Result<(), &str> { - Err("DNSSEC was not enabled during compilation.") + self.sign_zone(origin, dns_class) } /// Dummy implementation for when DNSSEC is disabled. #[cfg(feature = "dnssec")] - fn nsec_zone(&mut self) { + fn nsec_zone(&mut self, origin: &LowerName, dns_class: DNSClass) { use crate::client::rr::rdata::NSEC; // only create nsec records for secure zones if self.secure_keys.is_empty() { return; } - debug!("generating nsec records: {}", self.origin); + debug!("generating nsec records: {}", origin); // first remove all existing nsec records let delete_keys: Vec = self @@ -485,8 +672,8 @@ impl InMemoryAuthority { } // now go through and generate the nsec records - let ttl = self.minimum_ttl(); - let serial = self.serial(); + let ttl = self.minimum_ttl(origin); + let serial = self.serial(origin); let mut records: Vec = vec![]; { @@ -514,7 +701,7 @@ impl InMemoryAuthority { if let Some((name, vec)) = nsec_info { // names aren't equal, create the NSEC record let mut record = Record::with(name.clone(), RecordType::NSEC, ttl); - let rdata = NSEC::new_cover_self(Authority::origin(self).clone().into(), vec); + let rdata = NSEC::new_cover_self(origin.clone().into(), vec); record.set_rdata(RData::DNSSEC(DNSSECRData::NSEC(rdata))); records.push(record); } @@ -522,7 +709,7 @@ impl InMemoryAuthority { // insert all the nsec records for record in records { - let upserted = self.upsert(record, serial); + let upserted = self.upsert(record, serial, dns_class); debug_assert!(upserted); } } @@ -631,25 +818,26 @@ impl InMemoryAuthority { /// Signs any records in the zone that have serial numbers greater than or equal to `serial` #[cfg(feature = "dnssec")] - fn sign_zone(&mut self) -> DnsSecResult<()> { - use log::warn; + fn sign_zone(&mut self, origin: &LowerName, dns_class: DNSClass) -> DnsSecResult<()> { + debug!("signing zone: {}", origin); - debug!("signing zone: {}", self.origin); - - let minimum_ttl = self.minimum_ttl(); + let minimum_ttl = self.minimum_ttl(origin); let secure_keys = &self.secure_keys; let records = &mut self.records; // TODO: should this be an error? if secure_keys.is_empty() { - warn!("attempt to sign_zone for dnssec, but no keys available!") + warn!( + "attempt to sign_zone {} for dnssec, but no keys available!", + origin + ) } // sign all record_sets, as of 0.12.1 this includes DNSKEY for rr_set_orig in records.values_mut() { // because the rrset is an Arc, it must be cloned before mutated let rr_set = Arc::make_mut(rr_set_orig); - Self::sign_rrset(rr_set, secure_keys, minimum_ttl, self.class)?; + Self::sign_rrset(rr_set, secure_keys, minimum_ttl, dns_class)?; } Ok(()) @@ -776,7 +964,7 @@ impl Authority for InMemoryAuthority { /// /// true if any of additions, updates or deletes were made to the zone, false otherwise. Err is /// returned in the case of bad data, etc. - async fn update(&mut self, _update: &MessageRequest) -> UpdateResult { + async fn update(&self, _update: &MessageRequest) -> UpdateResult { Err(ResponseCode::NotImp) } @@ -805,13 +993,15 @@ impl Authority for InMemoryAuthority { query_type: RecordType, lookup_options: LookupOptions, ) -> Result { + let inner = self.inner.lock().await; + // Collect the records from each rr_set let (result, additionals): (LookupResult, Option) = match query_type { RecordType::AXFR | RecordType::ANY => { let result = AnyRecords::new( lookup_options, - self.records.values().cloned().collect(), + inner.records.values().cloned().collect(), query_type, name.clone(), ); @@ -819,20 +1009,21 @@ impl Authority for InMemoryAuthority { } _ => { // perform the lookup - let answer = self.inner_lookup(name, query_type, lookup_options); + let answer = inner.inner_lookup(name, query_type, lookup_options); // evaluate any cnames for additional inclusion let additionals_root_chain_type: Option<(_, _)> = answer .as_ref() .and_then(|a| maybe_next_name(&*a, query_type)) .and_then(|(search_name, search_type)| { - self.additional_search( - query_type, - search_name, - search_type, - lookup_options, - ) - .map(|adds| (adds, search_type)) + inner + .additional_search( + query_type, + search_name, + search_type, + lookup_options, + ) + .map(|adds| (adds, search_type)) }); // if the chain started with an ANAME, take the A or AAAA record from the list @@ -891,10 +1082,10 @@ impl Authority for InMemoryAuthority { // ANAME's are constructed on demand, so need to be signed before return if lookup_options.is_dnssec() { - Self::sign_rrset( + InnerInMemory::sign_rrset( &mut new_answer, - self.secure_keys(), - self.minimum_ttl(), + inner.secure_keys(), + inner.minimum_ttl(self.origin()), self.class(), ) // rather than failing the request, we'll just warn @@ -935,7 +1126,7 @@ impl Authority for InMemoryAuthority { // TODO: can we get rid of this? let result = match result { Err(LookupError::ResponseCode(ResponseCode::NXDomain)) => { - if self + if inner .records .keys() .any(|key| key.name() == name || name.zone_of(key.name())) @@ -1026,13 +1217,14 @@ impl Authority for InMemoryAuthority { name: &LowerName, lookup_options: LookupOptions, ) -> Result { + let inner = self.inner.lock().await; fn is_nsec_rrset(rr_set: &RecordSet) -> bool { rr_set.record_type() == RecordType::NSEC } // TODO: need a BorrowdRrKey let rr_key = RrKey::new(name.clone(), RecordType::NSEC); - let no_data = self + let no_data = inner .records .get(&rr_key) .map(|rr_set| LookupRecords::new(lookup_options, rr_set.clone())); @@ -1042,7 +1234,8 @@ impl Authority for InMemoryAuthority { } let get_closest_nsec = |name: &LowerName| -> Option> { - self.records + inner + .records .values() .rev() .filter(|rr_set| is_nsec_rrset(rr_set)) @@ -1070,10 +1263,11 @@ impl Authority for InMemoryAuthority { // we need the wildcard proof, but make sure that it's still part of the zone. let wildcard = name.base_name(); - let wildcard = if self.origin().zone_of(&wildcard) { + let origin = self.origin(); + let wildcard = if origin.zone_of(&wildcard) { wildcard } else { - self.origin().clone() + origin.clone() }; // don't duplicate the record... @@ -1111,19 +1305,13 @@ impl Authority for InMemoryAuthority { #[cfg(feature = "dnssec")] #[cfg_attr(docsrs, doc(cfg(feature = "dnssec")))] +#[async_trait::async_trait] impl DnssecAuthority for InMemoryAuthority { /// Add a (Sig0) key that is authorized to perform updates against this authority - fn add_update_auth_key(&mut self, name: Name, key: KEY) -> DnsSecResult<()> { - let rdata = RData::DNSSEC(DNSSECRData::KEY(key)); - // TODO: what TTL? - let record = Record::from_rdata(name, 86400, rdata); + async fn add_update_auth_key(&self, name: Name, key: KEY) -> DnsSecResult<()> { + let mut inner = self.inner.lock().await; - let serial = self.serial(); - if self.upsert(record, serial) { - Ok(()) - } else { - Err("failed to add auth key".into()) - } + Self::inner_add_update_auth_key(&mut inner, name, key, self.origin(), self.class) } /// By adding a secure key, this will implicitly enable dnssec for the zone. @@ -1131,34 +1319,16 @@ impl DnssecAuthority for InMemoryAuthority { /// # Arguments /// /// * `signer` - Signer with associated private key - fn add_zone_signing_key(&mut self, signer: SigSigner) -> DnsSecResult<()> { - // also add the key to the zone - let zone_ttl = self.minimum_ttl(); - let dnskey = signer.key().to_dnskey(signer.algorithm())?; - let dnskey = Record::from_rdata( - self.origin.clone().into(), - zone_ttl, - RData::DNSSEC(DNSSECRData::DNSKEY(dnskey)), - ); + async fn add_zone_signing_key(&self, signer: SigSigner) -> DnsSecResult<()> { + let mut inner = self.inner.lock().await; - // TODO: also generate the CDS and CDNSKEY - let serial = self.serial(); - self.upsert(dnskey, serial); - self.secure_keys.push(signer); - Ok(()) + Self::inner_add_zone_signing_key(&mut inner, signer, self.origin(), self.class) } /// Sign the zone for DNSSEC - fn secure_zone(&mut self) -> DnsSecResult<()> { - // TODO: only call nsec_zone after adds/deletes - // needs to be called before incrementing the soa serial, to make sure IXFR works properly - self.nsec_zone(); + async fn secure_zone(&self) -> DnsSecResult<()> { + let mut inner = self.inner.lock().await; - // need to resign any records at the current serial number and bump the number. - // first bump the serial number on the SOA, so that it is resigned with the new serial. - self.increment_soa_serial(); - - // TODO: should we auto sign here? or maybe up a level... - self.sign_zone() + inner.secure_zone_mut(self.origin(), self.class) } } diff --git a/crates/server/src/store/sqlite/authority.rs b/crates/server/src/store/sqlite/authority.rs index 3365d1ea..f959ac23 100644 --- a/crates/server/src/store/sqlite/authority.rs +++ b/crates/server/src/store/sqlite/authority.rs @@ -13,6 +13,7 @@ use std::{ sync::Arc, }; +use futures_util::lock::Mutex; use log::{error, info, warn}; use crate::{ @@ -45,7 +46,7 @@ use crate::{ #[allow(dead_code)] pub struct SqliteAuthority { in_memory: InMemoryAuthority, - journal: Option, + journal: Mutex>, allow_update: bool, is_dnssec_enabled: bool, } @@ -66,14 +67,14 @@ impl SqliteAuthority { pub fn new(in_memory: InMemoryAuthority, allow_update: bool, is_dnssec_enabled: bool) -> Self { Self { in_memory, - journal: None, + journal: Mutex::new(None), allow_update, is_dnssec_enabled, } } /// load the authority from the configuration - pub fn try_from_config( + pub async fn try_from_config( origin: Name, zone_type: ZoneType, allow_axfr: bool, @@ -99,11 +100,13 @@ impl SqliteAuthority { let in_memory = InMemoryAuthority::empty(zone_name.clone(), zone_type, allow_axfr); let mut authority = SqliteAuthority::new(in_memory, config.allow_update, enable_dnssec); + authority .recover_with_journal(&journal) + .await .map_err(|e| format!("error recovering from journal: {}", e))?; - authority.set_journal(journal); + authority.set_journal(journal).await; info!("recovered zone: {}", zone_name); Ok(authority) @@ -131,11 +134,12 @@ impl SqliteAuthority { let journal = Journal::from_file(&journal_path) .map_err(|e| format!("error creating journal {:?}: {}", journal_path, e))?; - authority.set_journal(journal); + authority.set_journal(journal).await; // preserve to the new journal, i.e. we just loaded the zone from disk, start the journal authority .persist_to_journal() + .await .map_err(|e| format!("error persisting to journal {:?}: {}", journal_path, e))?; info!("zone file loaded: {}", zone_name); @@ -153,9 +157,9 @@ impl SqliteAuthority { /// # Arguments /// /// * `journal` - the journal from which to load the persisted zone. - pub fn recover_with_journal(&mut self, journal: &Journal) -> PersistenceResult<()> { + pub async fn recover_with_journal(&mut self, journal: &Journal) -> PersistenceResult<()> { assert!( - self.in_memory.records().is_empty(), + self.in_memory.records_get_mut().is_empty(), "records should be empty during a recovery" ); @@ -166,7 +170,7 @@ impl SqliteAuthority { // authority. if record.rr_type() == RecordType::AXFR { self.in_memory.clear(); - } else if let Err(error) = self.update_records(&[record], false) { + } else if let Err(error) = self.update_records(&[record], false).await { return Err(PersistenceErrorKind::Recovery(error.to_str()).into()); } } @@ -178,16 +182,16 @@ impl SqliteAuthority { /// Journal. /// /// Returns an error if there was an issue writing to the persistence layer. - pub fn persist_to_journal(&self) -> PersistenceResult<()> { - if let Some(journal) = self.journal.as_ref() { - let serial = self.serial(); + pub async fn persist_to_journal(&self) -> PersistenceResult<()> { + if let Some(journal) = self.journal.lock().await.as_ref() { + let serial = self.in_memory.serial().await; info!("persisting zone to journal at SOA.serial: {}", serial); // TODO: THIS NEEDS TO BE IN A TRANSACTION!!! journal.insert_record(serial, Record::new().set_rr_type(RecordType::AXFR))?; - for rr_set in self.in_memory.records().values() { + for rr_set in self.in_memory.records().await.values() { // TODO: should we preserve rr_sets or not? for record in rr_set.records_without_rrsigs() { journal.insert_record(serial, record)?; @@ -201,13 +205,15 @@ impl SqliteAuthority { } /// Associate a backing Journal with this Authority for Updatable zones - pub fn set_journal(&mut self, journal: Journal) { - self.journal = Some(journal); + pub async fn set_journal(&mut self, journal: Journal) { + *self.journal.lock().await = Some(journal); } /// Returns the associated Journal - pub fn journal(&self) -> Option<&Journal> { - self.journal.as_ref() + #[cfg(any(test, feature = "testing"))] + #[cfg_attr(docsrs, doc(cfg(feature = "testing")))] + pub async fn journal(&self) -> impl Deref> + '_ { + self.journal.lock().await } /// Enables the zone for dynamic DNS updates @@ -215,6 +221,13 @@ impl SqliteAuthority { self.allow_update = allow_update; } + /// Get serial + #[cfg(any(test, feature = "testing"))] + #[cfg_attr(docsrs, doc(cfg(feature = "testing")))] + pub async fn serial(&self) -> u32 { + self.in_memory.serial().await + } + /// [RFC 2136](https://tools.ietf.org/html/rfc2136), DNS Update, April 1997 /// /// ```text @@ -306,8 +319,9 @@ impl SqliteAuthority { return Err(ResponseCode::FormErr); } - if !self.origin().zone_of(&require.name().into()) { - warn!("{} is not a zone_of {}", require.name(), self.origin()); + let origin = self.origin(); + if !origin.zone_of(&require.name().into()) { + warn!("{} is not a zone_of {}", require.name(), origin); return Err(ResponseCode::NotZone); } @@ -388,7 +402,7 @@ impl SqliteAuthority { return Err(ResponseCode::FormErr); } } - class if class == self.class() => + class if class == self.in_memory.class() => // zone rrset rr RRset exists (value dependent) { if !self @@ -553,7 +567,7 @@ impl SqliteAuthority { /// type, else signal FORMERR to the requestor. /// ``` #[allow(clippy::unused_unit)] - pub fn pre_scan(&self, records: &[Record]) -> UpdateResult<()> { + pub async fn pre_scan(&self, records: &[Record]) -> UpdateResult<()> { // 3.4.1.3 - Pseudocode For Update Section Prescan // // [rr] for rr in updates @@ -577,7 +591,7 @@ impl SqliteAuthority { } let class: DNSClass = rr.dns_class(); - if class == self.class() { + if class == self.in_memory.class() { match rr.rr_type() { RecordType::ANY | RecordType::AXFR | RecordType::IXFR => { return Err(ResponseCode::FormErr); @@ -642,17 +656,17 @@ impl SqliteAuthority { /// * `records` - set of record instructions for update following above rules /// * `auto_signing_and_increment` - if true, the zone will sign and increment the SOA, this /// should be disabled during recovery. - pub fn update_records( - &mut self, + pub async fn update_records( + &self, records: &[Record], auto_signing_and_increment: bool, ) -> UpdateResult { let mut updated = false; - let serial: u32 = self.serial(); + let serial: u32 = self.in_memory.serial().await; // the persistence act as a write-ahead log. The WAL will also be used for recovery of a zone // subsequent to a failure of the server. - if let Some(ref journal) = self.journal { + if let Some(ref journal) = *self.journal.lock().await { if let Err(error) = journal.insert_records(serial, records) { error!("could not persist update records: {}", error); return Err(ResponseCode::ServFail); @@ -703,7 +717,7 @@ impl SqliteAuthority { let rr_key = RrKey::new(rr_name.clone(), rr.rr_type()); match rr.dns_class() { - class if class == self.class() => { + class if class == self.in_memory.class() => { // RFC 2136 - 3.4.2.2. Any Update RR whose CLASS is the same as ZCLASS is added to // the zone. In case of duplicate RDATAs (which for SOA RRs is always // the case, and for WKS RRs is the case if the ADDRESS and PROTOCOL @@ -717,7 +731,7 @@ impl SqliteAuthority { // zone rrset rr Add to an RRset info!("upserting record: {:?}", rr); - updated = self.upsert(rr.clone(), serial) || updated; + updated = self.in_memory.upsert(rr.clone(), serial).await || updated; } DNSClass::ANY => { // This is a delete of entire RRSETs, either many or one. In either case, the spec is clear: @@ -738,19 +752,22 @@ impl SqliteAuthority { "deleting all records at name (not SOA or NS at origin): {:?}", rr_name ); + let origin = self.origin(); let to_delete = self + .in_memory .records() + .await .keys() .filter(|k| { !((k.record_type == RecordType::SOA || k.record_type == RecordType::NS) - && k.name != *self.origin()) + && k.name != *origin) }) .filter(|k| k.name == rr_name) .cloned() .collect::>(); for delete in to_delete { - self.records_mut().remove(&delete); + self.in_memory.records_mut().await.remove(&delete); updated = true; } } @@ -762,7 +779,7 @@ impl SqliteAuthority { // ANY rrset empty Delete an RRset if let RData::NULL(..) = *rr.rdata() { - let deleted = self.records_mut().remove(&rr_key); + let deleted = self.in_memory.records_mut().await.remove(&rr_key); info!("deleted rrset: {:?}", deleted); updated = updated || deleted.is_some(); } else { @@ -775,7 +792,7 @@ impl SqliteAuthority { DNSClass::NONE => { info!("deleting specific record: {:?}", rr); // NONE rrset rr Delete an RR from an RRset - if let Some(rrset) = self.records_mut().get_mut(&rr_key) { + if let Some(rrset) = self.in_memory.records_mut().await.get_mut(&rr_key) { // b/c this is an Arc, we need to clone, then remove, and replace the node. let mut rrset_clone: RecordSet = RecordSet::clone(&*rrset); let deleted = rrset_clone.remove(rr, serial); @@ -797,14 +814,21 @@ impl SqliteAuthority { // update the serial... if updated && auto_signing_and_increment { if self.is_dnssec_enabled { - self.secure_zone().map_err(|e| { - error!("failure securing zone: {}", e); - ResponseCode::ServFail - })? + cfg_if::cfg_if! { + if #[cfg(feature = "dnssec")] { + self.secure_zone().await.map_err(|e| { + error!("failure securing zone: {}", e); + ResponseCode::ServFail + })? + } else { + error!("failure securing zone, dnssec feature not enabled"); + return Err(ResponseCode::ServFail) + } + } } else { // the secure_zone() function increments the SOA during it's operation, if we're not // dnssec, then we need to do it here... - self.increment_soa_serial(); + self.in_memory.increment_soa_serial().await; } } @@ -898,18 +922,19 @@ impl Authority for SqliteAuthority { /// true if any of additions, updates or deletes were made to the zone, false otherwise. Err is /// returned in the case of bad data, etc. #[cfg(feature = "dnssec")] - async fn update(&mut self, update: &MessageRequest) -> UpdateResult { + async fn update(&self, update: &MessageRequest) -> UpdateResult { + //let this = &mut self.in_memory.lock().await; // the spec says to authorize after prereqs, seems better to auth first. self.authorize(update).await?; self.verify_prerequisites(update.prerequisites()).await?; - self.pre_scan(update.updates())?; + self.pre_scan(update.updates()).await?; - self.update_records(update.updates(), true) + self.update_records(update.updates(), true).await } /// Always fail when DNSSEC is disabled. #[cfg(not(feature = "dnssec"))] - async fn update(&mut self, _update: &MessageRequest) -> UpdateResult { + async fn update(&self, _update: &MessageRequest) -> UpdateResult { Err(ResponseCode::NotImp) } @@ -967,9 +992,10 @@ impl Authority for SqliteAuthority { #[cfg(feature = "dnssec")] #[cfg_attr(docsrs, doc(cfg(feature = "dnssec")))] +#[async_trait::async_trait] impl DnssecAuthority for SqliteAuthority { - fn add_update_auth_key(&mut self, name: Name, key: KEY) -> DnsSecResult<()> { - self.in_memory.add_update_auth_key(name, key) + async fn add_update_auth_key(&self, name: Name, key: KEY) -> DnsSecResult<()> { + self.in_memory.add_update_auth_key(name, key).await } /// By adding a secure key, this will implicitly enable dnssec for the zone. @@ -977,12 +1003,26 @@ impl DnssecAuthority for SqliteAuthority { /// # Arguments /// /// * `signer` - Signer with associated private key - fn add_zone_signing_key(&mut self, signer: SigSigner) -> DnsSecResult<()> { - self.in_memory.add_zone_signing_key(signer) + async fn add_zone_signing_key(&self, signer: SigSigner) -> DnsSecResult<()> { + self.in_memory.add_zone_signing_key(signer).await } /// (Re)generates the nsec records, increments the serial number and signs the zone - fn secure_zone(&mut self) -> DnsSecResult<()> { - DnssecAuthority::secure_zone(&mut self.in_memory) + async fn secure_zone(&self) -> DnsSecResult<()> { + self.in_memory.secure_zone().await + } +} + +#[cfg(test)] +mod tests { + use crate::store::sqlite::SqliteAuthority; + + #[test] + fn test_is_send_sync() { + fn send_sync() -> bool { + true + } + + assert!(send_sync::()); } } diff --git a/crates/server/tests/authority_battery/dnssec.rs b/crates/server/tests/authority_battery/dnssec.rs index d7aa4799..2eef24be 100644 --- a/crates/server/tests/authority_battery/dnssec.rs +++ b/crates/server/tests/authority_battery/dnssec.rs @@ -351,10 +351,8 @@ pub fn add_signers(authority: &mut A) -> Vec { .try_into_signer(signer_name.clone()) .expect("failed to read key_config"); keys.push(signer.to_dnskey().expect("failed to create DNSKEY")); - authority - .add_zone_signing_key(signer) - .expect("failed to add signer to zone"); - authority.secure_zone().expect("failed to sign zone"); + block_on(authority.add_zone_signing_key(signer)).expect("failed to add signer to zone"); + block_on(authority.secure_zone()).expect("failed to sign zone"); } // // TODO: why are ecdsa tests failing in this context? @@ -408,10 +406,8 @@ pub fn add_signers(authority: &mut A) -> Vec { .try_into_signer(signer_name) .expect("failed to read key_config"); keys.push(signer.to_dnskey().expect("failed to create DNSKEY")); - authority - .add_zone_signing_key(signer) - .expect("failed to add signer to zone"); - authority.secure_zone().expect("failed to sign zone"); + block_on(authority.add_zone_signing_key(signer)).expect("failed to add signer to zone"); + block_on(authority.secure_zone()).expect("failed to sign zone"); } keys diff --git a/crates/server/tests/authority_battery/dynamic_update.rs b/crates/server/tests/authority_battery/dynamic_update.rs index 438eea92..6ba44d10 100644 --- a/crates/server/tests/authority_battery/dynamic_update.rs +++ b/crates/server/tests/authority_battery/dynamic_update.rs @@ -690,8 +690,7 @@ pub fn add_auth(authority: &mut A) -> Vec { .to_sig0key_with_usage(Algorithm::RSASHA512, KeyUsage::Host) .expect("failed to get sig0 key"); - authority - .add_update_auth_key(update_name.clone(), public_key) + block_on(authority.add_update_auth_key(update_name.clone(), public_key)) .expect("failed to add signer to zone"); keys.push(signer); } @@ -751,8 +750,7 @@ pub fn add_auth(authority: &mut A) -> Vec { .to_sig0key_with_usage(Algorithm::ED25519, KeyUsage::Host) .expect("failed to get sig0 key"); - authority - .add_update_auth_key(update_name, public_key) + block_on(authority.add_update_auth_key(update_name, public_key)) .expect("failed to add signer to zone"); keys.push(signer); } diff --git a/crates/server/tests/store_file_tests.rs b/crates/server/tests/store_file_tests.rs index 1065eba9..27c6e5c8 100644 --- a/crates/server/tests/store_file_tests.rs +++ b/crates/server/tests/store_file_tests.rs @@ -34,7 +34,7 @@ fn test_all_lines_are_loaded() { .to_string(), }; - let authority = FileAuthority::try_from_config( + let mut authority = FileAuthority::try_from_config( Name::from_str("example.com.").unwrap(), ZoneType::Primary, false, @@ -46,5 +46,5 @@ fn test_all_lines_are_loaded() { record_type: RecordType::A, name: LowerName::from(Name::from_ascii("ensure.nonewline.").unwrap()), }; - assert!(authority.records().get(&rrkey).is_some()) + assert!(authority.records_get_mut().get(&rrkey).is_some()) } diff --git a/crates/server/tests/store_sqlite_tests.rs b/crates/server/tests/store_sqlite_tests.rs index 4ff3bcff..1e1e27dc 100644 --- a/crates/server/tests/store_sqlite_tests.rs +++ b/crates/server/tests/store_sqlite_tests.rs @@ -4,9 +4,13 @@ use std::fs; use std::path::PathBuf; use std::str::FromStr; +use futures_executor::block_on; + use trust_dns_client::rr::Name; -use trust_dns_server::authority::ZoneType; -use trust_dns_server::store::sqlite::{SqliteAuthority, SqliteConfig}; +use trust_dns_server::{ + authority::ZoneType, + store::sqlite::{SqliteAuthority, SqliteConfig}, +}; #[macro_use] mod authority_battery; @@ -27,14 +31,14 @@ fn sqlite(master_file_path: &str, module: &str, test_name: &str) -> SqliteAuthor allow_update: true, }; - SqliteAuthority::try_from_config( + block_on(SqliteAuthority::try_from_config( Name::from_str("example.com.").unwrap(), ZoneType::Primary, false, true, None, &config, - ) + )) .expect("failed to load file") } @@ -55,14 +59,14 @@ fn sqlite_update(master_file_path: &str, module: &str, test_name: &str) -> Sqlit allow_update: true, }; - SqliteAuthority::try_from_config( + block_on(SqliteAuthority::try_from_config( Name::from_str("example.com.").unwrap(), ZoneType::Primary, false, true, None, &config, - ) + )) .expect("failed to load file") } diff --git a/tests/integration-tests/Cargo.toml b/tests/integration-tests/Cargo.toml index 4e27c048..f7bd2ee7 100644 --- a/tests/integration-tests/Cargo.toml +++ b/tests/integration-tests/Cargo.toml @@ -84,8 +84,9 @@ trust-dns-client= { version = "0.21.0-alpha.2", path = "../../crates/client" } trust-dns-proto = { version = "0.21.0-alpha.2", path = "../../crates/proto", features = ["testing"] } trust-dns-resolver = { version = "0.21.0-alpha.2", path = "../../crates/resolver" } # TODO: fixup tests to not require openssl -trust-dns-server = { version = "0.21.0-alpha.2", path = "../../crates/server" } +trust-dns-server = { version = "0.21.0-alpha.2", path = "../../crates/server", features = ["testing"] } webpki-roots = { version = "0.21", optional = true } [dev-dependencies] futures = { version = "0.3.5", features = ["thread-pool"] } +tokio = { version="1.0", features = ["macros", "rt"] } diff --git a/tests/integration-tests/src/authority.rs b/tests/integration-tests/src/authority.rs index 55c6c57c..314da7b4 100644 --- a/tests/integration-tests/src/authority.rs +++ b/tests/integration-tests/src/authority.rs @@ -15,7 +15,7 @@ pub fn create_example() -> InMemoryAuthority { let mut records = InMemoryAuthority::empty(origin.clone(), ZoneType::Primary, false); // example.com. 3600 IN SOA sns.dns.icann.org. noc.dns.icann.org. 2015082403 7200 3600 1209600 3600 - records.upsert( + records.upsert_mut( Record::new() .set_name(origin.clone()) .set_ttl(3600) @@ -34,7 +34,7 @@ pub fn create_example() -> InMemoryAuthority { 0, ); - records.upsert( + records.upsert_mut( Record::new() .set_name(origin.clone()) .set_ttl(86400) @@ -44,7 +44,7 @@ pub fn create_example() -> InMemoryAuthority { .clone(), 0, ); - records.upsert( + records.upsert_mut( Record::new() .set_name(origin.clone()) .set_ttl(86400) @@ -58,7 +58,7 @@ pub fn create_example() -> InMemoryAuthority { // example.com. 60 IN TXT "v=spf1 -all" //records.upsert(origin.clone(), Record::new().name(origin.clone()).ttl(60).rr_type(RecordType::TXT).dns_class(DNSClass::IN).rdata(RData::TXT{ txt_data: vec!["v=spf1 -all".to_string()] }).clone()); // example.com. 60 IN TXT "$Id: example.com 4415 2015-08-24 20:12:23Z davids $" - records.upsert( + records.upsert_mut( Record::new() .set_name(origin.clone()) .set_ttl(60) @@ -74,7 +74,7 @@ pub fn create_example() -> InMemoryAuthority { ); // example.com. 86400 IN A 93.184.216.34 - records.upsert( + records.upsert_mut( Record::new() .set_name(origin.clone()) .set_ttl(86400) @@ -86,7 +86,7 @@ pub fn create_example() -> InMemoryAuthority { ); // example.com. 86400 IN AAAA 2606:2800:220:1:248:1893:25c8:1946 - records.upsert( + records.upsert_mut( Record::new() .set_name(origin) .set_ttl(86400) @@ -117,7 +117,7 @@ pub fn create_example() -> InMemoryAuthority { let www_name: Name = Name::parse("www.example.com.", None).unwrap(); // www.example.com. 86400 IN TXT "v=spf1 -all" - records.upsert( + records.upsert_mut( Record::new() .set_name(www_name.clone()) .set_ttl(86400) @@ -129,7 +129,7 @@ pub fn create_example() -> InMemoryAuthority { ); // www.example.com. 86400 IN A 93.184.216.34 - records.upsert( + records.upsert_mut( Record::new() .set_name(www_name.clone()) .set_ttl(86400) @@ -141,7 +141,7 @@ pub fn create_example() -> InMemoryAuthority { ); // www.example.com. 86400 IN AAAA 2606:2800:220:1:248:1893:25c8:1946 - records.upsert( + records.upsert_mut( Record::new() .set_name(www_name.clone()) .set_ttl(86400) @@ -155,7 +155,7 @@ pub fn create_example() -> InMemoryAuthority { ); // alias 86400 IN www - records.upsert( + records.upsert_mut( Record::new() .set_name(Name::from_str("alias.example.com.").unwrap()) .set_ttl(86400) @@ -167,7 +167,7 @@ pub fn create_example() -> InMemoryAuthority { ); // alias2 86400 IN www, multiple cname chains - records.upsert( + records.upsert_mut( Record::new() .set_name(Name::from_str("alias2.example.com.").unwrap()) .set_ttl(86400) @@ -206,8 +206,8 @@ pub fn create_secure_example() -> InMemoryAuthority { Duration::weeks(1), ); - authority.add_zone_signing_key(signer); - authority.secure_zone(); + authority.add_zone_signing_key_mut(signer); + authority.secure_zone_mut(); authority } diff --git a/tests/integration-tests/tests/catalog_tests.rs b/tests/integration-tests/tests/catalog_tests.rs index 33bf8afd..93258fba 100644 --- a/tests/integration-tests/tests/catalog_tests.rs +++ b/tests/integration-tests/tests/catalog_tests.rs @@ -1,20 +1,17 @@ -use std::net::*; -use std::str::FromStr; -use std::sync::Arc; +use std::{net::*, str::FromStr, sync::Arc}; -use futures::executor::block_on; +use trust_dns_client::{ + op::*, + rr::{rdata::*, *}, + serialize::binary::{BinDecodable, BinEncodable}, +}; -use futures::lock::Mutex; -use trust_dns_client::op::*; -use trust_dns_client::rr::rdata::*; -use trust_dns_client::rr::*; -use trust_dns_client::serialize::binary::{BinDecodable, BinEncodable}; +use trust_dns_server::{ + authority::{Authority, Catalog, MessageRequest, ZoneType}, + store::in_memory::InMemoryAuthority, +}; -use trust_dns_server::authority::{Authority, Catalog, MessageRequest, ZoneType}; -use trust_dns_server::store::in_memory::InMemoryAuthority; - -use trust_dns_integration::authority::create_example; -use trust_dns_integration::*; +use trust_dns_integration::{authority::create_example, *}; #[allow(clippy::unreadable_literal)] pub fn create_test() -> InMemoryAuthority { @@ -22,7 +19,7 @@ pub fn create_test() -> InMemoryAuthority { let mut records = InMemoryAuthority::empty(origin.clone(), ZoneType::Primary, false); - records.upsert( + records.upsert_mut( Record::new() .set_name(origin.clone()) .set_ttl(3600) @@ -41,7 +38,7 @@ pub fn create_test() -> InMemoryAuthority { 0, ); - records.upsert( + records.upsert_mut( Record::new() .set_name(origin.clone()) .set_ttl(86400) @@ -51,7 +48,7 @@ pub fn create_test() -> InMemoryAuthority { .clone(), 0, ); - records.upsert( + records.upsert_mut( Record::new() .set_name(origin.clone()) .set_ttl(86400) @@ -62,7 +59,7 @@ pub fn create_test() -> InMemoryAuthority { 0, ); - records.upsert( + records.upsert_mut( Record::new() .set_name(origin.clone()) .set_ttl(86400) @@ -72,7 +69,7 @@ pub fn create_test() -> InMemoryAuthority { .clone(), 0, ); - records.upsert( + records.upsert_mut( Record::new() .set_name(origin) .set_ttl(86400) @@ -86,7 +83,7 @@ pub fn create_test() -> InMemoryAuthority { ); let www_name: Name = Name::parse("www.test.com.", None).unwrap(); - records.upsert( + records.upsert_mut( Record::new() .set_name(www_name.clone()) .set_ttl(86400) @@ -96,7 +93,7 @@ pub fn create_test() -> InMemoryAuthority { .clone(), 0, ); - records.upsert( + records.upsert_mut( Record::new() .set_name(www_name) .set_ttl(86400) @@ -112,16 +109,16 @@ pub fn create_test() -> InMemoryAuthority { records } -#[test] -fn test_catalog_lookup() { +#[tokio::test] +async fn test_catalog_lookup() { let example = create_example(); let test = create_test(); let origin = example.origin().clone(); let test_origin = test.origin().clone(); let mut catalog: Catalog = Catalog::new(); - catalog.upsert(origin.clone(), Box::new(Arc::new(Mutex::new(example)))); - catalog.upsert(test_origin.clone(), Box::new(Arc::new(Mutex::new(test)))); + catalog.upsert(origin.clone(), Box::new(Arc::new(example))); + catalog.upsert(test_origin.clone(), Box::new(Arc::new(test))); let mut question: Message = Message::new(); @@ -135,8 +132,10 @@ fn test_catalog_lookup() { let question_req = MessageRequest::from_bytes(&question_bytes).unwrap(); let response_handler = TestResponseHandler::new(); - block_on(catalog.lookup(question_req, None, response_handler.clone())); - let result = block_on(response_handler.into_message()); + catalog + .lookup(question_req, None, response_handler.clone()) + .await; + let result = response_handler.into_message().await; assert_eq!(result.response_code(), ResponseCode::NoError); assert_eq!(result.message_type(), MessageType::Response); @@ -166,8 +165,10 @@ fn test_catalog_lookup() { let question_req = MessageRequest::from_bytes(&question_bytes).unwrap(); let response_handler = TestResponseHandler::new(); - block_on(catalog.lookup(question_req, None, response_handler.clone())); - let result = block_on(response_handler.into_message()); + catalog + .lookup(question_req, None, response_handler.clone()) + .await; + let result = response_handler.into_message().await; assert_eq!(result.response_code(), ResponseCode::NoError); assert_eq!(result.message_type(), MessageType::Response); @@ -183,16 +184,16 @@ fn test_catalog_lookup() { ); } -#[test] -fn test_catalog_lookup_soa() { +#[tokio::test] +async fn test_catalog_lookup_soa() { let example = create_example(); let test = create_test(); let origin = example.origin().clone(); let test_origin = test.origin().clone(); let mut catalog: Catalog = Catalog::new(); - catalog.upsert(origin.clone(), Box::new(Arc::new(Mutex::new(example)))); - catalog.upsert(test_origin, Box::new(Arc::new(Mutex::new(test)))); + catalog.upsert(origin.clone(), Box::new(Arc::new(example))); + catalog.upsert(test_origin, Box::new(Arc::new(test))); let mut question: Message = Message::new(); @@ -207,8 +208,10 @@ fn test_catalog_lookup_soa() { let question_req = MessageRequest::from_bytes(&question_bytes).unwrap(); let response_handler = TestResponseHandler::new(); - block_on(catalog.lookup(question_req, None, response_handler.clone())); - let result = block_on(response_handler.into_message()); + catalog + .lookup(question_req, None, response_handler.clone()) + .await; + let result = response_handler.into_message().await; assert_eq!(result.response_code(), ResponseCode::NoError); assert_eq!(result.message_type(), MessageType::Response); @@ -248,14 +251,14 @@ fn test_catalog_lookup_soa() { ); } -#[test] +#[tokio::test] #[allow(clippy::unreadable_literal)] -fn test_catalog_nx_soa() { +async fn test_catalog_nx_soa() { let example = create_example(); let origin = example.origin().clone(); let mut catalog: Catalog = Catalog::new(); - catalog.upsert(origin, Box::new(Arc::new(Mutex::new(example)))); + catalog.upsert(origin, Box::new(Arc::new(example))); let mut question: Message = Message::new(); @@ -269,8 +272,10 @@ fn test_catalog_nx_soa() { let question_req = MessageRequest::from_bytes(&question_bytes).unwrap(); let response_handler = TestResponseHandler::new(); - block_on(catalog.lookup(question_req, None, response_handler.clone())); - let result = block_on(response_handler.into_message()); + catalog + .lookup(question_req, None, response_handler.clone()) + .await; + let result = response_handler.into_message().await; assert_eq!(result.response_code(), ResponseCode::NXDomain); assert_eq!(result.message_type(), MessageType::Response); @@ -294,13 +299,13 @@ fn test_catalog_nx_soa() { ); } -#[test] -fn test_non_authoritive_nx_refused() { +#[tokio::test] +async fn test_non_authoritive_nx_refused() { let example = create_example(); let origin = example.origin().clone(); let mut catalog: Catalog = Catalog::new(); - catalog.upsert(origin, Box::new(Arc::new(Mutex::new(example)))); + catalog.upsert(origin, Box::new(Arc::new(example))); let mut question: Message = Message::new(); @@ -315,8 +320,10 @@ fn test_non_authoritive_nx_refused() { let question_req = MessageRequest::from_bytes(&question_bytes).unwrap(); let response_handler = TestResponseHandler::new(); - block_on(catalog.lookup(question_req, None, response_handler.clone())); - let result = block_on(response_handler.into_message()); + catalog + .lookup(question_req, None, response_handler.clone()) + .await; + let result = response_handler.into_message().await; assert_eq!(result.response_code(), ResponseCode::Refused); assert_eq!(result.message_type(), MessageType::Response); @@ -327,9 +334,9 @@ fn test_non_authoritive_nx_refused() { assert_eq!(result.additionals().len(), 0); } -#[test] +#[tokio::test] #[allow(clippy::unreadable_literal)] -fn test_axfr() { +async fn test_axfr() { let mut test = create_test(); test.set_allow_axfr(true); @@ -351,7 +358,7 @@ fn test_axfr() { .clone(); let mut catalog: Catalog = Catalog::new(); - catalog.upsert(origin.clone(), Box::new(Arc::new(Mutex::new(test)))); + catalog.upsert(origin.clone(), Box::new(Arc::new(test))); let mut query: Query = Query::new(); query.set_name(origin.clone().into()); @@ -365,8 +372,10 @@ fn test_axfr() { let question_req = MessageRequest::from_bytes(&question_bytes).unwrap(); let response_handler = TestResponseHandler::new(); - block_on(catalog.lookup(question_req, None, response_handler.clone())); - let result = block_on(response_handler.into_message()); + catalog + .lookup(question_req, None, response_handler.clone()) + .await; + let result = response_handler.into_message().await; let mut answers: Vec = result.answers().to_vec(); @@ -460,15 +469,15 @@ fn test_axfr() { assert_eq!(expected_set, answers); } -#[test] -fn test_axfr_refused() { +#[tokio::test] +async fn test_axfr_refused() { let mut test = create_test(); test.set_allow_axfr(false); let origin = test.origin().clone(); let mut catalog: Catalog = Catalog::new(); - catalog.upsert(origin.clone(), Box::new(Arc::new(Mutex::new(test)))); + catalog.upsert(origin.clone(), Box::new(Arc::new(test))); let mut query: Query = Query::new(); query.set_name(origin.into()); @@ -482,8 +491,10 @@ fn test_axfr_refused() { let question_req = MessageRequest::from_bytes(&question_bytes).unwrap(); let response_handler = TestResponseHandler::new(); - block_on(catalog.lookup(question_req, None, response_handler.clone())); - let result = block_on(response_handler.into_message()); + catalog + .lookup(question_req, None, response_handler.clone()) + .await; + let result = response_handler.into_message().await; assert_eq!(result.response_code(), ResponseCode::Refused); assert!(result.answers().is_empty()); @@ -498,13 +509,13 @@ fn test_axfr_refused() { // } // TODO: these should be moved to the battery tests -#[test] -fn test_cname_additionals() { +#[tokio::test] +async fn test_cname_additionals() { let example = create_example(); let origin = example.origin().clone(); let mut catalog: Catalog = Catalog::new(); - catalog.upsert(origin, Box::new(Arc::new(Mutex::new(example)))); + catalog.upsert(origin, Box::new(Arc::new(example))); let mut question: Message = Message::new(); @@ -519,8 +530,10 @@ fn test_cname_additionals() { let question_req = MessageRequest::from_bytes(&question_bytes).unwrap(); let response_handler = TestResponseHandler::new(); - block_on(catalog.lookup(question_req, None, response_handler.clone())); - let result = block_on(response_handler.into_message()); + catalog + .lookup(question_req, None, response_handler.clone()) + .await; + let result = response_handler.into_message().await; assert_eq!(result.message_type(), MessageType::Response); assert_eq!(result.response_code(), ResponseCode::NoError); @@ -542,13 +555,13 @@ fn test_cname_additionals() { ); } -#[test] -fn test_multiple_cname_additionals() { +#[tokio::test] +async fn test_multiple_cname_additionals() { let example = create_example(); let origin = example.origin().clone(); let mut catalog: Catalog = Catalog::new(); - catalog.upsert(origin, Box::new(Arc::new(Mutex::new(example)))); + catalog.upsert(origin, Box::new(Arc::new(example))); let mut question: Message = Message::new(); @@ -563,8 +576,10 @@ fn test_multiple_cname_additionals() { let question_req = MessageRequest::from_bytes(&question_bytes).unwrap(); let response_handler = TestResponseHandler::new(); - block_on(catalog.lookup(question_req, None, response_handler.clone())); - let result = block_on(response_handler.into_message()); + catalog + .lookup(question_req, None, response_handler.clone()) + .await; + let result = response_handler.into_message().await; assert_eq!(result.message_type(), MessageType::Response); assert_eq!(result.response_code(), ResponseCode::NoError); diff --git a/tests/integration-tests/tests/client_future_tests.rs b/tests/integration-tests/tests/client_future_tests.rs index f1a8358d..60cea6c7 100644 --- a/tests/integration-tests/tests/client_future_tests.rs +++ b/tests/integration-tests/tests/client_future_tests.rs @@ -6,7 +6,7 @@ use std::{ #[cfg(feature = "dnssec")] use chrono::Duration; -use futures::{lock::Mutex, Future, FutureExt, TryFutureExt}; +use futures::{Future, FutureExt, TryFutureExt}; use tokio::{ net::{TcpStream as TokioTcpStream, UdpSocket as TokioUdpSocket}, runtime::Runtime, @@ -45,10 +45,7 @@ fn test_query_nonet() { let authority = create_example(); let mut catalog = Catalog::new(); - catalog.upsert( - authority.origin().clone(), - Box::new(Arc::new(Mutex::new(authority))), - ); + catalog.upsert(authority.origin().clone(), Box::new(Arc::new(authority))); let io_loop = Runtime::new().unwrap(); let (stream, sender) = TestClientStream::new(Arc::new(StdMutex::new(catalog))); @@ -249,14 +246,11 @@ fn test_query_edns(client: &mut AsyncClient) -> impl Future { #[test] fn test_notify() { + let io_loop = Runtime::new().unwrap(); let authority = create_example(); let mut catalog = Catalog::new(); - catalog.upsert( - authority.origin().clone(), - Box::new(Arc::new(Mutex::new(authority))), - ); + catalog.upsert(authority.origin().clone(), Box::new(Arc::new(authority))); - let io_loop = Runtime::new().unwrap(); let (stream, sender) = TestClientStream::new(Arc::new(StdMutex::new(catalog))); let client = AsyncClient::new(stream, sender, None); let (mut client, bg) = io_loop.block_on(client).expect("client failed to connect"); @@ -312,14 +306,11 @@ async fn create_sig0_ready_client() -> ( Duration::minutes(5).num_seconds() as u32, ); auth_key.set_rdata(RData::DNSSEC(DNSSECRData::KEY(sig0_key))); - authority.upsert(auth_key, 0); + authority.upsert_mut(auth_key, 0); // setup the catalog let mut catalog = Catalog::new(); - catalog.upsert( - authority.origin().clone(), - Box::new(Arc::new(Mutex::new(authority))), - ); + catalog.upsert(authority.origin().clone(), Box::new(Arc::new(authority))); let signer = Arc::new(signer.into()); let (stream, sender) = TestClientStream::new(Arc::new(StdMutex::new(catalog))); diff --git a/tests/integration-tests/tests/client_tests.rs b/tests/integration-tests/tests/client_tests.rs index dc69ed70..928a18de 100644 --- a/tests/integration-tests/tests/client_tests.rs +++ b/tests/integration-tests/tests/client_tests.rs @@ -8,7 +8,6 @@ use std::sync::{Arc, Mutex as StdMutex}; use chrono::Duration; use futures::Future; -use futures::lock::Mutex; use trust_dns_client::client::Signer; #[cfg(feature = "dnssec")] use trust_dns_client::client::SyncDnssecClient; @@ -63,10 +62,7 @@ impl ClientConnection for TestClientConnection { fn test_query_nonet() { let authority = create_example(); let mut catalog = Catalog::new(); - catalog.upsert( - authority.origin().clone(), - Box::new(Arc::new(Mutex::new(authority))), - ); + catalog.upsert(authority.origin().clone(), Box::new(Arc::new(authority))); let client = SyncClient::new(TestClientConnection::new(catalog)); @@ -472,12 +468,9 @@ fn create_sig0_ready_client(mut catalog: Catalog) -> (SyncClient InMemoryAuthority { let mut authority = create_example(); - authority.upsert( + authority.upsert_mut( Record::new() .set_name(Name::from_str("1.2.3.4.example.com.").unwrap()) .set_ttl(86400) @@ -124,10 +117,7 @@ fn create_ip_like_example() -> InMemoryAuthority { fn test_lookup_ipv4_like() { let authority = create_ip_like_example(); let mut catalog = Catalog::new(); - catalog.upsert( - authority.origin().clone(), - Box::new(Arc::new(Mutex::new(authority))), - ); + catalog.upsert(authority.origin().clone(), Box::new(Arc::new(authority))); let io_loop = Runtime::new().unwrap(); let (stream, sender) = TestClientStream::new(Arc::new(StdMutex::new(catalog))); @@ -157,10 +147,7 @@ fn test_lookup_ipv4_like() { fn test_lookup_ipv4_like_fall_through() { let authority = create_ip_like_example(); let mut catalog = Catalog::new(); - catalog.upsert( - authority.origin().clone(), - Box::new(Arc::new(Mutex::new(authority))), - ); + catalog.upsert(authority.origin().clone(), Box::new(Arc::new(authority))); let io_loop = Runtime::new().unwrap(); let (stream, sender) = TestClientStream::new(Arc::new(StdMutex::new(catalog))); diff --git a/tests/integration-tests/tests/server_future_tests.rs b/tests/integration-tests/tests/server_future_tests.rs index 82dfa1cb..49f53cc4 100644 --- a/tests/integration-tests/tests/server_future_tests.rs +++ b/tests/integration-tests/tests/server_future_tests.rs @@ -5,7 +5,6 @@ use std::sync::Arc; use std::thread; use std::time::Duration; -use futures::lock::Mutex; use futures::{future, Future, FutureExt}; use tokio::net::TcpListener; use tokio::net::UdpSocket; @@ -265,7 +264,7 @@ fn new_catalog() -> Catalog { let origin = example.origin().clone(); let mut catalog: Catalog = Catalog::new(); - catalog.upsert(origin, Box::new(Arc::new(Mutex::new(example)))); + catalog.upsert(origin, Box::new(Arc::new(example))); catalog } diff --git a/tests/integration-tests/tests/sqlite_authority_tests.rs b/tests/integration-tests/tests/sqlite_authority_tests.rs index c5831926..44dfdd01 100644 --- a/tests/integration-tests/tests/sqlite_authority_tests.rs +++ b/tests/integration-tests/tests/sqlite_authority_tests.rs @@ -3,8 +3,6 @@ use std::net::*; use std::str::FromStr; -use futures::executor::block_on; - use rusqlite::*; use trust_dns_client::op::*; @@ -28,8 +26,8 @@ fn create_secure_example() -> SqliteAuthority { SqliteAuthority::new(authority, true, true) } -#[test] -fn test_search() { +#[tokio::test] +async fn test_search() { let example = create_example(); let origin = example.origin().clone(); @@ -37,7 +35,10 @@ fn test_search() { query.set_name(origin.into()); let query = LowerQuery::from(query); - let result = block_on(example.search(&query, LookupOptions::default())).unwrap(); + let result = example + .search(&query, LookupOptions::default()) + .await + .unwrap(); if !result.is_empty() { let record = result.iter().next().unwrap(); assert_eq!(record.rr_type(), RecordType::A); @@ -49,8 +50,8 @@ fn test_search() { } /// this is a litte more interesting b/c it requires a recursive lookup for the origin -#[test] -fn test_search_www() { +#[tokio::test] +async fn test_search_www() { let example = create_example(); let www_name = Name::parse("www.example.com.", None).unwrap(); @@ -58,7 +59,10 @@ fn test_search_www() { query.set_name(www_name); let query = LowerQuery::from(query); - let result = block_on(example.search(&query, LookupOptions::default())).unwrap(); + let result = example + .search(&query, LookupOptions::default()) + .await + .unwrap(); if !result.is_empty() { let record = result.iter().next().unwrap(); assert_eq!(record.rr_type(), RecordType::A); @@ -69,12 +73,14 @@ fn test_search_www() { } } -#[test] -fn test_authority() { +#[tokio::test] +async fn test_authority() { let authority = create_example(); assert_eq!( - block_on(authority.soa()) + authority + .soa() + .await .unwrap() .iter() .next() @@ -83,15 +89,15 @@ fn test_authority() { DNSClass::IN ); - assert!(!block_on(authority.lookup( - authority.origin(), - RecordType::NS, - LookupOptions::default() - )) - .unwrap() - .was_empty()); + assert!(!authority + .lookup(authority.origin(), RecordType::NS, LookupOptions::default()) + .await + .unwrap() + .was_empty()); - let mut lookup: Vec<_> = block_on(authority.ns(LookupOptions::default())) + let mut lookup: Vec<_> = authority + .ns(LookupOptions::default()) + .await .unwrap() .iter() .cloned() @@ -119,23 +125,27 @@ fn test_authority() { .clone() ); - assert!(!block_on(authority.lookup( - authority.origin(), - RecordType::TXT, - LookupOptions::default() - )) - .unwrap() - .was_empty()); + assert!(!authority + .lookup( + authority.origin(), + RecordType::TXT, + LookupOptions::default() + ) + .await + .unwrap() + .was_empty()); - let mut lookup: Vec<_> = block_on(authority.lookup( - authority.origin(), - RecordType::TXT, - LookupOptions::default(), - )) - .unwrap() - .iter() - .cloned() - .collect(); + let mut lookup: Vec<_> = authority + .lookup( + authority.origin(), + RecordType::TXT, + LookupOptions::default(), + ) + .await + .unwrap() + .iter() + .cloned() + .collect(); lookup.sort(); assert_eq!( @@ -154,7 +164,9 @@ fn test_authority() { ); assert_eq!( - *block_on(authority.lookup(authority.origin(), RecordType::A, LookupOptions::default())) + *authority + .lookup(authority.origin(), RecordType::A, LookupOptions::default()) + .await .unwrap() .iter() .next() @@ -170,8 +182,8 @@ fn test_authority() { } #[cfg(feature = "dnssec")] -#[test] -fn test_authorize() { +#[tokio::test] +async fn test_authorize() { use trust_dns_client::serialize::binary::{BinDecodable, BinEncodable}; use trust_dns_server::authority::MessageRequest; @@ -187,7 +199,7 @@ fn test_authorize() { let message = MessageRequest::from_bytes(&bytes).unwrap(); assert_eq!( - block_on(authority.authorize(&message)), + authority.authorize(&message).await, Err(ResponseCode::Refused) ); @@ -196,8 +208,8 @@ fn test_authorize() { // assert!(authority.authorize(&message).is_ok()); } -#[test] -fn test_prerequisites() { +#[tokio::test] +async fn test_prerequisites() { let not_zone = Name::from_str("not.a.domain.com").unwrap(); let not_in_zone = Name::from_str("not.example.com").unwrap(); @@ -206,224 +218,232 @@ fn test_prerequisites() { // first check the initial negatives, ttl = 0, and the zone is the same assert_eq!( - block_on( - authority.verify_prerequisites(&[Record::new() + authority + .verify_prerequisites(&[Record::new() .set_name(not_in_zone.clone()) .set_ttl(86400) .set_rr_type(RecordType::A) .set_dns_class(DNSClass::IN) .set_rdata(RData::NULL(NULL::new())) .clone()],) - ), + .await, Err(ResponseCode::FormErr) ); assert_eq!( - block_on( - authority.verify_prerequisites(&[Record::new() + authority + .verify_prerequisites(&[Record::new() .set_name(not_zone) .set_ttl(0) .set_rr_type(RecordType::A) .set_dns_class(DNSClass::IN) .set_rdata(RData::NULL(NULL::new())) .clone()],) - ), + .await, Err(ResponseCode::NotZone) ); // * ANY ANY empty Name is in use - assert!(block_on( - authority.verify_prerequisites(&[Record::new() + assert!(authority + .verify_prerequisites(&[Record::new() .set_name(authority.origin().clone().into()) .set_ttl(0) .set_dns_class(DNSClass::ANY) .set_rr_type(RecordType::ANY) .set_rdata(RData::NULL(NULL::new())) .clone()]) - ) - .is_ok()); + .await + .is_ok()); assert_eq!( - block_on( - authority.verify_prerequisites(&[Record::new() + authority + .verify_prerequisites(&[Record::new() .set_name(not_in_zone.clone()) .set_ttl(0) .set_dns_class(DNSClass::ANY) .set_rr_type(RecordType::ANY) .set_rdata(RData::NULL(NULL::new())) .clone()],) - ), + .await, Err(ResponseCode::NXDomain) ); // * ANY rrset empty RRset exists (value independent) - assert!(block_on( - authority.verify_prerequisites(&[Record::new() + assert!(authority + .verify_prerequisites(&[Record::new() .set_name(authority.origin().clone().into()) .set_ttl(0) .set_dns_class(DNSClass::ANY) .set_rr_type(RecordType::A) .set_rdata(RData::NULL(NULL::new())) .clone()]) - ) - .is_ok()); + .await + .is_ok()); assert_eq!( - block_on( - authority.verify_prerequisites(&[Record::new() + authority + .verify_prerequisites(&[Record::new() .set_name(not_in_zone.clone()) .set_ttl(0) .set_dns_class(DNSClass::ANY) .set_rr_type(RecordType::A) .set_rdata(RData::NULL(NULL::new())) .clone()],) - ), + .await, Err(ResponseCode::NXRRSet) ); // * NONE ANY empty Name is not in use - assert!(block_on( - authority.verify_prerequisites(&[Record::new() + assert!(authority + .verify_prerequisites(&[Record::new() .set_name(not_in_zone.clone()) .set_ttl(0) .set_dns_class(DNSClass::NONE) .set_rr_type(RecordType::ANY) .set_rdata(RData::NULL(NULL::new())) .clone()]) - ) - .is_ok()); + .await + .is_ok()); assert_eq!( - block_on( - authority.verify_prerequisites(&[Record::new() + authority + .verify_prerequisites(&[Record::new() .set_name(authority.origin().clone().into()) .set_ttl(0) .set_dns_class(DNSClass::NONE) .set_rr_type(RecordType::ANY) .set_rdata(RData::NULL(NULL::new())) .clone()],) - ), + .await, Err(ResponseCode::YXDomain) ); // * NONE rrset empty RRset does not exist - assert!(block_on( - authority.verify_prerequisites(&[Record::new() + assert!(authority + .verify_prerequisites(&[Record::new() .set_name(not_in_zone.clone()) .set_ttl(0) .set_dns_class(DNSClass::NONE) .set_rr_type(RecordType::A) .set_rdata(RData::NULL(NULL::new())) .clone()]) - ) - .is_ok()); + .await + .is_ok()); assert_eq!( - block_on( - authority.verify_prerequisites(&[Record::new() + authority + .verify_prerequisites(&[Record::new() .set_name(authority.origin().clone().into()) .set_ttl(0) .set_dns_class(DNSClass::NONE) .set_rr_type(RecordType::A) .set_rdata(RData::NULL(NULL::new())) .clone()],) - ), + .await, Err(ResponseCode::YXRRSet) ); // * zone rrset rr RRset exists (value dependent) - assert!(block_on( - authority.verify_prerequisites(&[Record::new() + assert!(authority + .verify_prerequisites(&[Record::new() .set_name(authority.origin().clone().into()) .set_ttl(0) .set_dns_class(DNSClass::IN) .set_rr_type(RecordType::A) .set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 34))) .clone()]) - ) - .is_ok()); + .await + .is_ok()); // wrong class assert_eq!( - block_on( - authority.verify_prerequisites(&[Record::new() + authority + .verify_prerequisites(&[Record::new() .set_name(authority.origin().clone().into()) .set_ttl(0) .set_dns_class(DNSClass::CH) .set_rr_type(RecordType::A) .set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 34))) .clone()],) - ), + .await, Err(ResponseCode::FormErr) ); // wrong Name assert_eq!( - block_on( - authority.verify_prerequisites(&[Record::new() + authority + .verify_prerequisites(&[Record::new() .set_name(not_in_zone) .set_ttl(0) .set_dns_class(DNSClass::IN) .set_rr_type(RecordType::A) .set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 24))) .clone()],) - ), + .await, Err(ResponseCode::NXRRSet) ); // wrong IP assert_eq!( - block_on( - authority.verify_prerequisites(&[Record::new() + authority + .verify_prerequisites(&[Record::new() .set_name(authority.origin().clone().into()) .set_ttl(0) .set_dns_class(DNSClass::IN) .set_rr_type(RecordType::A) .set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 24))) .clone()],) - ), + .await, Err(ResponseCode::NXRRSet) ); } -#[test] -fn test_pre_scan() { +#[tokio::test] +async fn test_pre_scan() { let up_name = Name::from_str("www.example.com").unwrap(); let not_zone = Name::from_str("not.zone.com").unwrap(); let authority = create_example(); assert_eq!( - authority.pre_scan(&[Record::new() - .set_name(not_zone) - .set_ttl(86400) - .set_rr_type(RecordType::A) - .set_dns_class(DNSClass::IN) - .set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 24))) - .clone()],), + authority + .pre_scan(&[Record::new() + .set_name(not_zone) + .set_ttl(86400) + .set_rr_type(RecordType::A) + .set_dns_class(DNSClass::IN) + .set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 24))) + .clone()],) + .await, Err(ResponseCode::NotZone) ); assert_eq!( - authority.pre_scan(&[Record::new() - .set_name(up_name.clone()) - .set_ttl(86400) - .set_rr_type(RecordType::ANY) - .set_dns_class(DNSClass::IN) - .set_rdata(RData::NULL(NULL::new())) - .clone()],), + authority + .pre_scan(&[Record::new() + .set_name(up_name.clone()) + .set_ttl(86400) + .set_rr_type(RecordType::ANY) + .set_dns_class(DNSClass::IN) + .set_rdata(RData::NULL(NULL::new())) + .clone()],) + .await, Err(ResponseCode::FormErr) ); assert_eq!( - authority.pre_scan(&[Record::new() - .set_name(up_name.clone()) - .set_ttl(86400) - .set_rr_type(RecordType::AXFR) - .set_dns_class(DNSClass::IN) - .set_rdata(RData::NULL(NULL::new())) - .clone()],), + authority + .pre_scan(&[Record::new() + .set_name(up_name.clone()) + .set_ttl(86400) + .set_rr_type(RecordType::AXFR) + .set_dns_class(DNSClass::IN) + .set_rdata(RData::NULL(NULL::new())) + .clone()],) + .await, Err(ResponseCode::FormErr) ); assert_eq!( - authority.pre_scan(&[Record::new() - .set_name(up_name.clone()) - .set_ttl(86400) - .set_rr_type(RecordType::IXFR) - .set_dns_class(DNSClass::IN) - .set_rdata(RData::NULL(NULL::new())) - .clone()],), + authority + .pre_scan(&[Record::new() + .set_name(up_name.clone()) + .set_ttl(86400) + .set_rr_type(RecordType::IXFR) + .set_dns_class(DNSClass::IN) + .set_rdata(RData::NULL(NULL::new())) + .clone()],) + .await, Err(ResponseCode::FormErr) ); assert!(authority @@ -434,6 +454,7 @@ fn test_pre_scan() { .set_dns_class(DNSClass::IN) .set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 24))) .clone()]) + .await .is_ok()); assert!(authority .pre_scan(&[Record::new() @@ -443,46 +464,55 @@ fn test_pre_scan() { .set_dns_class(DNSClass::IN) .set_rdata(RData::NULL(NULL::new())) .clone()]) + .await .is_ok()); assert_eq!( - authority.pre_scan(&[Record::new() - .set_name(up_name.clone()) - .set_ttl(86400) - .set_rr_type(RecordType::A) - .set_dns_class(DNSClass::ANY) - .set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 24))) - .clone()],), + authority + .pre_scan(&[Record::new() + .set_name(up_name.clone()) + .set_ttl(86400) + .set_rr_type(RecordType::A) + .set_dns_class(DNSClass::ANY) + .set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 24))) + .clone()],) + .await, Err(ResponseCode::FormErr) ); assert_eq!( - authority.pre_scan(&[Record::new() - .set_name(up_name.clone()) - .set_ttl(0) - .set_rr_type(RecordType::A) - .set_dns_class(DNSClass::ANY) - .set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 24))) - .clone()],), + authority + .pre_scan(&[Record::new() + .set_name(up_name.clone()) + .set_ttl(0) + .set_rr_type(RecordType::A) + .set_dns_class(DNSClass::ANY) + .set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 24))) + .clone()],) + .await, Err(ResponseCode::FormErr) ); assert_eq!( - authority.pre_scan(&[Record::new() - .set_name(up_name.clone()) - .set_ttl(0) - .set_rr_type(RecordType::AXFR) - .set_dns_class(DNSClass::ANY) - .set_rdata(RData::NULL(NULL::new())) - .clone()],), + authority + .pre_scan(&[Record::new() + .set_name(up_name.clone()) + .set_ttl(0) + .set_rr_type(RecordType::AXFR) + .set_dns_class(DNSClass::ANY) + .set_rdata(RData::NULL(NULL::new())) + .clone()],) + .await, Err(ResponseCode::FormErr) ); assert_eq!( - authority.pre_scan(&[Record::new() - .set_name(up_name.clone()) - .set_ttl(0) - .set_rr_type(RecordType::IXFR) - .set_dns_class(DNSClass::ANY) - .set_rdata(RData::NULL(NULL::new())) - .clone()],), + authority + .pre_scan(&[Record::new() + .set_name(up_name.clone()) + .set_ttl(0) + .set_rr_type(RecordType::IXFR) + .set_dns_class(DNSClass::ANY) + .set_rdata(RData::NULL(NULL::new())) + .clone()],) + .await, Err(ResponseCode::FormErr) ); assert!(authority @@ -493,6 +523,7 @@ fn test_pre_scan() { .set_dns_class(DNSClass::ANY) .set_rdata(RData::NULL(NULL::new())) .clone()]) + .await .is_ok()); assert!(authority .pre_scan(&[Record::new() @@ -502,46 +533,55 @@ fn test_pre_scan() { .set_dns_class(DNSClass::ANY) .set_rdata(RData::NULL(NULL::new())) .clone()]) + .await .is_ok()); assert_eq!( - authority.pre_scan(&[Record::new() - .set_name(up_name.clone()) - .set_ttl(86400) - .set_rr_type(RecordType::A) - .set_dns_class(DNSClass::NONE) - .set_rdata(RData::NULL(NULL::new())) - .clone()],), + authority + .pre_scan(&[Record::new() + .set_name(up_name.clone()) + .set_ttl(86400) + .set_rr_type(RecordType::A) + .set_dns_class(DNSClass::NONE) + .set_rdata(RData::NULL(NULL::new())) + .clone()],) + .await, Err(ResponseCode::FormErr) ); assert_eq!( - authority.pre_scan(&[Record::new() - .set_name(up_name.clone()) - .set_ttl(0) - .set_rr_type(RecordType::ANY) - .set_dns_class(DNSClass::NONE) - .set_rdata(RData::NULL(NULL::new())) - .clone()],), + authority + .pre_scan(&[Record::new() + .set_name(up_name.clone()) + .set_ttl(0) + .set_rr_type(RecordType::ANY) + .set_dns_class(DNSClass::NONE) + .set_rdata(RData::NULL(NULL::new())) + .clone()],) + .await, Err(ResponseCode::FormErr) ); assert_eq!( - authority.pre_scan(&[Record::new() - .set_name(up_name.clone()) - .set_ttl(0) - .set_rr_type(RecordType::AXFR) - .set_dns_class(DNSClass::NONE) - .set_rdata(RData::NULL(NULL::new())) - .clone()],), + authority + .pre_scan(&[Record::new() + .set_name(up_name.clone()) + .set_ttl(0) + .set_rr_type(RecordType::AXFR) + .set_dns_class(DNSClass::NONE) + .set_rdata(RData::NULL(NULL::new())) + .clone()],) + .await, Err(ResponseCode::FormErr) ); assert_eq!( - authority.pre_scan(&[Record::new() - .set_name(up_name.clone()) - .set_ttl(0) - .set_rr_type(RecordType::IXFR) - .set_dns_class(DNSClass::NONE) - .set_rdata(RData::NULL(NULL::new())) - .clone()],), + authority + .pre_scan(&[Record::new() + .set_name(up_name.clone()) + .set_ttl(0) + .set_rr_type(RecordType::IXFR) + .set_dns_class(DNSClass::NONE) + .set_rdata(RData::NULL(NULL::new())) + .clone()],) + .await, Err(ResponseCode::FormErr) ); assert!(authority @@ -552,6 +592,7 @@ fn test_pre_scan() { .set_dns_class(DNSClass::NONE) .set_rdata(RData::NULL(NULL::new())) .clone()]) + .await .is_ok()); assert!(authority .pre_scan(&[Record::new() @@ -561,26 +602,29 @@ fn test_pre_scan() { .set_dns_class(DNSClass::NONE) .set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 24))) .clone()]) + .await .is_ok()); assert_eq!( - authority.pre_scan(&[Record::new() - .set_name(up_name) - .set_ttl(86400) - .set_rr_type(RecordType::A) - .set_dns_class(DNSClass::CH) - .set_rdata(RData::NULL(NULL::new())) - .clone()],), + authority + .pre_scan(&[Record::new() + .set_name(up_name) + .set_ttl(86400) + .set_rr_type(RecordType::A) + .set_dns_class(DNSClass::CH) + .set_rdata(RData::NULL(NULL::new())) + .clone()],) + .await, Err(ResponseCode::FormErr) ); } -#[test] -fn test_update() { +#[tokio::test] +async fn test_update() { let new_name = Name::from_str("new.example.com").unwrap(); let www_name = Name::from_str("www.example.com").unwrap(); let mut authority = create_example(); - let serial = authority.serial(); + let serial = authority.serial().await; authority.set_allow_update(true); @@ -614,27 +658,31 @@ fn test_update() { { // assert that the correct set of records is there. - let mut www_rrset: Vec = block_on(authority.lookup( - &www_name.clone().into(), - RecordType::ANY, - LookupOptions::default(), - )) - .unwrap() - .iter() - .cloned() - .collect(); + let mut www_rrset: Vec = authority + .lookup( + &www_name.clone().into(), + RecordType::ANY, + LookupOptions::default(), + ) + .await + .unwrap() + .iter() + .cloned() + .collect(); www_rrset.sort(); assert_eq!(www_rrset, original_vec); // assert new record doesn't exist - assert!(block_on(authority.lookup( - &new_name.clone().into(), - RecordType::ANY, - LookupOptions::default() - )) - .unwrap() - .was_empty()); + assert!(authority + .lookup( + &new_name.clone().into(), + RecordType::ANY, + LookupOptions::default() + ) + .await + .unwrap() + .was_empty()); } // @@ -648,19 +696,22 @@ fn test_update() { .clone()]; assert!(authority .update_records(add_record, true,) + .await .expect("update failed",)); assert_eq!( - block_on(authority.lookup( - &new_name.clone().into(), - RecordType::ANY, - LookupOptions::default() - )) - .unwrap() - .iter() - .collect::>(), + authority + .lookup( + &new_name.clone().into(), + RecordType::ANY, + LookupOptions::default() + ) + .await + .unwrap() + .iter() + .collect::>(), add_record.iter().collect::>() ); - assert_eq!(serial + 1, authority.serial()); + assert_eq!(serial + 1, authority.serial().await); let add_www_record = &[Record::new() .set_name(www_name.clone()) @@ -671,19 +722,22 @@ fn test_update() { .clone()]; assert!(authority .update_records(add_www_record, true,) + .await .expect("update failed",)); - assert_eq!(serial + 2, authority.serial()); + assert_eq!(serial + 2, authority.serial().await); { - let mut www_rrset: Vec<_> = block_on(authority.lookup( - &www_name.clone().into(), - RecordType::ANY, - LookupOptions::default(), - )) - .unwrap() - .iter() - .cloned() - .collect(); + let mut www_rrset: Vec<_> = authority + .lookup( + &www_name.clone().into(), + RecordType::ANY, + LookupOptions::default(), + ) + .await + .unwrap() + .iter() + .cloned() + .collect(); www_rrset.sort(); let mut plus_10 = original_vec.clone(); @@ -703,12 +757,14 @@ fn test_update() { .clone()]; assert!(authority .update_records(del_record, true,) + .await .expect("update failed",)); - assert_eq!(serial + 3, authority.serial()); + assert_eq!(serial + 3, authority.serial().await); { - let lookup = - block_on(authority.lookup(&new_name.into(), RecordType::ANY, LookupOptions::default())) - .unwrap(); + let lookup = authority + .lookup(&new_name.into(), RecordType::ANY, LookupOptions::default()) + .await + .unwrap(); println!("after delete of specific record: {:?}", lookup); assert!(lookup.was_empty()); @@ -724,18 +780,21 @@ fn test_update() { .clone()]; assert!(authority .update_records(del_record, true,) + .await .expect("update failed",)); - assert_eq!(serial + 4, authority.serial()); + assert_eq!(serial + 4, authority.serial().await); { - let mut www_rrset: Vec<_> = block_on(authority.lookup( - &www_name.clone().into(), - RecordType::ANY, - LookupOptions::default(), - )) - .unwrap() - .iter() - .cloned() - .collect(); + let mut www_rrset: Vec<_> = authority + .lookup( + &www_name.clone().into(), + RecordType::ANY, + LookupOptions::default(), + ) + .await + .unwrap() + .iter() + .cloned() + .collect(); www_rrset.sort(); assert_eq!(www_rrset, original_vec); @@ -752,8 +811,9 @@ fn test_update() { .clone()]; assert!(authority .update_records(del_record, true,) + .await .expect("update failed",)); - assert_eq!(serial + 5, authority.serial()); + assert_eq!(serial + 5, authority.serial().await); let mut removed_a_vec: Vec<_> = vec![ Record::new() .set_name(www_name.clone()) @@ -775,15 +835,17 @@ fn test_update() { removed_a_vec.sort(); { - let mut www_rrset: Vec = block_on(authority.lookup( - &www_name.clone().into(), - RecordType::ANY, - LookupOptions::default(), - )) - .unwrap() - .iter() - .cloned() - .collect(); + let mut www_rrset: Vec = authority + .lookup( + &www_name.clone().into(), + RecordType::ANY, + LookupOptions::default(), + ) + .await + .unwrap() + .iter() + .cloned() + .collect(); www_rrset.sort(); assert_eq!(www_rrset, removed_a_vec); @@ -802,42 +864,45 @@ fn test_update() { assert!(authority .update_records(del_record, true,) + .await .expect("update failed",)); - assert!(block_on(authority.lookup( - &www_name.into(), - RecordType::ANY, - LookupOptions::default() - )) - .unwrap() - .was_empty()); + assert!(authority + .lookup(&www_name.into(), RecordType::ANY, LookupOptions::default()) + .await + .unwrap() + .was_empty()); - assert_eq!(serial + 6, authority.serial()); + assert_eq!(serial + 6, authority.serial().await); } #[cfg(feature = "dnssec")] -#[test] -fn test_zone_signing() { +#[tokio::test] +async fn test_zone_signing() { let authority = create_secure_example(); - let results = block_on(authority.lookup( - authority.origin(), - RecordType::AXFR, - LookupOptions::for_dnssec(true, SupportedAlgorithms::all()), - )) - .unwrap(); + let results = authority + .lookup( + authority.origin(), + RecordType::AXFR, + LookupOptions::for_dnssec(true, SupportedAlgorithms::all()), + ) + .await + .unwrap(); assert!( results.iter().any(|r| r.rr_type() == RecordType::DNSKEY), "must contain a DNSKEY" ); - let results = block_on(authority.lookup( - authority.origin(), - RecordType::AXFR, - LookupOptions::for_dnssec(true, SupportedAlgorithms::all()), - )) - .unwrap(); + let results = authority + .lookup( + authority.origin(), + RecordType::AXFR, + LookupOptions::for_dnssec(true, SupportedAlgorithms::all()), + ) + .await + .unwrap(); for record in &results { if record.rr_type() == RecordType::RRSIG { @@ -847,12 +912,14 @@ fn test_zone_signing() { continue; } - let inner_results = block_on(authority.lookup( - authority.origin(), - RecordType::AXFR, - LookupOptions::for_dnssec(true, SupportedAlgorithms::all()), - )) - .unwrap(); + let inner_results = authority + .lookup( + authority.origin(), + RecordType::AXFR, + LookupOptions::for_dnssec(true, SupportedAlgorithms::all()), + ) + .await + .unwrap(); // validate all records have associated RRSIGs after signing assert!( @@ -872,33 +939,35 @@ fn test_zone_signing() { } #[cfg(feature = "dnssec")] -#[test] -fn test_get_nsec() { +#[tokio::test] +async fn test_get_nsec() { let name = Name::from_str("zzz.example.com").unwrap(); let authority = create_secure_example(); let lower_name = LowerName::from(name.clone()); - let results = block_on(authority.get_nsec_records( - &lower_name, - LookupOptions::for_dnssec(true, SupportedAlgorithms::all()), - )) - .unwrap(); + let results = authority + .get_nsec_records( + &lower_name, + LookupOptions::for_dnssec(true, SupportedAlgorithms::all()), + ) + .await + .unwrap(); for record in &results { assert!(*record.name() < name); } } -#[test] -fn test_journal() { +#[tokio::test] +async fn test_journal() { // test that this message can be inserted let conn = Connection::open_in_memory().expect("could not create in memory DB"); let mut journal = Journal::new(conn).unwrap(); journal.schema_up().unwrap(); let mut authority = create_example(); - authority.set_journal(journal); - authority.persist_to_journal().unwrap(); + authority.set_journal(journal).await; + authority.persist_to_journal().await.unwrap(); let new_name = Name::from_str("new.example.com").unwrap(); let delete_name = Name::from_str("www.example.com").unwrap(); @@ -913,24 +982,28 @@ fn test_journal() { .clone(); authority .update_records(&[new_record.clone(), delete_record], true) + .await .unwrap(); // assert that the correct set of records is there. - let new_rrset: Vec = block_on(authority.lookup( - &new_name.clone().into(), - RecordType::A, - LookupOptions::default(), - )) - .unwrap() - .iter() - .cloned() - .collect(); + let new_rrset: Vec = authority + .lookup( + &new_name.clone().into(), + RecordType::A, + LookupOptions::default(), + ) + .await + .unwrap() + .iter() + .cloned() + .collect(); assert!(new_rrset.iter().all(|r| *r == new_record)); let lower_delete_name = LowerName::from(delete_name); - let delete_rrset = - block_on(authority.lookup(&lower_delete_name, RecordType::A, LookupOptions::default())) - .unwrap(); + let delete_rrset = authority + .lookup(&lower_delete_name, RecordType::A, LookupOptions::default()) + .await + .unwrap(); assert!(delete_rrset.was_empty()); // that record should have been recorded... let's reload the journal and see if we get it. @@ -939,40 +1012,49 @@ fn test_journal() { let mut recovered_authority = SqliteAuthority::new(in_memory, false, false); recovered_authority - .recover_with_journal(authority.journal().expect("journal not Some")) + .recover_with_journal( + authority + .journal() + .await + .as_ref() + .expect("journal not Some"), + ) + .await .expect("recovery"); // assert that the correct set of records is there. - let new_rrset: Vec = block_on(recovered_authority.lookup( - &new_name.into(), - RecordType::A, - LookupOptions::default(), - )) - .unwrap() - .iter() - .cloned() - .collect(); + let new_rrset: Vec = recovered_authority + .lookup(&new_name.into(), RecordType::A, LookupOptions::default()) + .await + .unwrap() + .iter() + .cloned() + .collect(); assert!(new_rrset.iter().all(|r| *r == new_record)); - let delete_rrset = - block_on(authority.lookup(&lower_delete_name, RecordType::A, LookupOptions::default())) - .unwrap(); + let delete_rrset = authority + .lookup(&lower_delete_name, RecordType::A, LookupOptions::default()) + .await + .unwrap(); assert!(delete_rrset.was_empty()); } -#[test] +#[tokio::test] #[allow(clippy::blocks_in_if_conditions)] -fn test_recovery() { +async fn test_recovery() { // test that this message can be inserted let conn = Connection::open_in_memory().expect("could not create in memory DB"); let mut journal = Journal::new(conn).unwrap(); journal.schema_up().unwrap(); let mut authority = create_example(); - authority.set_journal(journal); - authority.persist_to_journal().unwrap(); + authority.set_journal(journal).await; + authority.persist_to_journal().await.unwrap(); - let journal = authority.journal().unwrap(); + let journal = authority.journal().await; + let journal = journal + .as_ref() + .expect("test should have associated journal"); let in_memory = InMemoryAuthority::empty(authority.origin().clone().into(), ZoneType::Primary, false); @@ -980,38 +1062,39 @@ fn test_recovery() { recovered_authority .recover_with_journal(journal) + .await .expect("recovery"); assert_eq!( - recovered_authority.records().len(), - authority.records().len() + recovered_authority.records().await.len(), + authority.records().await.len() ); - assert!(block_on(recovered_authority.soa()) + assert!(recovered_authority + .soa() + .await .unwrap() .iter() - .zip(block_on(authority.soa()).unwrap().iter()) + .zip(authority.soa().await.unwrap().iter()) .all(|(r1, r2)| r1 == r2)); - assert!(recovered_authority - .records() - .iter() - .all(|(rr_key, rr_set)| { - let other_rr_set = authority - .records() - .get(rr_key) - .unwrap_or_else(|| panic!("key doesn't exist: {:?}", rr_key)); - rr_set - .records_without_rrsigs() - .zip(other_rr_set.records_without_rrsigs()) - .all(|(record, other_record)| { - record.ttl() == other_record.ttl() && record.rdata() == other_record.rdata() - }) - },)); + let recovered_records = recovered_authority.records().await; + let records = authority.records().await; - assert!(authority.records().iter().all(|(rr_key, rr_set)| { - let other_rr_set = recovered_authority - .records() + assert!(recovered_records.iter().all(|(rr_key, rr_set)| { + let other_rr_set = records + .get(rr_key) + .unwrap_or_else(|| panic!("key doesn't exist: {:?}", rr_key)); + rr_set + .records_without_rrsigs() + .zip(other_rr_set.records_without_rrsigs()) + .all(|(record, other_record)| { + record.ttl() == other_record.ttl() && record.rdata() == other_record.rdata() + }) + },)); + + assert!(records.iter().all(|(rr_key, rr_set)| { + let other_rr_set = recovered_records .get(rr_key) .unwrap_or_else(|| panic!("key doesn't exist: {:?}", rr_key)); rr_set @@ -1023,8 +1106,8 @@ fn test_recovery() { })); } -#[test] -fn test_axfr() { +#[tokio::test] +async fn test_axfr() { let mut authority = create_example(); authority.set_allow_axfr(true); @@ -1036,14 +1119,17 @@ fn test_axfr() { Name::from_str("example.com.").unwrap(), RecordType::AXFR, )); - let result = block_on(authority.search(&query, LookupOptions::default())).unwrap(); + let result = authority + .search(&query, LookupOptions::default()) + .await + .unwrap(); // just update this if the count goes up in the authority assert_eq!(result.iter().count(), 12); } -#[test] -fn test_refused_axfr() { +#[tokio::test] +async fn test_refused_axfr() { let mut authority = create_example(); authority.set_allow_axfr(false); @@ -1051,7 +1137,7 @@ fn test_refused_axfr() { Name::from_str("example.com.").unwrap(), RecordType::AXFR, )); - let result = block_on(authority.search(&query, LookupOptions::default())); + let result = authority.search(&query, LookupOptions::default()).await; // just update this if the count goes up in the authority assert!(result.unwrap_err().is_refused());