From 639498c78d4c7d80859484a8bc6fffd28dfe455f Mon Sep 17 00:00:00 2001 From: Connor Slade Date: Tue, 25 Jun 2024 18:42:26 -0400 Subject: [PATCH] Match initial query with mqtt client --- Cargo.lock | 6 + remote_send/Cargo.toml | 11 +- remote_send/src/commands/mod.rs | 50 ++++ remote_send/src/lib.rs | 42 +--- remote_send/src/main.rs | 132 +---------- remote_send/src/mqtt/mod.rs | 63 ++--- remote_send/src/mqtt/packets/subscribe_ack.rs | 1 + remote_send/src/mqtt_server.rs | 222 ++++++++++++++++++ 8 files changed, 321 insertions(+), 206 deletions(-) create mode 100644 remote_send/src/commands/mod.rs create mode 100644 remote_send/src/mqtt_server.rs diff --git a/Cargo.lock b/Cargo.lock index 39efcea..b3672a3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3232,6 +3232,7 @@ dependencies = [ "rand", "serde", "serde_json", + "soon", ] [[package]] @@ -3545,6 +3546,11 @@ dependencies = [ "winapi", ] +[[package]] +name = "soon" +version = "0.1.0" +source = "git+https://github.com/connorslade/misc#0e48a0c5652e7f58481074690ded4e0b9aa24a9a" + [[package]] name = "spin" version = "0.9.8" diff --git a/remote_send/Cargo.toml b/remote_send/Cargo.toml index 0bec9b7..4939af2 100644 --- a/remote_send/Cargo.toml +++ b/remote_send/Cargo.toml @@ -5,11 +5,12 @@ edition = "2021" [dependencies] anyhow = "1.0.86" -serde = { version = "1.0.203", features = ["derive"] } -serde_json = "1.0.117" - -common = { path = "../common" } bitflags = "2.5.0" +chrono = "0.4.38" parking_lot = "0.12.3" rand = "0.8.5" -chrono = "0.4.38" +serde = { version = "1.0.203", features = ["derive"] } +serde_json = "1.0.117" +soon = { git = "https://github.com/connorslade/misc" } + +common = { path = "../common" } diff --git a/remote_send/src/commands/mod.rs b/remote_send/src/commands/mod.rs new file mode 100644 index 0000000..80c0b7d --- /dev/null +++ b/remote_send/src/commands/mod.rs @@ -0,0 +1,50 @@ +use chrono::Utc; +use serde::Serialize; + +#[derive(Debug, Serialize)] +#[serde(rename_all = "PascalCase")] +pub struct Command { + pub cmd: u16, + pub data: Data, + pub from: u8, + #[serde(rename = "MainboardID")] + pub mainboard_id: String, + #[serde(rename = "RequestID")] + pub request_id: String, + #[serde(rename = "TimeStamp")] + pub time_stamp: u64, +} + +pub trait CommandTrait: Serialize { + const CMD: u16; +} + +impl Command { + pub fn new(cmd: u16, data: Data, mainboard_id: String) -> Self { + let request_id = format!("{:x}", rand::random::()); + let time_stamp = Utc::now().timestamp_millis() as u64; + + Self { + cmd, + data, + from: 0, + mainboard_id, + request_id, + time_stamp, + } + } +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "PascalCase")] +pub struct StartPrinting { + pub filename: String, + pub start_layer: u32, +} + +#[derive(Serialize)] +pub struct DisconnectCommand; + +impl CommandTrait for DisconnectCommand { + const CMD: u16 = 64; +} diff --git a/remote_send/src/lib.rs b/remote_send/src/lib.rs index be984f8..1a9bff2 100644 --- a/remote_send/src/lib.rs +++ b/remote_send/src/lib.rs @@ -1,54 +1,18 @@ use anyhow::Result; -use chrono::Utc; use serde::{Deserialize, Deserializer, Serialize}; +pub mod commands; pub mod mqtt; +pub mod mqtt_server; pub mod status; -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "PascalCase")] pub struct Response { pub id: String, pub data: Data, } -#[derive(Debug, Serialize)] -#[serde(rename_all = "PascalCase")] -pub struct Command { - pub cmd: u16, - pub data: Data, - pub from: u8, - #[serde(rename = "MainboardID")] - pub mainboard_id: String, - #[serde(rename = "RequestID")] - pub request_id: String, - #[serde(rename = "TimeStamp")] - pub time_stamp: u64, -} - -#[derive(Debug, Serialize)] -#[serde(rename_all = "PascalCase")] -pub struct StartPrinting { - pub filename: String, - pub start_layer: u32, -} - -impl Command { - pub fn new(cmd: u16, data: Data, mainboard_id: String) -> Self { - let request_id = format!("{:x}", rand::random::()); - let time_stamp = Utc::now().timestamp_millis() as u64; - - Self { - cmd, - data, - from: 0, - mainboard_id, - request_id, - time_stamp, - } - } -} - #[derive(Debug)] pub struct Resolution { pub x: u16, diff --git a/remote_send/src/main.rs b/remote_send/src/main.rs index ba3f7e8..4e2e2cb 100644 --- a/remote_send/src/main.rs +++ b/remote_send/src/main.rs @@ -1,120 +1,13 @@ +use std::{net::UdpSocket, thread}; + use anyhow::Result; -use serde::Serialize; -use std::net::UdpSocket; - -use remote_send::{ - mqtt::{ - packets::{ - connect::ConnectPacket, - connect_ack::{ConnectAckFlags, ConnectAckPacket, ConnectReturnCode}, - publish::{PublishFlags, PublishPacket}, - publish_ack::PublishAckPacket, - subscribe::SubscribePacket, - subscribe_ack::{SubscribeAckPacket, SubscribeReturnCode}, - }, - HandlerCtx, MqttHandler, MqttServer, - }, - status::{Attributes, FullStatusData, StatusData}, - Command, Response, -}; - -struct Mqtt { - // todo: must support multiple clients - status: Attributes, - id: String, -} - -impl MqttHandler for Mqtt { - fn on_connect( - &self, - ctx: &HandlerCtx, - _packet: ConnectPacket, - ) -> Result { - println!("Client `{}` connected", ctx.client_id); - - Ok(ConnectAckPacket { - flags: ConnectAckFlags::empty(), - return_code: ConnectReturnCode::Accepted, - }) - } - - fn on_subscribe( - &self, - ctx: &HandlerCtx, - packet: SubscribePacket, - ) -> Result { - println!( - "Client `{}` subscribed to topics: {:?}", - ctx.client_id, packet.filters - ); - - Ok(SubscribeAckPacket { - packet_id: packet.packet_id, - return_codes: packet - .filters - .iter() - .map(|(_, qos)| SubscribeReturnCode::Success(*qos)) - .collect::>(), - }) - } - - fn on_publish(&self, ctx: &HandlerCtx, packet: PublishPacket) -> Result<()> { - println!( - "Client `{}` published to topic `{}`", - ctx.client_id, packet.topic - ); - - if let Some(board_id) = packet.topic.strip_prefix("/sdcp/status/") { - let status = serde_json::from_slice::>(&packet.data)?; - println!("Got status from `{}`", board_id); - println!("{:?}", status); - } else if let Some(board_id) = packet.topic.strip_prefix("/sdcp/response/") { - println!("Got command response from `{}`", board_id); - println!("{:?}", String::from_utf8_lossy(&packet.data)); - } - - Ok(()) - } - - fn on_publish_ack(&self, _ctx: &HandlerCtx, _packet: PublishAckPacket) -> Result<()> { - Ok(()) - } - - fn on_disconnect(&self, _ctx: &HandlerCtx) -> Result<()> { - Ok(()) - } -} - -impl Mqtt { - fn send_command( - &self, - ctx: &HandlerCtx, - cmd: u16, - command: Data, - ) -> Result<()> { - let id = ctx.next_packet_id(); - - let data = Command::new(cmd, command, self.id.to_owned()); - let data = serde_json::to_vec(&data).unwrap(); - - ctx.server - .send_packet( - ctx.client_id, - PublishPacket { - flags: PublishFlags::QOS1, - topic: format!("/sdcp/request/{}", self.status.mainboard_id), - packet_id: Some(id), - data, - } - .to_packet(), - ) - .unwrap(); - - Ok(()) - } -} +use remote_send::{mqtt::MqttServer, mqtt_server::Mqtt, status::FullStatusData, Response}; fn main() -> Result<()> { + let mqtt = Mqtt::new(); + let mqtt_inner = mqtt.clone(); + MqttServer::new(mqtt).start_async()?; + let socket = UdpSocket::bind("0.0.0.0:3000")?; let msg = b"M99999"; @@ -129,15 +22,12 @@ fn main() -> Result<()> { "Got status from `{}`", response.data.attributes.machine_name ); - - MqttServer::new(Mqtt { - status: response.data.attributes, - id: response.id, - }) - .start_async()?; + mqtt_inner.add_future_client(response.data.attributes, response.id); let msg = b"M66666 1883"; socket.send_to(msg, "192.168.1.233:3000")?; - Ok(()) + loop { + thread::park() + } } diff --git a/remote_send/src/mqtt/mod.rs b/remote_send/src/mqtt/mod.rs index 35223a9..d270a0d 100644 --- a/remote_send/src/mqtt/mod.rs +++ b/remote_send/src/mqtt/mod.rs @@ -1,10 +1,7 @@ use std::{ collections::HashMap, net::{TcpListener, TcpStream}, - sync::{ - atomic::{AtomicU16, Ordering}, - Arc, - }, + sync::Arc, thread, }; @@ -20,6 +17,7 @@ use packets::{ Packet, }; use parking_lot::{lock_api::MutexGuard, MappedMutexGuard, Mutex}; +use soon::Soon; mod misc; pub mod packets; @@ -27,38 +25,33 @@ pub mod packets; pub struct MqttServer { listeners: Mutex>, clients: Mutex>, - handler: H, -} - -pub struct HandlerCtx { - pub server: Arc>, - pub client_id: u64, - next_packet_id: AtomicU16, + handler: Soon, } pub trait MqttHandler where Self: Sized, { - fn on_connect(&self, ctx: &HandlerCtx, packet: ConnectPacket) - -> Result; - fn on_subscribe( - &self, - ctx: &HandlerCtx, - packet: SubscribePacket, - ) -> Result; - fn on_publish(&self, ctx: &HandlerCtx, packet: PublishPacket) -> Result<()>; - fn on_publish_ack(&self, ctx: &HandlerCtx, packet: PublishAckPacket) -> Result<()>; - fn on_disconnect(&self, ctx: &HandlerCtx) -> Result<()>; + fn init(&self, server: Arc>); + fn on_connect(&self, client_id: u64, packet: ConnectPacket) -> Result; + fn on_subscribe(&self, client_id: u64, packet: SubscribePacket) -> Result; + fn on_publish(&self, client_id: u64, packet: PublishPacket) -> Result<()>; + fn on_publish_ack(&self, client_id: u64, packet: PublishAckPacket) -> Result<()>; + fn on_disconnect(&self, client_id: u64) -> Result<()>; } impl MqttServer { pub fn new(handler: H) -> Arc { - Arc::new(Self { + let this = Arc::new(Self { listeners: Mutex::new(None), clients: Mutex::new(HashMap::new()), - handler, - }) + handler: Soon::empty(), + }); + + handler.init(this.clone()); + this.handler.replace(handler); + + this } pub fn start_async(self: Arc) -> Result<()> { @@ -102,22 +95,10 @@ impl MqttServer { } } -impl HandlerCtx { - pub fn next_packet_id(&self) -> u16 { - self.next_packet_id.fetch_add(1, Ordering::Relaxed) - } -} - fn handle_client(server: Arc>, client_id: u64, mut stream: TcpStream) -> Result<()> where H: MqttHandler + Send + Sync + 'static, { - let ctx = HandlerCtx { - server: server.clone(), - client_id, - next_packet_id: AtomicU16::new(0), - }; - loop { let packet = Packet::read(&mut stream)?; @@ -127,7 +108,7 @@ where server .handler - .on_connect(&ctx, packet)? + .on_connect(client_id, packet)? .to_packet() .write(&mut stream)?; } @@ -136,7 +117,7 @@ where server .handler - .on_subscribe(&ctx, packet)? + .on_subscribe(client_id, packet)? .to_packet() .write(&mut stream)?; } @@ -147,7 +128,7 @@ where .contains(PublishFlags::QOS1) .then(|| packet.packet_id.unwrap()); - server.handler.on_publish(&ctx, packet)?; + server.handler.on_publish(client_id, packet)?; if let Some(packet_id) = packet_id { PublishAckPacket { packet_id } @@ -157,10 +138,10 @@ where } PublishAckPacket::PACKET_TYPE => { let packet = PublishAckPacket::from_packet(&packet)?; - server.handler.on_publish_ack(&ctx, packet)?; + server.handler.on_publish_ack(client_id, packet)?; } 0x0E => { - server.handler.on_disconnect(&ctx)?; + server.handler.on_disconnect(client_id)?; server.remove_client(client_id); break; } diff --git a/remote_send/src/mqtt/packets/subscribe_ack.rs b/remote_send/src/mqtt/packets/subscribe_ack.rs index c41954b..c5f1899 100644 --- a/remote_send/src/mqtt/packets/subscribe_ack.rs +++ b/remote_send/src/mqtt/packets/subscribe_ack.rs @@ -9,6 +9,7 @@ pub struct SubscribeAckPacket { pub return_codes: Vec, } +#[derive(Clone, Copy)] pub enum SubscribeReturnCode { Success(QoS), Failure, diff --git a/remote_send/src/mqtt_server.rs b/remote_send/src/mqtt_server.rs new file mode 100644 index 0000000..7de0d86 --- /dev/null +++ b/remote_send/src/mqtt_server.rs @@ -0,0 +1,222 @@ +use std::{ + collections::HashMap, + ops::Deref, + sync::{ + atomic::{AtomicU16, Ordering}, + Arc, + }, +}; + +use anyhow::Result; +use parking_lot::RwLock; +use soon::Soon; + +use crate::{ + commands::{Command, CommandTrait, DisconnectCommand}, + mqtt::{ + packets::{ + connect::ConnectPacket, + connect_ack::{ConnectAckFlags, ConnectAckPacket, ConnectReturnCode}, + publish::{PublishFlags, PublishPacket}, + publish_ack::PublishAckPacket, + subscribe::SubscribePacket, + subscribe_ack::{SubscribeAckPacket, SubscribeReturnCode}, + }, + MqttHandler, MqttServer, + }, + status::{Attributes, StatusData}, + Response, +}; + +pub struct MqttInner { + server: Soon>>, + // mainboard_id -> MqttClient + clients: RwLock>, + // client_id -> mainboard_id + client_ids: RwLock>, +} + +pub struct Mqtt { + inner: Arc, +} + +struct MqttClient { + status: Attributes, + machine_id: String, + client_id: Option, + next_packet_id: AtomicU16, +} + +impl MqttHandler for Mqtt { + fn init(&self, server: Arc>) { + self.server.replace(server); + } + + fn on_connect(&self, client_id: u64, _packet: ConnectPacket) -> Result { + println!("Client `{client_id}` connected"); + + Ok(ConnectAckPacket { + flags: ConnectAckFlags::empty(), + return_code: ConnectReturnCode::Accepted, + }) + } + + fn on_subscribe(&self, client_id: u64, packet: SubscribePacket) -> Result { + println!( + "Client `{client_id}` subscribed to topics: {:?}", + packet.filters + ); + + let mut return_codes = vec![SubscribeReturnCode::Failure; packet.filters.len()]; + if let Some((idx, mainboard_id, qos)) = + packet + .filters + .iter() + .enumerate() + .find_map(|(idx, (topic, qos))| { + topic.strip_prefix("/sdcp/request/").map(|x| (idx, x, qos)) + }) + { + if self.clients.read().get(mainboard_id).is_none() { + eprintln!("Client `{mainboard_id}` does not exist."); + return Ok(SubscribeAckPacket { + packet_id: packet.packet_id, + return_codes, + }); + } + + return_codes[idx] = SubscribeReturnCode::Success(*qos); + self.client_ids + .write() + .insert(client_id, mainboard_id.to_owned()); + } + + Ok(SubscribeAckPacket { + packet_id: packet.packet_id, + return_codes, + }) + } + + fn on_publish(&self, client_id: u64, packet: PublishPacket) -> Result<()> { + println!("Client `{client_id}` published to topic `{}`", packet.topic); + + if let Some(board_id) = packet.topic.strip_prefix("/sdcp/status/") { + let status = serde_json::from_slice::>(&packet.data)?; + println!("Got status from `{}`", board_id); + println!("{:?}", status); + } else if let Some(board_id) = packet.topic.strip_prefix("/sdcp/response/") { + println!("Got command response from `{}`", board_id); + println!("{:?}", String::from_utf8_lossy(&packet.data)); + } + + Ok(()) + } + + fn on_publish_ack(&self, _client_id: u64, _packet: PublishAckPacket) -> Result<()> { + Ok(()) + } + + fn on_disconnect(&self, client_id: u64) -> Result<()> { + let machine_id = self.client_ids.write().remove(&client_id); + if let Some(machine_id) = machine_id { + self.clients.write().remove(&machine_id); + println!("Client `{machine_id}` disconnected"); + } + Ok(()) + } +} + +impl Mqtt { + pub fn new() -> Self { + Self { + inner: Arc::new(MqttInner { + server: Soon::empty(), + clients: RwLock::new(HashMap::new()), + client_ids: RwLock::new(HashMap::new()), + }), + } + } + + pub fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } + + pub fn send_command( + &self, + mainboard_id: &str, + command: Data, + ) -> Result<()> { + let clients = self.clients.read(); + let client = clients.get(mainboard_id).unwrap(); + let packet_id = client.next_id(); + + let Some(client_id) = client.client_id else { + eprintln!("Client `{mainboard_id}` is not connected. Command not sent."); + return Ok(()); + }; + + let data = Response { + data: Command::new(Data::CMD, command, client.machine_id.to_owned()), + id: String::new(), + }; + let data = serde_json::to_vec(&data).unwrap(); + + self.server + .send_packet( + client_id, + PublishPacket { + flags: PublishFlags::QOS1, + topic: format!("/sdcp/request/{}", client.status.mainboard_id), + packet_id: Some(packet_id), + data, + } + .to_packet(), + ) + .unwrap(); + + Ok(()) + } + + pub fn add_future_client(&self, attributes: Attributes, machine_id: String) { + if self.clients.read().contains_key(&attributes.mainboard_id) { + println!("Client `{}` already exists.", attributes.mainboard_id); + return; + } + + let mainboard_id = attributes.mainboard_id.clone(); + let client = MqttClient { + status: attributes, + machine_id: machine_id.clone(), + client_id: None, + next_packet_id: AtomicU16::new(0), + }; + + let mut clients = self.clients.write(); + clients.insert(mainboard_id, client); + } +} + +impl MqttClient { + fn next_id(&self) -> u16 { + self.next_packet_id.fetch_add(1, Ordering::Relaxed) + } +} + +impl Deref for Mqtt { + type Target = MqttInner; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl Drop for Mqtt { + fn drop(&mut self) { + for mainboard_id in self.clients.read().keys() { + println!("Disconnecting `{mainboard_id}`"); + let _ = self.send_command(mainboard_id, DisconnectCommand); + } + } +}