Add type annotations

This commit is contained in:
Olivier 'reivilibre' 2022-08-25 12:03:20 +01:00
parent 86d0bfa053
commit 288098ac0a
2 changed files with 47 additions and 22 deletions

3
.gitignore vendored
View File

@ -128,4 +128,5 @@ dmypy.json
# Pyre type checker
.pyre/
.idea/
.idea/
.envrc

View File

@ -1,21 +1,31 @@
#!/usr/bin/env python3
import socket
from typing import Any, AsyncGenerator, Generator, List, Optional
import paho.mqtt.client as mqtt
import trio
from paho.mqtt.properties import Properties
from paho.mqtt.reasoncodes import ReasonCodes
from trio import CancelScope, MemoryReceiveChannel, MemorySendChannel
# TODO
Message = Any
class AsyncClient:
def __init__(
self, sync_client: mqtt.Client, parent_nursery: trio.Nursery, max_buffer=100
):
self,
sync_client: mqtt.Client,
parent_nursery: trio.Nursery,
max_buffer: int = 100,
) -> None:
self._client = sync_client
self._nursery = parent_nursery
self.socket = self._client.socket()
self.socket = self._client.socket() # type: ignore
self._cancel_scopes = []
self._cancel_scopes: List[CancelScope] = []
self._event_connect = trio.Event()
self._event_large_write = trio.Event()
@ -30,6 +40,8 @@ class AsyncClient:
self._client.on_socket_register_write = self._on_socket_register_write
self._client.on_socket_unregister_write = self._on_socket_unregister_write
self._msg_send_channel: MemorySendChannel[Message]
self._msg_receive_channel: MemoryReceiveChannel[Message]
self._msg_send_channel, self._msg_receive_channel = trio.open_memory_channel(
max_buffer
)
@ -47,25 +59,25 @@ class AsyncClient:
self.username_pw_set = self._client.username_pw_set
self.ws_set_options = self._client.ws_set_options
def _start_all_loop(self):
def _start_all_loop(self) -> "AsyncClient":
self._nursery.start_soon(self._loop_read)
self._nursery.start_soon(self._loop_write)
self._nursery.start_soon(self._loop_misc)
return self
def _stop_all_loop(self):
def _stop_all_loop(self) -> None:
for cs in self._cancel_scopes:
cs.cancel()
async def _loop_misc(self):
async def _loop_misc(self) -> None:
cs = trio.CancelScope()
self._cancel_scopes.append(cs)
await self._event_connect.wait()
with cs:
while self._client.loop_misc() == mqtt.MQTT_ERR_SUCCESS:
while self._client.loop_misc() == mqtt.MQTT_ERR_SUCCESS: # type: ignore
await trio.sleep(1)
async def _loop_read(self):
async def _loop_read(self) -> None:
cs = trio.CancelScope()
self._cancel_scopes.append(cs)
with cs:
@ -74,7 +86,7 @@ class AsyncClient:
await trio.lowlevel.wait_readable(self.socket)
self._client.loop_read()
async def _loop_write(self):
async def _loop_write(self) -> None:
cs = trio.CancelScope()
self._cancel_scopes.append(cs)
with cs:
@ -84,15 +96,21 @@ class AsyncClient:
self._client.loop_write()
def connect(
self, host, port=1883, keepalive=60, bind_address="", bind_port=0, **kwargs
):
self,
host: str,
port: int = 1883,
keepalive: int = 60,
bind_address: str = "",
bind_port: int = 0,
**kwargs: Any
) -> None:
self._start_all_loop()
self._client.connect(host, port, keepalive, bind_address, bind_port, **kwargs)
def _on_connect(self, client, userdata, flags, rc):
def _on_connect(self, client: Any, userdata: Any, flags: Any, rc: Any) -> None:
self._event_connect.set()
async def messages(self):
async def messages(self) -> AsyncGenerator[Message, None]:
self._event_should_read.set()
cs = trio.CancelScope()
self._cancel_scopes.append(cs)
@ -102,7 +120,7 @@ class AsyncClient:
yield msg
self._event_should_read.set()
def _on_message(self, client, userdata, msg):
def _on_message(self, client: Any, userdata: Any, msg: Message) -> None:
try:
self._msg_send_channel.send_nowait(msg)
except trio.WouldBlock:
@ -114,25 +132,31 @@ class AsyncClient:
# Stop reading until the messages are read off the mem channel
self._event_should_read = trio.Event()
def disconnect(self, reasoncode=None, properties=None):
def disconnect(
self,
reasoncode: Optional[ReasonCodes] = None,
properties: Optional[Properties] = None,
) -> None:
self._client.disconnect(reasoncode, properties)
self._stop_all_loop()
def _on_disconnect(self, client, userdata, rc):
def _on_disconnect(self, client: Any, userdata: Any, rc: Any) -> None:
self._event_connect = trio.Event()
self._stop_all_loop()
def _on_socket_open(self, client, userdata, sock):
def _on_socket_open(self, client: Any, userdata: Any, sock: Any) -> None:
self.socket = sock
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048)
def _on_socket_close(self, client, userdata, sock):
def _on_socket_close(self, client: Any, userdata: Any, sock: Any) -> None:
pass
def _on_socket_register_write(self, client, userdata, sock):
def _on_socket_register_write(self, client: Any, userdata: Any, sock: Any) -> None:
# large write request - start write loop
self._event_large_write.set()
def _on_socket_unregister_write(self, client, userdata, sock):
def _on_socket_unregister_write(
self, client: Any, userdata: Any, sock: Any
) -> None:
# finished large write - stop write loop
self._event_large_write = trio.Event()