Add some type hints to `event_federation` datastore (#12753)

Co-authored-by: David Robertson <david.m.robertson1@gmail.com>
This commit is contained in:
Dirk Klimpel 2022-05-18 17:02:10 +02:00 committed by GitHub
parent 682431efbe
commit 50ae4eafe1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 127 additions and 65 deletions

1
changelog.d/12753.misc Normal file
View File

@ -0,0 +1 @@
Add some type hints to datastore.

View File

@ -27,7 +27,6 @@ exclude = (?x)
|synapse/storage/databases/__init__.py |synapse/storage/databases/__init__.py
|synapse/storage/databases/main/cache.py |synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py |synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/event_federation.py
|synapse/storage/schema/ |synapse/storage/schema/
|tests/api/test_auth.py |tests/api/test_auth.py

View File

@ -53,6 +53,7 @@ class RoomBatchHandler:
# We want to use the successor event depth so they appear after `prev_event` because # We want to use the successor event depth so they appear after `prev_event` because
# it has a larger `depth` but before the successor event because the `stream_ordering` # it has a larger `depth` but before the successor event because the `stream_ordering`
# is negative before the successor event. # is negative before the successor event.
assert most_recent_prev_event_id is not None
successor_event_ids = await self.store.get_successor_events( successor_event_ids = await self.store.get_successor_events(
most_recent_prev_event_id most_recent_prev_event_id
) )
@ -139,6 +140,7 @@ class RoomBatchHandler:
_, _,
) = await self.store.get_max_depth_of(event_ids) ) = await self.store.get_max_depth_of(event_ids)
# mapping from (type, state_key) -> state_event_id # mapping from (type, state_key) -> state_event_id
assert most_recent_event_id is not None
prev_state_map = await self.state_store.get_state_ids_for_event( prev_state_map = await self.state_store.get_state_ids_for_event(
most_recent_event_id most_recent_event_id
) )

View File

@ -14,7 +14,17 @@
import itertools import itertools
import logging import logging
from queue import Empty, PriorityQueue from queue import Empty, PriorityQueue
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
cast,
)
import attr import attr
from prometheus_client import Counter, Gauge from prometheus_client import Counter, Gauge
@ -33,7 +43,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
@ -135,7 +145,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover # Check if we have indexed the room so we can use the chain cover
# algorithm. # algorithm.
room = await self.get_room(room_id) room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]: if room["has_auth_chain_index"]:
try: try:
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
@ -158,7 +168,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) )
def _get_auth_chain_ids_using_cover_index_txn( def _get_auth_chain_ids_using_cover_index_txn(
self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool self,
txn: LoggingTransaction,
room_id: str,
event_ids: Collection[str],
include_given: bool,
) -> Set[str]: ) -> Set[str]:
"""Calculates the auth chain IDs using the chain index.""" """Calculates the auth chain IDs using the chain index."""
@ -215,9 +229,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
chains: Dict[int, int] = {} chains: Dict[int, int] = {}
# Add all linked chains reachable from initial set of chains. # Add all linked chains reachable from initial set of chains.
for batch in batch_iter(event_chains, 1000): for batch2 in batch_iter(event_chains, 1000):
clause, args = make_in_list_sql_clause( clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch txn.database_engine, "origin_chain_id", batch2
) )
txn.execute(sql % (clause,), args) txn.execute(sql % (clause,), args)
@ -297,7 +311,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
front = set(event_ids) front = set(event_ids)
while front: while front:
new_front = set() new_front: Set[str] = set()
for chunk in batch_iter(front, 100): for chunk in batch_iter(front, 100):
# Pull the auth events either from the cache or DB. # Pull the auth events either from the cache or DB.
to_fetch: List[str] = [] # Event IDs to fetch from DB to_fetch: List[str] = [] # Event IDs to fetch from DB
@ -316,7 +330,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Note we need to batch up the results by event ID before # Note we need to batch up the results by event ID before
# adding to the cache. # adding to the cache.
to_cache = {} to_cache: Dict[str, List[Tuple[str, int]]] = {}
for event_id, auth_event_id, auth_event_depth in txn: for event_id, auth_event_id, auth_event_depth in txn:
to_cache.setdefault(event_id, []).append( to_cache.setdefault(event_id, []).append(
(auth_event_id, auth_event_depth) (auth_event_id, auth_event_depth)
@ -349,7 +363,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover # Check if we have indexed the room so we can use the chain cover
# algorithm. # algorithm.
room = await self.get_room(room_id) room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]: if room["has_auth_chain_index"]:
try: try:
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
@ -370,7 +384,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) )
def _get_auth_chain_difference_using_cover_index_txn( def _get_auth_chain_difference_using_cover_index_txn(
self, txn: Cursor, room_id: str, state_sets: List[Set[str]] self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]]
) -> Set[str]: ) -> Set[str]:
"""Calculates the auth chain difference using the chain index. """Calculates the auth chain difference using the chain index.
@ -444,9 +458,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# (We need to take a copy of `seen_chains` as we want to mutate it in # (We need to take a copy of `seen_chains` as we want to mutate it in
# the loop) # the loop)
for batch in batch_iter(set(seen_chains), 1000): for batch2 in batch_iter(set(seen_chains), 1000):
clause, args = make_in_list_sql_clause( clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch txn.database_engine, "origin_chain_id", batch2
) )
txn.execute(sql % (clause,), args) txn.execute(sql % (clause,), args)
@ -529,7 +543,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return result return result
def _get_auth_chain_difference_txn( def _get_auth_chain_difference_txn(
self, txn, state_sets: List[Set[str]] self, txn: LoggingTransaction, state_sets: List[Set[str]]
) -> Set[str]: ) -> Set[str]:
"""Calculates the auth chain difference using a breadth first search. """Calculates the auth chain difference using a breadth first search.
@ -602,7 +616,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# I think building a temporary list with fetchall is more efficient than # I think building a temporary list with fetchall is more efficient than
# just `search.extend(txn)`, but this is unconfirmed # just `search.extend(txn)`, but this is unconfirmed
search.extend(txn.fetchall()) search.extend(cast(List[Tuple[int, str]], txn.fetchall()))
# sort by depth # sort by depth
search.sort() search.sort()
@ -645,7 +659,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# We parse the results and add the to the `found` set and the # We parse the results and add the to the `found` set and the
# cache (note we need to batch up the results by event ID before # cache (note we need to batch up the results by event ID before
# adding to the cache). # adding to the cache).
to_cache = {} to_cache: Dict[str, List[Tuple[str, int]]] = {}
for event_id, auth_event_id, auth_event_depth in txn: for event_id, auth_event_id, auth_event_depth in txn:
to_cache.setdefault(event_id, []).append( to_cache.setdefault(event_id, []).append(
(auth_event_id, auth_event_depth) (auth_event_id, auth_event_depth)
@ -696,7 +710,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return {eid for eid, n in event_to_missing_sets.items() if n} return {eid for eid, n in event_to_missing_sets.items() if n}
async def get_oldest_event_ids_with_depth_in_room( async def get_oldest_event_ids_with_depth_in_room(
self, room_id self, room_id: str
) -> List[Tuple[str, int]]: ) -> List[Tuple[str, int]]:
"""Gets the oldest events(backwards extremities) in the room along with the """Gets the oldest events(backwards extremities) in the room along with the
aproximate depth. aproximate depth.
@ -713,7 +727,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
List of (event_id, depth) tuples List of (event_id, depth) tuples
""" """
def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id): def get_oldest_event_ids_with_depth_in_room_txn(
txn: LoggingTransaction, room_id: str
) -> List[Tuple[str, int]]:
# Assemble a dictionary with event_id -> depth for the oldest events # Assemble a dictionary with event_id -> depth for the oldest events
# we know of in the room. Backwards extremeties are the oldest # we know of in the room. Backwards extremeties are the oldest
# events we know of in the room but we only know of them because # events we know of in the room but we only know of them because
@ -743,7 +759,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(sql, (room_id, False)) txn.execute(sql, (room_id, False))
return txn.fetchall() return cast(List[Tuple[str, int]], txn.fetchall())
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_oldest_event_ids_with_depth_in_room", "get_oldest_event_ids_with_depth_in_room",
@ -752,7 +768,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) )
async def get_insertion_event_backward_extremities_in_room( async def get_insertion_event_backward_extremities_in_room(
self, room_id self, room_id: str
) -> List[Tuple[str, int]]: ) -> List[Tuple[str, int]]:
"""Get the insertion events we know about that we haven't backfilled yet. """Get the insertion events we know about that we haven't backfilled yet.
@ -768,7 +784,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
List of (event_id, depth) tuples List of (event_id, depth) tuples
""" """
def get_insertion_event_backward_extremities_in_room_txn(txn, room_id): def get_insertion_event_backward_extremities_in_room_txn(
txn: LoggingTransaction, room_id: str
) -> List[Tuple[str, int]]:
sql = """ sql = """
SELECT b.event_id, MAX(e.depth) FROM insertion_events as i SELECT b.event_id, MAX(e.depth) FROM insertion_events as i
/* We only want insertion events that are also marked as backwards extremities */ /* We only want insertion events that are also marked as backwards extremities */
@ -780,7 +798,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
""" """
txn.execute(sql, (room_id,)) txn.execute(sql, (room_id,))
return txn.fetchall() return cast(List[Tuple[str, int]], txn.fetchall())
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_insertion_event_backward_extremities_in_room", "get_insertion_event_backward_extremities_in_room",
@ -788,7 +806,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id, room_id,
) )
async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]: async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
"""Returns the event ID and depth for the event that has the max depth from a set of event IDs """Returns the event ID and depth for the event that has the max depth from a set of event IDs
Args: Args:
@ -817,7 +835,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return max_depth_event_id, current_max_depth return max_depth_event_id, current_max_depth
async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[str, int]: async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
"""Returns the event ID and depth for the event that has the min depth from a set of event IDs """Returns the event ID and depth for the event that has the min depth from a set of event IDs
Args: Args:
@ -865,7 +883,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
) )
def _get_prev_events_for_room_txn(self, txn, room_id: str): def _get_prev_events_for_room_txn(
self, txn: LoggingTransaction, room_id: str
) -> List[str]:
# we just use the 10 newest events. Older events will become # we just use the 10 newest events. Older events will become
# prev_events of future events. # prev_events of future events.
@ -896,7 +916,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
sorted by extremity count. sorted by extremity count.
""" """
def _get_rooms_with_many_extremities_txn(txn): def _get_rooms_with_many_extremities_txn(txn: LoggingTransaction) -> List[str]:
where_clause = "1=1" where_clause = "1=1"
if room_id_filter: if room_id_filter:
where_clause = "room_id NOT IN (%s)" % ( where_clause = "room_id NOT IN (%s)" % (
@ -937,7 +957,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"get_min_depth", self._get_min_depth_interaction, room_id "get_min_depth", self._get_min_depth_interaction, room_id
) )
def _get_min_depth_interaction(self, txn, room_id): def _get_min_depth_interaction(
self, txn: LoggingTransaction, room_id: str
) -> Optional[int]:
min_depth = self.db_pool.simple_select_one_onecol_txn( min_depth = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="room_depth", table="room_depth",
@ -966,22 +988,24 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
""" """
# We want to make the cache more effective, so we clamp to the last # We want to make the cache more effective, so we clamp to the last
# change before the given ordering. # change before the given ordering.
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # type: ignore[attr-defined]
# We don't always have a full stream_to_exterm_id table, e.g. after # We don't always have a full stream_to_exterm_id table, e.g. after
# the upgrade that introduced it, so we make sure we never ask for a # the upgrade that introduced it, so we make sure we never ask for a
# stream_ordering from before a restart # stream_ordering from before a restart
last_change = max(self._stream_order_on_start, last_change) last_change = max(self._stream_order_on_start, last_change) # type: ignore[attr-defined]
# provided the last_change is recent enough, we now clamp the requested # provided the last_change is recent enough, we now clamp the requested
# stream_ordering to it. # stream_ordering to it.
if last_change > self.stream_ordering_month_ago: if last_change > self.stream_ordering_month_ago: # type: ignore[attr-defined]
stream_ordering = min(last_change, stream_ordering) stream_ordering = min(last_change, stream_ordering)
return await self._get_forward_extremeties_for_room(room_id, stream_ordering) return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
@cached(max_entries=5000, num_args=2) @cached(max_entries=5000, num_args=2)
async def _get_forward_extremeties_for_room(self, room_id, stream_ordering): async def _get_forward_extremeties_for_room(
self, room_id: str, stream_ordering: int
) -> List[str]:
"""For a given room_id and stream_ordering, return the forward """For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time". extremeties of the room at that point in "time".
@ -989,7 +1013,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
stream_orderings from that point. stream_orderings from that point.
""" """
if stream_ordering <= self.stream_ordering_month_ago: if stream_ordering <= self.stream_ordering_month_ago: # type: ignore[attr-defined]
raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,)) raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
sql = """ sql = """
@ -1002,7 +1026,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
WHERE room_id = ? WHERE room_id = ?
""" """
def get_forward_extremeties_for_room_txn(txn): def get_forward_extremeties_for_room_txn(txn: LoggingTransaction) -> List[str]:
txn.execute(sql, (stream_ordering, room_id)) txn.execute(sql, (stream_ordering, room_id))
return [event_id for event_id, in txn] return [event_id for event_id, in txn]
@ -1104,8 +1128,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
] ]
async def get_backfill_events( async def get_backfill_events(
self, room_id: str, seed_event_id_list: list, limit: int self, room_id: str, seed_event_id_list: List[str], limit: int
): ) -> List[EventBase]:
"""Get a list of Events for a given topic that occurred before (and """Get a list of Events for a given topic that occurred before (and
including) the events in seed_event_id_list. Return a list of max size `limit` including) the events in seed_event_id_list. Return a list of max size `limit`
@ -1123,10 +1147,19 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) )
events = await self.get_events_as_list(event_ids) events = await self.get_events_as_list(event_ids)
return sorted( return sorted(
events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering) # type-ignore: mypy doesn't like negating the Optional[int] stream_ordering.
# But it's never None, because these events were previously persisted to the DB.
events,
key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering), # type: ignore[operator]
) )
def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit): def _get_backfill_events(
self,
txn: LoggingTransaction,
room_id: str,
seed_event_id_list: List[str],
limit: int,
) -> Set[str]:
""" """
We want to make sure that we do a breadth-first, "depth" ordered search. We want to make sure that we do a breadth-first, "depth" ordered search.
We also handle navigating historical branches of history connected by We also handle navigating historical branches of history connected by
@ -1139,7 +1172,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
limit, limit,
) )
event_id_results = set() event_id_results: Set[str] = set()
# In a PriorityQueue, the lowest valued entries are retrieved first. # In a PriorityQueue, the lowest valued entries are retrieved first.
# We're using depth as the priority in the queue and tie-break based on # We're using depth as the priority in the queue and tie-break based on
@ -1147,7 +1180,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# highest and newest-in-time message. We add events to the queue with a # highest and newest-in-time message. We add events to the queue with a
# negative depth so that we process the newest-in-time messages first # negative depth so that we process the newest-in-time messages first
# going backwards in time. stream_ordering follows the same pattern. # going backwards in time. stream_ordering follows the same pattern.
queue = PriorityQueue() queue: "PriorityQueue[Tuple[int, int, str, str]]" = PriorityQueue()
for seed_event_id in seed_event_id_list: for seed_event_id in seed_event_id_list:
event_lookup_result = self.db_pool.simple_select_one_txn( event_lookup_result = self.db_pool.simple_select_one_txn(
@ -1253,7 +1286,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return event_id_results return event_id_results
async def get_missing_events(self, room_id, earliest_events, latest_events, limit): async def get_missing_events(
self,
room_id: str,
earliest_events: List[str],
latest_events: List[str],
limit: int,
) -> List[EventBase]:
ids = await self.db_pool.runInteraction( ids = await self.db_pool.runInteraction(
"get_missing_events", "get_missing_events",
self._get_missing_events, self._get_missing_events,
@ -1264,11 +1303,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) )
return await self.get_events_as_list(ids) return await self.get_events_as_list(ids)
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): def _get_missing_events(
self,
txn: LoggingTransaction,
room_id: str,
earliest_events: List[str],
latest_events: List[str],
limit: int,
) -> List[str]:
seen_events = set(earliest_events) seen_events = set(earliest_events)
front = set(latest_events) - seen_events front = set(latest_events) - seen_events
event_results = [] event_results: List[str] = []
query = ( query = (
"SELECT prev_event_id FROM event_edges " "SELECT prev_event_id FROM event_edges "
@ -1311,7 +1357,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@wrap_as_background_process("delete_old_forward_extrem_cache") @wrap_as_background_process("delete_old_forward_extrem_cache")
async def _delete_old_forward_extrem_cache(self) -> None: async def _delete_old_forward_extrem_cache(self) -> None:
def _delete_old_forward_extrem_cache_txn(txn): def _delete_old_forward_extrem_cache_txn(txn: LoggingTransaction) -> None:
# Delete entries older than a month, while making sure we don't delete # Delete entries older than a month, while making sure we don't delete
# the only entries for a room. # the only entries for a room.
sql = """ sql = """
@ -1324,7 +1370,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) AND stream_ordering < ? ) AND stream_ordering < ?
""" """
txn.execute( txn.execute(
sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) # type: ignore[attr-defined]
) )
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
@ -1382,7 +1428,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
""" """
if self.db_pool.engine.supports_returning: if self.db_pool.engine.supports_returning:
def _remove_received_event_from_staging_txn(txn): def _remove_received_event_from_staging_txn(
txn: LoggingTransaction,
) -> Optional[int]:
sql = """ sql = """
DELETE FROM federation_inbound_events_staging DELETE FROM federation_inbound_events_staging
WHERE origin = ? AND event_id = ? WHERE origin = ? AND event_id = ?
@ -1390,21 +1438,24 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
""" """
txn.execute(sql, (origin, event_id)) txn.execute(sql, (origin, event_id))
return txn.fetchone() row = cast(Optional[Tuple[int]], txn.fetchone())
row = await self.db_pool.runInteraction(
"remove_received_event_from_staging",
_remove_received_event_from_staging_txn,
db_autocommit=True,
)
if row is None: if row is None:
return None return None
return row[0] return row[0]
return await self.db_pool.runInteraction(
"remove_received_event_from_staging",
_remove_received_event_from_staging_txn,
db_autocommit=True,
)
else: else:
def _remove_received_event_from_staging_txn(txn): def _remove_received_event_from_staging_txn(
txn: LoggingTransaction,
) -> Optional[int]:
received_ts = self.db_pool.simple_select_one_onecol_txn( received_ts = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="federation_inbound_events_staging", table="federation_inbound_events_staging",
@ -1437,7 +1488,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) -> Optional[Tuple[str, str]]: ) -> Optional[Tuple[str, str]]:
"""Get the next event ID in the staging area for the given room.""" """Get the next event ID in the staging area for the given room."""
def _get_next_staged_event_id_for_room_txn(txn): def _get_next_staged_event_id_for_room_txn(
txn: LoggingTransaction,
) -> Optional[Tuple[str, str]]:
sql = """ sql = """
SELECT origin, event_id SELECT origin, event_id
FROM federation_inbound_events_staging FROM federation_inbound_events_staging
@ -1448,7 +1501,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(sql, (room_id,)) txn.execute(sql, (room_id,))
return txn.fetchone() return cast(Optional[Tuple[str, str]], txn.fetchone())
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_next_staged_event_id_for_room", _get_next_staged_event_id_for_room_txn "get_next_staged_event_id_for_room", _get_next_staged_event_id_for_room_txn
@ -1461,7 +1514,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) -> Optional[Tuple[str, EventBase]]: ) -> Optional[Tuple[str, EventBase]]:
"""Get the next event in the staging area for the given room.""" """Get the next event in the staging area for the given room."""
def _get_next_staged_event_for_room_txn(txn): def _get_next_staged_event_for_room_txn(
txn: LoggingTransaction,
) -> Optional[Tuple[str, str, str]]:
sql = """ sql = """
SELECT event_json, internal_metadata, origin SELECT event_json, internal_metadata, origin
FROM federation_inbound_events_staging FROM federation_inbound_events_staging
@ -1471,7 +1526,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
""" """
txn.execute(sql, (room_id,)) txn.execute(sql, (room_id,))
return txn.fetchone() return cast(Optional[Tuple[str, str, str]], txn.fetchone())
row = await self.db_pool.runInteraction( row = await self.db_pool.runInteraction(
"get_next_staged_event_for_room", _get_next_staged_event_for_room_txn "get_next_staged_event_for_room", _get_next_staged_event_for_room_txn
@ -1599,18 +1654,20 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) )
@wrap_as_background_process("_get_stats_for_federation_staging") @wrap_as_background_process("_get_stats_for_federation_staging")
async def _get_stats_for_federation_staging(self): async def _get_stats_for_federation_staging(self) -> None:
"""Update the prometheus metrics for the inbound federation staging area.""" """Update the prometheus metrics for the inbound federation staging area."""
def _get_stats_for_federation_staging_txn(txn): def _get_stats_for_federation_staging_txn(
txn: LoggingTransaction,
) -> Tuple[int, int]:
txn.execute("SELECT count(*) FROM federation_inbound_events_staging") txn.execute("SELECT count(*) FROM federation_inbound_events_staging")
(count,) = txn.fetchone() (count,) = cast(Tuple[int], txn.fetchone())
txn.execute( txn.execute(
"SELECT min(received_ts) FROM federation_inbound_events_staging" "SELECT min(received_ts) FROM federation_inbound_events_staging"
) )
(received_ts,) = txn.fetchone() (received_ts,) = cast(Tuple[Optional[int]], txn.fetchone())
# If there is nothing in the staging area default it to 0. # If there is nothing in the staging area default it to 0.
age = 0 age = 0
@ -1651,19 +1708,21 @@ class EventFederationStore(EventFederationWorkerStore):
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
) )
async def clean_room_for_join(self, room_id): async def clean_room_for_join(self, room_id: str) -> None:
return await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"clean_room_for_join", self._clean_room_for_join_txn, room_id "clean_room_for_join", self._clean_room_for_join_txn, room_id
) )
def _clean_room_for_join_txn(self, txn, room_id): def _clean_room_for_join_txn(self, txn: LoggingTransaction, room_id: str) -> None:
query = "DELETE FROM event_forward_extremities WHERE room_id = ?" query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
txn.execute(query, (room_id,)) txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
async def _background_delete_non_state_event_auth(self, progress, batch_size): async def _background_delete_non_state_event_auth(
def delete_event_auth(txn): self, progress: JsonDict, batch_size: int
) -> int:
def delete_event_auth(txn: LoggingTransaction) -> bool:
target_min_stream_id = progress.get("target_min_stream_id_inclusive") target_min_stream_id = progress.get("target_min_stream_id_inclusive")
max_stream_id = progress.get("max_stream_id_exclusive") max_stream_id = progress.get("max_stream_id_exclusive")

View File

@ -332,6 +332,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
most_recent_prev_event_depth, most_recent_prev_event_depth,
) = self.get_success(self.store.get_max_depth_of(prev_event_ids)) ) = self.get_success(self.store.get_max_depth_of(prev_event_ids))
# mapping from (type, state_key) -> state_event_id # mapping from (type, state_key) -> state_event_id
assert most_recent_prev_event_id is not None
prev_state_map = self.get_success( prev_state_map = self.get_success(
self.state_store.get_state_ids_for_event(most_recent_prev_event_id) self.state_store.get_state_ids_for_event(most_recent_prev_event_id)
) )