return boolean in AccessControl rather than Result

This commit is contained in:
Benjamin Fry 2024-02-02 19:00:58 -08:00
parent 414abf7087
commit f1b4207154
2 changed files with 38 additions and 40 deletions

View File

@ -1,6 +1,5 @@
use std::net::IpAddr;
use hickory_proto::error::{ProtoError, ProtoErrorKind};
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use prefix_trie::{Prefix, PrefixSet};
@ -55,7 +54,8 @@ impl AccessControl {
/// # Return
///
/// Ok if access is granted, Err otherwise
pub(crate) fn allow(&self, ip: IpAddr) -> Result<(), ProtoError> {
#[must_use]
pub(crate) fn allow(&self, ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => {
let v4 = Ipv4Net::from(v4);
@ -89,12 +89,14 @@ impl<I: Prefix> InnerAccessControl<I> {
/// # Return
///
/// Ok if access is granted, Err otherwise
fn allow(&self, ip: &I) -> Result<(), ProtoError> {
// If the IP is denied, there might be an override, otherwise we default to the result of the deny
// Allows are the in the context of deny, so if there are any networks in the deny, then allow is only applied
// if the network is denied. If there were no denies, then allow is applied and only those networks specified
// are allowed
let allowed = match (self.deny.get_lpm(ip), self.allow.get_lpm(ip)) {
#[must_use]
fn allow(&self, ip: &I) -> bool {
// If there are no allows or denies specified, we will always default to allow.
// Allows without denies always translate to deny all except those in the allow list.
// Denies without allows only deny those in the specified deny list.
// If there are both allow and deny lists, then the deny list takes precedent with the allow list
// overriding the deny if it is more specific.
match (self.deny.get_lpm(ip), self.allow.get_lpm(ip)) {
(Some(denied), Some(allowed)) => allowed.prefix_len() > denied.prefix_len(),
(Some(_denied), None) => false,
(None, Some(_allowed)) => true,
@ -106,12 +108,6 @@ impl<I: Prefix> InnerAccessControl<I> {
(false, true) => false, // there are only allow entries, but this isn't one
(false, false) => true, // there are no entries
},
};
if allowed {
Ok(())
} else {
Err(ProtoErrorKind::RequestRefused.into())
}
}
}
@ -123,8 +119,8 @@ mod tests {
#[test]
fn test_none() {
let access = AccessControl::default();
assert!(access.allow("192.168.1.1".parse().unwrap()).is_ok());
assert!(access.allow("fd00::1".parse().unwrap()).is_ok());
assert!(access.allow("192.168.1.1".parse().unwrap()));
assert!(access.allow("fd00::1".parse().unwrap()));
}
#[test]
@ -132,10 +128,10 @@ mod tests {
let mut access = AccessControl::default();
access.insert_allow(&["192.168.1.0/24".parse().unwrap()]);
assert!(access.allow("192.168.1.1".parse().unwrap()).is_ok());
assert!(access.allow("192.168.1.255".parse().unwrap()).is_ok());
assert!(access.allow("192.168.2.1".parse().unwrap()).is_err());
assert!(access.allow("192.168.0.0".parse().unwrap()).is_err());
assert!(access.allow("192.168.1.1".parse().unwrap()));
assert!(access.allow("192.168.1.255".parse().unwrap()));
assert!(!access.allow("192.168.2.1".parse().unwrap()));
assert!(!access.allow("192.168.0.0".parse().unwrap()));
}
#[test]
@ -143,10 +139,10 @@ mod tests {
let mut access = AccessControl::default();
access.insert_allow(&["fd00::/120".parse().unwrap()]);
assert!(access.allow("fd00::1".parse().unwrap()).is_ok());
assert!(access.allow("fd00::00ff".parse().unwrap()).is_ok());
assert!(access.allow("fd00::ffff".parse().unwrap()).is_err());
assert!(access.allow("fd00::1:1".parse().unwrap()).is_err());
assert!(access.allow("fd00::1".parse().unwrap()));
assert!(access.allow("fd00::00ff".parse().unwrap()));
assert!(!access.allow("fd00::ffff".parse().unwrap()));
assert!(!access.allow("fd00::1:1".parse().unwrap()));
}
#[test]
@ -154,10 +150,10 @@ mod tests {
let mut access = AccessControl::default();
access.insert_deny(&["192.168.1.0/24".parse().unwrap()]);
assert!(access.allow("192.168.1.1".parse().unwrap()).is_err());
assert!(access.allow("192.168.1.255".parse().unwrap()).is_err());
assert!(access.allow("192.168.2.1".parse().unwrap()).is_ok());
assert!(access.allow("192.168.0.0".parse().unwrap()).is_ok());
assert!(!access.allow("192.168.1.1".parse().unwrap()));
assert!(!access.allow("192.168.1.255".parse().unwrap()));
assert!(access.allow("192.168.2.1".parse().unwrap()));
assert!(access.allow("192.168.0.0".parse().unwrap()));
}
#[test]
@ -165,10 +161,10 @@ mod tests {
let mut access = AccessControl::default();
access.insert_deny(&["fd00::/120".parse().unwrap()]);
assert!(access.allow("fd00::1".parse().unwrap()).is_err());
assert!(access.allow("fd00::00ff".parse().unwrap()).is_err());
assert!(access.allow("fd00::ffff".parse().unwrap()).is_ok());
assert!(access.allow("fd00::1:1".parse().unwrap()).is_ok());
assert!(!access.allow("fd00::1".parse().unwrap()));
assert!(!access.allow("fd00::00ff".parse().unwrap()));
assert!(access.allow("fd00::ffff".parse().unwrap()));
assert!(access.allow("fd00::1:1".parse().unwrap()));
}
#[test]
@ -177,12 +173,12 @@ mod tests {
access.insert_deny(&["192.168.0.0/16".parse().unwrap()]);
access.insert_allow(&["192.168.1.0/24".parse().unwrap()]);
assert!(access.allow("192.168.1.1".parse().unwrap()).is_ok());
assert!(access.allow("192.168.1.255".parse().unwrap()).is_ok());
assert!(access.allow("192.168.2.1".parse().unwrap()).is_err());
assert!(access.allow("192.168.0.0".parse().unwrap()).is_err());
assert!(access.allow("192.168.1.1".parse().unwrap()));
assert!(access.allow("192.168.1.255".parse().unwrap()));
assert!(!access.allow("192.168.2.1".parse().unwrap()));
assert!(!access.allow("192.168.0.0".parse().unwrap()));
// but all other networks should be allowed
assert!(access.allow("10.0.0.1".parse().unwrap()).is_ok());
assert!(access.allow("10.0.0.1".parse().unwrap()));
}
}

View File

@ -12,7 +12,7 @@ use std::{
};
use futures_util::{FutureExt, StreamExt};
use hickory_proto::{op::MessageType, rr::Record};
use hickory_proto::{error::ProtoErrorKind, op::MessageType, rr::Record};
use ipnet::IpNet;
#[cfg(feature = "dns-over-rustls")]
use rustls::{Certificate, PrivateKey, ServerConfig};
@ -1135,10 +1135,12 @@ pub(crate) async fn handle_request<R: ResponseHandler, T: RequestHandler>(
MessageRequest::read(&mut decoder),
access.allow(src_addr.ip()),
) {
(Ok(message), Ok(())) => {
(Ok(message), true) => {
inner_handle_request(message, response_handler).await;
}
(Ok(message), Err(error)) => {
(Ok(message), false) => {
let error = ProtoErrorKind::RequestRefused.into();
// The message will be refused from an non-allowed network
error_response_handler(
protocol,