diff --git a/common/src/serde/deserializer.rs b/common/src/serde/deserializer.rs index 909bd03..a28c9e3 100644 --- a/common/src/serde/deserializer.rs +++ b/common/src/serde/deserializer.rs @@ -14,6 +14,10 @@ impl<'a> Deserializer<'a> { } } + pub fn pos(&self) -> usize { + self.offset + } + pub fn read_bool(&mut self) -> bool { self.read_u8() != 0 } diff --git a/remote_send/src/mqtt/misc.rs b/remote_send/src/mqtt/misc.rs index 12c88a5..706376c 100644 --- a/remote_send/src/mqtt/misc.rs +++ b/remote_send/src/mqtt/misc.rs @@ -1,4 +1,37 @@ -use std::sync::atomic::{AtomicU64, Ordering}; +use std::{ + borrow::Cow, + sync::atomic::{AtomicU64, Ordering}, +}; + +use common::serde::{Deserializer, Serializer}; + +pub trait MqttDeserialize<'a> { + fn read_string(&mut self) -> Cow<'a, str>; +} + +pub trait MqttSerializer { + fn write_string(&mut self, data: &str); +} + +impl<'a> MqttDeserialize<'a> for Deserializer<'a> { + fn read_string(&mut self) -> Cow<'a, str> { + let len = self.read_u16(); + let buf = self.read_bytes(len as usize); + String::from_utf8_lossy(buf) + } +} + +impl MqttSerializer for T +where + T: Serializer, +{ + fn write_string(&mut self, data: &str) { + assert!(data.len() <= u16::MAX as usize); + let len = data.len() as u16; + self.write_u16(len); + self.write_bytes(data.as_bytes()); + } +} pub fn next_id() -> u64 { static NEXT_ID: AtomicU64 = AtomicU64::new(0); diff --git a/remote_send/src/mqtt/mod.rs b/remote_send/src/mqtt/mod.rs index 7c63af6..e1d92a0 100644 --- a/remote_send/src/mqtt/mod.rs +++ b/remote_send/src/mqtt/mod.rs @@ -1,19 +1,20 @@ use std::{ - borrow::Cow, collections::HashMap, - io::{Read, Write}, net::{TcpListener, TcpStream}, sync::Arc, thread, }; use anyhow::Result; -use common::serde::Deserializer; use misc::next_id; use packets::{ connect::ConnectPacket, connect_ack::{ConnectAckFlags, ConnectAckPacket, ConnectReturnCode}, + publish::{PublishFlags, PublishPacket}, + publish_ack::PublishAckPacket, subscribe::SubscribePacket, + subscribe_ack::{SubscribeAckPacket, SubscribeReturnCode}, + Packet, }; use parking_lot::{lock_api::MutexGuard, MappedMutexGuard, Mutex}; @@ -86,7 +87,7 @@ fn handle_client(server: Arc, client_id: u64, mut stream: TcpStream) match packet.packet_type { ConnectPacket::PACKET_TYPE => { - let packet = ConnectPacket::from_bytes(&packet.remaining_bytes)?; + let packet = ConnectPacket::from_packet(&packet)?; println!("Connect packet: {:?}", packet); ConnectAckPacket { @@ -97,13 +98,40 @@ fn handle_client(server: Arc, client_id: u64, mut stream: TcpStream) .write(&mut stream)?; } SubscribePacket::PACKET_TYPE => { - let packet = SubscribePacket::from_bytes(&packet.remaining_bytes)?; + let packet = SubscribePacket::from_packet(&packet)?; println!("Subscribe packet: {:?}", packet); + + let return_codes = packet + .filters + .iter() + .map(|(_topic, qos)| SubscribeReturnCode::Success(*qos)) + .collect(); + server .get_client_mut(client_id) .subscriptions .extend(packet.filters.into_iter().map(|x| x.0)); dbg!(&server.get_client_mut(client_id)); + + SubscribeAckPacket { + packet_id: packet.packet_id, + return_codes, + } + .to_packet() + .write(&mut stream)?; + } + PublishPacket::PACKET_TYPE => { + let packet = PublishPacket::from_packet(&packet)?; + println!("Publish packet: {:?}", packet); + println!("{}", String::from_utf8_lossy(&packet.data)); + + if packet.flags.contains(PublishFlags::QOS1) { + PublishAckPacket { + packet_id: packet.packet_id.unwrap(), + } + .to_packet() + .write(&mut stream)?; + } } 0x0E => { println!("Client disconnect: {client_id}"); @@ -116,73 +144,3 @@ fn handle_client(server: Arc, client_id: u64, mut stream: TcpStream) Ok(()) } - -pub struct Packet { - packet_type: u8, - flags: u8, - remaining_length: u32, - remaining_bytes: Vec, -} - -impl Packet { - fn write(&self, stream: &mut Stream) -> Result<()> { - let mut bytes = vec![self.packet_type << 4 | self.flags]; - let mut remaining_length = self.remaining_length; - loop { - let mut byte = (remaining_length % 128) as u8; - remaining_length /= 128; - if remaining_length > 0 { - byte |= 0x80; - } - bytes.push(byte); - if remaining_length == 0 { - break; - } - } - bytes.extend(self.remaining_bytes.iter()); - - stream.write_all(&bytes)?; - Ok(()) - } - - fn read(stream: &mut Stream) -> Result { - let mut header = [0; 2]; - stream.read_exact(&mut header)?; - - let (packet_type, flags) = (header[0] >> 4, header[0] & 0xF); - let mut multiplier = 1; - let mut remaining_length = 0; - let mut pos = 1; - loop { - let byte = header[pos]; - remaining_length += (byte & 0x7F) as u32 * multiplier; - multiplier *= 128; - pos += 1; - if byte & 0x80 == 0 { - break; - } - } - - let mut remaining_bytes = vec![0; remaining_length as usize]; - stream.read_exact(&mut remaining_bytes)?; - - Ok(Self { - packet_type, - flags, - remaining_length, - remaining_bytes, - }) - } -} - -trait MqttDeserialize<'a> { - fn read_string(&mut self) -> Cow<'a, str>; -} - -impl<'a> MqttDeserialize<'a> for Deserializer<'a> { - fn read_string(&mut self) -> Cow<'a, str> { - let len = self.read_u16(); - let buf = self.read_bytes(len as usize); - String::from_utf8_lossy(buf) - } -} diff --git a/remote_send/src/mqtt/packets/connect.rs b/remote_send/src/mqtt/packets/connect.rs index 5bfddbc..838a5f2 100644 --- a/remote_send/src/mqtt/packets/connect.rs +++ b/remote_send/src/mqtt/packets/connect.rs @@ -3,7 +3,9 @@ use bitflags::bitflags; use common::serde::Deserializer; -use crate::mqtt::MqttDeserialize; +use crate::mqtt::misc::MqttDeserialize; + +use super::Packet; #[derive(Debug)] pub struct ConnectPacket { @@ -35,8 +37,8 @@ bitflags! { impl ConnectPacket { pub const PACKET_TYPE: u8 = 0x01; - pub fn from_bytes(bytes: &[u8]) -> Result { - let mut des = Deserializer::new(bytes); + pub fn from_packet(packet: &Packet) -> Result { + let mut des = Deserializer::new(&packet.remaining_bytes); let protocol_name = des.read_string().into_owned(); let protocol_level = des.read_u8(); diff --git a/remote_send/src/mqtt/packets/mod.rs b/remote_send/src/mqtt/packets/mod.rs index d310ea6..0f6a05d 100644 --- a/remote_send/src/mqtt/packets/mod.rs +++ b/remote_send/src/mqtt/packets/mod.rs @@ -1,3 +1,76 @@ +use std::io::{Read, Write}; + +use anyhow::Result; + pub mod connect; pub mod connect_ack; +pub mod publish; +pub mod publish_ack; pub mod subscribe; +pub mod subscribe_ack; + +pub struct Packet { + pub packet_type: u8, + pub flags: u8, + pub remaining_length: u32, + pub remaining_bytes: Vec, +} + +#[derive(Debug, Clone, Copy)] +pub struct QoS(pub u8); + +impl Packet { + pub fn write(&self, stream: &mut Stream) -> Result<()> { + let mut bytes = vec![self.packet_type << 4 | self.flags]; + let mut remaining_length = self.remaining_length; + loop { + let mut byte = (remaining_length % 128) as u8; + remaining_length /= 128; + if remaining_length > 0 { + byte |= 0x80; + } + bytes.push(byte); + if remaining_length == 0 { + break; + } + } + bytes.extend(self.remaining_bytes.iter()); + + stream.write_all(&bytes)?; + Ok(()) + } + + pub fn read(stream: &mut Stream) -> Result { + let mut header = vec![0; 2]; + stream.read_exact(&mut header)?; + + let (packet_type, flags) = (header[0] >> 4, header[0] & 0xF); + let mut multiplier = 1; + let mut remaining_length = 0; + let mut pos = 1; + loop { + if pos == header.len() { + header.resize(header.len() + 1, 0); + stream.read_exact(&mut header[pos..])?; + } + + let byte = header[pos]; + remaining_length += (byte & 0x7F) as u32 * multiplier; + multiplier *= 128; + pos += 1; + if byte & 0x80 == 0 { + break; + } + } + + let mut remaining_bytes = vec![0; remaining_length as usize]; + stream.read_exact(&mut remaining_bytes)?; + + Ok(Self { + packet_type, + flags, + remaining_length, + remaining_bytes, + }) + } +} diff --git a/remote_send/src/mqtt/packets/publish.rs b/remote_send/src/mqtt/packets/publish.rs new file mode 100644 index 0000000..dcf3622 --- /dev/null +++ b/remote_send/src/mqtt/packets/publish.rs @@ -0,0 +1,65 @@ +use anyhow::Result; +use bitflags::bitflags; + +use common::serde::{Deserializer, DynamicSerializer, Serializer}; + +use crate::mqtt::misc::{MqttDeserialize, MqttSerializer}; + +use super::Packet; + +#[derive(Debug)] +pub struct PublishPacket { + pub flags: PublishFlags, + pub topic: String, + pub packet_id: Option, + pub data: Vec, +} + +bitflags! { + #[derive(Debug)] + pub struct PublishFlags: u8 { + const DUP = 0b1000; + const QOS1 = 0b0010; + const QOS2 = 0b0100; + const RETAIN = 0b0001; + } +} + +impl PublishPacket { + pub const PACKET_TYPE: u8 = 0x03; + + pub fn from_packet(packet: &Packet) -> Result { + assert_eq!(packet.packet_type, Self::PACKET_TYPE); + let mut des = Deserializer::new(&packet.remaining_bytes); + + let flags = PublishFlags::from_bits(packet.flags).unwrap(); + let topic = des.read_string().into_owned(); + + let packet_id = (flags.contains(PublishFlags::QOS1) || flags.contains(PublishFlags::QOS2)) + .then(|| des.read_u16()); + let data = des + .read_bytes(packet.remaining_length as usize - des.pos()) + .to_vec(); + + Ok(Self { + flags, + topic, + packet_id, + data, + }) + } + + pub fn to_packet(&self) -> Packet { + let mut ser = DynamicSerializer::new(); + ser.write_string(&self.topic); + ser.write_bytes(&self.data); + + let data = ser.into_inner(); + Packet { + packet_type: Self::PACKET_TYPE, + flags: self.flags.bits(), + remaining_length: data.len() as u32, + remaining_bytes: data, + } + } +} diff --git a/remote_send/src/mqtt/packets/publish_ack.rs b/remote_send/src/mqtt/packets/publish_ack.rs new file mode 100644 index 0000000..0ff5dd4 --- /dev/null +++ b/remote_send/src/mqtt/packets/publish_ack.rs @@ -0,0 +1,24 @@ +use common::serde::{DynamicSerializer, Serializer}; + +use super::Packet; + +pub struct PublishAckPacket { + pub packet_id: u16, +} + +impl PublishAckPacket { + pub const PACKET_TYPE: u8 = 0x04; + + pub fn to_packet(&self) -> Packet { + let mut ser = DynamicSerializer::new(); + ser.write_u16(self.packet_id); + + let data = ser.into_inner(); + Packet { + packet_type: Self::PACKET_TYPE, + flags: 0, + remaining_length: data.len() as u32, + remaining_bytes: data, + } + } +} diff --git a/remote_send/src/mqtt/packets/subscribe.rs b/remote_send/src/mqtt/packets/subscribe.rs index 2e22155..a67f0e3 100644 --- a/remote_send/src/mqtt/packets/subscribe.rs +++ b/remote_send/src/mqtt/packets/subscribe.rs @@ -2,7 +2,8 @@ use anyhow::Result; use common::serde::Deserializer; -use crate::mqtt::MqttDeserialize; +use super::{Packet, QoS}; +use crate::mqtt::misc::MqttDeserialize; #[derive(Debug)] pub struct SubscribePacket { @@ -10,14 +11,11 @@ pub struct SubscribePacket { pub filters: Vec<(String, QoS)>, } -#[derive(Debug)] -pub struct QoS(pub u8); - impl SubscribePacket { pub const PACKET_TYPE: u8 = 0x08; - pub fn from_bytes(bytes: &[u8]) -> Result { - let mut des = Deserializer::new(bytes); + pub fn from_packet(packet: &Packet) -> Result { + let mut des = Deserializer::new(&packet.remaining_bytes); let packet_id = des.read_u16(); let mut filters = Vec::new(); diff --git a/remote_send/src/mqtt/packets/subscribe_ack.rs b/remote_send/src/mqtt/packets/subscribe_ack.rs new file mode 100644 index 0000000..c41954b --- /dev/null +++ b/remote_send/src/mqtt/packets/subscribe_ack.rs @@ -0,0 +1,38 @@ +use common::serde::{DynamicSerializer, Serializer}; + +use crate::mqtt::Packet; + +use super::QoS; + +pub struct SubscribeAckPacket { + pub packet_id: u16, + pub return_codes: Vec, +} + +pub enum SubscribeReturnCode { + Success(QoS), + Failure, +} + +impl SubscribeAckPacket { + pub const PACKET_TYPE: u8 = 0x09; + + pub fn to_packet(&self) -> Packet { + let mut ser = DynamicSerializer::new(); + ser.write_u16(self.packet_id); + for return_code in &self.return_codes { + match return_code { + SubscribeReturnCode::Failure => ser.write_u8(0x80), + SubscribeReturnCode::Success(qos) => ser.write_u8(qos.0), + } + } + + let data = ser.into_inner(); + Packet { + packet_type: Self::PACKET_TYPE, + flags: 0, + remaining_length: data.len() as u32, + remaining_bytes: data, + } + } +}