wire up Access to check request IP addr before servicing the IP
This commit is contained in:
parent
4b6c7022c4
commit
4f4f3172bf
@ -240,6 +240,10 @@ pub enum ProtoErrorKind {
|
||||
#[error("lock poisoned error")]
|
||||
Poisoned,
|
||||
|
||||
/// A request was Refused due to some access check
|
||||
#[error("request refused")]
|
||||
RequestRefused,
|
||||
|
||||
/// A ring error
|
||||
#[error("ring error: {0}")]
|
||||
Ring(#[from] Unspecified),
|
||||
@ -624,6 +628,7 @@ impl Clone for ProtoErrorKind {
|
||||
response_code,
|
||||
trusted,
|
||||
},
|
||||
RequestRefused => RequestRefused,
|
||||
RrsigsNotPresent {
|
||||
ref name,
|
||||
ref record_type,
|
||||
|
@ -1,5 +1,6 @@
|
||||
use std::net::IpAddr;
|
||||
|
||||
use hickory_proto::error::{ProtoError, ProtoErrorKind};
|
||||
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
|
||||
use prefix_trie::PrefixSet;
|
||||
|
||||
@ -36,18 +37,24 @@ impl Access {
|
||||
/// # Return
|
||||
///
|
||||
/// Ok if access is granted, Err otherwise
|
||||
pub(crate) fn allow(&self, ip: IpAddr) -> Result<(), ()> {
|
||||
pub(crate) fn allow(&self, ip: IpAddr) -> Result<(), ProtoError> {
|
||||
match ip {
|
||||
IpAddr::V4(v4) => {
|
||||
let v4 = Ipv4Net::from(v4);
|
||||
self.allow_ipv4.as_ref().map_or(Ok(()), |allow_ipv4| {
|
||||
allow_ipv4.get_lpm(&v4).map(|_| ()).ok_or(())
|
||||
allow_ipv4
|
||||
.get_lpm(&v4)
|
||||
.map(|_| ())
|
||||
.ok_or(ProtoErrorKind::RequestRefused.into())
|
||||
})
|
||||
}
|
||||
IpAddr::V6(v6) => {
|
||||
let v6 = Ipv6Net::from(v6);
|
||||
self.allow_ipv6.as_ref().map_or(Ok(()), |allow_ipv6| {
|
||||
allow_ipv6.get_lpm(&v6).map(|_| ()).ok_or(())
|
||||
allow_ipv6
|
||||
.get_lpm(&v6)
|
||||
.map(|_| ())
|
||||
.ok_or(ProtoErrorKind::RequestRefused.into())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::{
|
||||
access::Access,
|
||||
authority::MessageResponse,
|
||||
proto::h2::h2_server,
|
||||
server::{
|
||||
@ -25,6 +26,7 @@ use crate::{
|
||||
};
|
||||
|
||||
pub(crate) async fn h2_handler<T, I>(
|
||||
access: Arc<Access>,
|
||||
handler: Arc<T>,
|
||||
io: I,
|
||||
src_addr: SocketAddr,
|
||||
@ -68,11 +70,12 @@ pub(crate) async fn h2_handler<T, I>(
|
||||
debug!("Received request: {:#?}", request);
|
||||
let dns_hostname = dns_hostname.clone();
|
||||
let handler = handler.clone();
|
||||
let access = access.clone();
|
||||
let responder = HttpsResponseHandle(Arc::new(Mutex::new(respond)));
|
||||
|
||||
tokio::spawn(async move {
|
||||
match h2_server::message_from(dns_hostname, request).await {
|
||||
Ok(bytes) => handle_request(bytes, src_addr, handler, responder).await,
|
||||
Ok(bytes) => handle_request(bytes, src_addr, access, handler, responder).await,
|
||||
Err(err) => warn!("error while handling request from {}: {}", src_addr, err),
|
||||
};
|
||||
});
|
||||
@ -84,12 +87,21 @@ pub(crate) async fn h2_handler<T, I>(
|
||||
async fn handle_request<T>(
|
||||
bytes: BytesMut,
|
||||
src_addr: SocketAddr,
|
||||
access: Arc<Access>,
|
||||
handler: Arc<T>,
|
||||
responder: HttpsResponseHandle,
|
||||
) where
|
||||
T: RequestHandler,
|
||||
{
|
||||
server_future::handle_request(&bytes, src_addr, Protocol::Https, handler, responder).await
|
||||
server_future::handle_request(
|
||||
&bytes,
|
||||
src_addr,
|
||||
Protocol::Https,
|
||||
access,
|
||||
handler,
|
||||
responder,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
@ -18,6 +18,7 @@ use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::{
|
||||
access::Access,
|
||||
authority::MessageResponse,
|
||||
server::{
|
||||
request_handler::RequestHandler, response_handler::ResponseHandler, server_future,
|
||||
@ -26,6 +27,7 @@ use crate::{
|
||||
};
|
||||
|
||||
pub(crate) async fn h3_handler<T>(
|
||||
access: Arc<Access>,
|
||||
handler: Arc<T>,
|
||||
mut connection: H3Connection,
|
||||
src_addr: SocketAddr,
|
||||
@ -71,10 +73,13 @@ where
|
||||
request.remaining()
|
||||
);
|
||||
let handler = handler.clone();
|
||||
let access = access.clone();
|
||||
let stream = Arc::new(Mutex::new(stream));
|
||||
let responder = H3ResponseHandle(stream.clone());
|
||||
|
||||
tokio::spawn(handle_request(request, src_addr, handler, responder));
|
||||
tokio::spawn(handle_request(
|
||||
request, src_addr, access, handler, responder,
|
||||
));
|
||||
|
||||
max_requests -= 1;
|
||||
if max_requests == 0 {
|
||||
@ -91,12 +96,13 @@ where
|
||||
async fn handle_request<T>(
|
||||
bytes: Bytes,
|
||||
src_addr: SocketAddr,
|
||||
access: Arc<Access>,
|
||||
handler: Arc<T>,
|
||||
responder: H3ResponseHandle,
|
||||
) where
|
||||
T: RequestHandler,
|
||||
{
|
||||
server_future::handle_request(&bytes, src_addr, Protocol::H3, handler, responder).await
|
||||
server_future::handle_request(&bytes, src_addr, Protocol::H3, access, handler, responder).await
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
@ -18,6 +18,7 @@ use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::{
|
||||
access::Access,
|
||||
authority::MessageResponse,
|
||||
proto::quic::QuicStreams,
|
||||
server::{
|
||||
@ -27,6 +28,7 @@ use crate::{
|
||||
};
|
||||
|
||||
pub(crate) async fn quic_handler<T>(
|
||||
access: Arc<Access>,
|
||||
handler: Arc<T>,
|
||||
mut quic_streams: QuicStreams,
|
||||
src_addr: SocketAddr,
|
||||
@ -65,10 +67,11 @@ where
|
||||
request.len()
|
||||
);
|
||||
let handler = handler.clone();
|
||||
let access = access.clone();
|
||||
let stream = Arc::new(Mutex::new(request_stream));
|
||||
let responder = QuicResponseHandle(stream.clone());
|
||||
|
||||
handle_request(request, src_addr, handler, responder).await;
|
||||
handle_request(request, src_addr, access, handler, responder).await;
|
||||
|
||||
max_requests -= 1;
|
||||
if max_requests == 0 {
|
||||
@ -86,12 +89,14 @@ where
|
||||
async fn handle_request<T>(
|
||||
bytes: BytesMut,
|
||||
src_addr: SocketAddr,
|
||||
access: Arc<Access>,
|
||||
handler: Arc<T>,
|
||||
responder: QuicResponseHandle,
|
||||
) where
|
||||
T: RequestHandler,
|
||||
{
|
||||
server_future::handle_request(&bytes, src_addr, Protocol::Quic, handler, responder).await
|
||||
server_future::handle_request(&bytes, src_addr, Protocol::Quic, access, handler, responder)
|
||||
.await
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
@ -22,6 +22,7 @@ use tracing::{debug, info, warn};
|
||||
#[cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls")))]
|
||||
use crate::proto::openssl::tls_server::*;
|
||||
use crate::{
|
||||
access::Access,
|
||||
authority::{MessageRequest, MessageResponseBuilder},
|
||||
proto::{
|
||||
error::ProtoError,
|
||||
@ -42,6 +43,7 @@ pub struct ServerFuture<T: RequestHandler> {
|
||||
handler: Arc<T>,
|
||||
join_set: JoinSet<Result<(), ProtoError>>,
|
||||
shutdown_token: CancellationToken,
|
||||
access: Arc<Access>,
|
||||
}
|
||||
|
||||
impl<T: RequestHandler> ServerFuture<T> {
|
||||
@ -51,6 +53,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
handler: Arc::new(handler),
|
||||
join_set: JoinSet::new(),
|
||||
shutdown_token: CancellationToken::new(),
|
||||
access: Arc::new(Access::default()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -64,6 +67,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
UdpStream::with_bound(socket, ([127, 255, 255, 254], 0).into());
|
||||
let shutdown = self.shutdown_token.clone();
|
||||
let handler = self.handler.clone();
|
||||
let access = self.access.clone();
|
||||
|
||||
// this spawns a ForEach future which handles all the requests into a Handler.
|
||||
self.join_set.spawn({
|
||||
@ -100,10 +104,12 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
}
|
||||
|
||||
let handler = handler.clone();
|
||||
let access = access.clone();
|
||||
let stream_handle = stream_handle.with_remote_addr(src_addr);
|
||||
|
||||
inner_join_set.spawn(async move {
|
||||
handle_raw_request(message, Protocol::Udp, handler, stream_handle).await;
|
||||
handle_raw_request(message, Protocol::Udp, access, handler, stream_handle)
|
||||
.await;
|
||||
});
|
||||
|
||||
reap_tasks(&mut inner_join_set);
|
||||
@ -141,6 +147,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
debug!("register tcp: {:?}", listener);
|
||||
|
||||
let handler = self.handler.clone();
|
||||
let access = self.access.clone();
|
||||
|
||||
// for each incoming request...
|
||||
let shutdown = self.shutdown_token.clone();
|
||||
@ -172,6 +179,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
}
|
||||
|
||||
let handler = handler.clone();
|
||||
let access = access.clone();
|
||||
|
||||
// and spawn to the io_loop
|
||||
inner_join_set.spawn(async move {
|
||||
@ -198,6 +206,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
handle_raw_request(
|
||||
message,
|
||||
Protocol::Tcp,
|
||||
access.clone(),
|
||||
handler.clone(),
|
||||
stream_handle.clone(),
|
||||
)
|
||||
@ -265,6 +274,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
let ((cert, chain), key) = certificate_and_key;
|
||||
|
||||
let handler = self.handler.clone();
|
||||
let acces = self.access.clone();
|
||||
debug!("registered tcp: {:?}", listener);
|
||||
|
||||
let tls_acceptor = Box::pin(tls_server::new_acceptor(cert, chain, key)?);
|
||||
@ -343,6 +353,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
self::handle_raw_request(
|
||||
message,
|
||||
Protocol::Tls,
|
||||
access.clone(),
|
||||
handler.clone(),
|
||||
stream_handle.clone(),
|
||||
)
|
||||
@ -415,6 +426,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
|
||||
let handler = self.handler.clone();
|
||||
let access = self.access.clone();
|
||||
|
||||
debug!("registered tcp: {:?}", listener);
|
||||
|
||||
@ -450,6 +462,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
}
|
||||
|
||||
let handler = handler.clone();
|
||||
let access = access.clone();
|
||||
let tls_acceptor = tls_acceptor.clone();
|
||||
|
||||
// kick out to a different task immediately, let them do the TLS handshake
|
||||
@ -486,6 +499,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
handle_raw_request(
|
||||
message,
|
||||
Protocol::Tls,
|
||||
access.clone(),
|
||||
handler.clone(),
|
||||
stream_handle.clone(),
|
||||
)
|
||||
@ -600,6 +614,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
let dns_hostname: Option<Arc<str>> = dns_hostname.map(|n| n.into());
|
||||
|
||||
let handler = self.handler.clone();
|
||||
let access = self.access.clone();
|
||||
debug!("registered https: {listener:?}");
|
||||
|
||||
let tls_acceptor = tls_server::new_acceptor(certificate_and_key.0, certificate_and_key.1)
|
||||
@ -638,6 +653,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
}
|
||||
|
||||
let handler = handler.clone();
|
||||
let access = access.clone();
|
||||
let tls_acceptor = tls_acceptor.clone();
|
||||
let dns_hostname = dns_hostname.clone();
|
||||
|
||||
@ -658,6 +674,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
debug!("accepted HTTPS request from: {src_addr}");
|
||||
|
||||
h2_handler(
|
||||
access,
|
||||
handler,
|
||||
tls_stream,
|
||||
src_addr,
|
||||
@ -705,6 +722,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
let dns_hostname: Option<Arc<str>> = dns_hostname.map(|n| n.into());
|
||||
|
||||
let handler = self.handler.clone();
|
||||
let access = self.access.clone();
|
||||
|
||||
debug!("registered quic: {:?}", socket);
|
||||
let mut server =
|
||||
@ -743,15 +761,22 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
}
|
||||
|
||||
let handler = handler.clone();
|
||||
let access = access.clone();
|
||||
let dns_hostname = dns_hostname.clone();
|
||||
|
||||
inner_join_set.spawn(async move {
|
||||
debug!("starting quic stream request from: {src_addr}");
|
||||
|
||||
// TODO: need to consider timeout of total connect...
|
||||
let result =
|
||||
quic_handler(handler, streams, src_addr, dns_hostname, shutdown.clone())
|
||||
.await;
|
||||
let result = quic_handler(
|
||||
access,
|
||||
handler,
|
||||
streams,
|
||||
src_addr,
|
||||
dns_hostname,
|
||||
shutdown.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
warn!("quic stream processing failed from {src_addr}: {e}")
|
||||
@ -796,6 +821,7 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
let dns_hostname: Option<Arc<str>> = dns_hostname.map(|n| n.into());
|
||||
|
||||
let handler = self.handler.clone();
|
||||
let access = self.access.clone();
|
||||
|
||||
debug!("registered h3: {:?}", socket);
|
||||
let mut server =
|
||||
@ -834,15 +860,22 @@ impl<T: RequestHandler> ServerFuture<T> {
|
||||
}
|
||||
|
||||
let handler = handler.clone();
|
||||
let access = access.clone();
|
||||
let dns_hostname = dns_hostname.clone();
|
||||
|
||||
inner_join_set.spawn(async move {
|
||||
debug!("starting h3 stream request from: {src_addr}");
|
||||
|
||||
// TODO: need to consider timeout of total connect...
|
||||
let result =
|
||||
h3_handler(handler, streams, src_addr, dns_hostname, shutdown.clone())
|
||||
.await;
|
||||
let result = h3_handler(
|
||||
access,
|
||||
handler,
|
||||
streams,
|
||||
src_addr,
|
||||
dns_hostname,
|
||||
shutdown.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
warn!("h3 stream processing failed from {src_addr}: {e}")
|
||||
@ -912,6 +945,7 @@ fn reap_tasks(join_set: &mut JoinSet<()>) {
|
||||
pub(crate) async fn handle_raw_request<T: RequestHandler>(
|
||||
message: SerialMessage,
|
||||
protocol: Protocol,
|
||||
access: Arc<Access>,
|
||||
request_handler: Arc<T>,
|
||||
response_handler: BufDnsStreamHandle,
|
||||
) {
|
||||
@ -922,6 +956,7 @@ pub(crate) async fn handle_raw_request<T: RequestHandler>(
|
||||
message.bytes(),
|
||||
src_addr,
|
||||
protocol,
|
||||
access,
|
||||
request_handler,
|
||||
response_handler,
|
||||
)
|
||||
@ -992,6 +1027,7 @@ pub(crate) async fn handle_request<R: ResponseHandler, T: RequestHandler>(
|
||||
message_bytes: &[u8],
|
||||
src_addr: SocketAddr,
|
||||
protocol: Protocol,
|
||||
access: Arc<Access>,
|
||||
request_handler: Arc<T>,
|
||||
response_handler: R,
|
||||
) {
|
||||
@ -1045,49 +1081,86 @@ pub(crate) async fn handle_request<R: ResponseHandler, T: RequestHandler>(
|
||||
request_handler.handle_request(&request, reporter).await;
|
||||
};
|
||||
|
||||
// method to return an error to the client
|
||||
let error_response_handler = |protocol: Protocol,
|
||||
src_addr: SocketAddr,
|
||||
header: Header,
|
||||
query: LowerQuery,
|
||||
response_code: ResponseCode,
|
||||
error: Box<ProtoError>,
|
||||
response_handler: R| async move {
|
||||
// debug for more info on why the message parsing failed
|
||||
debug!(
|
||||
"request:{id} src:{proto}://{addr}#{port} type:{message_type} {op}:{response_code}:{error}",
|
||||
id = header.id(),
|
||||
proto = protocol,
|
||||
addr = src_addr.ip(),
|
||||
port = src_addr.port(),
|
||||
message_type = header.message_type(),
|
||||
op = header.op_code(),
|
||||
response_code = response_code,
|
||||
error = error,
|
||||
);
|
||||
|
||||
// The reporter will handle making sure to log the result of the request
|
||||
let mut reporter = ReportingResponseHandler {
|
||||
request_header: header,
|
||||
query,
|
||||
protocol,
|
||||
src_addr,
|
||||
handler: response_handler,
|
||||
};
|
||||
|
||||
let response = MessageResponseBuilder::new(None);
|
||||
let result = reporter
|
||||
.send_response(response.error_msg(&header, response_code))
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
warn!("failed to return FormError to client: {}", e);
|
||||
}
|
||||
};
|
||||
|
||||
// Attempt to decode the message
|
||||
match MessageRequest::read(&mut decoder) {
|
||||
Ok(message) => {
|
||||
match (
|
||||
MessageRequest::read(&mut decoder),
|
||||
access.allow(src_addr.ip()),
|
||||
) {
|
||||
(Ok(message), Ok(())) => {
|
||||
inner_handle_request(message, response_handler).await;
|
||||
}
|
||||
Err(ProtoError { kind, .. }) if kind.as_form_error().is_some() => {
|
||||
(Ok(message), Err(error)) => {
|
||||
// The message will be refused from an non-allowed network
|
||||
error_response_handler(
|
||||
protocol,
|
||||
src_addr,
|
||||
*message.header(),
|
||||
message.query().clone(),
|
||||
ResponseCode::Refused,
|
||||
Box::new(error),
|
||||
response_handler,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
(Err(ProtoError { kind, .. }), _) if kind.as_form_error().is_some() => {
|
||||
// We failed to parse the request due to some issue in the message, but the header is available, so we can respond
|
||||
let (header, error) = kind
|
||||
.into_form_error()
|
||||
.expect("as form_error already confirmed this is a FormError");
|
||||
let query = LowerQuery::query(Query::default());
|
||||
|
||||
// debug for more info on why the message parsing failed
|
||||
debug!(
|
||||
"request:{id} src:{proto}://{addr}#{port} type:{message_type} {op}:FormError:{error}",
|
||||
id = header.id(),
|
||||
proto = protocol,
|
||||
addr = src_addr.ip(),
|
||||
port = src_addr.port(),
|
||||
message_type= header.message_type(),
|
||||
op = header.op_code(),
|
||||
error = error,
|
||||
);
|
||||
|
||||
// The reporter will handle making sure to log the result of the request
|
||||
let mut reporter = ReportingResponseHandler {
|
||||
request_header: header,
|
||||
query,
|
||||
error_response_handler(
|
||||
protocol,
|
||||
src_addr,
|
||||
handler: response_handler,
|
||||
};
|
||||
|
||||
let response = MessageResponseBuilder::new(None);
|
||||
let result = reporter
|
||||
.send_response(response.error_msg(&header, ResponseCode::FormErr))
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
warn!("failed to return FormError to client: {}", e);
|
||||
}
|
||||
header,
|
||||
query,
|
||||
ResponseCode::FormErr,
|
||||
error,
|
||||
response_handler,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Err(e) => warn!("failed to read message: {}", e),
|
||||
(Err(e), _) => warn!("failed to read message: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user