163 lines
5.7 KiB
Python
163 lines
5.7 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import socket
|
|
from typing import Any, AsyncGenerator, Generator, List, Optional
|
|
|
|
import paho.mqtt.client as mqtt
|
|
import trio
|
|
from paho.mqtt.client import MQTTMessage
|
|
from paho.mqtt.properties import Properties
|
|
from paho.mqtt.reasoncodes import ReasonCodes
|
|
from trio import CancelScope, MemoryReceiveChannel, MemorySendChannel
|
|
|
|
Message = MQTTMessage
|
|
|
|
|
|
class AsyncClient:
|
|
def __init__(
|
|
self,
|
|
sync_client: mqtt.Client,
|
|
parent_nursery: trio.Nursery,
|
|
max_buffer: int = 100,
|
|
) -> None:
|
|
self._client = sync_client
|
|
self._nursery = parent_nursery
|
|
|
|
self.socket = self._client.socket() # type: ignore
|
|
|
|
self._cancel_scopes: List[CancelScope] = []
|
|
|
|
self._event_connect = trio.Event()
|
|
self._event_large_write = trio.Event()
|
|
self._event_should_read = trio.Event()
|
|
self._event_should_read.set()
|
|
|
|
self._client.on_connect = self._on_connect
|
|
self._client.on_disconnect = self._on_disconnect
|
|
self._client.on_socket_open = self._on_socket_open
|
|
self._client.on_socket_close = self._on_socket_close
|
|
self._client.on_message = self._on_message
|
|
self._client.on_socket_register_write = self._on_socket_register_write
|
|
self._client.on_socket_unregister_write = self._on_socket_unregister_write
|
|
|
|
self._msg_send_channel: MemorySendChannel[Message]
|
|
self._msg_receive_channel: MemoryReceiveChannel[Message]
|
|
self._msg_send_channel, self._msg_receive_channel = trio.open_memory_channel(
|
|
max_buffer
|
|
)
|
|
|
|
self.subscribe = self._client.subscribe
|
|
self.publish = self._client.publish
|
|
self.unsubscribe = self._client.unsubscribe
|
|
self.will_set = self._client.will_set
|
|
self.will_clear = self._client.will_clear
|
|
self.proxy_set = self._client.proxy_set
|
|
self.tls_set = self._client.tls_set
|
|
self.tls_insecure_set = self._client.tls_insecure_set
|
|
self.tls_set_context = self._client.tls_set_context
|
|
self.user_data_set = self._client.user_data_set
|
|
self.username_pw_set = self._client.username_pw_set
|
|
self.ws_set_options = self._client.ws_set_options
|
|
|
|
def _start_all_loop(self) -> "AsyncClient":
|
|
self._nursery.start_soon(self._loop_read)
|
|
self._nursery.start_soon(self._loop_write)
|
|
self._nursery.start_soon(self._loop_misc)
|
|
return self
|
|
|
|
def _stop_all_loop(self) -> None:
|
|
for cs in self._cancel_scopes:
|
|
cs.cancel()
|
|
|
|
async def _loop_misc(self) -> None:
|
|
cs = trio.CancelScope()
|
|
self._cancel_scopes.append(cs)
|
|
await self._event_connect.wait()
|
|
with cs:
|
|
while self._client.loop_misc() == mqtt.MQTT_ERR_SUCCESS: # type: ignore
|
|
await trio.sleep(1)
|
|
|
|
async def _loop_read(self) -> None:
|
|
cs = trio.CancelScope()
|
|
self._cancel_scopes.append(cs)
|
|
with cs:
|
|
while True:
|
|
await self._event_should_read.wait()
|
|
await trio.lowlevel.wait_readable(self.socket)
|
|
self._client.loop_read()
|
|
|
|
async def _loop_write(self) -> None:
|
|
cs = trio.CancelScope()
|
|
self._cancel_scopes.append(cs)
|
|
with cs:
|
|
while True:
|
|
await self._event_large_write.wait()
|
|
await trio.lowlevel.wait_writable(self.socket)
|
|
self._client.loop_write()
|
|
|
|
def connect(
|
|
self,
|
|
host: str,
|
|
port: int = 1883,
|
|
keepalive: int = 60,
|
|
bind_address: str = "",
|
|
bind_port: int = 0,
|
|
**kwargs: Any
|
|
) -> None:
|
|
self._start_all_loop()
|
|
self._client.connect(host, port, keepalive, bind_address, bind_port, **kwargs)
|
|
|
|
def _on_connect(self, client: Any, userdata: Any, flags: Any, rc: Any) -> None:
|
|
self._event_connect.set()
|
|
|
|
async def messages(self) -> AsyncGenerator[Message, None]:
|
|
self._event_should_read.set()
|
|
cs = trio.CancelScope()
|
|
self._cancel_scopes.append(cs)
|
|
with cs:
|
|
while True:
|
|
msg = await self._msg_receive_channel.receive()
|
|
yield msg
|
|
self._event_should_read.set()
|
|
|
|
def _on_message(self, client: Any, userdata: Any, msg: Message) -> None:
|
|
try:
|
|
self._msg_send_channel.send_nowait(msg)
|
|
except trio.WouldBlock:
|
|
print("Buffer full. Discarding an old msg!")
|
|
# Take the old msg off the channel, discard it, and put the new msg on
|
|
old_msg = self._msg_receive_channel.receive_nowait()
|
|
# TODO: Store this old msg?
|
|
self._msg_send_channel.send_nowait(msg)
|
|
# Stop reading until the messages are read off the mem channel
|
|
self._event_should_read = trio.Event()
|
|
|
|
def disconnect(
|
|
self,
|
|
reasoncode: Optional[ReasonCodes] = None,
|
|
properties: Optional[Properties] = None,
|
|
) -> None:
|
|
self._client.disconnect(reasoncode, properties)
|
|
self._stop_all_loop()
|
|
|
|
def _on_disconnect(self, client: Any, userdata: Any, rc: Any) -> None:
|
|
self._event_connect = trio.Event()
|
|
self._stop_all_loop()
|
|
|
|
def _on_socket_open(self, client: Any, userdata: Any, sock: Any) -> None:
|
|
self.socket = sock
|
|
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048)
|
|
|
|
def _on_socket_close(self, client: Any, userdata: Any, sock: Any) -> None:
|
|
pass
|
|
|
|
def _on_socket_register_write(self, client: Any, userdata: Any, sock: Any) -> None:
|
|
# large write request - start write loop
|
|
self._event_large_write.set()
|
|
|
|
def _on_socket_unregister_write(
|
|
self, client: Any, userdata: Any, sock: Any
|
|
) -> None:
|
|
# finished large write - stop write loop
|
|
self._event_large_write = trio.Event()
|