wire up Access to check request IP addr before servicing the IP

This commit is contained in:
Benjamin Fry 2024-01-06 14:41:54 -08:00
parent 4b6c7022c4
commit 4f4f3172bf
6 changed files with 155 additions and 47 deletions

View File

@ -240,6 +240,10 @@ pub enum ProtoErrorKind {
#[error("lock poisoned error")]
Poisoned,
/// A request was Refused due to some access check
#[error("request refused")]
RequestRefused,
/// A ring error
#[error("ring error: {0}")]
Ring(#[from] Unspecified),
@ -624,6 +628,7 @@ impl Clone for ProtoErrorKind {
response_code,
trusted,
},
RequestRefused => RequestRefused,
RrsigsNotPresent {
ref name,
ref record_type,

View File

@ -1,5 +1,6 @@
use std::net::IpAddr;
use hickory_proto::error::{ProtoError, ProtoErrorKind};
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use prefix_trie::PrefixSet;
@ -36,18 +37,24 @@ impl Access {
/// # Return
///
/// Ok if access is granted, Err otherwise
pub(crate) fn allow(&self, ip: IpAddr) -> Result<(), ()> {
pub(crate) fn allow(&self, ip: IpAddr) -> Result<(), ProtoError> {
match ip {
IpAddr::V4(v4) => {
let v4 = Ipv4Net::from(v4);
self.allow_ipv4.as_ref().map_or(Ok(()), |allow_ipv4| {
allow_ipv4.get_lpm(&v4).map(|_| ()).ok_or(())
allow_ipv4
.get_lpm(&v4)
.map(|_| ())
.ok_or(ProtoErrorKind::RequestRefused.into())
})
}
IpAddr::V6(v6) => {
let v6 = Ipv6Net::from(v6);
self.allow_ipv6.as_ref().map_or(Ok(()), |allow_ipv6| {
allow_ipv6.get_lpm(&v6).map(|_| ()).ok_or(())
allow_ipv6
.get_lpm(&v6)
.map(|_| ())
.ok_or(ProtoErrorKind::RequestRefused.into())
})
}
}

View File

@ -16,6 +16,7 @@ use tokio_util::sync::CancellationToken;
use tracing::{debug, warn};
use crate::{
access::Access,
authority::MessageResponse,
proto::h2::h2_server,
server::{
@ -25,6 +26,7 @@ use crate::{
};
pub(crate) async fn h2_handler<T, I>(
access: Arc<Access>,
handler: Arc<T>,
io: I,
src_addr: SocketAddr,
@ -68,11 +70,12 @@ pub(crate) async fn h2_handler<T, I>(
debug!("Received request: {:#?}", request);
let dns_hostname = dns_hostname.clone();
let handler = handler.clone();
let access = access.clone();
let responder = HttpsResponseHandle(Arc::new(Mutex::new(respond)));
tokio::spawn(async move {
match h2_server::message_from(dns_hostname, request).await {
Ok(bytes) => handle_request(bytes, src_addr, handler, responder).await,
Ok(bytes) => handle_request(bytes, src_addr, access, handler, responder).await,
Err(err) => warn!("error while handling request from {}: {}", src_addr, err),
};
});
@ -84,12 +87,21 @@ pub(crate) async fn h2_handler<T, I>(
async fn handle_request<T>(
bytes: BytesMut,
src_addr: SocketAddr,
access: Arc<Access>,
handler: Arc<T>,
responder: HttpsResponseHandle,
) where
T: RequestHandler,
{
server_future::handle_request(&bytes, src_addr, Protocol::Https, handler, responder).await
server_future::handle_request(
&bytes,
src_addr,
Protocol::Https,
access,
handler,
responder,
)
.await
}
#[derive(Clone)]

View File

@ -18,6 +18,7 @@ use tokio_util::sync::CancellationToken;
use tracing::{debug, warn};
use crate::{
access::Access,
authority::MessageResponse,
server::{
request_handler::RequestHandler, response_handler::ResponseHandler, server_future,
@ -26,6 +27,7 @@ use crate::{
};
pub(crate) async fn h3_handler<T>(
access: Arc<Access>,
handler: Arc<T>,
mut connection: H3Connection,
src_addr: SocketAddr,
@ -71,10 +73,13 @@ where
request.remaining()
);
let handler = handler.clone();
let access = access.clone();
let stream = Arc::new(Mutex::new(stream));
let responder = H3ResponseHandle(stream.clone());
tokio::spawn(handle_request(request, src_addr, handler, responder));
tokio::spawn(handle_request(
request, src_addr, access, handler, responder,
));
max_requests -= 1;
if max_requests == 0 {
@ -91,12 +96,13 @@ where
async fn handle_request<T>(
bytes: Bytes,
src_addr: SocketAddr,
access: Arc<Access>,
handler: Arc<T>,
responder: H3ResponseHandle,
) where
T: RequestHandler,
{
server_future::handle_request(&bytes, src_addr, Protocol::H3, handler, responder).await
server_future::handle_request(&bytes, src_addr, Protocol::H3, access, handler, responder).await
}
#[derive(Clone)]

View File

@ -18,6 +18,7 @@ use tokio_util::sync::CancellationToken;
use tracing::{debug, warn};
use crate::{
access::Access,
authority::MessageResponse,
proto::quic::QuicStreams,
server::{
@ -27,6 +28,7 @@ use crate::{
};
pub(crate) async fn quic_handler<T>(
access: Arc<Access>,
handler: Arc<T>,
mut quic_streams: QuicStreams,
src_addr: SocketAddr,
@ -65,10 +67,11 @@ where
request.len()
);
let handler = handler.clone();
let access = access.clone();
let stream = Arc::new(Mutex::new(request_stream));
let responder = QuicResponseHandle(stream.clone());
handle_request(request, src_addr, handler, responder).await;
handle_request(request, src_addr, access, handler, responder).await;
max_requests -= 1;
if max_requests == 0 {
@ -86,12 +89,14 @@ where
async fn handle_request<T>(
bytes: BytesMut,
src_addr: SocketAddr,
access: Arc<Access>,
handler: Arc<T>,
responder: QuicResponseHandle,
) where
T: RequestHandler,
{
server_future::handle_request(&bytes, src_addr, Protocol::Quic, handler, responder).await
server_future::handle_request(&bytes, src_addr, Protocol::Quic, access, handler, responder)
.await
}
#[derive(Clone)]

View File

@ -22,6 +22,7 @@ use tracing::{debug, info, warn};
#[cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls")))]
use crate::proto::openssl::tls_server::*;
use crate::{
access::Access,
authority::{MessageRequest, MessageResponseBuilder},
proto::{
error::ProtoError,
@ -42,6 +43,7 @@ pub struct ServerFuture<T: RequestHandler> {
handler: Arc<T>,
join_set: JoinSet<Result<(), ProtoError>>,
shutdown_token: CancellationToken,
access: Arc<Access>,
}
impl<T: RequestHandler> ServerFuture<T> {
@ -51,6 +53,7 @@ impl<T: RequestHandler> ServerFuture<T> {
handler: Arc::new(handler),
join_set: JoinSet::new(),
shutdown_token: CancellationToken::new(),
access: Arc::new(Access::default()),
}
}
@ -64,6 +67,7 @@ impl<T: RequestHandler> ServerFuture<T> {
UdpStream::with_bound(socket, ([127, 255, 255, 254], 0).into());
let shutdown = self.shutdown_token.clone();
let handler = self.handler.clone();
let access = self.access.clone();
// this spawns a ForEach future which handles all the requests into a Handler.
self.join_set.spawn({
@ -100,10 +104,12 @@ impl<T: RequestHandler> ServerFuture<T> {
}
let handler = handler.clone();
let access = access.clone();
let stream_handle = stream_handle.with_remote_addr(src_addr);
inner_join_set.spawn(async move {
handle_raw_request(message, Protocol::Udp, handler, stream_handle).await;
handle_raw_request(message, Protocol::Udp, access, handler, stream_handle)
.await;
});
reap_tasks(&mut inner_join_set);
@ -141,6 +147,7 @@ impl<T: RequestHandler> ServerFuture<T> {
debug!("register tcp: {:?}", listener);
let handler = self.handler.clone();
let access = self.access.clone();
// for each incoming request...
let shutdown = self.shutdown_token.clone();
@ -172,6 +179,7 @@ impl<T: RequestHandler> ServerFuture<T> {
}
let handler = handler.clone();
let access = access.clone();
// and spawn to the io_loop
inner_join_set.spawn(async move {
@ -198,6 +206,7 @@ impl<T: RequestHandler> ServerFuture<T> {
handle_raw_request(
message,
Protocol::Tcp,
access.clone(),
handler.clone(),
stream_handle.clone(),
)
@ -265,6 +274,7 @@ impl<T: RequestHandler> ServerFuture<T> {
let ((cert, chain), key) = certificate_and_key;
let handler = self.handler.clone();
let acces = self.access.clone();
debug!("registered tcp: {:?}", listener);
let tls_acceptor = Box::pin(tls_server::new_acceptor(cert, chain, key)?);
@ -343,6 +353,7 @@ impl<T: RequestHandler> ServerFuture<T> {
self::handle_raw_request(
message,
Protocol::Tls,
access.clone(),
handler.clone(),
stream_handle.clone(),
)
@ -415,6 +426,7 @@ impl<T: RequestHandler> ServerFuture<T> {
use tokio_rustls::TlsAcceptor;
let handler = self.handler.clone();
let access = self.access.clone();
debug!("registered tcp: {:?}", listener);
@ -450,6 +462,7 @@ impl<T: RequestHandler> ServerFuture<T> {
}
let handler = handler.clone();
let access = access.clone();
let tls_acceptor = tls_acceptor.clone();
// kick out to a different task immediately, let them do the TLS handshake
@ -486,6 +499,7 @@ impl<T: RequestHandler> ServerFuture<T> {
handle_raw_request(
message,
Protocol::Tls,
access.clone(),
handler.clone(),
stream_handle.clone(),
)
@ -600,6 +614,7 @@ impl<T: RequestHandler> ServerFuture<T> {
let dns_hostname: Option<Arc<str>> = dns_hostname.map(|n| n.into());
let handler = self.handler.clone();
let access = self.access.clone();
debug!("registered https: {listener:?}");
let tls_acceptor = tls_server::new_acceptor(certificate_and_key.0, certificate_and_key.1)
@ -638,6 +653,7 @@ impl<T: RequestHandler> ServerFuture<T> {
}
let handler = handler.clone();
let access = access.clone();
let tls_acceptor = tls_acceptor.clone();
let dns_hostname = dns_hostname.clone();
@ -658,6 +674,7 @@ impl<T: RequestHandler> ServerFuture<T> {
debug!("accepted HTTPS request from: {src_addr}");
h2_handler(
access,
handler,
tls_stream,
src_addr,
@ -705,6 +722,7 @@ impl<T: RequestHandler> ServerFuture<T> {
let dns_hostname: Option<Arc<str>> = dns_hostname.map(|n| n.into());
let handler = self.handler.clone();
let access = self.access.clone();
debug!("registered quic: {:?}", socket);
let mut server =
@ -743,15 +761,22 @@ impl<T: RequestHandler> ServerFuture<T> {
}
let handler = handler.clone();
let access = access.clone();
let dns_hostname = dns_hostname.clone();
inner_join_set.spawn(async move {
debug!("starting quic stream request from: {src_addr}");
// TODO: need to consider timeout of total connect...
let result =
quic_handler(handler, streams, src_addr, dns_hostname, shutdown.clone())
.await;
let result = quic_handler(
access,
handler,
streams,
src_addr,
dns_hostname,
shutdown.clone(),
)
.await;
if let Err(e) = result {
warn!("quic stream processing failed from {src_addr}: {e}")
@ -796,6 +821,7 @@ impl<T: RequestHandler> ServerFuture<T> {
let dns_hostname: Option<Arc<str>> = dns_hostname.map(|n| n.into());
let handler = self.handler.clone();
let access = self.access.clone();
debug!("registered h3: {:?}", socket);
let mut server =
@ -834,15 +860,22 @@ impl<T: RequestHandler> ServerFuture<T> {
}
let handler = handler.clone();
let access = access.clone();
let dns_hostname = dns_hostname.clone();
inner_join_set.spawn(async move {
debug!("starting h3 stream request from: {src_addr}");
// TODO: need to consider timeout of total connect...
let result =
h3_handler(handler, streams, src_addr, dns_hostname, shutdown.clone())
.await;
let result = h3_handler(
access,
handler,
streams,
src_addr,
dns_hostname,
shutdown.clone(),
)
.await;
if let Err(e) = result {
warn!("h3 stream processing failed from {src_addr}: {e}")
@ -912,6 +945,7 @@ fn reap_tasks(join_set: &mut JoinSet<()>) {
pub(crate) async fn handle_raw_request<T: RequestHandler>(
message: SerialMessage,
protocol: Protocol,
access: Arc<Access>,
request_handler: Arc<T>,
response_handler: BufDnsStreamHandle,
) {
@ -922,6 +956,7 @@ pub(crate) async fn handle_raw_request<T: RequestHandler>(
message.bytes(),
src_addr,
protocol,
access,
request_handler,
response_handler,
)
@ -992,6 +1027,7 @@ pub(crate) async fn handle_request<R: ResponseHandler, T: RequestHandler>(
message_bytes: &[u8],
src_addr: SocketAddr,
protocol: Protocol,
access: Arc<Access>,
request_handler: Arc<T>,
response_handler: R,
) {
@ -1045,49 +1081,86 @@ pub(crate) async fn handle_request<R: ResponseHandler, T: RequestHandler>(
request_handler.handle_request(&request, reporter).await;
};
// method to return an error to the client
let error_response_handler = |protocol: Protocol,
src_addr: SocketAddr,
header: Header,
query: LowerQuery,
response_code: ResponseCode,
error: Box<ProtoError>,
response_handler: R| async move {
// debug for more info on why the message parsing failed
debug!(
"request:{id} src:{proto}://{addr}#{port} type:{message_type} {op}:{response_code}:{error}",
id = header.id(),
proto = protocol,
addr = src_addr.ip(),
port = src_addr.port(),
message_type = header.message_type(),
op = header.op_code(),
response_code = response_code,
error = error,
);
// The reporter will handle making sure to log the result of the request
let mut reporter = ReportingResponseHandler {
request_header: header,
query,
protocol,
src_addr,
handler: response_handler,
};
let response = MessageResponseBuilder::new(None);
let result = reporter
.send_response(response.error_msg(&header, response_code))
.await;
if let Err(e) = result {
warn!("failed to return FormError to client: {}", e);
}
};
// Attempt to decode the message
match MessageRequest::read(&mut decoder) {
Ok(message) => {
match (
MessageRequest::read(&mut decoder),
access.allow(src_addr.ip()),
) {
(Ok(message), Ok(())) => {
inner_handle_request(message, response_handler).await;
}
Err(ProtoError { kind, .. }) if kind.as_form_error().is_some() => {
(Ok(message), Err(error)) => {
// The message will be refused from an non-allowed network
error_response_handler(
protocol,
src_addr,
*message.header(),
message.query().clone(),
ResponseCode::Refused,
Box::new(error),
response_handler,
)
.await;
}
(Err(ProtoError { kind, .. }), _) if kind.as_form_error().is_some() => {
// We failed to parse the request due to some issue in the message, but the header is available, so we can respond
let (header, error) = kind
.into_form_error()
.expect("as form_error already confirmed this is a FormError");
let query = LowerQuery::query(Query::default());
// debug for more info on why the message parsing failed
debug!(
"request:{id} src:{proto}://{addr}#{port} type:{message_type} {op}:FormError:{error}",
id = header.id(),
proto = protocol,
addr = src_addr.ip(),
port = src_addr.port(),
message_type= header.message_type(),
op = header.op_code(),
error = error,
);
// The reporter will handle making sure to log the result of the request
let mut reporter = ReportingResponseHandler {
request_header: header,
query,
error_response_handler(
protocol,
src_addr,
handler: response_handler,
};
let response = MessageResponseBuilder::new(None);
let result = reporter
.send_response(response.error_msg(&header, ResponseCode::FormErr))
.await;
if let Err(e) = result {
warn!("failed to return FormError to client: {}", e);
}
header,
query,
ResponseCode::FormErr,
error,
response_handler,
)
.await;
}
Err(e) => warn!("failed to read message: {}", e),
(Err(e), _) => warn!("failed to read message: {}", e),
}
}