Return an immutable value from get_latest_event_ids_in_room. (#16326)

This commit is contained in:
Patrick Cloke 2023-09-18 09:29:05 -04:00 committed by GitHub
parent 63d28a88c1
commit 85bfd4735e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 48 additions and 40 deletions

1
changelog.d/16326.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -103,7 +103,7 @@ class EventBuilder:
async def build( async def build(
self, self,
prev_event_ids: StrCollection, prev_event_ids: List[str],
auth_event_ids: Optional[List[str]], auth_event_ids: Optional[List[str]],
depth: Optional[int] = None, depth: Optional[int] = None,
) -> EventBase: ) -> EventBase:

View File

@ -723,12 +723,11 @@ class FederationEventHandler:
if not prevs - seen: if not prevs - seen:
return return
latest_list = await self._store.get_latest_event_ids_in_room(room_id) latest_frozen = await self._store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest # We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us # list to ensure the remote server doesn't give them to us
latest = set(latest_list) latest = seen | latest_frozen
latest |= seen
logger.info( logger.info(
"Requesting missing events between %s and %s", "Requesting missing events between %s and %s",
@ -1976,8 +1975,7 @@ class FederationEventHandler:
# partial and full state and may not be accurate. # partial and full state and may not be accurate.
return return
extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id) extrem_ids = await self._store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids_list)
prev_event_ids = set(event.prev_event_ids()) prev_event_ids = set(event.prev_event_ids())
if extrem_ids == prev_event_ids: if extrem_ids == prev_event_ids:

View File

@ -19,6 +19,7 @@ import logging
from collections import deque from collections import deque
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
AbstractSet,
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
@ -618,7 +619,7 @@ class EventsPersistenceStorageController:
) )
for room_id, ev_ctx_rm in events_by_room.items(): for room_id, ev_ctx_rm in events_by_room.items():
latest_event_ids = set( latest_event_ids = (
await self.main_store.get_latest_event_ids_in_room(room_id) await self.main_store.get_latest_event_ids_in_room(room_id)
) )
new_latest_event_ids = await self._calculate_new_extremities( new_latest_event_ids = await self._calculate_new_extremities(
@ -740,7 +741,7 @@ class EventsPersistenceStorageController:
self, self,
room_id: str, room_id: str,
event_contexts: List[Tuple[EventBase, EventContext]], event_contexts: List[Tuple[EventBase, EventContext]],
latest_event_ids: Collection[str], latest_event_ids: AbstractSet[str],
) -> Set[str]: ) -> Set[str]:
"""Calculates the new forward extremities for a room given events to """Calculates the new forward extremities for a room given events to
persist. persist.
@ -758,8 +759,6 @@ class EventsPersistenceStorageController:
and not event.internal_metadata.is_soft_failed() and not event.internal_metadata.is_soft_failed()
] ]
latest_event_ids = set(latest_event_ids)
# start with the existing forward extremities # start with the existing forward extremities
result = set(latest_event_ids) result = set(latest_event_ids)
@ -798,7 +797,7 @@ class EventsPersistenceStorageController:
self, self,
room_id: str, room_id: str,
events_context: List[Tuple[EventBase, EventContext]], events_context: List[Tuple[EventBase, EventContext]],
old_latest_event_ids: Set[str], old_latest_event_ids: AbstractSet[str],
new_latest_event_ids: Set[str], new_latest_event_ids: Set[str],
) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]], Set[str]]: ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]], Set[str]]:
"""Calculate the current state dict after adding some new events to """Calculate the current state dict after adding some new events to

View File

@ -19,6 +19,7 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Collection, Collection,
Dict, Dict,
FrozenSet,
Iterable, Iterable,
List, List,
Optional, Optional,
@ -47,7 +48,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict, StrCollection, StrSequence from synapse.types import JsonDict, StrCollection
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
@ -1179,13 +1180,14 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) )
@cached(max_entries=5000, iterable=True) @cached(max_entries=5000, iterable=True)
async def get_latest_event_ids_in_room(self, room_id: str) -> StrSequence: async def get_latest_event_ids_in_room(self, room_id: str) -> FrozenSet[str]:
return await self.db_pool.simple_select_onecol( event_ids = await self.db_pool.simple_select_onecol(
table="event_forward_extremities", table="event_forward_extremities",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
retcol="event_id", retcol="event_id",
desc="get_latest_event_ids_in_room", desc="get_latest_event_ids_in_room",
) )
return frozenset(event_ids)
async def get_min_depth(self, room_id: str) -> Optional[int]: async def get_min_depth(self, room_id: str) -> Optional[int]:
"""For the given room, get the minimum depth we have seen for it.""" """For the given room, get the minimum depth we have seen for it."""

View File

@ -222,7 +222,7 @@ class PersistEventsStore:
for room_id, latest_event_ids in new_forward_extremities.items(): for room_id, latest_event_ids in new_forward_extremities.items():
self.store.get_latest_event_ids_in_room.prefill( self.store.get_latest_event_ids_in_room.prefill(
(room_id,), list(latest_event_ids) (room_id,), frozenset(latest_event_ids)
) )
async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]: async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:

View File

@ -1858,7 +1858,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
) )
event = self.get_success( event = self.get_success(
builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None) builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None)
) )
self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event)) self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event))

View File

@ -90,7 +90,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
def test_get_latest_event_ids_in_room(self) -> None: def test_get_latest_event_ids_in_room(self) -> None:
create = self.persist(type="m.room.create", key="", creator=USER_ID) create = self.persist(type="m.room.create", key="", creator=USER_ID)
self.replicate() self.replicate()
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]) self.check("get_latest_event_ids_in_room", (ROOM_ID,), {create.event_id})
join = self.persist( join = self.persist(
type="m.room.member", type="m.room.member",
@ -99,7 +99,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
prev_events=[(create.event_id, {})], prev_events=[(create.event_id, {})],
) )
self.replicate() self.replicate()
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]) self.check("get_latest_event_ids_in_room", (ROOM_ID,), {join.event_id})
def test_redactions(self) -> None: def test_redactions(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID) self.persist(type="m.room.create", key="", creator=USER_ID)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, List, Optional, Sequence from typing import Any, List, Optional
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -139,7 +139,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
) )
# this is the point in the DAG where we make a fork # this is the point in the DAG where we make a fork
fork_point: Sequence[str] = self.get_success( fork_point = self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
) )
@ -294,7 +294,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
) )
# this is the point in the DAG where we make a fork # this is the point in the DAG where we make a fork
fork_point: Sequence[str] = self.get_success( fork_point = self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
) )
@ -316,14 +316,14 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.test_handler.received_rdata_rows.clear() self.test_handler.received_rdata_rows.clear()
# now roll back all that state by de-modding the users # now roll back all that state by de-modding the users
prev_events = fork_point prev_events = list(fork_point)
pl_events = [] pl_events = []
for u in user_ids: for u in user_ids:
pls["users"][u] = 0 pls["users"][u] = 0
e = self.get_success( e = self.get_success(
inject_event( inject_event(
self.hs, self.hs,
prev_event_ids=list(prev_events), prev_event_ids=prev_events,
type=EventTypes.PowerLevels, type=EventTypes.PowerLevels,
state_key="", state_key="",
sender=self.user_id, sender=self.user_id,

View File

@ -261,7 +261,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
builder = factory.for_room_version(room_version, event_dict) builder = factory.for_room_version(room_version, event_dict)
join_event = self.get_success( join_event = self.get_success(
builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None) builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None)
) )
self.get_success(federation.on_send_membership_event(remote_server, join_event)) self.get_success(federation.on_send_membership_event(remote_server, join_event))

View File

@ -120,7 +120,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
self.store.get_latest_event_ids_in_room(self.room_id) self.store.get_latest_event_ids_in_room(self.room_id)
) )
self.assertEqual(latest_event_ids, [event_id_4]) self.assertEqual(latest_event_ids, {event_id_4})
def test_basic_cleanup(self) -> None: def test_basic_cleanup(self) -> None:
"""Test that extremities are correctly calculated in the presence of """Test that extremities are correctly calculated in the presence of
@ -147,7 +147,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
latest_event_ids = self.get_success( latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id) self.store.get_latest_event_ids_in_room(self.room_id)
) )
self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b}) self.assertEqual(latest_event_ids, {event_id_a, event_id_b})
# Run the background update and check it did the right thing # Run the background update and check it did the right thing
self.run_background_update() self.run_background_update()
@ -155,7 +155,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
latest_event_ids = self.get_success( latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id) self.store.get_latest_event_ids_in_room(self.room_id)
) )
self.assertEqual(latest_event_ids, [event_id_b]) self.assertEqual(latest_event_ids, {event_id_b})
def test_chain_of_fail_cleanup(self) -> None: def test_chain_of_fail_cleanup(self) -> None:
"""Test that extremities are correctly calculated in the presence of """Test that extremities are correctly calculated in the presence of
@ -185,7 +185,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
latest_event_ids = self.get_success( latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id) self.store.get_latest_event_ids_in_room(self.room_id)
) )
self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b}) self.assertEqual(latest_event_ids, {event_id_a, event_id_b})
# Run the background update and check it did the right thing # Run the background update and check it did the right thing
self.run_background_update() self.run_background_update()
@ -193,7 +193,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
latest_event_ids = self.get_success( latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id) self.store.get_latest_event_ids_in_room(self.room_id)
) )
self.assertEqual(latest_event_ids, [event_id_b]) self.assertEqual(latest_event_ids, {event_id_b})
def test_forked_graph_cleanup(self) -> None: def test_forked_graph_cleanup(self) -> None:
r"""Test that extremities are correctly calculated in the presence of r"""Test that extremities are correctly calculated in the presence of
@ -240,7 +240,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
latest_event_ids = self.get_success( latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id) self.store.get_latest_event_ids_in_room(self.room_id)
) )
self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b, event_id_c}) self.assertEqual(latest_event_ids, {event_id_a, event_id_b, event_id_c})
# Run the background update and check it did the right thing # Run the background update and check it did the right thing
self.run_background_update() self.run_background_update()
@ -248,7 +248,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
latest_event_ids = self.get_success( latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id) self.store.get_latest_event_ids_in_room(self.room_id)
) )
self.assertEqual(set(latest_event_ids), {event_id_b, event_id_c}) self.assertEqual(latest_event_ids, {event_id_b, event_id_c})
class CleanupExtremDummyEventsTestCase(HomeserverTestCase): class CleanupExtremDummyEventsTestCase(HomeserverTestCase):

View File

@ -51,9 +51,15 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main
# Figure out what the most recent event is # Figure out what the most recent event is
most_recent = self.get_success( most_recent = next(
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) iter(
)[0] self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(
self.room_id
)
)
)
)
join_event = make_event_from_dict( join_event = make_event_from_dict(
{ {
@ -100,8 +106,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Make sure we actually joined the room # Make sure we actually joined the room
self.assertEqual( self.assertEqual(
self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0], self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)),
"$join:test.serv", {"$join:test.serv"},
) )
def test_cant_hide_direct_ancestors(self) -> None: def test_cant_hide_direct_ancestors(self) -> None:
@ -127,9 +133,11 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
self.http_client.post_json = post_json self.http_client.post_json = post_json
# Figure out what the most recent event is # Figure out what the most recent event is
most_recent = self.get_success( most_recent = next(
self.store.get_latest_event_ids_in_room(self.room_id) iter(
)[0] self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
)
)
# Now lie about an event # Now lie about an event
lying_event = make_event_from_dict( lying_event = make_event_from_dict(
@ -165,7 +173,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Make sure the invalid event isn't there # Make sure the invalid event isn't there
extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
self.assertEqual(extrem[0], "$join:test.serv") self.assertEqual(extrem, {"$join:test.serv"})
def test_retry_device_list_resync(self) -> None: def test_retry_device_list_resync(self) -> None:
"""Tests that device lists are marked as stale if they couldn't be synced, and """Tests that device lists are marked as stale if they couldn't be synced, and