diff --git a/.gitignore b/.gitignore index 1d4e738..fa3593e 100644 --- a/.gitignore +++ b/.gitignore @@ -128,4 +128,5 @@ dmypy.json # Pyre type checker .pyre/ -.idea/ \ No newline at end of file +.idea/ +.envrc \ No newline at end of file diff --git a/trio_paho_mqtt/client.py b/trio_paho_mqtt/client.py index 3a00b51..209c2cd 100644 --- a/trio_paho_mqtt/client.py +++ b/trio_paho_mqtt/client.py @@ -1,21 +1,31 @@ #!/usr/bin/env python3 import socket +from typing import Any, AsyncGenerator, Generator, List, Optional import paho.mqtt.client as mqtt import trio +from paho.mqtt.properties import Properties +from paho.mqtt.reasoncodes import ReasonCodes +from trio import CancelScope, MemoryReceiveChannel, MemorySendChannel + +# TODO +Message = Any class AsyncClient: def __init__( - self, sync_client: mqtt.Client, parent_nursery: trio.Nursery, max_buffer=100 - ): + self, + sync_client: mqtt.Client, + parent_nursery: trio.Nursery, + max_buffer: int = 100, + ) -> None: self._client = sync_client self._nursery = parent_nursery - self.socket = self._client.socket() + self.socket = self._client.socket() # type: ignore - self._cancel_scopes = [] + self._cancel_scopes: List[CancelScope] = [] self._event_connect = trio.Event() self._event_large_write = trio.Event() @@ -30,6 +40,8 @@ class AsyncClient: self._client.on_socket_register_write = self._on_socket_register_write self._client.on_socket_unregister_write = self._on_socket_unregister_write + self._msg_send_channel: MemorySendChannel[Message] + self._msg_receive_channel: MemoryReceiveChannel[Message] self._msg_send_channel, self._msg_receive_channel = trio.open_memory_channel( max_buffer ) @@ -47,25 +59,25 @@ class AsyncClient: self.username_pw_set = self._client.username_pw_set self.ws_set_options = self._client.ws_set_options - def _start_all_loop(self): + def _start_all_loop(self) -> "AsyncClient": self._nursery.start_soon(self._loop_read) self._nursery.start_soon(self._loop_write) self._nursery.start_soon(self._loop_misc) return self - def _stop_all_loop(self): + def _stop_all_loop(self) -> None: for cs in self._cancel_scopes: cs.cancel() - async def _loop_misc(self): + async def _loop_misc(self) -> None: cs = trio.CancelScope() self._cancel_scopes.append(cs) await self._event_connect.wait() with cs: - while self._client.loop_misc() == mqtt.MQTT_ERR_SUCCESS: + while self._client.loop_misc() == mqtt.MQTT_ERR_SUCCESS: # type: ignore await trio.sleep(1) - async def _loop_read(self): + async def _loop_read(self) -> None: cs = trio.CancelScope() self._cancel_scopes.append(cs) with cs: @@ -74,7 +86,7 @@ class AsyncClient: await trio.lowlevel.wait_readable(self.socket) self._client.loop_read() - async def _loop_write(self): + async def _loop_write(self) -> None: cs = trio.CancelScope() self._cancel_scopes.append(cs) with cs: @@ -84,15 +96,21 @@ class AsyncClient: self._client.loop_write() def connect( - self, host, port=1883, keepalive=60, bind_address="", bind_port=0, **kwargs - ): + self, + host: str, + port: int = 1883, + keepalive: int = 60, + bind_address: str = "", + bind_port: int = 0, + **kwargs: Any + ) -> None: self._start_all_loop() self._client.connect(host, port, keepalive, bind_address, bind_port, **kwargs) - def _on_connect(self, client, userdata, flags, rc): + def _on_connect(self, client: Any, userdata: Any, flags: Any, rc: Any) -> None: self._event_connect.set() - async def messages(self): + async def messages(self) -> AsyncGenerator[Message, None]: self._event_should_read.set() cs = trio.CancelScope() self._cancel_scopes.append(cs) @@ -102,7 +120,7 @@ class AsyncClient: yield msg self._event_should_read.set() - def _on_message(self, client, userdata, msg): + def _on_message(self, client: Any, userdata: Any, msg: Message) -> None: try: self._msg_send_channel.send_nowait(msg) except trio.WouldBlock: @@ -114,25 +132,31 @@ class AsyncClient: # Stop reading until the messages are read off the mem channel self._event_should_read = trio.Event() - def disconnect(self, reasoncode=None, properties=None): + def disconnect( + self, + reasoncode: Optional[ReasonCodes] = None, + properties: Optional[Properties] = None, + ) -> None: self._client.disconnect(reasoncode, properties) self._stop_all_loop() - def _on_disconnect(self, client, userdata, rc): + def _on_disconnect(self, client: Any, userdata: Any, rc: Any) -> None: self._event_connect = trio.Event() self._stop_all_loop() - def _on_socket_open(self, client, userdata, sock): + def _on_socket_open(self, client: Any, userdata: Any, sock: Any) -> None: self.socket = sock self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048) - def _on_socket_close(self, client, userdata, sock): + def _on_socket_close(self, client: Any, userdata: Any, sock: Any) -> None: pass - def _on_socket_register_write(self, client, userdata, sock): + def _on_socket_register_write(self, client: Any, userdata: Any, sock: Any) -> None: # large write request - start write loop self._event_large_write.set() - def _on_socket_unregister_write(self, client, userdata, sock): + def _on_socket_unregister_write( + self, client: Any, userdata: Any, sock: Any + ) -> None: # finished large write - stop write loop self._event_large_write = trio.Event()