diff --git a/README.md b/README.md index 6cf4913..3c9e193 100644 --- a/README.md +++ b/README.md @@ -23,11 +23,12 @@ install the `alive-progress` package for nicer progress bars (`pip3 install aliv ``` $ ./cassini.py status -0: Saturn3Ultra (ELEGOO Saturn 3 Ultra) -- 192.168.x.x - Status: 4 Layers: 27/1002 +192.168.7.128: Saturn3Ultra (ELEGOO Saturn 3 Ultra) Status: 1 + Print Status: 2 Layers: 19/130 + File Transfer Status: 0 ``` -### Watch live ptogress +### Watch live print progress ``` $ ./cassini.py watch [interval] @@ -37,18 +38,20 @@ _STL_B_Warriors_1_Sword_Combined_Supported.goo |██████████ ### File transfer ``` -$ ./cassini.py [--target printer_id] put-file [--start-print] MyFile.goo +$ ./cassini.py [--printer printer_ip] upload MyFile.goo +15:39:15,190 INFO: Using printer Saturn3Ultra (ELEGOO Saturn 3 Ultra) +MyFile.goo |████████████████████████████████████████| 100% [5750174/5750174] (3291238.22/s) ``` ### Start a print (of an existing file) ``` -$ ./cassini.py [--target printer_id] start-print Myfile.goo +$ ./cassini.py [--printer printer_ip] print Myfile.goo ``` ## Protocol Description -The protocol is pretty simple. (There is no encryption or anything that I could find.) +The protocol is pretty simple. There is no encryption or any obfuscation that I could find. There is a UDP discovery, status, and MQTT connection protocol. When the UDP command to connect to a MQTT server is given, the printer connects to a MQTT server, and can be @@ -200,3 +203,7 @@ is connected to. The file needs to be accessible at the specified URL. }, "Id": "0a69ee780fbd40d7bfb95b312250bf46" }``` + +The printer will set CurrentStatus = 2 (Busy), and FileTransferInfo.Status = 0 while the transfer is in progress, +and will give Status updates. When finished, CurrentStatus will return to 0, and FileTransferInfo will be either 2 (success) +or 3 (failure). diff --git a/cassini.py b/cassini.py index e7eda48..fa8115e 100755 --- a/cassini.py +++ b/cassini.py @@ -6,21 +6,18 @@ # Copyright (C) 2023 Vladimir Vukicevic # License: MIT # - +import os import sys -import socket -import struct import time -import json import asyncio import logging -import random +import argparse from simple_mqtt_server import SimpleMQTTServer from simple_http_server import SimpleHTTPServer from saturn_printer import SaturnPrinter logging.basicConfig( - level=logging.DEBUG, # .INFO + level=logging.INFO, format="%(asctime)s,%(msecs)d %(levelname)s: %(message)s", datefmt="%H:%M:%S", ) @@ -43,28 +40,27 @@ except ImportError: async def create_mqtt_server(): mqtt = SimpleMQTTServer('0.0.0.0', 0) await mqtt.start() - logging.info(f"MQTT Server created, port {mqtt.port}") mqtt_server_task = asyncio.create_task(mqtt.serve_forever()) return mqtt, mqtt.port, mqtt_server_task async def create_http_server(): http = SimpleHTTPServer('0.0.0.0', 0) await http.start() - logging.info(f"HTTP Server created, port {http.port}") http_server_task = asyncio.create_task(http.serve_forever()) return http, http.port, http_server_task -def print_printer_status(printers): +def do_status(printers): for i, p in enumerate(printers): attrs = p.desc['Data']['Attributes'] status = p.desc['Data']['Status'] - printInfo = status['PrintInfo'] - print(f"{i}: {attrs['Name']} ({attrs['MachineName']})") - print(f" Status: {printInfo['Status']} Layers: {printInfo['CurrentLayer']}/{printInfo['TotalLayer']}") + print_info = status['PrintInfo'] + file_info = status['FileTransferInfo'] + print(f"{p.addr[0]}: {attrs['Name']} ({attrs['MachineName']}) Status: {status['CurrentStatus']}") + print(f" Print Status: {print_info['Status']} Layers: {print_info['CurrentLayer']}/{print_info['TotalLayer']} File: {print_info['FileName']}") + print(f" File Transfer Status: {file_info['Status']}") -def do_watch(interval): - printers = SaturnPrinter.find_printers() - status = printers[0].status() +def do_watch(printer, interval=5): + status = printer.status() with alive_bar(total=status['totalLayers'], manual=True, elapsed=False, title=status['filename']) as bar: while True: printers = SaturnPrinter.find_printers() @@ -76,54 +72,111 @@ def do_watch(interval): break time.sleep(interval) -async def main(): - cmd = None - printers = SaturnPrinter.find_printers() - - if len(printers) == 0: - print("No printers found") - return - - if len(printers) > 1: - print("More than 1 printer found.") - print("Usage --printer argument to specify the ID. [TODO]") - return - - if len(sys.argv) > 1: - cmd = sys.argv[1] - - if cmd == 'watch': - do_watch(int(sys.argv[2]) if len(sys.argv) > 2 else 5) - return - - if cmd == 'status': - print_printer_status(printers) - return - - # Spin up our private servers +async def create_servers(): mqtt, mqtt_port, mqtt_task = await create_mqtt_server() http, http_port, http_task = await create_http_server() - printer = printers[0] - printer.connect(mqtt, http) + return mqtt, http - await asyncio.sleep(3) - printer.send_command(SATURN_CMD_0) - await asyncio.sleep(1) - printer.send_command(SATURN_CMD_1) - await asyncio.sleep(1) - printer.send_command(SATURN_CMD_SET_MYSTERY_TIME_PERIOD, { 'TimePeriod': 5 }) - await asyncio.sleep(1000) +async def do_print(printer, filename): + mqtt, http = await create_servers() + connected = await printer.connect(mqtt, http) + if not connected: + logging.error("Failed to connect to printer") + sys.exit(1) - #printer_task = asyncio.create_task(printer_setup(mqtt.port)) - #while True: - # if server_task is not None and server_task.done(): - # print("Server task done") - # print(server_task.exception()) - # server_task = None - # if printer_task is not None and printer_task.done(): - # print("Printer task done") - # print(printer_task.exception()) - # printer_task = None + result = await printer.print_file(filename) + if result: + logging.info("Print started") + else: + logging.error("Failed to start print") + sys.exit(1) -asyncio.run(main()) \ No newline at end of file +async def do_upload(printer, filename, start_printing=False): + if not os.path.exists(filename): + logging.error(f"{filename} does not exist") + sys.exit(1) + + mqtt, http = await create_servers() + connected = await printer.connect(mqtt, http) + if not connected: + logging.error("Failed to connect to printer") + sys.exit(1) + + #await printer.upload_file(filename, start_printing=start_printing) + upload_task = asyncio.create_task(printer.upload_file(filename, start_printing=start_printing)) + # grab the first one, because we want the file size + basename = filename.split('\\')[-1].split('/')[-1] + file_size = os.path.getsize(filename) + with alive_bar(total=file_size, manual=True, elapsed=False, title=basename) as bar: + while True: + if printer.file_transfer_future is None: + await asyncio.sleep(0.1) + continue + progress = await printer.file_transfer_future + if progress[0] < 0: + logging.error("File upload failed!") + sys.exit(1) + bar(progress[0] / progress[1]) + if progress[0] >= progress[1]: + break + await upload_task + +def main(): + parser = argparse.ArgumentParser(prog='cassini', description='ELEGOO Saturn printer control utility') + parser.add_argument('-p', '--printer', help='ID of printer to target') + parser.add_argument('--debug', help='Enable debug logging', action='store_true') + + subparsers = parser.add_subparsers(title="commands", dest="command", required=True) + + parser_status = subparsers.add_parser('status', help='Discover and display status of all printers') + + parser_watch = subparsers.add_parser('watch', help='Continuously update the status of the selected printer') + parser_watch.add_argument('--interval', type=int, help='Status update interval (seconds)', default=5) + + parser_upload = subparsers.add_parser('upload', help='Upload a file to the printer') + parser_upload.add_argument('--start-printing', help='Start printing after upload is complete', action='store_true') + parser_upload.add_argument('filename', help='File to upload') + + parser_print = subparsers.add_parser('print', help='Start printing a file already present on the printer') + parser_print.add_argument('filename', help='File to print') + + args = parser.parse_args() + + if args.debug: + logging.getLogger().setLevel(logging.DEBUG) + + printers = [] + printer = None + if args.printer: + printer = SaturnPrinter.find_printer(args.printer) + printers = [printer] + if printer is None: + logging.error(f"No response from printer {args.printer}") + sys.exit(1) + else: + printers = SaturnPrinter.find_printers() + if len(printers) == 0: + logging.error("No printers found on network") + sys.exit(1) + printer = printers[0] + + if args.command == "status": + do_status(printers) + sys.exit(0) + + if args.command == "watch": + do_watch(printer, interval=args.interval) + sys.exit(0) + + logging.info(f'Printer: {printer.describe()} ({printer.addr[0]})') + if printer.busy: + logging.error(f'Printer is busy (status: {printer.current_status})') + sys.exit(1) + + if args.command == "upload": + asyncio.run(do_upload(printer, args.filename, start_printing=args.start_printing)) + elif args.command == "print": + asyncio.run(do_print(printer, args.filename)) + +main() \ No newline at end of file diff --git a/saturn_printer.py b/saturn_printer.py index 4cee4cc..795cd04 100644 --- a/saturn_printer.py +++ b/saturn_printer.py @@ -14,40 +14,59 @@ import asyncio import logging import random -SATURN_BROADCAST_PORT = 3000 +SATURN_UDP_PORT = 3000 -SATURN_STATUS_EXPOSURE = 2 # TODO: double check tese -SATURN_STATUS_RETRACTING = 3 -SATURN_STATUS_LOWERING = 4 -SATURN_STATUS_COMPLETE = 16 # ?? +# CurrentStatus field inside Status +SATURN_STATUS_READY = 0 +SATURN_STATUS_BUSY = 1 # Printer might be sitting at the "Completed" screen +SATURN_STATUS_BUSY_2 = 1 # post-HEAD call on file transfer, along with SATURN_FILE_STATUS = 3 on error, and 2 on completion -STATUS_NAMES = { - SATURN_STATUS_EXPOSURE: "Exposure", - SATURN_STATUS_RETRACTING: "Retracting", - SATURN_STATUS_LOWERING: "Lowering", - SATURN_STATUS_COMPLETE: "Complete" +# Status field inside PrintInfo +SATURN_PRINT_STATUS_EXPOSURE = 2 # TODO: double check tese +SATURN_PRINT_STATUS_RETRACTING = 3 +SATURN_PRINT_STATUS_LOWERING = 4 +SATURN_PRINT_STATUS_COMPLETE = 16 # pretty sure this is correct + +# Status field inside FileTransferInfo +SATURN_FILE_STATUS_NONE = 0 +SATURN_FILE_STATUS_DONE = 2 +SATURN_FILE_STATUS_ERROR = 3 + +SATURN_PRINT_STATUS_NAMES = { + SATURN_PRINT_STATUS_EXPOSURE: "Exposure", + SATURN_PRINT_STATUS_RETRACTING: "Retracting", + SATURN_PRINT_STATUS_LOWERING: "Lowering", + SATURN_PRINT_STATUS_COMPLETE: "Complete" } SATURN_CMD_0 = 0 # null data SATURN_CMD_1 = 1 # null data -SATURN_CMD_SET_MYSTERY_TIME_PERIOD = 512 # "TimePeriod": 5000 +SATURN_CMD_DISCONNECT = 64 # Maybe disconnect? SATURN_CMD_START_PRINTING = 128 # "Filename": "X", "StartLayer": 0 SATURN_CMD_UPLOAD_FILE = 256 # "Check": 0, "CleanCache": 1, "Compress": 0, "FileSize": 3541068, "Filename": "_ResinXP2-ValidationMatrix_v2.goo", "MD5": "205abc8fab0762ad2b0ee1f6b63b1750", "URL": "http://${ipaddr}:58883/f60c0718c8144b0db48b7149d4d85390.goo" }, -SATURN_CMD_DISCONNECT = 64 # Maybe disconnect? +SATURN_CMD_SET_MYSTERY_TIME_PERIOD = 512 # "TimePeriod": 5000 + +def random_hexstr(): + return '%032x' % random.getrandbits(128) class SaturnPrinter: - def __init__(self, addr, desc): + def __init__(self, addr, desc, timeout=5): self.addr = addr - self.desc = desc + self.timeout = timeout + self.file_transfer_future = None + if desc is not None: + self.set_desc(desc) + else: + self.desc = None - # Class method: UDP broadcast search for all printers + # Broadcast and find all printers, return array of SaturnPrinter objects def find_printers(timeout=1): printers = [] sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) with sock: sock.settimeout(timeout) sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, timeout) - sock.sendto(b'M99999', ('', SATURN_BROADCAST_PORT)) + sock.sendto(b'M99999', ('', SATURN_UDP_PORT)) now = time.time() while True: @@ -62,34 +81,213 @@ class SaturnPrinter: pdata = json.loads(data.decode('utf-8')) printers.append(SaturnPrinter(addr, pdata)) return printers - - # Tell this printer to connect to the given mqtt server - def connect(self, mqtt, http): + + # Find a specific printer at the given address, return a SaturnPrinter object + # or None if no response is obtained + def find_printer(addr, timeout=5): + printer = SaturnPrinter(addr, None) + if printer.refresh(timeout): + return printer + return None + + # Refresh this SaturnPrinter with latest status + def refresh(self, timeout=5): + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + with sock: + sock.settimeout(timeout) + sock.sendto(b'M99999', (self.addr, SATURN_UDP_PORT)) + try: + data, addr = sock.recvfrom(1024) + except socket.timeout: + return False + else: + pdata = json.loads(data.decode('utf-8')) + self.set_desc(pdata) + + def set_desc(self, desc): + self.desc = desc + self.id = desc['Data']['Attributes']['MainboardID'] + self.name = desc['Data']['Attributes']['Name'] + self.machine_name = desc['Data']['Attributes']['MachineName'] + self.current_status = desc['Data']['Status']['CurrentStatus'] + self.busy = self.current_status > 0 + + # Tell this printer to connect to the specified mqtt and http + # servers, for further control + async def connect(self, mqtt, http): self.mqtt = mqtt self.http = http - mainboard = self.desc['Data']['Attributes']['MainboardID'] - mqtt.add_handler("/sdcp/saturn/" + mainboard, self.incoming_data) - mqtt.add_handler("/sdcp/response/" + mainboard, self.incoming_data) - + # Tell the printer to connect sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) with sock: sock.sendto(b'M66666 ' + str(mqtt.port).encode('utf-8'), self.addr) - def incoming_data(self, topic, payload): - if topic.startswith("/sdcp/status/"): - self.incoming_status(payload['Data']['Status']) - elif topic.startswith("/sdcp/attributes/"): - # don't think I care about attributes - pass - elif topic.startswith("/sdcp/response/"): - self.incoming_response(payload['Data']['RequestID'], payload['Data']['Cmd'], payload['Data']['Data']) + # wait for the connection + client_id = await asyncio.wait_for(mqtt.client_connection, timeout=self.timeout) + if client_id != self.id: + logging.error(f"Client ID mismatch: {client_id} != {self.id}") + return False + + # wait for the client to subscribe to the request topic + topic = await asyncio.wait_for(self.mqtt.client_subscribed, timeout=self.timeout) + logging.debug(f"Client subscribed to {topic}") + + await self.send_command_and_wait(SATURN_CMD_0) + await self.send_command_and_wait(SATURN_CMD_1) + await self.send_command_and_wait(SATURN_CMD_SET_MYSTERY_TIME_PERIOD, { 'TimePeriod': 5000 }) + + return True + + async def disconnect(self): + await self.send_command_and_wait(SATURN_CMD_DISCONNECT) + + async def upload_file(self, filename, start_printing=False): + try: + await self.upload_file_inner(filename, start_printing) + except Exception as ex: + logging.error(f"Exception during upload: {ex}") + self.file_transfer_future.set_result((-1, -1, filename)) + self.file_transfer_future = asyncio.get_running_loop().create_future() + + async def upload_file_inner(self, filename, start_printing=False): + # schedule a future that can be used for status, in case this is kicked off as a task + self.file_transfer_future = asyncio.get_running_loop().create_future() + + # get base filename and extension + basename = filename.split('\\')[-1].split('/')[-1] + ext = basename.split('.')[-1].lower() + if ext != 'ctb' and ext != 'goo': + logging.warning(f"Unknown file extension: {ext}") + + httpname = random_hexstr() + '.' + ext + fileinfo = self.http.register_file_route('/' + httpname, filename) + + cmd_data = { + "Check": 0, + "CleanCache": 1, + "Compress": 0, + "FileSize": fileinfo['size'], + "Filename": basename, + "MD5": fileinfo['md5'], + "URL": f"http://${{ipaddr}}:{self.http.port}/{httpname}" + } + + await self.send_command_and_wait(SATURN_CMD_UPLOAD_FILE, cmd_data) + + # now process status updates from the printer + while True: + reply = await asyncio.wait_for(self.mqtt.next_published_message(), timeout=self.timeout*2) + data = json.loads(reply['payload']) + if reply['topic'] == "/sdcp/response/" + self.id: + logging.warning(f"Got unexpected RESPONSE (no outstanding request), topic: {reply['topic']} data: {data}") + elif reply['topic'] == "/sdcp/status/" + self.id: + self.incoming_status(data['Data']['Status']) + + status = data['Data']['Status'] + file_info = status['FileTransferInfo'] + current_offset = file_info['DownloadOffset'] + total_size = file_info['FileTotalSize'] + file_name = file_info['Filename'] + + # We assume that the printer immediately goes into BUSY status after it processes + # the upload command + if status['CurrentStatus'] == SATURN_STATUS_READY: + if file_info['Status'] == SATURN_FILE_STATUS_DONE: + self.file_transfer_future.set_result((total_size, total_size, file_name)) + elif file_info['Status'] == SATURN_FILE_STATUS_ERROR: + logging.error("Transfer error!") + self.file_transfer_future.set_result((-1, total_size, file_name)) + else: + logging.error(f"Unknown file transfer status code: {file_info['Status']}") + self.file_transfer_future.set_result((-1, total_size, file_name)) + break + + self.file_transfer_future.set_result((current_offset, total_size, file_name)) + self.file_transfer_future = asyncio.get_running_loop().create_future() + elif reply['topic'] == "/sdcp/attributes/" + self.id: + # ignore these + pass + else: + logging.warning(f"Got unknown topic message: {reply['topic']}") + + self.file_transfer_future = None + + async def send_command_and_wait(self, cmdid, data=None, abort_on_bad_ack=True): + # Send the 0 and 1 messages + req = self.send_command(cmdid, data) + logging.debug(f"Sent command {cmdid} as request {req}") + while True: + reply = await asyncio.wait_for(self.mqtt.next_published_message(), timeout=self.timeout) + data = json.loads(reply['payload']) + if reply['topic'] == "/sdcp/response/" + self.id: + if data['Data']['RequestID'] == req: + logging.debug(f"Got response to {req}") + result = data['Data']['Data'] + if abort_on_bad_ack and result['Ack'] != 0: + logging.error(f"Got bad ack in response: {result}") + sys.exit(1) + return result + elif reply['topic'] == "/sdcp/status/" + self.id: + self.incoming_status(data['Data']['Status']) + elif reply['topic'] == "/sdcp/attributes/" + self.id: + # ignore these + pass + else: + logging.warning(f"Got unknown topic message: {reply['topic']}") + + async def print_file(self, filename): + cmd_data = { + "Filename": filename, + "StartLayer": 0 + } + + await self.send_command_and_wait(SATURN_CMD_START_PRINTING, cmd_data) + + # process status updates from the printer, enough to know whether printing + # started or failed to start + status_count = 0 + while True: + reply = await asyncio.wait_for(self.mqtt.next_published_message(), timeout=self.timeout*2) + data = json.loads(reply['payload']) + if reply['topic'] == "/sdcp/response/" + self.id: + logging.warning(f"Got unexpected RESPONSE (no outstanding request), topic: {reply['topic']} data: {data}") + elif reply['topic'] == "/sdcp/status/" + self.id: + self.incoming_status(data['Data']['Status']) + status_count += 1 + + status = data['Data']['Status'] + print_info = status['PrintInfo'] + + current_status = status['CurrentStatus'] + print_status = print_info['Status'] + + if current_status == SATURN_STATUS_BUSY and print_status > 0: + return True + + logging.debug(status) + logging.debug(print_info) + + if status_count >= 5: + logging.warning("Too many status replies without success or failure") + return False + + elif reply['topic'] == "/sdcp/attributes/" + self.id: + # ignore these + pass + else: + logging.warning(f"Got unknown topic message: {reply['topic']}") + + async def process_responses(self): + while True: + reply = await asyncio.wait_for(self.mqtt.next_published_message(), timeout=self.timeout) + data = json.loads(reply['payload']) def incoming_status(self, status): - logging.info(f"STATUS: {status}") + logging.debug(f"STATUS: {status}") def incoming_response(self, id, cmd, data): - logging.info(f"RESPONSE: {id} -- {cmd}: {data}") + logging.debug(f"RESPONSE: {id} -- {cmd}: {data}") def describe(self): attrs = self.desc['Data']['Attributes'] @@ -106,19 +304,18 @@ class SaturnPrinter: def send_command(self, cmdid, data=None): # generate 16-byte random identifier as a hex string - hexstr = '%032x' % random.getrandbits(128) + hexstr = random_hexstr() timestamp = int(time.time() * 1000) - mainboard = self.desc['Data']['Attributes']['MainboardID'] cmd_data = { "Data": { "Cmd": cmdid, "Data": data, "From": 0, - "MainboardID": mainboard, + "MainboardID": self.id, "RequestID": hexstr, "TimeStamp": timestamp }, "Id": self.desc['Id'] } - print("SENDING REQUEST: " + json.dumps(cmd_data)) - self.mqtt.outgoing_messages.put_nowait({'topic': '/sdcp/request/' + mainboard, 'payload': json.dumps(cmd_data)}) + self.mqtt.publish('/sdcp/request/' + self.id, json.dumps(cmd_data)) + return hexstr diff --git a/simple_http_server.py b/simple_http_server.py index b7bf608..ffb64ca 100644 --- a/simple_http_server.py +++ b/simple_http_server.py @@ -5,12 +5,15 @@ # License: MIT # +import logging import asyncio import os import hashlib class SimpleHTTPServer: - def __init__(self, host="127.0.0.1", port=0): + BufferSize = 1024768 + + def __init__(self, host="0.0.0.0", port=0): self.host = host self.port = port self.server = None @@ -29,46 +32,70 @@ class SimpleHTTPServer: self.routes[path] = route return route + def unregister_file_route(self, path): + del self.routes[path] + async def start(self): self.server = await asyncio.start_server(self.handle_client, self.host, self.port) self.port = self.server.sockets[0].getsockname()[1] + logging.debug(f'HTTP Listening on {self.server.sockets[0].getsockname()}') async def serve_forever(self): await self.server.serve_forever() async def handle_client(self, reader, writer): + try: + await self.handle_client_inner(reader, writer) + except Exception as e: + logging.error(f"HTTP Exception handling client: {e}") + + async def handle_client_inner(self, reader, writer): + logging.debug(f"HTTP connection from {writer.get_extra_info('peername')}") data = b'' while True: data += await reader.read(1024) if b'\r\n\r\n' in data: break + logging.debug(f"HTTP request: {data}") request_line = data.decode().splitlines()[0] method, path, _ = request_line.split() if path not in self.routes: + logging.debug(f"HTTP path {path} not found in routes") + logging.debug(self.routes) writer.write("HTTP/1.1 404 Not Found\r\n".encode()) writer.close() return route = self.routes[path] - header = f"HTTP/1.1 200 OK\r\n" - header += f"Content-Type: application/octet-stream\r\n" - header += f"Etag: {route['md5']}\r\n" - header += f"Content-Length: {path['size']}\r\n" + logging.debug(f"HTTP method {method} path {path} route: {route}") + header = f"HTTP/1.1 200 OK\r\n" + #header += f"Content-Type: application/octet-stream\r\n" + header += f"Content-Type: text/plain; charset=utf-8\r\n" + header += f"Etag: {route['md5']}\r\n" + header += f"Content-Length: {route['size']}\r\n" + header += "\r\n" + + logging.debug(f"Writing header:\n{header}") writer.write(header.encode()) if method == "GET": - writer.write(b'\r\n') + total = 0 with open(route['file'], 'rb') as f: while True: - data = f.read(8192) + data = f.read(self.BufferSize) if not data: break writer.write(data) + logging.debug(f"HTTP wrote {len(data)} bytes") + #await asyncio.sleep(1) + total += len(data) + logging.debug(f"HTTP wrote total {total} bytes") await writer.drain() writer.close() await writer.wait_closed() + logging.debug(f"HTTP connection closed") diff --git a/simple_mqtt_server.py b/simple_mqtt_server.py index f312cef..5a1f0d5 100644 --- a/simple_mqtt_server.py +++ b/simple_mqtt_server.py @@ -24,41 +24,47 @@ class SimpleMQTTServer: self.server = None self.incoming_messages = asyncio.Queue() self.outgoing_messages = asyncio.Queue() + self.connected_clients = {} self.next_pack_id_value = 1 - self.handlers = {} - - def add_handler(self, topic, handler): - if topic not in self.handlers: - self.handlers[topic] = [] - self.handlers[topic].append(handler) async def start(self): self.server = await asyncio.start_server(self.handle_client, self.host, self.port) self.port = self.server.sockets[0].getsockname()[1] - logging.debug(f'Listening on {self.server.sockets[0].getsockname()}') + logging.debug(f'MQTT Listening on {self.server.sockets[0].getsockname()}') async def serve_forever(self): + loop = asyncio.get_event_loop() + self.client_connection = loop.create_future() + self.client_subscribed = loop.create_future() await self.server.serve_forever() + def publish(self, topic, payload): + self.outgoing_messages.put_nowait({'topic': topic, 'payload': payload}) + + async def next_published_message(self): + return await self.incoming_messages.get() + async def handle_client(self, reader, writer): + try: + await self.handle_client_inner(reader, writer) + except Exception as e: + logging.error(f"MQTT Exception handling client: {e}") + + async def handle_client_inner(self, reader, writer): addr = writer.get_extra_info('peername') logging.debug(f'Socket connected from {addr}') data = b'' + subscribed_topics = dict() + client_id = None + read_future = asyncio.ensure_future(reader.read(1024)) outgoing_messages_future = asyncio.ensure_future(self.outgoing_messages.get()) - subscribed_topics = dict() - while True: - # get Future representing a reader.read(1024) - # get Future representing a self.incoming_messages.get() completed, pending = await asyncio.wait([read_future, outgoing_messages_future], return_when=asyncio.FIRST_COMPLETED) - #print(completed) - #print(pending) if outgoing_messages_future in completed: - #print("Got outgoing message") outmsg = outgoing_messages_future.result() topic = outmsg['topic'] payload = outmsg['payload'] @@ -84,8 +90,6 @@ class SimpleMQTTServer: # must have at least 2 bytes if len(data) < 2: break - #print(f"Remaining bytes: {len(data)}") - #print(f"Data: {data}") msg_type = data[0] >> 4 msg_flags = data[0] & 0xf @@ -93,7 +97,7 @@ class SimpleMQTTServer: # TODO -- we could maybe not have enough bytes to decode the length, but assume # that won't happen msg_length, len_bytes_consumed = self.decode_length(data[1:]) - logging.debug(f" in msg_type: {msg_type} flags: {msg_flags} msg_length {msg_length} bytes_consumed for msg_length {len_bytes_consumed}") + #logging.debug(f"mqtt in msg_type: {msg_type} flags: {msg_flags} msg_length {msg_length} bytes_consumed for msg_length {len_bytes_consumed}") # is there enough to process the message? head_len = len_bytes_consumed + 1 @@ -106,17 +110,27 @@ class SimpleMQTTServer: data = data[head_len+msg_length:] if msg_type == MQTT_CONNECT: - # ignore the contents of the message, should maybe check for 'MQTT' identifier at least - logging.info(f"Client {addr} connected") + if message[0:6] != b'\x00\x04MQTT': + logging.error(f"MQTT client {addr}: bad CONNECT") + writer.close() + return + + client_id_len = struct.unpack("!H", message[10:12])[0] + client_id = message[12:12+client_id_len].decode("utf-8") + + logging.debug(f"MQTT client {client_id} at {addr} connected") + self.connected_clients[client_id] = addr await self.send_msg(writer, MQTT_CONNACK, payload=b'\x00\x00') + + self.client_connection.set_result(client_id) + self.client_connection = asyncio.get_event_loop().create_future() + elif msg_type == MQTT_PUBLISH: qos = (msg_flags >> 1) & 0x3 topic, packid, content = self.parse_publish(message) - logging.info(f"Got DATA on: {topic}") - if topic in self.handlers: - for handler in self.handlers[topic]: - handler(topic, content) + #logging.debug(f"Got DATA on: {topic}") + self.incoming_messages.put_nowait({ 'topic': topic, 'payload': content}) if qos > 0: await self.send_msg(writer, MQTT_PUBACK, packet_ident=packid) elif msg_type == MQTT_SUBSCRIBE: @@ -124,13 +138,19 @@ class SimpleMQTTServer: packid = message[0] << 8 | message[1] message = message[2:] topic = self.parse_subscribe(message) - logging.info(f"Client {addr} subscribed to topic '{topic}', QoS {qos}") + logging.debug(f"Client {addr} subscribed to topic '{topic}', QoS {qos}") subscribed_topics[topic] = qos await self.send_msg(writer, MQTT_SUBACK, packet_ident=packid, payload=bytes([qos])) + + self.client_subscribed.set_result(topic) + self.client_subscribed = asyncio.get_event_loop().create_future() elif msg_type == MQTT_DISCONNECT: logging.info(f"Client {addr} disconnected") writer.close() await writer.wait_closed() + + if client_id is not None: + del self.connected_clients[client_id] return async def send_msg(self, writer, msg_type, flags=0, packet_ident=0, payload=b''): @@ -142,7 +162,7 @@ class SimpleMQTTServer: if packet_ident > 0: head += bytes([packet_ident >> 8, packet_ident & 0xff]) data = head + payload - logging.debug(f" writing {len(data)} bytes: {data}") + #logging.debug(f" writing {len(data)} bytes: {data}") writer.write(data) await writer.drain()