Make it all work

This commit is contained in:
Vladimir Vukicevic
2023-09-03 15:58:08 -07:00
parent b70d9bbb2f
commit ee4c383e80
5 changed files with 440 additions and 136 deletions

View File

@@ -23,11 +23,12 @@ install the `alive-progress` package for nicer progress bars (`pip3 install aliv
```
$ ./cassini.py status
0: Saturn3Ultra (ELEGOO Saturn 3 Ultra) -- 192.168.x.x
Status: 4 Layers: 27/1002
192.168.7.128: Saturn3Ultra (ELEGOO Saturn 3 Ultra) Status: 1
Print Status: 2 Layers: 19/130
File Transfer Status: 0
```
### Watch live ptogress
### Watch live print progress
```
$ ./cassini.py watch [interval]
@@ -37,18 +38,20 @@ _STL_B_Warriors_1_Sword_Combined_Supported.goo |██████████
### File transfer
```
$ ./cassini.py [--target printer_id] put-file [--start-print] MyFile.goo
$ ./cassini.py [--printer printer_ip] upload MyFile.goo
15:39:15,190 INFO: Using printer Saturn3Ultra (ELEGOO Saturn 3 Ultra)
MyFile.goo |████████████████████████████████████████| 100% [5750174/5750174] (3291238.22/s)
```
### Start a print (of an existing file)
```
$ ./cassini.py [--target printer_id] start-print Myfile.goo
$ ./cassini.py [--printer printer_ip] print Myfile.goo
```
## Protocol Description
The protocol is pretty simple. (There is no encryption or anything that I could find.)
The protocol is pretty simple. There is no encryption or any obfuscation that I could find.
There is a UDP discovery, status, and MQTT connection protocol. When the UDP command to
connect to a MQTT server is given, the printer connects to a MQTT server, and can be
@@ -200,3 +203,7 @@ is connected to. The file needs to be accessible at the specified URL.
},
"Id": "0a69ee780fbd40d7bfb95b312250bf46"
}```
The printer will set CurrentStatus = 2 (Busy), and FileTransferInfo.Status = 0 while the transfer is in progress,
and will give Status updates. When finished, CurrentStatus will return to 0, and FileTransferInfo will be either 2 (success)
or 3 (failure).

View File

@@ -6,21 +6,18 @@
# Copyright (C) 2023 Vladimir Vukicevic
# License: MIT
#
import os
import sys
import socket
import struct
import time
import json
import asyncio
import logging
import random
import argparse
from simple_mqtt_server import SimpleMQTTServer
from simple_http_server import SimpleHTTPServer
from saturn_printer import SaturnPrinter
logging.basicConfig(
level=logging.DEBUG, # .INFO
level=logging.INFO,
format="%(asctime)s,%(msecs)d %(levelname)s: %(message)s",
datefmt="%H:%M:%S",
)
@@ -43,28 +40,27 @@ except ImportError:
async def create_mqtt_server():
mqtt = SimpleMQTTServer('0.0.0.0', 0)
await mqtt.start()
logging.info(f"MQTT Server created, port {mqtt.port}")
mqtt_server_task = asyncio.create_task(mqtt.serve_forever())
return mqtt, mqtt.port, mqtt_server_task
async def create_http_server():
http = SimpleHTTPServer('0.0.0.0', 0)
await http.start()
logging.info(f"HTTP Server created, port {http.port}")
http_server_task = asyncio.create_task(http.serve_forever())
return http, http.port, http_server_task
def print_printer_status(printers):
def do_status(printers):
for i, p in enumerate(printers):
attrs = p.desc['Data']['Attributes']
status = p.desc['Data']['Status']
printInfo = status['PrintInfo']
print(f"{i}: {attrs['Name']} ({attrs['MachineName']})")
print(f" Status: {printInfo['Status']} Layers: {printInfo['CurrentLayer']}/{printInfo['TotalLayer']}")
print_info = status['PrintInfo']
file_info = status['FileTransferInfo']
print(f"{p.addr[0]}: {attrs['Name']} ({attrs['MachineName']}) Status: {status['CurrentStatus']}")
print(f" Print Status: {print_info['Status']} Layers: {print_info['CurrentLayer']}/{print_info['TotalLayer']} File: {print_info['FileName']}")
print(f" File Transfer Status: {file_info['Status']}")
def do_watch(interval):
printers = SaturnPrinter.find_printers()
status = printers[0].status()
def do_watch(printer, interval=5):
status = printer.status()
with alive_bar(total=status['totalLayers'], manual=True, elapsed=False, title=status['filename']) as bar:
while True:
printers = SaturnPrinter.find_printers()
@@ -76,54 +72,111 @@ def do_watch(interval):
break
time.sleep(interval)
async def main():
cmd = None
printers = SaturnPrinter.find_printers()
if len(printers) == 0:
print("No printers found")
return
if len(printers) > 1:
print("More than 1 printer found.")
print("Usage --printer argument to specify the ID. [TODO]")
return
if len(sys.argv) > 1:
cmd = sys.argv[1]
if cmd == 'watch':
do_watch(int(sys.argv[2]) if len(sys.argv) > 2 else 5)
return
if cmd == 'status':
print_printer_status(printers)
return
# Spin up our private servers
async def create_servers():
mqtt, mqtt_port, mqtt_task = await create_mqtt_server()
http, http_port, http_task = await create_http_server()
printer = printers[0]
printer.connect(mqtt, http)
return mqtt, http
await asyncio.sleep(3)
printer.send_command(SATURN_CMD_0)
await asyncio.sleep(1)
printer.send_command(SATURN_CMD_1)
await asyncio.sleep(1)
printer.send_command(SATURN_CMD_SET_MYSTERY_TIME_PERIOD, { 'TimePeriod': 5 })
await asyncio.sleep(1000)
async def do_print(printer, filename):
mqtt, http = await create_servers()
connected = await printer.connect(mqtt, http)
if not connected:
logging.error("Failed to connect to printer")
sys.exit(1)
#printer_task = asyncio.create_task(printer_setup(mqtt.port))
#while True:
# if server_task is not None and server_task.done():
# print("Server task done")
# print(server_task.exception())
# server_task = None
# if printer_task is not None and printer_task.done():
# print("Printer task done")
# print(printer_task.exception())
# printer_task = None
result = await printer.print_file(filename)
if result:
logging.info("Print started")
else:
logging.error("Failed to start print")
sys.exit(1)
asyncio.run(main())
async def do_upload(printer, filename, start_printing=False):
if not os.path.exists(filename):
logging.error(f"{filename} does not exist")
sys.exit(1)
mqtt, http = await create_servers()
connected = await printer.connect(mqtt, http)
if not connected:
logging.error("Failed to connect to printer")
sys.exit(1)
#await printer.upload_file(filename, start_printing=start_printing)
upload_task = asyncio.create_task(printer.upload_file(filename, start_printing=start_printing))
# grab the first one, because we want the file size
basename = filename.split('\\')[-1].split('/')[-1]
file_size = os.path.getsize(filename)
with alive_bar(total=file_size, manual=True, elapsed=False, title=basename) as bar:
while True:
if printer.file_transfer_future is None:
await asyncio.sleep(0.1)
continue
progress = await printer.file_transfer_future
if progress[0] < 0:
logging.error("File upload failed!")
sys.exit(1)
bar(progress[0] / progress[1])
if progress[0] >= progress[1]:
break
await upload_task
def main():
parser = argparse.ArgumentParser(prog='cassini', description='ELEGOO Saturn printer control utility')
parser.add_argument('-p', '--printer', help='ID of printer to target')
parser.add_argument('--debug', help='Enable debug logging', action='store_true')
subparsers = parser.add_subparsers(title="commands", dest="command", required=True)
parser_status = subparsers.add_parser('status', help='Discover and display status of all printers')
parser_watch = subparsers.add_parser('watch', help='Continuously update the status of the selected printer')
parser_watch.add_argument('--interval', type=int, help='Status update interval (seconds)', default=5)
parser_upload = subparsers.add_parser('upload', help='Upload a file to the printer')
parser_upload.add_argument('--start-printing', help='Start printing after upload is complete', action='store_true')
parser_upload.add_argument('filename', help='File to upload')
parser_print = subparsers.add_parser('print', help='Start printing a file already present on the printer')
parser_print.add_argument('filename', help='File to print')
args = parser.parse_args()
if args.debug:
logging.getLogger().setLevel(logging.DEBUG)
printers = []
printer = None
if args.printer:
printer = SaturnPrinter.find_printer(args.printer)
printers = [printer]
if printer is None:
logging.error(f"No response from printer {args.printer}")
sys.exit(1)
else:
printers = SaturnPrinter.find_printers()
if len(printers) == 0:
logging.error("No printers found on network")
sys.exit(1)
printer = printers[0]
if args.command == "status":
do_status(printers)
sys.exit(0)
if args.command == "watch":
do_watch(printer, interval=args.interval)
sys.exit(0)
logging.info(f'Printer: {printer.describe()} ({printer.addr[0]})')
if printer.busy:
logging.error(f'Printer is busy (status: {printer.current_status})')
sys.exit(1)
if args.command == "upload":
asyncio.run(do_upload(printer, args.filename, start_printing=args.start_printing))
elif args.command == "print":
asyncio.run(do_print(printer, args.filename))
main()

View File

@@ -14,40 +14,59 @@ import asyncio
import logging
import random
SATURN_BROADCAST_PORT = 3000
SATURN_UDP_PORT = 3000
SATURN_STATUS_EXPOSURE = 2 # TODO: double check tese
SATURN_STATUS_RETRACTING = 3
SATURN_STATUS_LOWERING = 4
SATURN_STATUS_COMPLETE = 16 # ??
# CurrentStatus field inside Status
SATURN_STATUS_READY = 0
SATURN_STATUS_BUSY = 1 # Printer might be sitting at the "Completed" screen
SATURN_STATUS_BUSY_2 = 1 # post-HEAD call on file transfer, along with SATURN_FILE_STATUS = 3 on error, and 2 on completion
STATUS_NAMES = {
SATURN_STATUS_EXPOSURE: "Exposure",
SATURN_STATUS_RETRACTING: "Retracting",
SATURN_STATUS_LOWERING: "Lowering",
SATURN_STATUS_COMPLETE: "Complete"
# Status field inside PrintInfo
SATURN_PRINT_STATUS_EXPOSURE = 2 # TODO: double check tese
SATURN_PRINT_STATUS_RETRACTING = 3
SATURN_PRINT_STATUS_LOWERING = 4
SATURN_PRINT_STATUS_COMPLETE = 16 # pretty sure this is correct
# Status field inside FileTransferInfo
SATURN_FILE_STATUS_NONE = 0
SATURN_FILE_STATUS_DONE = 2
SATURN_FILE_STATUS_ERROR = 3
SATURN_PRINT_STATUS_NAMES = {
SATURN_PRINT_STATUS_EXPOSURE: "Exposure",
SATURN_PRINT_STATUS_RETRACTING: "Retracting",
SATURN_PRINT_STATUS_LOWERING: "Lowering",
SATURN_PRINT_STATUS_COMPLETE: "Complete"
}
SATURN_CMD_0 = 0 # null data
SATURN_CMD_1 = 1 # null data
SATURN_CMD_SET_MYSTERY_TIME_PERIOD = 512 # "TimePeriod": 5000
SATURN_CMD_DISCONNECT = 64 # Maybe disconnect?
SATURN_CMD_START_PRINTING = 128 # "Filename": "X", "StartLayer": 0
SATURN_CMD_UPLOAD_FILE = 256 # "Check": 0, "CleanCache": 1, "Compress": 0, "FileSize": 3541068, "Filename": "_ResinXP2-ValidationMatrix_v2.goo", "MD5": "205abc8fab0762ad2b0ee1f6b63b1750", "URL": "http://${ipaddr}:58883/f60c0718c8144b0db48b7149d4d85390.goo" },
SATURN_CMD_DISCONNECT = 64 # Maybe disconnect?
SATURN_CMD_SET_MYSTERY_TIME_PERIOD = 512 # "TimePeriod": 5000
def random_hexstr():
return '%032x' % random.getrandbits(128)
class SaturnPrinter:
def __init__(self, addr, desc):
def __init__(self, addr, desc, timeout=5):
self.addr = addr
self.desc = desc
self.timeout = timeout
self.file_transfer_future = None
if desc is not None:
self.set_desc(desc)
else:
self.desc = None
# Class method: UDP broadcast search for all printers
# Broadcast and find all printers, return array of SaturnPrinter objects
def find_printers(timeout=1):
printers = []
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
with sock:
sock.settimeout(timeout)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, timeout)
sock.sendto(b'M99999', ('<broadcast>', SATURN_BROADCAST_PORT))
sock.sendto(b'M99999', ('<broadcast>', SATURN_UDP_PORT))
now = time.time()
while True:
@@ -62,34 +81,213 @@ class SaturnPrinter:
pdata = json.loads(data.decode('utf-8'))
printers.append(SaturnPrinter(addr, pdata))
return printers
# Tell this printer to connect to the given mqtt server
def connect(self, mqtt, http):
# Find a specific printer at the given address, return a SaturnPrinter object
# or None if no response is obtained
def find_printer(addr, timeout=5):
printer = SaturnPrinter(addr, None)
if printer.refresh(timeout):
return printer
return None
# Refresh this SaturnPrinter with latest status
def refresh(self, timeout=5):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
with sock:
sock.settimeout(timeout)
sock.sendto(b'M99999', (self.addr, SATURN_UDP_PORT))
try:
data, addr = sock.recvfrom(1024)
except socket.timeout:
return False
else:
pdata = json.loads(data.decode('utf-8'))
self.set_desc(pdata)
def set_desc(self, desc):
self.desc = desc
self.id = desc['Data']['Attributes']['MainboardID']
self.name = desc['Data']['Attributes']['Name']
self.machine_name = desc['Data']['Attributes']['MachineName']
self.current_status = desc['Data']['Status']['CurrentStatus']
self.busy = self.current_status > 0
# Tell this printer to connect to the specified mqtt and http
# servers, for further control
async def connect(self, mqtt, http):
self.mqtt = mqtt
self.http = http
mainboard = self.desc['Data']['Attributes']['MainboardID']
mqtt.add_handler("/sdcp/saturn/" + mainboard, self.incoming_data)
mqtt.add_handler("/sdcp/response/" + mainboard, self.incoming_data)
# Tell the printer to connect
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
with sock:
sock.sendto(b'M66666 ' + str(mqtt.port).encode('utf-8'), self.addr)
def incoming_data(self, topic, payload):
if topic.startswith("/sdcp/status/"):
self.incoming_status(payload['Data']['Status'])
elif topic.startswith("/sdcp/attributes/"):
# don't think I care about attributes
pass
elif topic.startswith("/sdcp/response/"):
self.incoming_response(payload['Data']['RequestID'], payload['Data']['Cmd'], payload['Data']['Data'])
# wait for the connection
client_id = await asyncio.wait_for(mqtt.client_connection, timeout=self.timeout)
if client_id != self.id:
logging.error(f"Client ID mismatch: {client_id} != {self.id}")
return False
# wait for the client to subscribe to the request topic
topic = await asyncio.wait_for(self.mqtt.client_subscribed, timeout=self.timeout)
logging.debug(f"Client subscribed to {topic}")
await self.send_command_and_wait(SATURN_CMD_0)
await self.send_command_and_wait(SATURN_CMD_1)
await self.send_command_and_wait(SATURN_CMD_SET_MYSTERY_TIME_PERIOD, { 'TimePeriod': 5000 })
return True
async def disconnect(self):
await self.send_command_and_wait(SATURN_CMD_DISCONNECT)
async def upload_file(self, filename, start_printing=False):
try:
await self.upload_file_inner(filename, start_printing)
except Exception as ex:
logging.error(f"Exception during upload: {ex}")
self.file_transfer_future.set_result((-1, -1, filename))
self.file_transfer_future = asyncio.get_running_loop().create_future()
async def upload_file_inner(self, filename, start_printing=False):
# schedule a future that can be used for status, in case this is kicked off as a task
self.file_transfer_future = asyncio.get_running_loop().create_future()
# get base filename and extension
basename = filename.split('\\')[-1].split('/')[-1]
ext = basename.split('.')[-1].lower()
if ext != 'ctb' and ext != 'goo':
logging.warning(f"Unknown file extension: {ext}")
httpname = random_hexstr() + '.' + ext
fileinfo = self.http.register_file_route('/' + httpname, filename)
cmd_data = {
"Check": 0,
"CleanCache": 1,
"Compress": 0,
"FileSize": fileinfo['size'],
"Filename": basename,
"MD5": fileinfo['md5'],
"URL": f"http://${{ipaddr}}:{self.http.port}/{httpname}"
}
await self.send_command_and_wait(SATURN_CMD_UPLOAD_FILE, cmd_data)
# now process status updates from the printer
while True:
reply = await asyncio.wait_for(self.mqtt.next_published_message(), timeout=self.timeout*2)
data = json.loads(reply['payload'])
if reply['topic'] == "/sdcp/response/" + self.id:
logging.warning(f"Got unexpected RESPONSE (no outstanding request), topic: {reply['topic']} data: {data}")
elif reply['topic'] == "/sdcp/status/" + self.id:
self.incoming_status(data['Data']['Status'])
status = data['Data']['Status']
file_info = status['FileTransferInfo']
current_offset = file_info['DownloadOffset']
total_size = file_info['FileTotalSize']
file_name = file_info['Filename']
# We assume that the printer immediately goes into BUSY status after it processes
# the upload command
if status['CurrentStatus'] == SATURN_STATUS_READY:
if file_info['Status'] == SATURN_FILE_STATUS_DONE:
self.file_transfer_future.set_result((total_size, total_size, file_name))
elif file_info['Status'] == SATURN_FILE_STATUS_ERROR:
logging.error("Transfer error!")
self.file_transfer_future.set_result((-1, total_size, file_name))
else:
logging.error(f"Unknown file transfer status code: {file_info['Status']}")
self.file_transfer_future.set_result((-1, total_size, file_name))
break
self.file_transfer_future.set_result((current_offset, total_size, file_name))
self.file_transfer_future = asyncio.get_running_loop().create_future()
elif reply['topic'] == "/sdcp/attributes/" + self.id:
# ignore these
pass
else:
logging.warning(f"Got unknown topic message: {reply['topic']}")
self.file_transfer_future = None
async def send_command_and_wait(self, cmdid, data=None, abort_on_bad_ack=True):
# Send the 0 and 1 messages
req = self.send_command(cmdid, data)
logging.debug(f"Sent command {cmdid} as request {req}")
while True:
reply = await asyncio.wait_for(self.mqtt.next_published_message(), timeout=self.timeout)
data = json.loads(reply['payload'])
if reply['topic'] == "/sdcp/response/" + self.id:
if data['Data']['RequestID'] == req:
logging.debug(f"Got response to {req}")
result = data['Data']['Data']
if abort_on_bad_ack and result['Ack'] != 0:
logging.error(f"Got bad ack in response: {result}")
sys.exit(1)
return result
elif reply['topic'] == "/sdcp/status/" + self.id:
self.incoming_status(data['Data']['Status'])
elif reply['topic'] == "/sdcp/attributes/" + self.id:
# ignore these
pass
else:
logging.warning(f"Got unknown topic message: {reply['topic']}")
async def print_file(self, filename):
cmd_data = {
"Filename": filename,
"StartLayer": 0
}
await self.send_command_and_wait(SATURN_CMD_START_PRINTING, cmd_data)
# process status updates from the printer, enough to know whether printing
# started or failed to start
status_count = 0
while True:
reply = await asyncio.wait_for(self.mqtt.next_published_message(), timeout=self.timeout*2)
data = json.loads(reply['payload'])
if reply['topic'] == "/sdcp/response/" + self.id:
logging.warning(f"Got unexpected RESPONSE (no outstanding request), topic: {reply['topic']} data: {data}")
elif reply['topic'] == "/sdcp/status/" + self.id:
self.incoming_status(data['Data']['Status'])
status_count += 1
status = data['Data']['Status']
print_info = status['PrintInfo']
current_status = status['CurrentStatus']
print_status = print_info['Status']
if current_status == SATURN_STATUS_BUSY and print_status > 0:
return True
logging.debug(status)
logging.debug(print_info)
if status_count >= 5:
logging.warning("Too many status replies without success or failure")
return False
elif reply['topic'] == "/sdcp/attributes/" + self.id:
# ignore these
pass
else:
logging.warning(f"Got unknown topic message: {reply['topic']}")
async def process_responses(self):
while True:
reply = await asyncio.wait_for(self.mqtt.next_published_message(), timeout=self.timeout)
data = json.loads(reply['payload'])
def incoming_status(self, status):
logging.info(f"STATUS: {status}")
logging.debug(f"STATUS: {status}")
def incoming_response(self, id, cmd, data):
logging.info(f"RESPONSE: {id} -- {cmd}: {data}")
logging.debug(f"RESPONSE: {id} -- {cmd}: {data}")
def describe(self):
attrs = self.desc['Data']['Attributes']
@@ -106,19 +304,18 @@ class SaturnPrinter:
def send_command(self, cmdid, data=None):
# generate 16-byte random identifier as a hex string
hexstr = '%032x' % random.getrandbits(128)
hexstr = random_hexstr()
timestamp = int(time.time() * 1000)
mainboard = self.desc['Data']['Attributes']['MainboardID']
cmd_data = {
"Data": {
"Cmd": cmdid,
"Data": data,
"From": 0,
"MainboardID": mainboard,
"MainboardID": self.id,
"RequestID": hexstr,
"TimeStamp": timestamp
},
"Id": self.desc['Id']
}
print("SENDING REQUEST: " + json.dumps(cmd_data))
self.mqtt.outgoing_messages.put_nowait({'topic': '/sdcp/request/' + mainboard, 'payload': json.dumps(cmd_data)})
self.mqtt.publish('/sdcp/request/' + self.id, json.dumps(cmd_data))
return hexstr

View File

@@ -5,12 +5,15 @@
# License: MIT
#
import logging
import asyncio
import os
import hashlib
class SimpleHTTPServer:
def __init__(self, host="127.0.0.1", port=0):
BufferSize = 1024768
def __init__(self, host="0.0.0.0", port=0):
self.host = host
self.port = port
self.server = None
@@ -29,46 +32,70 @@ class SimpleHTTPServer:
self.routes[path] = route
return route
def unregister_file_route(self, path):
del self.routes[path]
async def start(self):
self.server = await asyncio.start_server(self.handle_client, self.host, self.port)
self.port = self.server.sockets[0].getsockname()[1]
logging.debug(f'HTTP Listening on {self.server.sockets[0].getsockname()}')
async def serve_forever(self):
await self.server.serve_forever()
async def handle_client(self, reader, writer):
try:
await self.handle_client_inner(reader, writer)
except Exception as e:
logging.error(f"HTTP Exception handling client: {e}")
async def handle_client_inner(self, reader, writer):
logging.debug(f"HTTP connection from {writer.get_extra_info('peername')}")
data = b''
while True:
data += await reader.read(1024)
if b'\r\n\r\n' in data:
break
logging.debug(f"HTTP request: {data}")
request_line = data.decode().splitlines()[0]
method, path, _ = request_line.split()
if path not in self.routes:
logging.debug(f"HTTP path {path} not found in routes")
logging.debug(self.routes)
writer.write("HTTP/1.1 404 Not Found\r\n".encode())
writer.close()
return
route = self.routes[path]
header = f"HTTP/1.1 200 OK\r\n"
header += f"Content-Type: application/octet-stream\r\n"
header += f"Etag: {route['md5']}\r\n"
header += f"Content-Length: {path['size']}\r\n"
logging.debug(f"HTTP method {method} path {path} route: {route}")
header = f"HTTP/1.1 200 OK\r\n"
#header += f"Content-Type: application/octet-stream\r\n"
header += f"Content-Type: text/plain; charset=utf-8\r\n"
header += f"Etag: {route['md5']}\r\n"
header += f"Content-Length: {route['size']}\r\n"
header += "\r\n"
logging.debug(f"Writing header:\n{header}")
writer.write(header.encode())
if method == "GET":
writer.write(b'\r\n')
total = 0
with open(route['file'], 'rb') as f:
while True:
data = f.read(8192)
data = f.read(self.BufferSize)
if not data:
break
writer.write(data)
logging.debug(f"HTTP wrote {len(data)} bytes")
#await asyncio.sleep(1)
total += len(data)
logging.debug(f"HTTP wrote total {total} bytes")
await writer.drain()
writer.close()
await writer.wait_closed()
logging.debug(f"HTTP connection closed")

View File

@@ -24,41 +24,47 @@ class SimpleMQTTServer:
self.server = None
self.incoming_messages = asyncio.Queue()
self.outgoing_messages = asyncio.Queue()
self.connected_clients = {}
self.next_pack_id_value = 1
self.handlers = {}
def add_handler(self, topic, handler):
if topic not in self.handlers:
self.handlers[topic] = []
self.handlers[topic].append(handler)
async def start(self):
self.server = await asyncio.start_server(self.handle_client, self.host, self.port)
self.port = self.server.sockets[0].getsockname()[1]
logging.debug(f'Listening on {self.server.sockets[0].getsockname()}')
logging.debug(f'MQTT Listening on {self.server.sockets[0].getsockname()}')
async def serve_forever(self):
loop = asyncio.get_event_loop()
self.client_connection = loop.create_future()
self.client_subscribed = loop.create_future()
await self.server.serve_forever()
def publish(self, topic, payload):
self.outgoing_messages.put_nowait({'topic': topic, 'payload': payload})
async def next_published_message(self):
return await self.incoming_messages.get()
async def handle_client(self, reader, writer):
try:
await self.handle_client_inner(reader, writer)
except Exception as e:
logging.error(f"MQTT Exception handling client: {e}")
async def handle_client_inner(self, reader, writer):
addr = writer.get_extra_info('peername')
logging.debug(f'Socket connected from {addr}')
data = b''
subscribed_topics = dict()
client_id = None
read_future = asyncio.ensure_future(reader.read(1024))
outgoing_messages_future = asyncio.ensure_future(self.outgoing_messages.get())
subscribed_topics = dict()
while True:
# get Future representing a reader.read(1024)
# get Future representing a self.incoming_messages.get()
completed, pending = await asyncio.wait([read_future, outgoing_messages_future], return_when=asyncio.FIRST_COMPLETED)
#print(completed)
#print(pending)
if outgoing_messages_future in completed:
#print("Got outgoing message")
outmsg = outgoing_messages_future.result()
topic = outmsg['topic']
payload = outmsg['payload']
@@ -84,8 +90,6 @@ class SimpleMQTTServer:
# must have at least 2 bytes
if len(data) < 2:
break
#print(f"Remaining bytes: {len(data)}")
#print(f"Data: {data}")
msg_type = data[0] >> 4
msg_flags = data[0] & 0xf
@@ -93,7 +97,7 @@ class SimpleMQTTServer:
# TODO -- we could maybe not have enough bytes to decode the length, but assume
# that won't happen
msg_length, len_bytes_consumed = self.decode_length(data[1:])
logging.debug(f" in msg_type: {msg_type} flags: {msg_flags} msg_length {msg_length} bytes_consumed for msg_length {len_bytes_consumed}")
#logging.debug(f"mqtt in msg_type: {msg_type} flags: {msg_flags} msg_length {msg_length} bytes_consumed for msg_length {len_bytes_consumed}")
# is there enough to process the message?
head_len = len_bytes_consumed + 1
@@ -106,17 +110,27 @@ class SimpleMQTTServer:
data = data[head_len+msg_length:]
if msg_type == MQTT_CONNECT:
# ignore the contents of the message, should maybe check for 'MQTT' identifier at least
logging.info(f"Client {addr} connected")
if message[0:6] != b'\x00\x04MQTT':
logging.error(f"MQTT client {addr}: bad CONNECT")
writer.close()
return
client_id_len = struct.unpack("!H", message[10:12])[0]
client_id = message[12:12+client_id_len].decode("utf-8")
logging.debug(f"MQTT client {client_id} at {addr} connected")
self.connected_clients[client_id] = addr
await self.send_msg(writer, MQTT_CONNACK, payload=b'\x00\x00')
self.client_connection.set_result(client_id)
self.client_connection = asyncio.get_event_loop().create_future()
elif msg_type == MQTT_PUBLISH:
qos = (msg_flags >> 1) & 0x3
topic, packid, content = self.parse_publish(message)
logging.info(f"Got DATA on: {topic}")
if topic in self.handlers:
for handler in self.handlers[topic]:
handler(topic, content)
#logging.debug(f"Got DATA on: {topic}")
self.incoming_messages.put_nowait({ 'topic': topic, 'payload': content})
if qos > 0:
await self.send_msg(writer, MQTT_PUBACK, packet_ident=packid)
elif msg_type == MQTT_SUBSCRIBE:
@@ -124,13 +138,19 @@ class SimpleMQTTServer:
packid = message[0] << 8 | message[1]
message = message[2:]
topic = self.parse_subscribe(message)
logging.info(f"Client {addr} subscribed to topic '{topic}', QoS {qos}")
logging.debug(f"Client {addr} subscribed to topic '{topic}', QoS {qos}")
subscribed_topics[topic] = qos
await self.send_msg(writer, MQTT_SUBACK, packet_ident=packid, payload=bytes([qos]))
self.client_subscribed.set_result(topic)
self.client_subscribed = asyncio.get_event_loop().create_future()
elif msg_type == MQTT_DISCONNECT:
logging.info(f"Client {addr} disconnected")
writer.close()
await writer.wait_closed()
if client_id is not None:
del self.connected_clients[client_id]
return
async def send_msg(self, writer, msg_type, flags=0, packet_ident=0, payload=b''):
@@ -142,7 +162,7 @@ class SimpleMQTTServer:
if packet_ident > 0:
head += bytes([packet_ident >> 8, packet_ident & 0xff])
data = head + payload
logging.debug(f" writing {len(data)} bytes: {data}")
#logging.debug(f" writing {len(data)} bytes: {data}")
writer.write(data)
await writer.drain()