change UdpSocket to have Poll based methods

This commit is contained in:
Benjamin Fry 2020-11-12 16:43:51 -08:00
parent 4d10e6f414
commit ecaf8c92d5
4 changed files with 60 additions and 23 deletions

1
Cargo.lock generated
View File

@ -151,6 +151,7 @@ dependencies = [
"async-std",
"async-trait",
"futures-io",
"futures-util",
"trust-dns-resolver",
]

View File

@ -67,6 +67,7 @@ path = "src/lib.rs"
async-std = "1.6"
async-trait = "0.1.36"
futures-io = { version = "0.3.5", default-features = false, features = ["std"] }
futures-util = { version = "0.3.5", default-features = false, features = ["std"] }
trust-dns-resolver = { version = "0.20.0-alpha.3", path = "../resolver", default-features = false }
[dev-dependencies]

View File

@ -8,9 +8,11 @@
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use async_trait::async_trait;
use futures_io::{AsyncRead, AsyncWrite};
use futures_util::future::FutureExt;
use trust_dns_resolver::proto::tcp::{Connect, DnsTcpStream};
use trust_dns_resolver::proto::udp::UdpSocket;
@ -28,12 +30,21 @@ impl UdpSocket for AsyncStdUdpSocket {
.map(AsyncStdUdpSocket)
}
async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.0.recv_from(buf).await
fn poll_recv_from(
&self,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<(usize, SocketAddr)>> {
Box::pin(self.0.recv_from(buf)).poll_unpin(cx)
}
async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> io::Result<usize> {
self.0.send_to(buf, target).await
fn poll_send_to(
&self,
cx: &mut Context,
buf: &[u8],
target: &SocketAddr,
) -> Poll<io::Result<usize>> {
Box::pin(self.0.send_to(buf, target)).poll_unpin(cx)
}
}

View File

@ -13,7 +13,7 @@ use std::task::{Context, Poll};
use async_trait::async_trait;
use futures_util::stream::Stream;
use futures_util::{future::Future, ready, FutureExt, TryFutureExt};
use futures_util::{future::Future, ready, TryFutureExt};
use log::debug;
use rand;
use rand::distributions::{uniform::Uniform, Distribution};
@ -32,11 +32,33 @@ where
/// UdpSocket
async fn bind(addr: &SocketAddr) -> io::Result<Self>;
/// Poll once Receive data from the socket and returns the number of bytes read and the address from
/// where the data came on success.
fn poll_recv_from(
&self,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<(usize, SocketAddr)>>;
/// Receive data from the socket and returns the number of bytes read and the address from
/// where the data came on success.
async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)>;
async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
futures_util::future::poll_fn(|cx| self.poll_recv_from(cx, buf)).await
}
/// Poll once to send data to the given address.
fn poll_send_to(
&self,
cx: &mut Context,
buf: &[u8],
target: &SocketAddr,
) -> Poll<io::Result<usize>>;
/// Send data to the given address.
async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> io::Result<usize>;
async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> io::Result<usize> {
futures_util::future::poll_fn(|cx| self.poll_send_to(cx, buf, target)).await
}
}
/// A UDP stream of DNS binary packets
@ -140,7 +162,7 @@ impl<S: UdpSocket + Send + 'static> Stream for UdpStream<S> {
// meaning that sending will be prefered over receiving...
// TODO: shouldn't this return the error to send to the sender?
ready!(socket.send_to(message.bytes(), addr).poll_unpin(cx))?;
ready!(socket.poll_send_to(cx, message.bytes(), addr))?;
// message sent, need to pop the message
assert!(outbound_messages.as_mut().poll_next(cx).is_ready());
@ -151,7 +173,7 @@ impl<S: UdpSocket + Send + 'static> Stream for UdpStream<S> {
// TODO: this should match edns settings
let mut buf = [0u8; 4096];
let (len, src) = ready!(socket.recv_from(&mut buf).poll_unpin(cx))?;
let (len, src) = ready!(socket.poll_recv_from(cx, &mut buf))?;
let serial_message = SerialMessage::new(buf.iter().take(len).cloned().collect(), src);
Poll::Ready(Some(Ok(serial_message)))
@ -230,23 +252,25 @@ impl UdpSocket for tokio::net::UdpSocket {
tokio::net::UdpSocket::bind(addr).await
}
// TODO: add poll_recv_from and poll_send_to to be more efficient in allocations...
async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
futures_util::future::poll_fn(|cx| {
let mut buf = tokio::io::ReadBuf::new(buf);
let addr = ready!(tokio::net::UdpSocket::poll_recv_from(self, cx, &mut buf))?;
let len = buf.filled().len();
fn poll_recv_from(
&self,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<(usize, SocketAddr)>> {
let mut buf = tokio::io::ReadBuf::new(buf);
let addr = ready!(tokio::net::UdpSocket::poll_recv_from(self, cx, &mut buf))?;
let len = buf.filled().len();
Poll::Ready(Ok((len, addr)))
})
.await
Poll::Ready(Ok((len, addr)))
}
async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> io::Result<usize> {
futures_util::future::poll_fn(|cx| {
tokio::net::UdpSocket::poll_send_to(self, cx, buf, target)
})
.await
fn poll_send_to(
&self,
cx: &mut Context,
buf: &[u8],
target: &SocketAddr,
) -> Poll<io::Result<usize>> {
tokio::net::UdpSocket::poll_send_to(self, cx, buf, target)
}
}