This commit is contained in:
Vladimir Vukicevic
2023-09-02 21:43:11 -07:00
parent 6001000324
commit ccb53f05c3
4 changed files with 165 additions and 92 deletions

View File

@@ -1,3 +1,11 @@
#
# Cassini
#
# Copyright (C) 2023 Vladimir Vukicevic
# License: MIT
#
import logging
import asyncio
import struct
@@ -17,18 +25,24 @@ class SimpleMQTTServer:
self.incoming_messages = asyncio.Queue()
self.outgoing_messages = asyncio.Queue()
self.next_pack_id_value = 1
self.handlers = {}
def add_handler(self, topic, handler):
if topic not in self.handlers:
self.handlers[topic] = []
self.handlers[topic].append(handler)
async def start(self):
self.server = await asyncio.start_server(self.handle_client, self.host, self.port)
self.port = self.server.sockets[0].getsockname()[1]
print(f'Listening on {self.server.sockets[0].getsockname()}')
logging.debug(f'Listening on {self.server.sockets[0].getsockname()}')
async def serve_forever(self):
await self.server.serve_forever()
async def handle_client(self, reader, writer):
addr = writer.get_extra_info('peername')
print(f'Socket connected from {addr}')
logging.debug(f'Socket connected from {addr}')
data = b''
read_future = asyncio.ensure_future(reader.read(1024))
@@ -53,7 +67,7 @@ class SimpleMQTTServer:
qos = subscribed_topics[topic]
await self.send_msg(writer, MQTT_PUBLISH, payload=self.encode_publish(topic, payload, self.next_pack_id()))
else:
print(f'SEND: NOT SUBSCRIBED {topic}: {payload}')
logging.debug(f'SEND: NOT SUBSCRIBED {topic}: {payload}')
#msg = (MQTT_PUBLISH, 0, topic.encode('utf-8') + payload.encode('utf-8'))
#await self.send_msg(writer, *msg)
outgoing_messages_future = asyncio.ensure_future(self.outgoing_messages.get())
@@ -79,12 +93,12 @@ class SimpleMQTTServer:
# TODO -- we could maybe not have enough bytes to decode the length, but assume
# that won't happen
msg_length, len_bytes_consumed = self.decode_length(data[1:])
print(f" in msg_type: {msg_type} flags: {msg_flags} msg_length {msg_length} bytes_consumed for msg_length {len_bytes_consumed}")
logging.debug(f" in msg_type: {msg_type} flags: {msg_flags} msg_length {msg_length} bytes_consumed for msg_length {len_bytes_consumed}")
# is there enough to process the message?
head_len = len_bytes_consumed + 1
if msg_length + head_len > len(data):
print("Not enough")
logging.debug("Not enough")
break
# pull the message payload out, and move data to next packet
@@ -93,12 +107,16 @@ class SimpleMQTTServer:
if msg_type == MQTT_CONNECT:
# ignore the contents of the message, should maybe check for 'MQTT' identifier at least
print(f"Client {addr} connected")
logging.info(f"Client {addr} connected")
await self.send_msg(writer, MQTT_CONNACK, payload=b'\x00\x00')
elif msg_type == MQTT_PUBLISH:
qos = (msg_flags >> 1) & 0x3
topic, packid, content = self.parse_publish(message)
print(f"{topic}: {content}")
logging.info(f"Got DATA on: {topic}")
if topic in self.handlers:
for handler in self.handlers[topic]:
handler(topic, content)
if qos > 0:
await self.send_msg(writer, MQTT_PUBACK, packet_ident=packid)
elif msg_type == MQTT_SUBSCRIBE:
@@ -106,11 +124,11 @@ class SimpleMQTTServer:
packid = message[0] << 8 | message[1]
message = message[2:]
topic = self.parse_subscribe(message)
print(f"Client {addr} subscribed to topic '{topic}', QoS {qos}")
logging.info(f"Client {addr} subscribed to topic '{topic}', QoS {qos}")
subscribed_topics[topic] = qos
await self.send_msg(writer, MQTT_SUBACK, packet_ident=packid, payload=bytes([qos]))
elif msg_type == MQTT_DISCONNECT:
print(f"Client {addr} disconnected")
logging.info(f"Client {addr} disconnected")
writer.close()
await writer.wait_closed()
return
@@ -124,7 +142,7 @@ class SimpleMQTTServer:
if packet_ident > 0:
head += bytes([packet_ident >> 8, packet_ident & 0xff])
data = head + payload
print(f" writing {len(data)} bytes: {data}")
logging.debug(f" writing {len(data)} bytes: {data}")
writer.write(data)
await writer.drain()