Add type annotations
This commit is contained in:
parent
86d0bfa053
commit
288098ac0a
3
.gitignore
vendored
3
.gitignore
vendored
@ -128,4 +128,5 @@ dmypy.json
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
.idea/
|
||||
.idea/
|
||||
.envrc
|
@ -1,21 +1,31 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import socket
|
||||
from typing import Any, AsyncGenerator, Generator, List, Optional
|
||||
|
||||
import paho.mqtt.client as mqtt
|
||||
import trio
|
||||
from paho.mqtt.properties import Properties
|
||||
from paho.mqtt.reasoncodes import ReasonCodes
|
||||
from trio import CancelScope, MemoryReceiveChannel, MemorySendChannel
|
||||
|
||||
# TODO
|
||||
Message = Any
|
||||
|
||||
|
||||
class AsyncClient:
|
||||
def __init__(
|
||||
self, sync_client: mqtt.Client, parent_nursery: trio.Nursery, max_buffer=100
|
||||
):
|
||||
self,
|
||||
sync_client: mqtt.Client,
|
||||
parent_nursery: trio.Nursery,
|
||||
max_buffer: int = 100,
|
||||
) -> None:
|
||||
self._client = sync_client
|
||||
self._nursery = parent_nursery
|
||||
|
||||
self.socket = self._client.socket()
|
||||
self.socket = self._client.socket() # type: ignore
|
||||
|
||||
self._cancel_scopes = []
|
||||
self._cancel_scopes: List[CancelScope] = []
|
||||
|
||||
self._event_connect = trio.Event()
|
||||
self._event_large_write = trio.Event()
|
||||
@ -30,6 +40,8 @@ class AsyncClient:
|
||||
self._client.on_socket_register_write = self._on_socket_register_write
|
||||
self._client.on_socket_unregister_write = self._on_socket_unregister_write
|
||||
|
||||
self._msg_send_channel: MemorySendChannel[Message]
|
||||
self._msg_receive_channel: MemoryReceiveChannel[Message]
|
||||
self._msg_send_channel, self._msg_receive_channel = trio.open_memory_channel(
|
||||
max_buffer
|
||||
)
|
||||
@ -47,25 +59,25 @@ class AsyncClient:
|
||||
self.username_pw_set = self._client.username_pw_set
|
||||
self.ws_set_options = self._client.ws_set_options
|
||||
|
||||
def _start_all_loop(self):
|
||||
def _start_all_loop(self) -> "AsyncClient":
|
||||
self._nursery.start_soon(self._loop_read)
|
||||
self._nursery.start_soon(self._loop_write)
|
||||
self._nursery.start_soon(self._loop_misc)
|
||||
return self
|
||||
|
||||
def _stop_all_loop(self):
|
||||
def _stop_all_loop(self) -> None:
|
||||
for cs in self._cancel_scopes:
|
||||
cs.cancel()
|
||||
|
||||
async def _loop_misc(self):
|
||||
async def _loop_misc(self) -> None:
|
||||
cs = trio.CancelScope()
|
||||
self._cancel_scopes.append(cs)
|
||||
await self._event_connect.wait()
|
||||
with cs:
|
||||
while self._client.loop_misc() == mqtt.MQTT_ERR_SUCCESS:
|
||||
while self._client.loop_misc() == mqtt.MQTT_ERR_SUCCESS: # type: ignore
|
||||
await trio.sleep(1)
|
||||
|
||||
async def _loop_read(self):
|
||||
async def _loop_read(self) -> None:
|
||||
cs = trio.CancelScope()
|
||||
self._cancel_scopes.append(cs)
|
||||
with cs:
|
||||
@ -74,7 +86,7 @@ class AsyncClient:
|
||||
await trio.lowlevel.wait_readable(self.socket)
|
||||
self._client.loop_read()
|
||||
|
||||
async def _loop_write(self):
|
||||
async def _loop_write(self) -> None:
|
||||
cs = trio.CancelScope()
|
||||
self._cancel_scopes.append(cs)
|
||||
with cs:
|
||||
@ -84,15 +96,21 @@ class AsyncClient:
|
||||
self._client.loop_write()
|
||||
|
||||
def connect(
|
||||
self, host, port=1883, keepalive=60, bind_address="", bind_port=0, **kwargs
|
||||
):
|
||||
self,
|
||||
host: str,
|
||||
port: int = 1883,
|
||||
keepalive: int = 60,
|
||||
bind_address: str = "",
|
||||
bind_port: int = 0,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
self._start_all_loop()
|
||||
self._client.connect(host, port, keepalive, bind_address, bind_port, **kwargs)
|
||||
|
||||
def _on_connect(self, client, userdata, flags, rc):
|
||||
def _on_connect(self, client: Any, userdata: Any, flags: Any, rc: Any) -> None:
|
||||
self._event_connect.set()
|
||||
|
||||
async def messages(self):
|
||||
async def messages(self) -> AsyncGenerator[Message, None]:
|
||||
self._event_should_read.set()
|
||||
cs = trio.CancelScope()
|
||||
self._cancel_scopes.append(cs)
|
||||
@ -102,7 +120,7 @@ class AsyncClient:
|
||||
yield msg
|
||||
self._event_should_read.set()
|
||||
|
||||
def _on_message(self, client, userdata, msg):
|
||||
def _on_message(self, client: Any, userdata: Any, msg: Message) -> None:
|
||||
try:
|
||||
self._msg_send_channel.send_nowait(msg)
|
||||
except trio.WouldBlock:
|
||||
@ -114,25 +132,31 @@ class AsyncClient:
|
||||
# Stop reading until the messages are read off the mem channel
|
||||
self._event_should_read = trio.Event()
|
||||
|
||||
def disconnect(self, reasoncode=None, properties=None):
|
||||
def disconnect(
|
||||
self,
|
||||
reasoncode: Optional[ReasonCodes] = None,
|
||||
properties: Optional[Properties] = None,
|
||||
) -> None:
|
||||
self._client.disconnect(reasoncode, properties)
|
||||
self._stop_all_loop()
|
||||
|
||||
def _on_disconnect(self, client, userdata, rc):
|
||||
def _on_disconnect(self, client: Any, userdata: Any, rc: Any) -> None:
|
||||
self._event_connect = trio.Event()
|
||||
self._stop_all_loop()
|
||||
|
||||
def _on_socket_open(self, client, userdata, sock):
|
||||
def _on_socket_open(self, client: Any, userdata: Any, sock: Any) -> None:
|
||||
self.socket = sock
|
||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048)
|
||||
|
||||
def _on_socket_close(self, client, userdata, sock):
|
||||
def _on_socket_close(self, client: Any, userdata: Any, sock: Any) -> None:
|
||||
pass
|
||||
|
||||
def _on_socket_register_write(self, client, userdata, sock):
|
||||
def _on_socket_register_write(self, client: Any, userdata: Any, sock: Any) -> None:
|
||||
# large write request - start write loop
|
||||
self._event_large_write.set()
|
||||
|
||||
def _on_socket_unregister_write(self, client, userdata, sock):
|
||||
def _on_socket_unregister_write(
|
||||
self, client: Any, userdata: Any, sock: Any
|
||||
) -> None:
|
||||
# finished large write - stop write loop
|
||||
self._event_large_write = trio.Event()
|
||||
|
Loading…
Reference in New Issue
Block a user