Make it all work

This commit is contained in:
Vladimir Vukicevic
2023-09-03 15:58:08 -07:00
parent b70d9bbb2f
commit ee4c383e80
5 changed files with 440 additions and 136 deletions

View File

@@ -24,41 +24,47 @@ class SimpleMQTTServer:
self.server = None
self.incoming_messages = asyncio.Queue()
self.outgoing_messages = asyncio.Queue()
self.connected_clients = {}
self.next_pack_id_value = 1
self.handlers = {}
def add_handler(self, topic, handler):
if topic not in self.handlers:
self.handlers[topic] = []
self.handlers[topic].append(handler)
async def start(self):
self.server = await asyncio.start_server(self.handle_client, self.host, self.port)
self.port = self.server.sockets[0].getsockname()[1]
logging.debug(f'Listening on {self.server.sockets[0].getsockname()}')
logging.debug(f'MQTT Listening on {self.server.sockets[0].getsockname()}')
async def serve_forever(self):
loop = asyncio.get_event_loop()
self.client_connection = loop.create_future()
self.client_subscribed = loop.create_future()
await self.server.serve_forever()
def publish(self, topic, payload):
self.outgoing_messages.put_nowait({'topic': topic, 'payload': payload})
async def next_published_message(self):
return await self.incoming_messages.get()
async def handle_client(self, reader, writer):
try:
await self.handle_client_inner(reader, writer)
except Exception as e:
logging.error(f"MQTT Exception handling client: {e}")
async def handle_client_inner(self, reader, writer):
addr = writer.get_extra_info('peername')
logging.debug(f'Socket connected from {addr}')
data = b''
subscribed_topics = dict()
client_id = None
read_future = asyncio.ensure_future(reader.read(1024))
outgoing_messages_future = asyncio.ensure_future(self.outgoing_messages.get())
subscribed_topics = dict()
while True:
# get Future representing a reader.read(1024)
# get Future representing a self.incoming_messages.get()
completed, pending = await asyncio.wait([read_future, outgoing_messages_future], return_when=asyncio.FIRST_COMPLETED)
#print(completed)
#print(pending)
if outgoing_messages_future in completed:
#print("Got outgoing message")
outmsg = outgoing_messages_future.result()
topic = outmsg['topic']
payload = outmsg['payload']
@@ -84,8 +90,6 @@ class SimpleMQTTServer:
# must have at least 2 bytes
if len(data) < 2:
break
#print(f"Remaining bytes: {len(data)}")
#print(f"Data: {data}")
msg_type = data[0] >> 4
msg_flags = data[0] & 0xf
@@ -93,7 +97,7 @@ class SimpleMQTTServer:
# TODO -- we could maybe not have enough bytes to decode the length, but assume
# that won't happen
msg_length, len_bytes_consumed = self.decode_length(data[1:])
logging.debug(f" in msg_type: {msg_type} flags: {msg_flags} msg_length {msg_length} bytes_consumed for msg_length {len_bytes_consumed}")
#logging.debug(f"mqtt in msg_type: {msg_type} flags: {msg_flags} msg_length {msg_length} bytes_consumed for msg_length {len_bytes_consumed}")
# is there enough to process the message?
head_len = len_bytes_consumed + 1
@@ -106,17 +110,27 @@ class SimpleMQTTServer:
data = data[head_len+msg_length:]
if msg_type == MQTT_CONNECT:
# ignore the contents of the message, should maybe check for 'MQTT' identifier at least
logging.info(f"Client {addr} connected")
if message[0:6] != b'\x00\x04MQTT':
logging.error(f"MQTT client {addr}: bad CONNECT")
writer.close()
return
client_id_len = struct.unpack("!H", message[10:12])[0]
client_id = message[12:12+client_id_len].decode("utf-8")
logging.debug(f"MQTT client {client_id} at {addr} connected")
self.connected_clients[client_id] = addr
await self.send_msg(writer, MQTT_CONNACK, payload=b'\x00\x00')
self.client_connection.set_result(client_id)
self.client_connection = asyncio.get_event_loop().create_future()
elif msg_type == MQTT_PUBLISH:
qos = (msg_flags >> 1) & 0x3
topic, packid, content = self.parse_publish(message)
logging.info(f"Got DATA on: {topic}")
if topic in self.handlers:
for handler in self.handlers[topic]:
handler(topic, content)
#logging.debug(f"Got DATA on: {topic}")
self.incoming_messages.put_nowait({ 'topic': topic, 'payload': content})
if qos > 0:
await self.send_msg(writer, MQTT_PUBACK, packet_ident=packid)
elif msg_type == MQTT_SUBSCRIBE:
@@ -124,13 +138,19 @@ class SimpleMQTTServer:
packid = message[0] << 8 | message[1]
message = message[2:]
topic = self.parse_subscribe(message)
logging.info(f"Client {addr} subscribed to topic '{topic}', QoS {qos}")
logging.debug(f"Client {addr} subscribed to topic '{topic}', QoS {qos}")
subscribed_topics[topic] = qos
await self.send_msg(writer, MQTT_SUBACK, packet_ident=packid, payload=bytes([qos]))
self.client_subscribed.set_result(topic)
self.client_subscribed = asyncio.get_event_loop().create_future()
elif msg_type == MQTT_DISCONNECT:
logging.info(f"Client {addr} disconnected")
writer.close()
await writer.wait_closed()
if client_id is not None:
del self.connected_clients[client_id]
return
async def send_msg(self, writer, msg_type, flags=0, packet_ident=0, payload=b''):
@@ -142,7 +162,7 @@ class SimpleMQTTServer:
if packet_ident > 0:
head += bytes([packet_ident >> 8, packet_ident & 0xff])
data = head + payload
logging.debug(f" writing {len(data)} bytes: {data}")
#logging.debug(f" writing {len(data)} bytes: {data}")
writer.write(data)
await writer.drain()