Add type annotations

This commit is contained in:
Olivier 'reivilibre' 2022-08-25 12:03:20 +01:00
parent 86d0bfa053
commit 288098ac0a
2 changed files with 47 additions and 22 deletions

3
.gitignore vendored
View File

@ -128,4 +128,5 @@ dmypy.json
# Pyre type checker # Pyre type checker
.pyre/ .pyre/
.idea/ .idea/
.envrc

View File

@ -1,21 +1,31 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import socket import socket
from typing import Any, AsyncGenerator, Generator, List, Optional
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
import trio import trio
from paho.mqtt.properties import Properties
from paho.mqtt.reasoncodes import ReasonCodes
from trio import CancelScope, MemoryReceiveChannel, MemorySendChannel
# TODO
Message = Any
class AsyncClient: class AsyncClient:
def __init__( def __init__(
self, sync_client: mqtt.Client, parent_nursery: trio.Nursery, max_buffer=100 self,
): sync_client: mqtt.Client,
parent_nursery: trio.Nursery,
max_buffer: int = 100,
) -> None:
self._client = sync_client self._client = sync_client
self._nursery = parent_nursery self._nursery = parent_nursery
self.socket = self._client.socket() self.socket = self._client.socket() # type: ignore
self._cancel_scopes = [] self._cancel_scopes: List[CancelScope] = []
self._event_connect = trio.Event() self._event_connect = trio.Event()
self._event_large_write = trio.Event() self._event_large_write = trio.Event()
@ -30,6 +40,8 @@ class AsyncClient:
self._client.on_socket_register_write = self._on_socket_register_write self._client.on_socket_register_write = self._on_socket_register_write
self._client.on_socket_unregister_write = self._on_socket_unregister_write self._client.on_socket_unregister_write = self._on_socket_unregister_write
self._msg_send_channel: MemorySendChannel[Message]
self._msg_receive_channel: MemoryReceiveChannel[Message]
self._msg_send_channel, self._msg_receive_channel = trio.open_memory_channel( self._msg_send_channel, self._msg_receive_channel = trio.open_memory_channel(
max_buffer max_buffer
) )
@ -47,25 +59,25 @@ class AsyncClient:
self.username_pw_set = self._client.username_pw_set self.username_pw_set = self._client.username_pw_set
self.ws_set_options = self._client.ws_set_options self.ws_set_options = self._client.ws_set_options
def _start_all_loop(self): def _start_all_loop(self) -> "AsyncClient":
self._nursery.start_soon(self._loop_read) self._nursery.start_soon(self._loop_read)
self._nursery.start_soon(self._loop_write) self._nursery.start_soon(self._loop_write)
self._nursery.start_soon(self._loop_misc) self._nursery.start_soon(self._loop_misc)
return self return self
def _stop_all_loop(self): def _stop_all_loop(self) -> None:
for cs in self._cancel_scopes: for cs in self._cancel_scopes:
cs.cancel() cs.cancel()
async def _loop_misc(self): async def _loop_misc(self) -> None:
cs = trio.CancelScope() cs = trio.CancelScope()
self._cancel_scopes.append(cs) self._cancel_scopes.append(cs)
await self._event_connect.wait() await self._event_connect.wait()
with cs: with cs:
while self._client.loop_misc() == mqtt.MQTT_ERR_SUCCESS: while self._client.loop_misc() == mqtt.MQTT_ERR_SUCCESS: # type: ignore
await trio.sleep(1) await trio.sleep(1)
async def _loop_read(self): async def _loop_read(self) -> None:
cs = trio.CancelScope() cs = trio.CancelScope()
self._cancel_scopes.append(cs) self._cancel_scopes.append(cs)
with cs: with cs:
@ -74,7 +86,7 @@ class AsyncClient:
await trio.lowlevel.wait_readable(self.socket) await trio.lowlevel.wait_readable(self.socket)
self._client.loop_read() self._client.loop_read()
async def _loop_write(self): async def _loop_write(self) -> None:
cs = trio.CancelScope() cs = trio.CancelScope()
self._cancel_scopes.append(cs) self._cancel_scopes.append(cs)
with cs: with cs:
@ -84,15 +96,21 @@ class AsyncClient:
self._client.loop_write() self._client.loop_write()
def connect( def connect(
self, host, port=1883, keepalive=60, bind_address="", bind_port=0, **kwargs self,
): host: str,
port: int = 1883,
keepalive: int = 60,
bind_address: str = "",
bind_port: int = 0,
**kwargs: Any
) -> None:
self._start_all_loop() self._start_all_loop()
self._client.connect(host, port, keepalive, bind_address, bind_port, **kwargs) self._client.connect(host, port, keepalive, bind_address, bind_port, **kwargs)
def _on_connect(self, client, userdata, flags, rc): def _on_connect(self, client: Any, userdata: Any, flags: Any, rc: Any) -> None:
self._event_connect.set() self._event_connect.set()
async def messages(self): async def messages(self) -> AsyncGenerator[Message, None]:
self._event_should_read.set() self._event_should_read.set()
cs = trio.CancelScope() cs = trio.CancelScope()
self._cancel_scopes.append(cs) self._cancel_scopes.append(cs)
@ -102,7 +120,7 @@ class AsyncClient:
yield msg yield msg
self._event_should_read.set() self._event_should_read.set()
def _on_message(self, client, userdata, msg): def _on_message(self, client: Any, userdata: Any, msg: Message) -> None:
try: try:
self._msg_send_channel.send_nowait(msg) self._msg_send_channel.send_nowait(msg)
except trio.WouldBlock: except trio.WouldBlock:
@ -114,25 +132,31 @@ class AsyncClient:
# Stop reading until the messages are read off the mem channel # Stop reading until the messages are read off the mem channel
self._event_should_read = trio.Event() self._event_should_read = trio.Event()
def disconnect(self, reasoncode=None, properties=None): def disconnect(
self,
reasoncode: Optional[ReasonCodes] = None,
properties: Optional[Properties] = None,
) -> None:
self._client.disconnect(reasoncode, properties) self._client.disconnect(reasoncode, properties)
self._stop_all_loop() self._stop_all_loop()
def _on_disconnect(self, client, userdata, rc): def _on_disconnect(self, client: Any, userdata: Any, rc: Any) -> None:
self._event_connect = trio.Event() self._event_connect = trio.Event()
self._stop_all_loop() self._stop_all_loop()
def _on_socket_open(self, client, userdata, sock): def _on_socket_open(self, client: Any, userdata: Any, sock: Any) -> None:
self.socket = sock self.socket = sock
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048)
def _on_socket_close(self, client, userdata, sock): def _on_socket_close(self, client: Any, userdata: Any, sock: Any) -> None:
pass pass
def _on_socket_register_write(self, client, userdata, sock): def _on_socket_register_write(self, client: Any, userdata: Any, sock: Any) -> None:
# large write request - start write loop # large write request - start write loop
self._event_large_write.set() self._event_large_write.set()
def _on_socket_unregister_write(self, client, userdata, sock): def _on_socket_unregister_write(
self, client: Any, userdata: Any, sock: Any
) -> None:
# finished large write - stop write loop # finished large write - stop write loop
self._event_large_write = trio.Event() self._event_large_write = trio.Event()