mirror of
https://github.com/matrix-org/synapse.git
synced 2025-01-09 17:56:43 +00:00
Add some type hints to datastore. (#12477)
This commit is contained in:
parent
147f098fb4
commit
989fa33096
1
changelog.d/12477.misc
Normal file
1
changelog.d/12477.misc
Normal file
@ -0,0 +1 @@
|
||||
Add some type hints to datastore.
|
@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import attr
|
||||
from frozendict import frozendict
|
||||
from typing_extensions import Literal
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
|
||||
@ -106,7 +107,7 @@ class EventContext:
|
||||
incomplete state.
|
||||
"""
|
||||
|
||||
rejected: Union[bool, str] = False
|
||||
rejected: Union[Literal[False], str] = False
|
||||
_state_group: Optional[int] = None
|
||||
state_group_before_event: Optional[int] = None
|
||||
prev_group: Optional[int] = None
|
||||
|
@ -49,7 +49,7 @@ from synapse.storage.databases.main.search import SearchEntry
|
||||
from synapse.storage.engines.postgres import PostgresEngine
|
||||
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
|
||||
from synapse.storage.util.sequence import SequenceGenerator
|
||||
from synapse.types import StateMap, get_domain_from_id
|
||||
from synapse.types import JsonDict, StateMap, get_domain_from_id
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.iterutils import batch_iter, sorted_topologically
|
||||
|
||||
@ -235,7 +235,9 @@ class PersistEventsStore:
|
||||
"""
|
||||
results: List[str] = []
|
||||
|
||||
def _get_events_which_are_prevs_txn(txn, batch):
|
||||
def _get_events_which_are_prevs_txn(
|
||||
txn: LoggingTransaction, batch: Collection[str]
|
||||
) -> None:
|
||||
sql = """
|
||||
SELECT prev_event_id, internal_metadata
|
||||
FROM event_edges
|
||||
@ -285,7 +287,9 @@ class PersistEventsStore:
|
||||
# and their prev events.
|
||||
existing_prevs = set()
|
||||
|
||||
def _get_prevs_before_rejected_txn(txn, batch):
|
||||
def _get_prevs_before_rejected_txn(
|
||||
txn: LoggingTransaction, batch: Collection[str]
|
||||
) -> None:
|
||||
to_recursively_check = batch
|
||||
|
||||
while to_recursively_check:
|
||||
@ -515,7 +519,7 @@ class PersistEventsStore:
|
||||
@classmethod
|
||||
def _add_chain_cover_index(
|
||||
cls,
|
||||
txn,
|
||||
txn: LoggingTransaction,
|
||||
db_pool: DatabasePool,
|
||||
event_chain_id_gen: SequenceGenerator,
|
||||
event_to_room_id: Dict[str, str],
|
||||
@ -809,7 +813,7 @@ class PersistEventsStore:
|
||||
|
||||
@staticmethod
|
||||
def _allocate_chain_ids(
|
||||
txn,
|
||||
txn: LoggingTransaction,
|
||||
db_pool: DatabasePool,
|
||||
event_chain_id_gen: SequenceGenerator,
|
||||
event_to_room_id: Dict[str, str],
|
||||
@ -943,7 +947,7 @@ class PersistEventsStore:
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
):
|
||||
) -> None:
|
||||
"""Persist the mapping from transaction IDs to event IDs (if defined)."""
|
||||
|
||||
to_insert = []
|
||||
@ -997,7 +1001,7 @@ class PersistEventsStore:
|
||||
txn: LoggingTransaction,
|
||||
state_delta_by_room: Dict[str, DeltaState],
|
||||
stream_id: int,
|
||||
):
|
||||
) -> None:
|
||||
for room_id, delta_state in state_delta_by_room.items():
|
||||
to_delete = delta_state.to_delete
|
||||
to_insert = delta_state.to_insert
|
||||
@ -1155,7 +1159,7 @@ class PersistEventsStore:
|
||||
txn, room_id, members_changed
|
||||
)
|
||||
|
||||
def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str):
|
||||
def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None:
|
||||
"""Update the room version in the database based off current state
|
||||
events.
|
||||
|
||||
@ -1189,7 +1193,7 @@ class PersistEventsStore:
|
||||
txn: LoggingTransaction,
|
||||
new_forward_extremities: Dict[str, Set[str]],
|
||||
max_stream_order: int,
|
||||
):
|
||||
) -> None:
|
||||
for room_id in new_forward_extremities.keys():
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
|
||||
@ -1254,9 +1258,9 @@ class PersistEventsStore:
|
||||
|
||||
def _update_room_depths_txn(
|
||||
self,
|
||||
txn,
|
||||
txn: LoggingTransaction,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
):
|
||||
) -> None:
|
||||
"""Update min_depth for each room
|
||||
|
||||
Args:
|
||||
@ -1385,7 +1389,7 @@ class PersistEventsStore:
|
||||
# nothing to do here
|
||||
return
|
||||
|
||||
def event_dict(event):
|
||||
def event_dict(event: EventBase) -> JsonDict:
|
||||
d = event.get_dict()
|
||||
d.pop("redacted", None)
|
||||
d.pop("redacted_because", None)
|
||||
@ -1476,18 +1480,20 @@ class PersistEventsStore:
|
||||
),
|
||||
)
|
||||
|
||||
def _store_rejected_events_txn(self, txn, events_and_contexts):
|
||||
def _store_rejected_events_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
) -> List[Tuple[EventBase, EventContext]]:
|
||||
"""Add rows to the 'rejections' table for received events which were
|
||||
rejected
|
||||
|
||||
Args:
|
||||
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||
events_and_contexts (list[(EventBase, EventContext)]): events
|
||||
we are persisting
|
||||
txn: db connection
|
||||
events_and_contexts: events we are persisting
|
||||
|
||||
Returns:
|
||||
list[(EventBase, EventContext)] new list, without the rejected
|
||||
events.
|
||||
new list, without the rejected events.
|
||||
"""
|
||||
# Remove the rejected events from the list now that we've added them
|
||||
# to the events table and the events_json table.
|
||||
@ -1508,7 +1514,7 @@ class PersistEventsStore:
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
all_events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
inhibit_local_membership_updates: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
"""Update all the miscellaneous tables for new events
|
||||
|
||||
Args:
|
||||
@ -1602,7 +1608,11 @@ class PersistEventsStore:
|
||||
# Prefill the event cache
|
||||
self._add_to_cache(txn, events_and_contexts)
|
||||
|
||||
def _add_to_cache(self, txn, events_and_contexts):
|
||||
def _add_to_cache(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
) -> None:
|
||||
to_prefill = []
|
||||
|
||||
rows = []
|
||||
@ -1633,7 +1643,7 @@ class PersistEventsStore:
|
||||
if not row["rejects"] and not row["redacts"]:
|
||||
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
|
||||
|
||||
def prefill():
|
||||
def prefill() -> None:
|
||||
for cache_entry in to_prefill:
|
||||
self.store._get_event_cache.set(
|
||||
(cache_entry.event.event_id,), cache_entry
|
||||
@ -1663,19 +1673,24 @@ class PersistEventsStore:
|
||||
)
|
||||
|
||||
def insert_labels_for_event_txn(
|
||||
self, txn, event_id, labels, room_id, topological_ordering
|
||||
):
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
event_id: str,
|
||||
labels: List[str],
|
||||
room_id: str,
|
||||
topological_ordering: int,
|
||||
) -> None:
|
||||
"""Store the mapping between an event's ID and its labels, with one row per
|
||||
(event_id, label) tuple.
|
||||
|
||||
Args:
|
||||
txn (LoggingTransaction): The transaction to execute.
|
||||
event_id (str): The event's ID.
|
||||
labels (list[str]): A list of text labels.
|
||||
room_id (str): The ID of the room the event was sent to.
|
||||
topological_ordering (int): The position of the event in the room's topology.
|
||||
txn: The transaction to execute.
|
||||
event_id: The event's ID.
|
||||
labels: A list of text labels.
|
||||
room_id: The ID of the room the event was sent to.
|
||||
topological_ordering: The position of the event in the room's topology.
|
||||
"""
|
||||
return self.db_pool.simple_insert_many_txn(
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn=txn,
|
||||
table="event_labels",
|
||||
keys=("event_id", "label", "room_id", "topological_ordering"),
|
||||
@ -1684,25 +1699,32 @@ class PersistEventsStore:
|
||||
],
|
||||
)
|
||||
|
||||
def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
|
||||
def _insert_event_expiry_txn(
|
||||
self, txn: LoggingTransaction, event_id: str, expiry_ts: int
|
||||
) -> None:
|
||||
"""Save the expiry timestamp associated with a given event ID.
|
||||
|
||||
Args:
|
||||
txn (LoggingTransaction): The database transaction to use.
|
||||
event_id (str): The event ID the expiry timestamp is associated with.
|
||||
expiry_ts (int): The timestamp at which to expire (delete) the event.
|
||||
txn: The database transaction to use.
|
||||
event_id: The event ID the expiry timestamp is associated with.
|
||||
expiry_ts: The timestamp at which to expire (delete) the event.
|
||||
"""
|
||||
return self.db_pool.simple_insert_txn(
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn=txn,
|
||||
table="event_expiry",
|
||||
values={"event_id": event_id, "expiry_ts": expiry_ts},
|
||||
)
|
||||
|
||||
def _store_room_members_txn(
|
||||
self, txn, events, *, inhibit_local_membership_updates: bool = False
|
||||
):
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
events: List[EventBase],
|
||||
*,
|
||||
inhibit_local_membership_updates: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Store a room member in the database.
|
||||
|
||||
Args:
|
||||
txn: The transaction to use.
|
||||
events: List of events to store.
|
||||
@ -1742,6 +1764,7 @@ class PersistEventsStore:
|
||||
)
|
||||
|
||||
for event in events:
|
||||
assert event.internal_metadata.stream_ordering is not None
|
||||
txn.call_after(
|
||||
self.store._membership_stream_cache.entity_has_changed,
|
||||
event.state_key,
|
||||
@ -1838,7 +1861,9 @@ class PersistEventsStore:
|
||||
(parent_id, event.sender),
|
||||
)
|
||||
|
||||
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
|
||||
def _handle_insertion_event(
|
||||
self, txn: LoggingTransaction, event: EventBase
|
||||
) -> None:
|
||||
"""Handles keeping track of insertion events and edges/connections.
|
||||
Part of MSC2716.
|
||||
|
||||
@ -1899,7 +1924,7 @@ class PersistEventsStore:
|
||||
},
|
||||
)
|
||||
|
||||
def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase):
|
||||
def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase) -> None:
|
||||
"""Handles inserting the batch edges/connections between the batch event
|
||||
and an insertion event. Part of MSC2716.
|
||||
|
||||
@ -1999,25 +2024,29 @@ class PersistEventsStore:
|
||||
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
|
||||
)
|
||||
|
||||
def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase):
|
||||
def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
|
||||
if isinstance(event.content.get("topic"), str):
|
||||
self.store_event_search_txn(
|
||||
txn, event, "content.topic", event.content["topic"]
|
||||
)
|
||||
|
||||
def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase):
|
||||
def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
|
||||
if isinstance(event.content.get("name"), str):
|
||||
self.store_event_search_txn(
|
||||
txn, event, "content.name", event.content["name"]
|
||||
)
|
||||
|
||||
def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase):
|
||||
def _store_room_message_txn(
|
||||
self, txn: LoggingTransaction, event: EventBase
|
||||
) -> None:
|
||||
if isinstance(event.content.get("body"), str):
|
||||
self.store_event_search_txn(
|
||||
txn, event, "content.body", event.content["body"]
|
||||
)
|
||||
|
||||
def _store_retention_policy_for_room_txn(self, txn, event):
|
||||
def _store_retention_policy_for_room_txn(
|
||||
self, txn: LoggingTransaction, event: EventBase
|
||||
) -> None:
|
||||
if not event.is_state():
|
||||
logger.debug("Ignoring non-state m.room.retention event")
|
||||
return
|
||||
@ -2077,8 +2106,11 @@ class PersistEventsStore:
|
||||
)
|
||||
|
||||
def _set_push_actions_for_event_and_users_txn(
|
||||
self, txn, events_and_contexts, all_events_and_contexts
|
||||
):
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
all_events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
) -> None:
|
||||
"""Handles moving push actions from staging table to main
|
||||
event_push_actions table for all events in `events_and_contexts`.
|
||||
|
||||
@ -2086,12 +2118,10 @@ class PersistEventsStore:
|
||||
from the push action staging area.
|
||||
|
||||
Args:
|
||||
events_and_contexts (list[(EventBase, EventContext)]): events
|
||||
we are persisting
|
||||
all_events_and_contexts (list[(EventBase, EventContext)]): all
|
||||
events that we were going to persist. This includes events
|
||||
we've already persisted, etc, that wouldn't appear in
|
||||
events_and_context.
|
||||
events_and_contexts: events we are persisting
|
||||
all_events_and_contexts: all events that we were going to persist.
|
||||
This includes events we've already persisted, etc, that wouldn't
|
||||
appear in events_and_context.
|
||||
"""
|
||||
|
||||
# Only non outlier events will have push actions associated with them,
|
||||
@ -2160,7 +2190,9 @@ class PersistEventsStore:
|
||||
),
|
||||
)
|
||||
|
||||
def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
|
||||
def _remove_push_actions_for_event_id_txn(
|
||||
self, txn: LoggingTransaction, room_id: str, event_id: str
|
||||
) -> None:
|
||||
# Sad that we have to blow away the cache for the whole room here
|
||||
txn.call_after(
|
||||
self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
|
||||
@ -2171,7 +2203,9 @@ class PersistEventsStore:
|
||||
(room_id, event_id),
|
||||
)
|
||||
|
||||
def _store_rejections_txn(self, txn, event_id, reason):
|
||||
def _store_rejections_txn(
|
||||
self, txn: LoggingTransaction, event_id: str, reason: str
|
||||
) -> None:
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
table="rejections",
|
||||
@ -2183,8 +2217,10 @@ class PersistEventsStore:
|
||||
)
|
||||
|
||||
def _store_event_state_mappings_txn(
|
||||
self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
|
||||
):
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
events_and_contexts: Collection[Tuple[EventBase, EventContext]],
|
||||
) -> None:
|
||||
state_groups = {}
|
||||
for event, context in events_and_contexts:
|
||||
if event.internal_metadata.is_outlier():
|
||||
@ -2241,7 +2277,9 @@ class PersistEventsStore:
|
||||
state_group_id,
|
||||
)
|
||||
|
||||
def _update_min_depth_for_room_txn(self, txn, room_id, depth):
|
||||
def _update_min_depth_for_room_txn(
|
||||
self, txn: LoggingTransaction, room_id: str, depth: int
|
||||
) -> None:
|
||||
min_depth = self.store._get_min_depth_interaction(txn, room_id)
|
||||
|
||||
if min_depth is not None and depth >= min_depth:
|
||||
@ -2254,7 +2292,9 @@ class PersistEventsStore:
|
||||
values={"min_depth": depth},
|
||||
)
|
||||
|
||||
def _handle_mult_prev_events(self, txn, events):
|
||||
def _handle_mult_prev_events(
|
||||
self, txn: LoggingTransaction, events: List[EventBase]
|
||||
) -> None:
|
||||
"""
|
||||
For the given event, update the event edges table and forward and
|
||||
backward extremities tables.
|
||||
@ -2272,7 +2312,9 @@ class PersistEventsStore:
|
||||
|
||||
self._update_backward_extremeties(txn, events)
|
||||
|
||||
def _update_backward_extremeties(self, txn, events):
|
||||
def _update_backward_extremeties(
|
||||
self, txn: LoggingTransaction, events: List[EventBase]
|
||||
) -> None:
|
||||
"""Updates the event_backward_extremities tables based on the new/updated
|
||||
events being persisted.
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set
|
||||
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
@ -27,7 +27,7 @@ from synapse.storage.database import (
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -149,7 +149,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
||||
self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings
|
||||
)
|
||||
|
||||
async def _background_reindex_search(self, progress, batch_size):
|
||||
async def _background_reindex_search(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
# we work through the events table from highest stream id to lowest
|
||||
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
||||
max_stream_id = progress["max_stream_id_exclusive"]
|
||||
@ -157,7 +159,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
||||
|
||||
TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
|
||||
|
||||
def reindex_search_txn(txn):
|
||||
def reindex_search_txn(txn: LoggingTransaction) -> int:
|
||||
sql = (
|
||||
"SELECT stream_ordering, event_id, room_id, type, json, "
|
||||
" origin_server_ts FROM events"
|
||||
@ -255,12 +257,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
||||
|
||||
return result
|
||||
|
||||
async def _background_reindex_gin_search(self, progress, batch_size):
|
||||
async def _background_reindex_gin_search(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
"""This handles old synapses which used GIST indexes, if any;
|
||||
converting them back to be GIN as per the actual schema.
|
||||
"""
|
||||
|
||||
def create_index(conn):
|
||||
def create_index(conn: LoggingDatabaseConnection) -> None:
|
||||
conn.rollback()
|
||||
|
||||
# we have to set autocommit, because postgres refuses to
|
||||
@ -299,7 +303,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
||||
)
|
||||
return 1
|
||||
|
||||
async def _background_reindex_search_order(self, progress, batch_size):
|
||||
async def _background_reindex_search_order(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
||||
max_stream_id = progress["max_stream_id_exclusive"]
|
||||
rows_inserted = progress.get("rows_inserted", 0)
|
||||
@ -307,7 +313,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
||||
|
||||
if not have_added_index:
|
||||
|
||||
def create_index(conn):
|
||||
def create_index(conn: LoggingDatabaseConnection) -> None:
|
||||
conn.rollback()
|
||||
conn.set_session(autocommit=True)
|
||||
c = conn.cursor()
|
||||
@ -336,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
||||
pg,
|
||||
)
|
||||
|
||||
def reindex_search_txn(txn):
|
||||
def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]:
|
||||
sql = (
|
||||
"UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
|
||||
" origin_server_ts = e.origin_server_ts"
|
||||
@ -644,7 +650,8 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
else:
|
||||
raise Exception("Unrecognized database engine")
|
||||
|
||||
args.append(limit)
|
||||
# mypy expects to append only a `str`, not an `int`
|
||||
args.append(limit) # type: ignore[arg-type]
|
||||
|
||||
results = await self.db_pool.execute(
|
||||
"search_rooms", self.db_pool.cursor_to_dict, sql, *args
|
||||
@ -705,7 +712,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
A set of strings.
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
def f(txn: LoggingTransaction) -> Set[str]:
|
||||
highlight_words = set()
|
||||
for event in events:
|
||||
# As a hack we simply join values of all possible keys. This is
|
||||
@ -759,11 +766,11 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
return await self.db_pool.runInteraction("_find_highlights", f)
|
||||
|
||||
|
||||
def _to_postgres_options(options_dict):
|
||||
def _to_postgres_options(options_dict: JsonDict) -> str:
|
||||
return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
|
||||
|
||||
|
||||
def _parse_query(database_engine, search_term):
|
||||
def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str:
|
||||
"""Takes a plain unicode string from the user and converts it into a form
|
||||
that can be passed to database.
|
||||
We use this so that we can add prefix matching, which isn't something
|
||||
|
Loading…
Reference in New Issue
Block a user