Type annotations in `synapse.databases.main.devices` (#13025)
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
parent
0d1d3e0708
commit
97e9fbe1b2
|
@ -0,0 +1 @@
|
|||
Add type annotations to `synapse.storage.databases.main.devices`.
|
1
mypy.ini
1
mypy.ini
|
@ -27,7 +27,6 @@ exclude = (?x)
|
|||
^(
|
||||
|synapse/storage/databases/__init__.py
|
||||
|synapse/storage/databases/main/cache.py
|
||||
|synapse/storage/databases/main/devices.py
|
||||
|synapse/storage/schema/
|
||||
|
||||
|tests/api/test_auth.py
|
||||
|
|
|
@ -19,13 +19,12 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
|||
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
|
||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||
from synapse.storage.databases.main.devices import DeviceWorkerStore
|
||||
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
|
||||
class SlavedDeviceStore(DeviceWorkerStore, BaseSlavedStore):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
|
|
|
@ -195,6 +195,7 @@ class DataStore(
|
|||
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
|
||||
|
||||
def get_device_stream_token(self) -> int:
|
||||
# TODO: shouldn't this be moved to `DeviceWorkerStore`?
|
||||
return self._device_list_id_gen.get_current_token()
|
||||
|
||||
async def get_users(self) -> List[JsonDict]:
|
||||
|
|
|
@ -28,6 +28,8 @@ from typing import (
|
|||
cast,
|
||||
)
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from synapse.api.constants import EduTypes
|
||||
from synapse.api.errors import Codes, StoreError
|
||||
from synapse.logging.opentracing import (
|
||||
|
@ -44,6 +46,8 @@ from synapse.storage.database import (
|
|||
LoggingTransaction,
|
||||
make_tuple_comparison_clause,
|
||||
)
|
||||
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
|
||||
from synapse.util import json_decoder, json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
@ -65,7 +69,7 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
|
|||
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
|
||||
|
||||
|
||||
class DeviceWorkerStore(SQLBaseStore):
|
||||
class DeviceWorkerStore(EndToEndKeyWorkerStore):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
|
@ -74,7 +78,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
device_list_max = self._device_list_id_gen.get_current_token()
|
||||
# Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a
|
||||
# StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker).
|
||||
device_list_max = self._device_list_id_gen.get_current_token() # type: ignore[attr-defined]
|
||||
device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
|
||||
db_conn,
|
||||
"device_lists_stream",
|
||||
|
@ -339,8 +345,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
# following this stream later.
|
||||
last_processed_stream_id = from_stream_id
|
||||
|
||||
query_map = {}
|
||||
cross_signing_keys_by_user = {}
|
||||
# A map of (user ID, device ID) to (stream ID, context).
|
||||
query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]] = {}
|
||||
cross_signing_keys_by_user: Dict[str, Dict[str, object]] = {}
|
||||
for user_id, device_id, update_stream_id, update_context in updates:
|
||||
# Calculate the remaining length budget.
|
||||
# Note that, for now, each entry in `cross_signing_keys_by_user`
|
||||
|
@ -596,7 +603,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
txn=txn,
|
||||
table="device_lists_outbound_last_success",
|
||||
key_names=("destination", "user_id"),
|
||||
key_values=((destination, user_id) for user_id, _ in rows),
|
||||
key_values=[(destination, user_id) for user_id, _ in rows],
|
||||
value_names=("stream_id",),
|
||||
value_values=((stream_id,) for _, stream_id in rows),
|
||||
)
|
||||
|
@ -621,7 +628,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
The new stream ID.
|
||||
"""
|
||||
|
||||
async with self._device_list_id_gen.get_next() as stream_id:
|
||||
# TODO: this looks like it's _writing_. Should this be on DeviceStore rather
|
||||
# than DeviceWorkerStore?
|
||||
async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
|
||||
await self.db_pool.runInteraction(
|
||||
"add_user_sig_change_to_streams",
|
||||
self._add_user_signature_change_txn,
|
||||
|
@ -686,7 +695,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
} - users_needing_resync
|
||||
user_ids_not_in_cache = user_ids - user_ids_in_cache
|
||||
|
||||
results = {}
|
||||
results: Dict[str, Dict[str, JsonDict]] = {}
|
||||
for user_id, device_id in query_list:
|
||||
if user_id not in user_ids_in_cache:
|
||||
continue
|
||||
|
@ -727,7 +736,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
def get_cached_device_list_changes(
|
||||
self,
|
||||
from_key: int,
|
||||
) -> Optional[Set[str]]:
|
||||
) -> Optional[List[str]]:
|
||||
"""Get set of users whose devices have changed since `from_key`, or None
|
||||
if that information is not in our cache.
|
||||
"""
|
||||
|
@ -737,7 +746,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
async def get_users_whose_devices_changed(
|
||||
self,
|
||||
from_key: int,
|
||||
user_ids: Optional[Iterable[str]] = None,
|
||||
user_ids: Optional[Collection[str]] = None,
|
||||
to_key: Optional[int] = None,
|
||||
) -> Set[str]:
|
||||
"""Get set of users whose devices have changed since `from_key` that
|
||||
|
@ -757,6 +766,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
"""
|
||||
# Get set of users who *may* have changed. Users not in the returned
|
||||
# list have definitely not changed.
|
||||
user_ids_to_check: Optional[Collection[str]]
|
||||
if user_ids is None:
|
||||
# Get set of all users that have had device list changes since 'from_key'
|
||||
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
|
||||
|
@ -772,7 +782,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
return set()
|
||||
|
||||
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
|
||||
changes = set()
|
||||
changes: Set[str] = set()
|
||||
|
||||
stream_id_where_clause = "stream_id > ?"
|
||||
sql_args = [from_key]
|
||||
|
@ -788,6 +798,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
"""
|
||||
|
||||
# Query device changes with a batch of users at a time
|
||||
# Assertion for mypy's benefit; see also
|
||||
# https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions
|
||||
assert user_ids_to_check is not None
|
||||
for chunk in batch_iter(user_ids_to_check, 100):
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "user_id", chunk
|
||||
|
@ -854,7 +867,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
if last_id == current_id:
|
||||
return [], current_id, False
|
||||
|
||||
def _get_all_device_list_changes_for_remotes(txn):
|
||||
def _get_all_device_list_changes_for_remotes(
|
||||
txn: Cursor,
|
||||
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||
# This query Does The Right Thing where it'll correctly apply the
|
||||
# bounds to the inner queries.
|
||||
sql = """
|
||||
|
@ -913,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
desc="get_device_list_last_stream_id_for_remotes",
|
||||
)
|
||||
|
||||
results = {user_id: None for user_id in user_ids}
|
||||
results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
|
||||
results.update({row["user_id"]: row["stream_id"] for row in rows})
|
||||
|
||||
return results
|
||||
|
@ -1337,9 +1352,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
|
||||
# Map of (user_id, device_id) -> bool. If there is an entry that implies
|
||||
# the device exists.
|
||||
self.device_id_exists_cache = LruCache(
|
||||
cache_name="device_id_exists", max_size=10000
|
||||
)
|
||||
self.device_id_exists_cache: LruCache[
|
||||
Tuple[str, str], Literal[True]
|
||||
] = LruCache(cache_name="device_id_exists", max_size=10000)
|
||||
|
||||
async def store_device(
|
||||
self,
|
||||
|
@ -1651,7 +1666,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
context,
|
||||
)
|
||||
|
||||
async with self._device_list_id_gen.get_next_mult(
|
||||
async with self._device_list_id_gen.get_next_mult( # type: ignore[attr-defined]
|
||||
len(device_ids)
|
||||
) as stream_ids:
|
||||
await self.db_pool.runInteraction(
|
||||
|
@ -1704,7 +1719,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
device_ids: Iterable[str],
|
||||
hosts: Collection[str],
|
||||
stream_ids: List[int],
|
||||
context: Dict[str, str],
|
||||
context: Optional[Dict[str, str]],
|
||||
) -> None:
|
||||
for host in hosts:
|
||||
txn.call_after(
|
||||
|
@ -1875,7 +1890,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
[],
|
||||
)
|
||||
|
||||
async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
|
||||
async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: # type: ignore[attr-defined]
|
||||
return await self.db_pool.runInteraction(
|
||||
"add_device_list_outbound_pokes",
|
||||
add_device_list_outbound_pokes_txn,
|
||||
|
|
Loading…
Reference in New Issue