move to new mpsc channels

This commit is contained in:
Benjamin Fry 2016-11-23 23:11:57 -08:00
parent 3f5f770741
commit cbfa7bea5a
17 changed files with 153 additions and 130 deletions

View File

@ -11,6 +11,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
### Changed
- Split Server and Client into separate crates, #43
- Moved many integration tests to `tests` from `src`, #52
- Migrated all handles to new futures::sync::mpsc impls
### Fixed
- Flush TcpStream after fully sending Message

View File

@ -10,15 +10,15 @@ use std::io;
use std::time::Duration;
use chrono::UTC;
use futures;
use futures::{Async, Complete, Future, Poll, task};
use futures::IntoFuture;
use futures::stream::{Peekable, Fuse as StreamFuse, Stream};
use futures::sync::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
use futures::sync::oneshot;
use futures::task::park;
use rand::Rng;
use rand;
use tokio_core::reactor::{Handle, Timeout};
use tokio_core::channel::{channel, Sender, Receiver};
use ::error::*;
use ::op::{Message, MessageType, OpCode, Query, UpdateMessage};
@ -29,15 +29,15 @@ use ::rr::rdata::NULL;
const QOS_MAX_RECEIVE_MSGS: usize = 100; // max number of messages to receive from the UDP socket
/// A reference to a Sender of bytes returned from the creation of a UdpClientStream or TcpClientStream
pub type StreamHandle = Sender<Vec<u8>>;
pub type StreamHandle = UnboundedSender<Vec<u8>>;
pub trait ClientStreamHandle {
fn send(&self, buffer: Vec<u8>) -> io::Result<()>;
fn send(&mut self, buffer: Vec<u8>) -> io::Result<()>;
}
impl ClientStreamHandle for StreamHandle {
fn send(&self, buffer: Vec<u8>) -> io::Result<()> {
Sender::send(self, buffer)
fn send(&mut self, buffer: Vec<u8>) -> io::Result<()> {
UnboundedSender::send(self, buffer).map_err(|_| io::Error::new(io::ErrorKind::Other, "unknown"))
}
}
@ -52,7 +52,7 @@ pub struct ClientFuture<S: Stream<Item=Vec<u8>, Error=io::Error>> {
timeout_duration: Duration,
// TODO genericize and remove this Box
stream_handle: Box<ClientStreamHandle>,
new_receiver: Peekable<StreamFuse<Receiver<(Message, Complete<ClientResult<Message>>)>>>,
new_receiver: Peekable<StreamFuse<UnboundedReceiver<(Message, Complete<ClientResult<Message>>)>>>,
active_requests: HashMap<u16, (Complete<ClientResult<Message>>, Timeout)>,
// TODO: Maybe make a typed version of ClientFuture for Updates?
signer: Option<Signer>,
@ -93,7 +93,7 @@ impl<S: Stream<Item=Vec<u8>, Error=io::Error> + 'static> ClientFuture<S> {
loop_handle: Handle,
timeout_duration: Duration,
signer: Option<Signer>) -> BasicClientHandle {
let (sender, rx) = channel(&loop_handle).expect("could not get channel!");
let (sender, rx) = unbounded();
let loop_handle_clone = loop_handle.clone();
loop_handle.spawn(
@ -192,8 +192,8 @@ impl<S: Stream<Item=Vec<u8>, Error=io::Error> + 'static> Future for ClientFuture
}
},
Ok(_) => None,
Err(e) => {
warn!("receiver was shutdown? {}", e);
Err(()) => {
warn!("receiver was shutdown?");
break
},
};
@ -247,8 +247,8 @@ impl<S: Stream<Item=Vec<u8>, Error=io::Error> + 'static> Future for ClientFuture
}
},
Ok(_) => break,
Err(e) => {
warn!("receiver was shutdown? {}", e);
Err(()) => {
warn!("receiver was shutdown?");
break
},
}
@ -307,24 +307,25 @@ impl<S: Stream<Item=Vec<u8>, Error=io::Error> + 'static> Future for ClientFuture
#[derive(Clone)]
#[must_use = "queries can only be sent through a ClientHandle"]
pub struct BasicClientHandle {
message_sender: Sender<(Message, Complete<ClientResult<Message>>)>,
message_sender: UnboundedSender<(Message, Complete<ClientResult<Message>>)>,
}
impl ClientHandle for BasicClientHandle {
fn send(&self, message: Message) -> Box<Future<Item=Message, Error=ClientError>> {
let (complete, oneshot) = futures::oneshot();
fn send(&mut self, message: Message) -> Box<Future<Item=Message, Error=ClientError>> {
let (complete, receiver) = oneshot::channel();
let message_sender: &mut _ = &mut self.message_sender;
let oneshot = match self.message_sender.send((message, complete)) {
Ok(()) => oneshot,
let receiver = match message_sender.send((message, complete)) {
Ok(()) => receiver,
Err(e) => {
let (complete, oneshot) = futures::oneshot();
let (complete, receiver) = oneshot::channel();
complete.complete(Err(e.into()));
oneshot
receiver
}
};
// conver the oneshot into a Box of a Future message and error.
Box::new(oneshot.map_err(|c| ClientError::from(c)).map(|result| result.into_future()).flatten())
Box::new(receiver.map_err(|c| ClientError::from(c)).map(|result| result.into_future()).flatten())
}
}
@ -338,7 +339,7 @@ pub trait ClientHandle: Clone {
/// * `message` - the fully constructed Message to send, note that most implementations of
/// will most likely be required to rewrite the QueryId, do no rely on that as
/// being stable.
fn send(&self, message: Message) -> Box<Future<Item=Message, Error=ClientError>>;
fn send(&mut self, message: Message) -> Box<Future<Item=Message, Error=ClientError>>;
/// A *classic* DNS query
///
@ -350,7 +351,7 @@ pub trait ClientHandle: Clone {
/// * `name` - the label to lookup
/// * `query_class` - most likely this should always be DNSClass::IN
/// * `query_type` - record type to lookup
fn query(&self, name: domain::Name, query_class: DNSClass, query_type: RecordType)
fn query(&mut self, name: domain::Name, query_class: DNSClass, query_type: RecordType)
-> Box<Future<Item=Message, Error=ClientError>> {
debug!("querying: {} {:?}", name, query_type);
@ -408,7 +409,7 @@ pub trait ClientHandle: Clone {
/// * `zone_origin` - the zone name to update, i.e. SOA name
///
/// The update must go to a zone authority (i.e. the server used in the ClientConnection)
fn create(&self,
fn create(&mut self,
record: Record,
zone_origin: domain::Name)
-> Box<Future<Item=Message, Error=ClientError>> {
@ -472,7 +473,7 @@ pub trait ClientHandle: Clone {
///
/// The update must go to a zone authority (i.e. the server used in the ClientConnection). If
/// the rrset does not exist and must_exist is false, then the RRSet will be created.
fn append(&self,
fn append(&mut self,
record: Record,
zone_origin: domain::Name,
must_exist: bool)
@ -547,7 +548,7 @@ pub trait ClientHandle: Clone {
/// * `zone_origin` - the zone name to update, i.e. SOA name
///
/// The update must go to a zone authority (i.e. the server used in the ClientConnection).
fn compare_and_swap(&self,
fn compare_and_swap(&mut self,
current: Record,
new: Record,
zone_origin: domain::Name)
@ -626,7 +627,7 @@ pub trait ClientHandle: Clone {
///
/// The update must go to a zone authority (i.e. the server used in the ClientConnection). If
/// the rrset does not exist and must_exist is false, then the RRSet will be deleted.
fn delete_by_rdata(&self,
fn delete_by_rdata(&mut self,
mut record: Record,
zone_origin: domain::Name)
-> Box<Future<Item=Message, Error=ClientError>> {
@ -693,7 +694,7 @@ pub trait ClientHandle: Clone {
///
/// The update must go to a zone authority (i.e. the server used in the ClientConnection). If
/// the rrset does not exist and must_exist is false, then the RRSet will be deleted.
fn delete_rrset(&self,
fn delete_rrset(&mut self,
mut record: Record,
zone_origin: domain::Name)
-> Box<Future<Item=Message, Error=ClientError>> {
@ -751,7 +752,7 @@ pub trait ClientHandle: Clone {
/// The update must go to a zone authority (i.e. the server used in the ClientConnection). This
/// operation attempts to delete all resource record sets the the specified name reguardless of
/// the record type.
fn delete_all(&self,
fn delete_all(&mut self,
name_of_records: domain::Name,
zone_origin: domain::Name,
dns_class: DNSClass)

View File

@ -36,23 +36,26 @@ impl<H> MemoizeClientHandle<H> where H: ClientHandle {
}
impl<H> ClientHandle for MemoizeClientHandle<H> where H: ClientHandle {
// TODO: should send be &mut so that we don't need RefCell here?
fn send(&self, message: Message) -> Box<Future<Item=Message, Error=ClientError>> {
fn send(&mut self, message: Message) -> Box<Future<Item=Message, Error=ClientError>> {
let query = message.get_queries().first().expect("no query!").clone();
if let Some(rc_future) = self.active_queries.borrow().get(&query) {
// TODO check TTLs?
// FIXME check TTLs?
return Box::new(rc_future.clone());
}
// TODO: it should be safe to loop here until the entry.or_insert_with returns...
// check if there are active queries
let mut map = self.active_queries.borrow_mut();
let rc_future = map.entry(query).or_insert_with(move ||{
rc_future(self.client.send(message))
});
{
let map = self.active_queries.borrow();
let request = map.get(&query);
if request.is_some() { return Box::new(request.unwrap().clone()) }
}
return Box::new(rc_future.clone());
let request = rc_future(self.client.send(message));
let mut map = self.active_queries.borrow_mut();
map.insert(query, request.clone());
return Box::new(request);
}
}
@ -69,7 +72,7 @@ mod test {
struct TestClient { i: Cell<u16> }
impl ClientHandle for TestClient {
fn send(&self, _: Message) -> Box<Future<Item=Message, Error=ClientError>> {
fn send(&mut self, _: Message) -> Box<Future<Item=Message, Error=ClientError>> {
let mut message = Message::new();
let i = self.i.get();
@ -82,7 +85,7 @@ mod test {
#[test]
fn test_memoized() {
let client = MemoizeClientHandle::new(TestClient{i: Cell::new(0)});
let mut client = MemoizeClientHandle::new(TestClient{i: Cell::new(0)});
let mut test1 = Message::new();
test1.add_query(Query::new().query_type(RecordType::A).clone());

View File

@ -28,7 +28,7 @@ impl<H> RetryClientHandle<H> where H: ClientHandle {
}
impl<H> ClientHandle for RetryClientHandle<H> where H: ClientHandle + 'static {
fn send(&self, message: Message) -> Box<Future<Item=Message, Error=ClientError>> {
fn send(&mut self, message: Message) -> Box<Future<Item=Message, Error=ClientError>> {
// need to clone here so that the retry can resend if necessary...
// obviously it would be nice to be lazy about this...
let future = self.client.send(message.clone());
@ -87,7 +87,7 @@ mod test {
struct TestClient { last_succeed: bool, retries: u16, attempts: Cell<u16> }
impl ClientHandle for TestClient {
fn send(&self, _: Message) -> Box<Future<Item=Message, Error=ClientError>> {
fn send(&mut self, _: Message) -> Box<Future<Item=Message, Error=ClientError>> {
let i = self.attempts.get();
if i > self.retries || self.retries - i == 0 {
@ -105,7 +105,7 @@ mod test {
#[test]
fn test_retry() {
let client = RetryClientHandle::new(TestClient{last_succeed: true, retries: 1, attempts: Cell::new(0)}, 2);
let mut client = RetryClientHandle::new(TestClient{last_succeed: true, retries: 1, attempts: Cell::new(0)}, 2);
let test1 = Message::new();
let result = client.send(test1).wait().ok().expect("should have succeeded");
assert_eq!(result.get_id(), 1); // this is checking the number of iterations the TestCient ran
@ -113,7 +113,7 @@ mod test {
#[test]
fn test_error() {
let client = RetryClientHandle::new(TestClient{last_succeed: false, retries: 1, attempts: Cell::new(0)}, 2);
let mut client = RetryClientHandle::new(TestClient{last_succeed: false, retries: 1, attempts: Cell::new(0)}, 2);
let test1 = Message::new();
assert!(client.send(test1).wait().is_err());

View File

@ -80,7 +80,7 @@ impl<H> SecureClientHandle<H> where H: ClientHandle + 'static {
}
impl<H> ClientHandle for SecureClientHandle<H> where H: ClientHandle + 'static {
fn send(&self, mut message: Message) -> Box<Future<Item=Message, Error=ClientError>> {
fn send(&mut self, mut message: Message) -> Box<Future<Item=Message, Error=ClientError>> {
// backstop, this might need to be configurable at some point
if self.request_depth > 20 {
return Box::new(failed(ClientErrorKind::Message("exceeded max validation depth").into()))
@ -342,7 +342,7 @@ fn verify_rrset<H>(client: SecureClientHandle<H>,
/// as a success. Otherwise, a query is sent to get the DS record, and the DNSKEY is validated
/// against the DS record.
fn verify_dnskey_rrset<H>(
client: SecureClientHandle<H>,
mut client: SecureClientHandle<H>,
rrset: Rrset)
-> Box<Future<Item=Rrset, Error=ClientError>>
where H: ClientHandle
@ -609,7 +609,7 @@ fn verify_default_rrset<H>(
)
.map(|sig| {
let rrset = rrset.clone();
let client = client.clone_with_context();
let mut client = client.clone_with_context();
client.query(sig.get_signer_name().clone(), rrset.record_class, RecordType::DNSKEY)
.and_then(move |message|

View File

@ -15,10 +15,9 @@
*/
use std::io::Error as IoError;
use std::sync::Arc;
use backtrace::Backtrace;
use futures::Canceled;
use futures::sync::mpsc::SendError;
use openssl::error::ErrorStack as SslErrorStack;
use ::op::ResponseCode;
@ -60,6 +59,11 @@ error_chain! {
// the same as `quick_error!`, but the `from()` and `cause()`
// syntax is not supported.
errors {
NoError {
description("no error specified")
display("no error specified")
}
Canceled(c: Canceled) {
description("future was canceled")
display("future was canceled: {:?}", c)
@ -125,9 +129,22 @@ error_chain! {
}
}
impl From<()> for Error {
fn from(_: ()) -> Self {
ErrorKind::NoError.into()
}
}
impl From<Canceled> for Error {
fn from(c: Canceled) -> Self {
Error(ErrorKind::Canceled(c), (None, Arc::new(Backtrace::new())))
ErrorKind::Canceled(c).into()
}
}
impl<T> From<SendError<T>> for Error {
fn from(e: SendError<T>) -> Self {
ErrorKind::Msg(format!("error sending to mpsc: {}", e)).into()
}
}

View File

@ -52,8 +52,8 @@ pub mod serialize;
use std::io;
use std::net::SocketAddr;
use futures::sync::mpsc::UnboundedSender;
use futures::stream::Stream;
use tokio_core::channel::Sender;
use op::Message;
use client::ClientStreamHandle;
@ -62,13 +62,13 @@ use client::ClientStreamHandle;
pub type BufStream = Stream<Item=(Vec<u8>, SocketAddr), Error=io::Error>;
/// A sender to which serialized DNS Messages can be sent
pub type BufStreamHandle = Sender<(Vec<u8>, SocketAddr)>;
pub type BufStreamHandle = UnboundedSender<(Vec<u8>, SocketAddr)>;
/// A stream of messsages
pub type MessageStream = Stream<Item=Message, Error=io::Error>;
/// A sender to which a Message can be sent
pub type MessageStreamHandle = Sender<Message>;
pub type MessageStreamHandle = UnboundedSender<Message>;
pub struct BufClientStreamHandle {
name_server: SocketAddr,
@ -76,8 +76,10 @@ pub struct BufClientStreamHandle {
}
impl ClientStreamHandle for BufClientStreamHandle {
fn send(&self, buffer: Vec<u8>) -> io::Result<()> {
self.sender.send((buffer, self.name_server))
fn send(&mut self, buffer: Vec<u8>) -> io::Result<()> {
let name_server: SocketAddr = self.name_server;
let sender: &mut _ = &mut self.sender;
sender.send((buffer, name_server)).map_err(|_| io::Error::new(io::ErrorKind::Other, "unknown"))
}
}

View File

@ -144,7 +144,7 @@ fn tcp_client_stream_test(server_addr: IpAddr) {
// the tests should run within 5 seconds... right?
// TODO: add timeout here, so that test never hangs...
// let timeout = Timeout::new(Duration::from_secs(5), &io_loop.handle());
let (stream, sender) = TcpClientStream::new(server_addr, io_loop.handle());
let (stream, mut sender) = TcpClientStream::new(server_addr, io_loop.handle());
let mut stream: TcpClientStream = io_loop.run(stream).ok().expect("run failed to get stream");

View File

@ -12,8 +12,8 @@ use std::io::{Read, Write};
use futures::{Async, Future, Poll};
use futures::stream::{Fuse, Peekable, Stream};
use futures::sync::mpsc::{unbounded, UnboundedReceiver};
use tokio_core::net::TcpStream as TokioTcpStream;
use tokio_core::channel::{channel, Receiver};
use tokio_core::reactor::{Handle};
use ::BufStreamHandle;
@ -32,7 +32,7 @@ enum ReadTcpState {
#[must_use = "futures do nothing unless polled"]
pub struct TcpStream {
socket: TokioTcpStream,
outbound_messages: Peekable<Fuse<Receiver<(Vec<u8>, SocketAddr)>>>,
outbound_messages: Peekable<Fuse<UnboundedReceiver<(Vec<u8>, SocketAddr)>>>,
send_state: Option<WriteTcpState>,
read_state: ReadTcpState,
}
@ -42,7 +42,7 @@ impl TcpStream {
/// new TcpClients such that each new client would have a random port (reduce chance of cache
/// poisoning)
pub fn new(name_server: SocketAddr, loop_handle: Handle) -> (Box<Future<Item=TcpStream, Error=io::Error>>, BufStreamHandle) {
let (message_sender, outbound_messages) = channel(&loop_handle).expect("somethings wrong with the event loop");
let (message_sender, outbound_messages) = unbounded();
let tcp = TokioTcpStream::connect(&name_server, &loop_handle);
// This set of futures collapses the next tcp socket into a stream which can be used for
@ -63,8 +63,8 @@ impl TcpStream {
/// Initializes a TcpStream with an existing tokio_core::net::TcpStream.
///
/// This is intended for use with a TcpListener and Incoming.
pub fn with_tcp_stream(stream: TokioTcpStream, loop_handle: Handle) -> (Self, BufStreamHandle) {
let (message_sender, outbound_messages) = channel(&loop_handle).expect("somethings wrong with the event loop");
pub fn with_tcp_stream(stream: TokioTcpStream) -> (Self, BufStreamHandle) {
let (message_sender, outbound_messages) = unbounded();
let stream = TcpStream {
socket: stream,
@ -137,7 +137,7 @@ impl Stream for TcpStream {
};
} else {
// then see if there is more to send
match try!(self.outbound_messages.poll()) {
match try!(self.outbound_messages.poll().map_err(|()| io::Error::new(io::ErrorKind::Other, "unknown"))) {
// already handled above, here to make sure the poll() pops the next message
Async::Ready(Some((buffer, dst))) => {
// if there is no peer, this connection should die...
@ -334,7 +334,7 @@ fn tcp_client_stream_test(server_addr: IpAddr) {
// the tests should run within 5 seconds... right?
// TODO: add timeout here, so that test never hangs...
// let timeout = Timeout::new(Duration::from_secs(5), &io_loop.handle());
let (stream, sender) = TcpStream::new(server_addr, io_loop.handle());
let (stream, mut sender) = TcpStream::new(server_addr, io_loop.handle());
let mut stream: TcpStream = io_loop.run(stream).ok().expect("run failed to get stream");

View File

@ -131,7 +131,7 @@ fn udp_client_stream_test(server_addr: IpAddr) {
// the tests should run within 5 seconds... right?
// TODO: add timeout here, so that test never hangs...
// let timeout = Timeout::new(Duration::from_secs(5), &io_loop.handle());
let (stream, sender) = UdpClientStream::new(server_addr, io_loop.handle());
let (stream, mut sender) = UdpClientStream::new(server_addr, io_loop.handle());
let mut stream: UdpClientStream = io_loop.run(stream).ok().unwrap();
for _ in 0..send_recv_times {

View File

@ -11,11 +11,11 @@ use std::io;
use futures::{Async, Future, Poll};
use futures::stream::{Fuse, Peekable, Stream};
use futures::sync::mpsc::{unbounded, UnboundedReceiver};
use futures::task::park;
use rand::Rng;
use rand;
use tokio_core;
use tokio_core::channel::{channel, Receiver};
use tokio_core::reactor::{Handle};
use ::BufStreamHandle;
@ -28,7 +28,7 @@ lazy_static!{
#[must_use = "futures do nothing unless polled"]
pub struct UdpStream {
socket: tokio_core::net::UdpSocket,
outbound_messages: Peekable<Fuse<Receiver<(Vec<u8>, SocketAddr)>>>,
outbound_messages: Peekable<Fuse<UnboundedReceiver<(Vec<u8>, SocketAddr)>>>,
}
impl UdpStream {
@ -47,7 +47,7 @@ impl UdpStream {
/// a tuple of a Future Stream which will handle sending and receiving messsages, and a
/// handle which can be used to send messages into the stream.
pub fn new(name_server: SocketAddr, loop_handle: Handle) -> (Box<Future<Item=UdpStream, Error=io::Error>>, BufStreamHandle) {
let (message_sender, outbound_messages) = channel(&loop_handle).expect("somethings wrong with the event loop");
let (message_sender, outbound_messages) = unbounded();
// TODO: allow the bind address to be specified...
// constructs a future for getting the next randomly bound port to a UdpSocket
@ -82,7 +82,7 @@ impl UdpStream {
/// a tuple of a Future Stream which will handle sending and receiving messsages, and a
/// handle which can be used to send messages into the stream.
pub fn with_bound(socket: std::net::UdpSocket, loop_handle: Handle) -> (Self, BufStreamHandle) {
let (message_sender, outbound_messages) = channel(&loop_handle).expect("somethings wrong with the event loop");
let (message_sender, outbound_messages) = unbounded();
// TODO: consider making this return a Result...
let socket = tokio_core::net::UdpSocket::from_socket(socket, &loop_handle).expect("could not register socket to loop");
@ -115,7 +115,7 @@ impl Stream for UdpStream {
// makes this self throttling.
loop {
// first try to send
match try!(self.outbound_messages.peek()) {
match try!(self.outbound_messages.peek().map_err(|()| io::Error::new(io::ErrorKind::Other, "unknown"))) {
Async::Ready(Some(&(ref buffer, addr))) => {
match self.socket.poll_write() {
Async::NotReady => {
@ -132,7 +132,7 @@ impl Stream for UdpStream {
}
// now pop the request and check if we should break or continue.
match try!(self.outbound_messages.poll()) {
match try!(self.outbound_messages.poll().map_err(|()| io::Error::new(io::ErrorKind::Other, "unknown"))) {
// already handled above, here to make sure the poll() pops the next message
Async::Ready(Some(_)) => (),
// now we get to drop through to the receives...
@ -261,7 +261,7 @@ fn udp_stream_test(server_addr: std::net::IpAddr) {
};
let socket = std::net::UdpSocket::bind(client_addr).expect("could not create socket"); // some random address...
let (mut stream, sender) = UdpStream::with_bound(socket, io_loop.handle());
let (mut stream, mut sender) = UdpStream::with_bound(socket, io_loop.handle());
//let mut stream: UdpStream = io_loop.run(stream).ok().unwrap();
for _ in 0..send_recv_times {

View File

@ -13,10 +13,10 @@ use std::cmp::Ordering;
use chrono::Duration;
use futures::{Async, Future, finished, Poll};
use futures::stream::{Fuse, Stream};
use futures::sync::mpsc::{unbounded, UnboundedReceiver};
use futures::task::park;
use openssl::crypto::rsa::RSA;
use tokio_core::reactor::{Core, Handle};
use tokio_core::channel::{channel, Receiver};
use tokio_core::reactor::Core;
use trust_dns::client::{ClientFuture, BasicClientHandle, ClientHandle, ClientStreamHandle};
use trust_dns::error::*;
@ -40,11 +40,11 @@ fn test_query_nonet() {
catalog.upsert(authority.get_origin().clone(), authority);
let mut io_loop = Core::new().unwrap();
let (stream, sender) = TestClientStream::new(catalog, io_loop.handle());
let client = ClientFuture::new(stream, sender, io_loop.handle(), None);
let (stream, sender) = TestClientStream::new(catalog);
let mut client = ClientFuture::new(stream, sender, io_loop.handle(), None);
io_loop.run(test_query(&client)).unwrap();
io_loop.run(test_query(&client)).unwrap();
io_loop.run(test_query(&mut client)).unwrap();
io_loop.run(test_query(&mut client)).unwrap();
}
#[test]
@ -56,11 +56,11 @@ fn test_query_udp_ipv4() {
let mut io_loop = Core::new().unwrap();
let addr: SocketAddr = ("8.8.8.8",53).to_socket_addrs().unwrap().next().unwrap();
let (stream, sender) = UdpClientStream::new(addr, io_loop.handle());
let client = ClientFuture::new(stream, sender, io_loop.handle(), None);
let mut client = ClientFuture::new(stream, sender, io_loop.handle(), None);
// TODO: timeouts on these requests so that the test doesn't hang
io_loop.run(test_query(&client)).unwrap();
io_loop.run(test_query(&client)).unwrap();
io_loop.run(test_query(&mut client)).unwrap();
io_loop.run(test_query(&mut client)).unwrap();
}
#[test]
@ -72,11 +72,11 @@ fn test_query_udp_ipv6() {
let mut io_loop = Core::new().unwrap();
let addr: SocketAddr = ("2001:4860:4860::8888",53).to_socket_addrs().unwrap().next().unwrap();
let (stream, sender) = UdpClientStream::new(addr, io_loop.handle());
let client = ClientFuture::new(stream, sender, io_loop.handle(), None);
let mut client = ClientFuture::new(stream, sender, io_loop.handle(), None);
// TODO: timeouts on these requests so that the test doesn't hang
io_loop.run(test_query(&client)).unwrap();
io_loop.run(test_query(&client)).unwrap();
io_loop.run(test_query(&mut client)).unwrap();
io_loop.run(test_query(&mut client)).unwrap();
}
#[test]
@ -88,11 +88,11 @@ fn test_query_tcp_ipv4() {
let mut io_loop = Core::new().unwrap();
let addr: SocketAddr = ("8.8.8.8",53).to_socket_addrs().unwrap().next().unwrap();
let (stream, sender) = TcpClientStream::new(addr, io_loop.handle());
let client = ClientFuture::new(stream, sender, io_loop.handle(), None);
let mut client = ClientFuture::new(stream, sender, io_loop.handle(), None);
// TODO: timeouts on these requests so that the test doesn't hang
io_loop.run(test_query(&client)).unwrap();
io_loop.run(test_query(&client)).unwrap();
io_loop.run(test_query(&mut client)).unwrap();
io_loop.run(test_query(&mut client)).unwrap();
}
#[test]
@ -104,15 +104,15 @@ fn test_query_tcp_ipv6() {
let mut io_loop = Core::new().unwrap();
let addr: SocketAddr = ("2001:4860:4860::8888",53).to_socket_addrs().unwrap().next().unwrap();
let (stream, sender) = TcpClientStream::new(addr, io_loop.handle());
let client = ClientFuture::new(stream, sender, io_loop.handle(), None);
let mut client = ClientFuture::new(stream, sender, io_loop.handle(), None);
// TODO: timeouts on these requests so that the test doesn't hang
io_loop.run(test_query(&client)).unwrap();
io_loop.run(test_query(&client)).unwrap();
io_loop.run(test_query(&mut client)).unwrap();
io_loop.run(test_query(&mut client)).unwrap();
}
#[cfg(test)]
fn test_query(client: &BasicClientHandle) -> Box<Future<Item=(), Error=()>> {
fn test_query(client: &mut BasicClientHandle) -> Box<Future<Item=(), Error=()>> {
let name = domain::Name::with_labels(vec!["WWW".to_string(), "example".to_string(), "com".to_string()]);
Box::new(client.query(name.clone(), DNSClass::IN, RecordType::A)
@ -164,7 +164,7 @@ fn create_sig0_ready_client(io_loop: &Core) -> (BasicClientHandle, domain::Name)
let mut catalog = Catalog::new();
catalog.upsert(authority.get_origin().clone(), authority);
let (stream, sender) = TestClientStream::new(catalog, io_loop.handle());
let (stream, sender) = TestClientStream::new(catalog);
let client = ClientFuture::new(stream, sender, io_loop.handle(), Some(signer));
(client, origin)
@ -173,7 +173,7 @@ fn create_sig0_ready_client(io_loop: &Core) -> (BasicClientHandle, domain::Name)
#[test]
fn test_create() {
let mut io_loop = Core::new().unwrap();
let (client, origin) = create_sig0_ready_client(&io_loop);
let (mut client, origin) = create_sig0_ready_client(&io_loop);
// create a record
let mut record = Record::with(domain::Name::with_labels(vec!["new".to_string(), "example".to_string(), "com".to_string()]),
@ -205,7 +205,7 @@ fn create_sig0_ready_client(io_loop: &Core) -> (BasicClientHandle, domain::Name)
#[test]
fn test_append() {
let mut io_loop = Core::new().unwrap();
let (client, origin) = create_sig0_ready_client(&io_loop);
let (mut client, origin) = create_sig0_ready_client(&io_loop);
// append a record
let mut record = Record::with(domain::Name::with_labels(vec!["new".to_string(), "example".to_string(), "com".to_string()]),
@ -253,7 +253,7 @@ fn create_sig0_ready_client(io_loop: &Core) -> (BasicClientHandle, domain::Name)
#[test]
fn test_compare_and_swap() {
let mut io_loop = Core::new().unwrap();
let (client, origin) = create_sig0_ready_client(&io_loop);
let (mut client, origin) = create_sig0_ready_client(&io_loop);
// create a record
let mut record = Record::with(domain::Name::with_labels(vec!["new".to_string(), "example".to_string(), "com".to_string()]),
@ -292,7 +292,7 @@ fn create_sig0_ready_client(io_loop: &Core) -> (BasicClientHandle, domain::Name)
#[test]
fn test_delete_by_rdata() {
let mut io_loop = Core::new().unwrap();
let (client, origin) = create_sig0_ready_client(&io_loop);
let (mut client, origin) = create_sig0_ready_client(&io_loop);
// append a record
let mut record = Record::with(domain::Name::with_labels(vec!["new".to_string(), "example".to_string(), "com".to_string()]),
@ -326,7 +326,7 @@ fn create_sig0_ready_client(io_loop: &Core) -> (BasicClientHandle, domain::Name)
#[test]
fn test_delete_rrset() {
let mut io_loop = Core::new().unwrap();
let (client, origin) = create_sig0_ready_client(&io_loop);
let (mut client, origin) = create_sig0_ready_client(&io_loop);
// append a record
let mut record = Record::with(domain::Name::with_labels(vec!["new".to_string(), "example".to_string(), "com".to_string()]),
@ -359,7 +359,7 @@ fn create_sig0_ready_client(io_loop: &Core) -> (BasicClientHandle, domain::Name)
#[test]
fn test_delete_all() {
let mut io_loop = Core::new().unwrap();
let (client, origin) = create_sig0_ready_client(&io_loop);
let (mut client, origin) = create_sig0_ready_client(&io_loop);
// append a record
let mut record = Record::with(domain::Name::with_labels(vec!["new".to_string(), "example".to_string(), "com".to_string()]),
@ -398,12 +398,12 @@ fn create_sig0_ready_client(io_loop: &Core) -> (BasicClientHandle, domain::Name)
// is no one listening to messages and shutdown...
#[allow(dead_code)]
pub struct NeverReturnsClientStream {
outbound_messages: Fuse<Receiver<Vec<u8>>>,
outbound_messages: Fuse<UnboundedReceiver<Vec<u8>>>,
}
impl NeverReturnsClientStream {
pub fn new(loop_handle: Handle) -> (Box<Future<Item=Self, Error=io::Error>>, Box<ClientStreamHandle>) {
let (message_sender, outbound_messages) = channel(&loop_handle).expect("somethings wrong with the event loop");
pub fn new() -> (Box<Future<Item=Self, Error=io::Error>>, Box<ClientStreamHandle>) {
let (message_sender, outbound_messages) = unbounded();
let stream: Box<Future<Item=NeverReturnsClientStream, Error=io::Error>> = Box::new(finished(
NeverReturnsClientStream {
@ -439,8 +439,8 @@ fn create_sig0_ready_client(io_loop: &Core) -> (BasicClientHandle, domain::Name)
catalog.upsert(authority.get_origin().clone(), authority);
let mut io_loop = Core::new().unwrap();
let (stream, sender) = NeverReturnsClientStream::new(io_loop.handle());
let client = ClientFuture::with_timeout(stream, sender, io_loop.handle(),
let (stream, sender) = NeverReturnsClientStream::new();
let mut client = ClientFuture::with_timeout(stream, sender, io_loop.handle(),
std::time::Duration::from_millis(1), None);
let name = domain::Name::with_labels(vec!["www".to_string(), "example".to_string(), "com".to_string()]);

View File

@ -3,9 +3,8 @@ use std::io;
use futures::{Async, Future, finished, Poll};
use futures::stream::{Fuse, Stream};
use futures::sync::mpsc::{unbounded, UnboundedReceiver};
use futures::task::park;
use tokio_core::reactor::*;
use tokio_core::channel::*;
use trust_dns::client::ClientStreamHandle;
use trust_dns::op::*;
@ -15,12 +14,12 @@ use trust_dns_server::authority::Catalog;
pub struct TestClientStream {
catalog: Catalog,
outbound_messages: Fuse<Receiver<Vec<u8>>>,
outbound_messages: Fuse<UnboundedReceiver<Vec<u8>>>,
}
impl TestClientStream {
pub fn new(catalog: Catalog, loop_handle: Handle) -> (Box<Future<Item=Self, Error=io::Error>>, Box<ClientStreamHandle>) {
let (message_sender, outbound_messages) = channel(&loop_handle).expect("somethings wrong with the event loop");
pub fn new(catalog: Catalog) -> (Box<Future<Item=Self, Error=io::Error>>, Box<ClientStreamHandle>) {
let (message_sender, outbound_messages) = unbounded();
let stream: Box<Future<Item=TestClientStream, Error=io::Error>> = Box::new(finished(
TestClientStream { catalog: catalog, outbound_messages: outbound_messages.fuse() }
@ -35,7 +34,7 @@ impl Stream for TestClientStream {
type Error = io::Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
match try!(self.outbound_messages.poll()) {
match try!(self.outbound_messages.poll().map_err(|_| io::Error::new(io::ErrorKind::Interrupted, "Server stopping due to interruption"))) {
// already handled above, here to make sure the poll() pops the next message
Async::Ready(Some(bytes)) => {
let mut decoder = BinDecoder::new(&bytes);

View File

@ -40,7 +40,7 @@ fn test_secure_query_example_tcp() {
with_tcp(test_secure_query_example);
}
fn test_secure_query_example<H>(client: SecureClientHandle<H>, mut io_loop: Core)
fn test_secure_query_example<H>(mut client: SecureClientHandle<H>, mut io_loop: Core)
where H: ClientHandle + 'static {
let name = domain::Name::with_labels(vec!["www".to_string(), "example".to_string(), "com".to_string()]);
let response = io_loop.run(client.query(name.clone(), DNSClass::IN, RecordType::A)).expect("query failed");
@ -78,7 +78,7 @@ fn test_nsec_query_example_tcp() {
with_tcp(test_nsec_query_example);
}
fn test_nsec_query_example<H>(client: SecureClientHandle<H>, mut io_loop: Core)
fn test_nsec_query_example<H>(mut client: SecureClientHandle<H>, mut io_loop: Core)
where H: ClientHandle + 'static {
let name = domain::Name::with_labels(vec!["none".to_string(), "example".to_string(), "com".to_string()]);
@ -104,7 +104,7 @@ fn test_nsec_query_type_tcp() {
with_tcp(test_nsec_query_type);
}
fn test_nsec_query_type<H>(client: SecureClientHandle<H>, mut io_loop: Core)
fn test_nsec_query_type<H>(mut client: SecureClientHandle<H>, mut io_loop: Core)
where H: ClientHandle + 'static {
let name = domain::Name::with_labels(vec!["www".to_string(), "example".to_string(), "com".to_string()]);
@ -132,7 +132,7 @@ fn test_dnssec_rollernet_td_tcp_mixed_case() {
with_tcp(dnssec_rollernet_td_mixed_case_test);
}
fn dnssec_rollernet_td_test<H>(client: SecureClientHandle<H>, mut io_loop: Core)
fn dnssec_rollernet_td_test<H>(mut client: SecureClientHandle<H>, mut io_loop: Core)
where H: ClientHandle + 'static {
let name = domain::Name::parse("rollernet.us.", None).unwrap();
@ -144,7 +144,7 @@ where H: ClientHandle + 'static {
assert!(response.get_answers().is_empty());
}
fn dnssec_rollernet_td_mixed_case_test<H>(client: SecureClientHandle<H>, mut io_loop: Core)
fn dnssec_rollernet_td_mixed_case_test<H>(mut client: SecureClientHandle<H>, mut io_loop: Core)
where H: ClientHandle + 'static {
let name = domain::Name::parse("RollErnet.Us.", None).unwrap();
@ -183,7 +183,7 @@ fn with_nonet<F>(test: F) where F: Fn(SecureClientHandle<MemoizeClientHandle<Bas
trust_anchor.insert_trust_anchor(public_key);
let io_loop = Core::new().unwrap();
let (stream, sender) = TestClientStream::new(catalog, io_loop.handle());
let (stream, sender) = TestClientStream::new(catalog);
let client = ClientFuture::new(stream, sender, io_loop.handle(), None);
let client = MemoizeClientHandle::new(client);
let secure_client = SecureClientHandle::with_trust_anchor(client, trust_anchor);

View File

@ -82,7 +82,7 @@ pub struct ResponseHandle {
impl ResponseHandle {
/// Serializes and sends a message to to the wrapped handle
pub fn send(&self, response: Message) -> io::Result<()> {
pub fn send(&mut self, response: Message) -> io::Result<()> {
debug!("sending message: {}", response.get_id());
let mut buffer = Vec::with_capacity(512);
let encode_result = {
@ -92,6 +92,6 @@ impl ResponseHandle {
try!(encode_result.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("error encoding message: {}", e))));
self.stream_handle.send((buffer, self.dst))
self.stream_handle.send((buffer, self.dst)).map_err(|_| io::Error::new(io::ErrorKind::Other, "unknown"))
}
}

View File

@ -78,7 +78,7 @@ impl ServerFuture {
.for_each(move |(tcp_stream, src_addr)| {
debug!("accepted request from: {}", src_addr);
// take the created stream...
let (buf_stream, stream_handle) = TcpStream::with_tcp_stream(tcp_stream, handle.clone());
let (buf_stream, stream_handle) = TcpStream::with_tcp_stream(tcp_stream);
let timeout_stream = try!(TimeoutStream::new(buf_stream, timeout, handle.clone()));
let request_stream = RequestStream::new(timeout_stream, stream_handle);
let catalog = catalog.clone();
@ -107,7 +107,7 @@ impl ServerFuture {
Err(io::Error::new(io::ErrorKind::Interrupted, "Server stopping due to interruption"))
}
fn handle_request(request: Request, response_handle: ResponseHandle, catalog: Arc<Catalog>) -> io::Result<()> {
fn handle_request(request: Request, mut response_handle: ResponseHandle, catalog: Arc<Catalog>) -> io::Result<()> {
let response = catalog.handle_request(&request.message);
response_handle.send(response)
}

View File

@ -103,7 +103,7 @@ fn named_test_harness<F, R>(toml: &str, test: F) where F: FnOnce(u16) -> R + Unw
// This only validates that a query to the server works, it shouldn't be used for more than this.
// i.e. more complex checks live with the clients and authorities to validate deeper funcionality
fn query(io_loop: &mut Core, client: BasicClientHandle) -> bool {
fn query(io_loop: &mut Core, client: &mut BasicClientHandle) -> bool {
let name = domain::Name::with_labels(vec!["www".to_string(), "example".to_string(), "com".to_string()]);
println!("sending request");
@ -132,16 +132,16 @@ fn test_example_toml_startup() {
let mut io_loop = Core::new().unwrap();
let addr: SocketAddr = ("127.0.0.1", port).to_socket_addrs().unwrap().next().unwrap();
let (stream, sender) = TcpClientStream::new(addr, io_loop.handle());
let client = ClientFuture::new(stream, sender, io_loop.handle(), None);
let mut client = ClientFuture::new(stream, sender, io_loop.handle(), None);
assert!(query(&mut io_loop, client));
assert!(query(&mut io_loop, &mut client));
// just tests that multiple queries work
let addr: SocketAddr = ("127.0.0.1", port).to_socket_addrs().unwrap().next().unwrap();
let (stream, sender) = TcpClientStream::new(addr, io_loop.handle());
let client = ClientFuture::new(stream, sender, io_loop.handle(), None);
let mut client = ClientFuture::new(stream, sender, io_loop.handle(), None);
assert!(query(&mut io_loop, client));
assert!(query(&mut io_loop, &mut client));
})
}
@ -151,17 +151,17 @@ fn test_ipv4_only_toml_startup() {
let mut io_loop = Core::new().unwrap();
let addr: SocketAddr = ("127.0.0.1", port).to_socket_addrs().unwrap().next().unwrap();
let (stream, sender) = TcpClientStream::new(addr, io_loop.handle());
let client = ClientFuture::new(stream, sender, io_loop.handle(), None);
let mut client = ClientFuture::new(stream, sender, io_loop.handle(), None);
// ipv4 should succeed
assert!(query(&mut io_loop, client));
assert!(query(&mut io_loop, &mut client));
let addr: SocketAddr = ("::1", port).to_socket_addrs().unwrap().next().unwrap();
let (stream, sender) = TcpClientStream::new(addr, io_loop.handle());
let client = ClientFuture::new(stream, sender, io_loop.handle(), None);
let mut client = ClientFuture::new(stream, sender, io_loop.handle(), None);
// ipv6 should fail
assert!(!query(&mut io_loop, client));
assert!(!query(&mut io_loop, &mut client));
})
}
@ -198,17 +198,17 @@ fn test_ipv4_and_ipv6_toml_startup() {
let mut io_loop = Core::new().unwrap();
let addr: SocketAddr = ("127.0.0.1", port).to_socket_addrs().unwrap().next().unwrap();
let (stream, sender) = TcpClientStream::new(addr, io_loop.handle());
let client = ClientFuture::new(stream, sender, io_loop.handle(), None);
let mut client = ClientFuture::new(stream, sender, io_loop.handle(), None);
// ipv4 should succeed
assert!(query(&mut io_loop, client));
assert!(query(&mut io_loop, &mut client));
let addr: SocketAddr = ("::1", port).to_socket_addrs().unwrap().next().unwrap();
let (stream, sender) = TcpClientStream::new(addr, io_loop.handle());
let client = ClientFuture::new(stream, sender, io_loop.handle(), None);
let mut client = ClientFuture::new(stream, sender, io_loop.handle(), None);
// ipv6 should succeed
assert!(query(&mut io_loop, client));
assert!(query(&mut io_loop, &mut client));
assert!(true);
})