Type annotations in `synapse.databases.main.devices` (#13025)
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
parent
0d1d3e0708
commit
97e9fbe1b2
|
@ -0,0 +1 @@
|
||||||
|
Add type annotations to `synapse.storage.databases.main.devices`.
|
1
mypy.ini
1
mypy.ini
|
@ -27,7 +27,6 @@ exclude = (?x)
|
||||||
^(
|
^(
|
||||||
|synapse/storage/databases/__init__.py
|
|synapse/storage/databases/__init__.py
|
||||||
|synapse/storage/databases/main/cache.py
|
|synapse/storage/databases/main/cache.py
|
||||||
|synapse/storage/databases/main/devices.py
|
|
||||||
|synapse/storage/schema/
|
|synapse/storage/schema/
|
||||||
|
|
||||||
|tests/api/test_auth.py
|
|tests/api/test_auth.py
|
||||||
|
|
|
@ -19,13 +19,12 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||||
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
|
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
|
||||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||||
from synapse.storage.databases.main.devices import DeviceWorkerStore
|
from synapse.storage.databases.main.devices import DeviceWorkerStore
|
||||||
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
|
||||||
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
|
class SlavedDeviceStore(DeviceWorkerStore, BaseSlavedStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
database: DatabasePool,
|
database: DatabasePool,
|
||||||
|
|
|
@ -195,6 +195,7 @@ class DataStore(
|
||||||
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
|
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
|
||||||
|
|
||||||
def get_device_stream_token(self) -> int:
|
def get_device_stream_token(self) -> int:
|
||||||
|
# TODO: shouldn't this be moved to `DeviceWorkerStore`?
|
||||||
return self._device_list_id_gen.get_current_token()
|
return self._device_list_id_gen.get_current_token()
|
||||||
|
|
||||||
async def get_users(self) -> List[JsonDict]:
|
async def get_users(self) -> List[JsonDict]:
|
||||||
|
|
|
@ -28,6 +28,8 @@ from typing import (
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from synapse.api.constants import EduTypes
|
from synapse.api.constants import EduTypes
|
||||||
from synapse.api.errors import Codes, StoreError
|
from synapse.api.errors import Codes, StoreError
|
||||||
from synapse.logging.opentracing import (
|
from synapse.logging.opentracing import (
|
||||||
|
@ -44,6 +46,8 @@ from synapse.storage.database import (
|
||||||
LoggingTransaction,
|
LoggingTransaction,
|
||||||
make_tuple_comparison_clause,
|
make_tuple_comparison_clause,
|
||||||
)
|
)
|
||||||
|
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
|
||||||
|
from synapse.storage.types import Cursor
|
||||||
from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
|
from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
|
||||||
from synapse.util import json_decoder, json_encoder
|
from synapse.util import json_decoder, json_encoder
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
@ -65,7 +69,7 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
|
||||||
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
|
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
|
||||||
|
|
||||||
|
|
||||||
class DeviceWorkerStore(SQLBaseStore):
|
class DeviceWorkerStore(EndToEndKeyWorkerStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
database: DatabasePool,
|
database: DatabasePool,
|
||||||
|
@ -74,7 +78,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
):
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
device_list_max = self._device_list_id_gen.get_current_token()
|
# Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a
|
||||||
|
# StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker).
|
||||||
|
device_list_max = self._device_list_id_gen.get_current_token() # type: ignore[attr-defined]
|
||||||
device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
|
device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
|
||||||
db_conn,
|
db_conn,
|
||||||
"device_lists_stream",
|
"device_lists_stream",
|
||||||
|
@ -339,8 +345,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
# following this stream later.
|
# following this stream later.
|
||||||
last_processed_stream_id = from_stream_id
|
last_processed_stream_id = from_stream_id
|
||||||
|
|
||||||
query_map = {}
|
# A map of (user ID, device ID) to (stream ID, context).
|
||||||
cross_signing_keys_by_user = {}
|
query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]] = {}
|
||||||
|
cross_signing_keys_by_user: Dict[str, Dict[str, object]] = {}
|
||||||
for user_id, device_id, update_stream_id, update_context in updates:
|
for user_id, device_id, update_stream_id, update_context in updates:
|
||||||
# Calculate the remaining length budget.
|
# Calculate the remaining length budget.
|
||||||
# Note that, for now, each entry in `cross_signing_keys_by_user`
|
# Note that, for now, each entry in `cross_signing_keys_by_user`
|
||||||
|
@ -596,7 +603,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
txn=txn,
|
txn=txn,
|
||||||
table="device_lists_outbound_last_success",
|
table="device_lists_outbound_last_success",
|
||||||
key_names=("destination", "user_id"),
|
key_names=("destination", "user_id"),
|
||||||
key_values=((destination, user_id) for user_id, _ in rows),
|
key_values=[(destination, user_id) for user_id, _ in rows],
|
||||||
value_names=("stream_id",),
|
value_names=("stream_id",),
|
||||||
value_values=((stream_id,) for _, stream_id in rows),
|
value_values=((stream_id,) for _, stream_id in rows),
|
||||||
)
|
)
|
||||||
|
@ -621,7 +628,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
The new stream ID.
|
The new stream ID.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async with self._device_list_id_gen.get_next() as stream_id:
|
# TODO: this looks like it's _writing_. Should this be on DeviceStore rather
|
||||||
|
# than DeviceWorkerStore?
|
||||||
|
async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"add_user_sig_change_to_streams",
|
"add_user_sig_change_to_streams",
|
||||||
self._add_user_signature_change_txn,
|
self._add_user_signature_change_txn,
|
||||||
|
@ -686,7 +695,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
} - users_needing_resync
|
} - users_needing_resync
|
||||||
user_ids_not_in_cache = user_ids - user_ids_in_cache
|
user_ids_not_in_cache = user_ids - user_ids_in_cache
|
||||||
|
|
||||||
results = {}
|
results: Dict[str, Dict[str, JsonDict]] = {}
|
||||||
for user_id, device_id in query_list:
|
for user_id, device_id in query_list:
|
||||||
if user_id not in user_ids_in_cache:
|
if user_id not in user_ids_in_cache:
|
||||||
continue
|
continue
|
||||||
|
@ -727,7 +736,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
def get_cached_device_list_changes(
|
def get_cached_device_list_changes(
|
||||||
self,
|
self,
|
||||||
from_key: int,
|
from_key: int,
|
||||||
) -> Optional[Set[str]]:
|
) -> Optional[List[str]]:
|
||||||
"""Get set of users whose devices have changed since `from_key`, or None
|
"""Get set of users whose devices have changed since `from_key`, or None
|
||||||
if that information is not in our cache.
|
if that information is not in our cache.
|
||||||
"""
|
"""
|
||||||
|
@ -737,7 +746,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
async def get_users_whose_devices_changed(
|
async def get_users_whose_devices_changed(
|
||||||
self,
|
self,
|
||||||
from_key: int,
|
from_key: int,
|
||||||
user_ids: Optional[Iterable[str]] = None,
|
user_ids: Optional[Collection[str]] = None,
|
||||||
to_key: Optional[int] = None,
|
to_key: Optional[int] = None,
|
||||||
) -> Set[str]:
|
) -> Set[str]:
|
||||||
"""Get set of users whose devices have changed since `from_key` that
|
"""Get set of users whose devices have changed since `from_key` that
|
||||||
|
@ -757,6 +766,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
# Get set of users who *may* have changed. Users not in the returned
|
# Get set of users who *may* have changed. Users not in the returned
|
||||||
# list have definitely not changed.
|
# list have definitely not changed.
|
||||||
|
user_ids_to_check: Optional[Collection[str]]
|
||||||
if user_ids is None:
|
if user_ids is None:
|
||||||
# Get set of all users that have had device list changes since 'from_key'
|
# Get set of all users that have had device list changes since 'from_key'
|
||||||
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
|
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
|
||||||
|
@ -772,7 +782,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
|
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
|
||||||
changes = set()
|
changes: Set[str] = set()
|
||||||
|
|
||||||
stream_id_where_clause = "stream_id > ?"
|
stream_id_where_clause = "stream_id > ?"
|
||||||
sql_args = [from_key]
|
sql_args = [from_key]
|
||||||
|
@ -788,6 +798,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Query device changes with a batch of users at a time
|
# Query device changes with a batch of users at a time
|
||||||
|
# Assertion for mypy's benefit; see also
|
||||||
|
# https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions
|
||||||
|
assert user_ids_to_check is not None
|
||||||
for chunk in batch_iter(user_ids_to_check, 100):
|
for chunk in batch_iter(user_ids_to_check, 100):
|
||||||
clause, args = make_in_list_sql_clause(
|
clause, args = make_in_list_sql_clause(
|
||||||
txn.database_engine, "user_id", chunk
|
txn.database_engine, "user_id", chunk
|
||||||
|
@ -854,7 +867,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
if last_id == current_id:
|
if last_id == current_id:
|
||||||
return [], current_id, False
|
return [], current_id, False
|
||||||
|
|
||||||
def _get_all_device_list_changes_for_remotes(txn):
|
def _get_all_device_list_changes_for_remotes(
|
||||||
|
txn: Cursor,
|
||||||
|
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||||
# This query Does The Right Thing where it'll correctly apply the
|
# This query Does The Right Thing where it'll correctly apply the
|
||||||
# bounds to the inner queries.
|
# bounds to the inner queries.
|
||||||
sql = """
|
sql = """
|
||||||
|
@ -913,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
desc="get_device_list_last_stream_id_for_remotes",
|
desc="get_device_list_last_stream_id_for_remotes",
|
||||||
)
|
)
|
||||||
|
|
||||||
results = {user_id: None for user_id in user_ids}
|
results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
|
||||||
results.update({row["user_id"]: row["stream_id"] for row in rows})
|
results.update({row["user_id"]: row["stream_id"] for row in rows})
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
@ -1337,9 +1352,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
|
|
||||||
# Map of (user_id, device_id) -> bool. If there is an entry that implies
|
# Map of (user_id, device_id) -> bool. If there is an entry that implies
|
||||||
# the device exists.
|
# the device exists.
|
||||||
self.device_id_exists_cache = LruCache(
|
self.device_id_exists_cache: LruCache[
|
||||||
cache_name="device_id_exists", max_size=10000
|
Tuple[str, str], Literal[True]
|
||||||
)
|
] = LruCache(cache_name="device_id_exists", max_size=10000)
|
||||||
|
|
||||||
async def store_device(
|
async def store_device(
|
||||||
self,
|
self,
|
||||||
|
@ -1651,7 +1666,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
context,
|
context,
|
||||||
)
|
)
|
||||||
|
|
||||||
async with self._device_list_id_gen.get_next_mult(
|
async with self._device_list_id_gen.get_next_mult( # type: ignore[attr-defined]
|
||||||
len(device_ids)
|
len(device_ids)
|
||||||
) as stream_ids:
|
) as stream_ids:
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
|
@ -1704,7 +1719,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
device_ids: Iterable[str],
|
device_ids: Iterable[str],
|
||||||
hosts: Collection[str],
|
hosts: Collection[str],
|
||||||
stream_ids: List[int],
|
stream_ids: List[int],
|
||||||
context: Dict[str, str],
|
context: Optional[Dict[str, str]],
|
||||||
) -> None:
|
) -> None:
|
||||||
for host in hosts:
|
for host in hosts:
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
|
@ -1875,7 +1890,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
[],
|
[],
|
||||||
)
|
)
|
||||||
|
|
||||||
async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
|
async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: # type: ignore[attr-defined]
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"add_device_list_outbound_pokes",
|
"add_device_list_outbound_pokes",
|
||||||
add_device_list_outbound_pokes_txn,
|
add_device_list_outbound_pokes_txn,
|
||||||
|
|
Loading…
Reference in New Issue