Async mqtt server

This commit is contained in:
Connor Slade
2024-06-24 01:39:04 -04:00
parent b113def198
commit 9e59e5a08a
6 changed files with 87 additions and 22 deletions

1
Cargo.lock generated
View File

@@ -3184,6 +3184,7 @@ dependencies = [
"anyhow",
"bitflags 2.5.0",
"common",
"parking_lot",
"serde",
"serde_json",
]

View File

@@ -10,3 +10,4 @@ serde_json = "1.0.117"
common = { path = "../common" }
bitflags = "2.5.0"
parking_lot = "0.12.3"

View File

@@ -1,13 +1,10 @@
use std::{net::UdpSocket, thread};
use anyhow::Result;
use std::net::UdpSocket;
use remote_send::{mqtt, status::StatusData, Response};
use remote_send::{mqtt::MqttServer, status::StatusData, Response};
fn main() -> Result<()> {
thread::spawn(|| {
mqtt::start().unwrap();
});
MqttServer::new().start_async()?;
let socket = UdpSocket::bind("0.0.0.0:3000")?;

View File

@@ -0,0 +1,6 @@
use std::sync::atomic::{AtomicU64, Ordering};
pub fn next_id() -> u64 {
static NEXT_ID: AtomicU64 = AtomicU64::new(0);
NEXT_ID.fetch_add(1, Ordering::Relaxed)
}

View File

@@ -1,37 +1,86 @@
use std::{
borrow::Cow,
collections::HashMap,
io::{Read, Write},
net::{TcpListener, TcpStream},
sync::Arc,
thread,
};
use anyhow::Result;
use common::serde::Deserializer;
use misc::next_id;
use packets::{
connect::ConnectPacket,
connect_ack::{ConnectAckFlags, ConnectAckPacket, ConnectReturnCode},
subscribe::SubscribePacket,
};
use parking_lot::{lock_api::MutexGuard, MappedMutexGuard, Mutex};
mod misc;
pub mod packets;
pub fn start() -> Result<()> {
let socket = TcpListener::bind("0.0.0.0:1883")?;
for stream in socket.incoming() {
let stream = stream?;
println!("Connection established: {:?}", stream);
thread::spawn(|| {
if let Err(e) = handle_client(stream) {
eprintln!("Error handling client: {:?}", e);
}
});
}
Ok(())
pub struct MqttServer {
clients: Mutex<HashMap<u64, MqttClient>>,
}
fn handle_client(mut stream: TcpStream) -> Result<()> {
#[derive(Debug)]
pub struct MqttClient {
stream: TcpStream,
subscriptions: Vec<String>,
}
impl MqttServer {
pub fn new() -> Arc<Self> {
Arc::new(Self {
clients: Mutex::new(HashMap::new()),
})
}
pub fn start_async(self: Arc<Self>) -> Result<()> {
let socket = TcpListener::bind("0.0.0.0:1883")?;
thread::spawn(move || {
for stream in socket.incoming() {
let stream = stream.unwrap();
let client_id = next_id();
self.clients
.lock()
.insert(client_id, MqttClient::new(stream.try_clone().unwrap()));
println!("Connection established: {:?}", stream);
let this_self = self.clone();
thread::spawn(move || {
if let Err(e) = handle_client(this_self, client_id, stream) {
eprintln!("Error handling client: {:?}", e);
}
});
}
});
Ok(())
}
fn get_client_mut(&self, client_id: u64) -> MappedMutexGuard<MqttClient> {
MutexGuard::map(self.clients.lock(), |x| x.get_mut(&client_id).unwrap())
}
fn remove_client(&self, client_id: u64) {
self.clients.lock().remove(&client_id);
}
}
impl MqttClient {
fn new(stream: TcpStream) -> Self {
Self {
stream,
subscriptions: Vec::new(),
}
}
}
fn handle_client(server: Arc<MqttServer>, client_id: u64, mut stream: TcpStream) -> Result<()> {
loop {
let packet = Packet::read(&mut stream)?;
@@ -50,11 +99,22 @@ fn handle_client(mut stream: TcpStream) -> Result<()> {
SubscribePacket::PACKET_TYPE => {
let packet = SubscribePacket::from_bytes(&packet.remaining_bytes)?;
println!("Subscribe packet: {:?}", packet);
server
.get_client_mut(client_id)
.subscriptions
.extend(packet.filters.into_iter().map(|x| x.0));
dbg!(&server.get_client_mut(client_id));
}
0x0E => {
println!("Client disconnect: {client_id}");
server.remove_client(client_id);
break;
}
ty => eprintln!("Unsupported packet type: 0x{ty:x}"),
}
}
Ok(())
}
pub struct Packet {

View File

@@ -50,7 +50,7 @@ pub fn ui(app: &mut App, ctx: &Context, _frame: &mut Frame) {
None
};
egui::Frame::canvas(&ui.style()).show(ui, |ui| {
egui::Frame::canvas(ui.style()).show(ui, |ui| {
let available_size = ui.available_size();
let (rect, _response) = ui.allocate_exact_size(
Vec2::new(