diff --git a/Cargo.lock b/Cargo.lock index 623a9ce..f3f2d3b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3184,6 +3184,7 @@ dependencies = [ "anyhow", "bitflags 2.5.0", "common", + "parking_lot", "serde", "serde_json", ] diff --git a/remote_send/Cargo.toml b/remote_send/Cargo.toml index 7927edd..2334fe4 100644 --- a/remote_send/Cargo.toml +++ b/remote_send/Cargo.toml @@ -10,3 +10,4 @@ serde_json = "1.0.117" common = { path = "../common" } bitflags = "2.5.0" +parking_lot = "0.12.3" diff --git a/remote_send/src/main.rs b/remote_send/src/main.rs index 5345e8e..90947ee 100644 --- a/remote_send/src/main.rs +++ b/remote_send/src/main.rs @@ -1,13 +1,10 @@ -use std::{net::UdpSocket, thread}; - use anyhow::Result; +use std::net::UdpSocket; -use remote_send::{mqtt, status::StatusData, Response}; +use remote_send::{mqtt::MqttServer, status::StatusData, Response}; fn main() -> Result<()> { - thread::spawn(|| { - mqtt::start().unwrap(); - }); + MqttServer::new().start_async()?; let socket = UdpSocket::bind("0.0.0.0:3000")?; diff --git a/remote_send/src/mqtt/misc.rs b/remote_send/src/mqtt/misc.rs new file mode 100644 index 0000000..12c88a5 --- /dev/null +++ b/remote_send/src/mqtt/misc.rs @@ -0,0 +1,6 @@ +use std::sync::atomic::{AtomicU64, Ordering}; + +pub fn next_id() -> u64 { + static NEXT_ID: AtomicU64 = AtomicU64::new(0); + NEXT_ID.fetch_add(1, Ordering::Relaxed) +} diff --git a/remote_send/src/mqtt/mod.rs b/remote_send/src/mqtt/mod.rs index 47d4910..7c63af6 100644 --- a/remote_send/src/mqtt/mod.rs +++ b/remote_send/src/mqtt/mod.rs @@ -1,37 +1,86 @@ use std::{ borrow::Cow, + collections::HashMap, io::{Read, Write}, net::{TcpListener, TcpStream}, + sync::Arc, thread, }; use anyhow::Result; use common::serde::Deserializer; +use misc::next_id; use packets::{ connect::ConnectPacket, connect_ack::{ConnectAckFlags, ConnectAckPacket, ConnectReturnCode}, subscribe::SubscribePacket, }; +use parking_lot::{lock_api::MutexGuard, MappedMutexGuard, Mutex}; +mod misc; pub mod packets; -pub fn start() -> Result<()> { - let socket = TcpListener::bind("0.0.0.0:1883")?; - - for stream in socket.incoming() { - let stream = stream?; - println!("Connection established: {:?}", stream); - thread::spawn(|| { - if let Err(e) = handle_client(stream) { - eprintln!("Error handling client: {:?}", e); - } - }); - } - - Ok(()) +pub struct MqttServer { + clients: Mutex>, } -fn handle_client(mut stream: TcpStream) -> Result<()> { +#[derive(Debug)] +pub struct MqttClient { + stream: TcpStream, + subscriptions: Vec, +} + +impl MqttServer { + pub fn new() -> Arc { + Arc::new(Self { + clients: Mutex::new(HashMap::new()), + }) + } + + pub fn start_async(self: Arc) -> Result<()> { + let socket = TcpListener::bind("0.0.0.0:1883")?; + + thread::spawn(move || { + for stream in socket.incoming() { + let stream = stream.unwrap(); + let client_id = next_id(); + self.clients + .lock() + .insert(client_id, MqttClient::new(stream.try_clone().unwrap())); + + println!("Connection established: {:?}", stream); + + let this_self = self.clone(); + thread::spawn(move || { + if let Err(e) = handle_client(this_self, client_id, stream) { + eprintln!("Error handling client: {:?}", e); + } + }); + } + }); + + Ok(()) + } + + fn get_client_mut(&self, client_id: u64) -> MappedMutexGuard { + MutexGuard::map(self.clients.lock(), |x| x.get_mut(&client_id).unwrap()) + } + + fn remove_client(&self, client_id: u64) { + self.clients.lock().remove(&client_id); + } +} + +impl MqttClient { + fn new(stream: TcpStream) -> Self { + Self { + stream, + subscriptions: Vec::new(), + } + } +} + +fn handle_client(server: Arc, client_id: u64, mut stream: TcpStream) -> Result<()> { loop { let packet = Packet::read(&mut stream)?; @@ -50,11 +99,22 @@ fn handle_client(mut stream: TcpStream) -> Result<()> { SubscribePacket::PACKET_TYPE => { let packet = SubscribePacket::from_bytes(&packet.remaining_bytes)?; println!("Subscribe packet: {:?}", packet); + server + .get_client_mut(client_id) + .subscriptions + .extend(packet.filters.into_iter().map(|x| x.0)); + dbg!(&server.get_client_mut(client_id)); + } + 0x0E => { + println!("Client disconnect: {client_id}"); + server.remove_client(client_id); + break; } ty => eprintln!("Unsupported packet type: 0x{ty:x}"), } } + Ok(()) } pub struct Packet { diff --git a/ui/src/windows/slice_preview.rs b/ui/src/windows/slice_preview.rs index 8d4b0bb..380fc4e 100644 --- a/ui/src/windows/slice_preview.rs +++ b/ui/src/windows/slice_preview.rs @@ -50,7 +50,7 @@ pub fn ui(app: &mut App, ctx: &Context, _frame: &mut Frame) { None }; - egui::Frame::canvas(&ui.style()).show(ui, |ui| { + egui::Frame::canvas(ui.style()).show(ui, |ui| { let available_size = ui.available_size(); let (rect, _response) = ui.allocate_exact_size( Vec2::new(