From 888a29f4127723a8d048ce47cff37ee8a7a6f1b9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 1 Jun 2022 16:02:53 +0100 Subject: [PATCH] Wait for lazy join to complete when getting current state (#12872) --- changelog.d/12872.misc | 1 + synapse/events/third_party_rules.py | 3 +- synapse/federation/federation_server.py | 4 +- synapse/handlers/device.py | 2 +- synapse/handlers/directory.py | 7 +- synapse/handlers/federation.py | 7 +- synapse/handlers/message.py | 2 +- synapse/handlers/presence.py | 6 +- synapse/handlers/register.py | 3 +- synapse/handlers/room.py | 13 +- synapse/handlers/room_list.py | 3 +- synapse/handlers/room_member.py | 5 +- synapse/handlers/room_summary.py | 11 +- synapse/handlers/stats.py | 6 +- synapse/handlers/sync.py | 13 +- synapse/handlers/user_directory.py | 6 +- synapse/module_api/__init__.py | 19 ++- synapse/push/mailer.py | 4 +- synapse/rest/admin/rooms.py | 3 +- synapse/storage/_base.py | 2 +- synapse/storage/controllers/__init__.py | 4 +- synapse/storage/controllers/persist_events.py | 4 +- synapse/storage/controllers/state.py | 112 +++++++++++++++++- synapse/storage/databases/main/room.py | 18 +++ synapse/storage/databases/main/state.py | 38 ++---- .../storage/databases/main/state_deltas.py | 4 +- .../storage/databases/main/user_directory.py | 4 +- .../util/partial_state_events_tracker.py | 60 ++++++++++ tests/handlers/test_federation.py | 6 +- tests/handlers/test_federation_event.py | 4 +- tests/handlers/test_typing.py | 2 +- tests/rest/client/test_upgrade_room.py | 8 +- .../util/test_partial_state_events_tracker.py | 59 ++++++++- 33 files changed, 361 insertions(+), 82 deletions(-) create mode 100644 changelog.d/12872.misc diff --git a/changelog.d/12872.misc b/changelog.d/12872.misc new file mode 100644 index 0000000000..f60a756f21 --- /dev/null +++ b/changelog.d/12872.misc @@ -0,0 +1 @@ +Faster room joins: when querying the current state of the room, wait for state to be populated. diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 9f4ff9799c..35f3f3690f 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -152,6 +152,7 @@ class ThirdPartyEventRules: self.third_party_rules = None self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = [] self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = [] @@ -463,7 +464,7 @@ class ThirdPartyEventRules: Returns: A dict mapping (event type, state key) to state event. """ - state_ids = await self.store.get_filtered_current_state_ids(room_id) + state_ids = await self._storage_controllers.state.get_current_state_ids(room_id) room_state_events = await self.store.get_events(state_ids.values()) state_events = {} diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 12591dc8db..f4af121c4d 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -118,6 +118,8 @@ class FederationServer(FederationBase): self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() + self._state_storage_controller = hs.get_storage_controllers().state + self.device_handler = hs.get_device_handler() # Ensure the following handlers are loaded since they register callbacks @@ -1221,7 +1223,7 @@ class FederationServer(FederationBase): Raises: AuthError if the server does not match the ACL """ - state_ids = await self.store.get_current_state_ids(room_id) + state_ids = await self._state_storage_controller.get_current_state_ids(room_id) acl_event_id = state_ids.get((EventTypes.ServerACL, "")) if not acl_event_id: diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 72faf2ee38..a0cbeedc30 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -166,7 +166,7 @@ class DeviceWorkerHandler: possibly_changed = set(changed) possibly_left = set() for room_id in rooms_changed: - current_state_ids = await self.store.get_current_state_ids(room_id) + current_state_ids = await self._state_storage.get_current_state_ids(room_id) # The user may have left the room # TODO: Check if they actually did or if we were just invited. diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 4aa33df884..44e84698c4 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -45,6 +45,7 @@ class DirectoryHandler: self.appservice_handler = hs.get_application_service_handler() self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.config = hs.config self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.require_membership = hs.config.server.require_membership_for_aliases @@ -463,7 +464,11 @@ class DirectoryHandler: making_public = visibility == "public" if making_public: room_aliases = await self.store.get_aliases_for_room(room_id) - canonical_alias = await self.store.get_canonical_alias_for_room(room_id) + canonical_alias = ( + await self._storage_controllers.state.get_canonical_alias_for_room( + room_id + ) + ) if canonical_alias: room_aliases.append(canonical_alias) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 659f279441..b212ee2172 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -750,7 +750,9 @@ class FederationHandler: # Note that this requires the /send_join request to come back to the # same server. if room_version.msc3083_join_rules: - state_ids = await self.store.get_current_state_ids(room_id) + state_ids = await self._state_storage_controller.get_current_state_ids( + room_id + ) if await self._event_auth_handler.has_restricted_join_rules( state_ids, room_version ): @@ -1552,6 +1554,9 @@ class FederationHandler: success = await self.store.clear_partial_state_room(room_id) if success: logger.info("State resync complete for %s", room_id) + self._storage_controllers.state.notify_room_un_partial_stated( + room_id + ) # TODO(faster_joins) update room stats and user directory? return diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ac911a2ddc..081625f0bd 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -217,7 +217,7 @@ class MessageHandler: ) if membership == Membership.JOIN: - state_ids = await self.store.get_filtered_current_state_ids( + state_ids = await self._state_storage_controller.get_current_state_ids( room_id, state_filter=state_filter ) room_state = await self.store.get_events(state_ids.values()) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index bf112b9e1e..895ea63ed3 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -134,6 +134,7 @@ class BasePresenceHandler(abc.ABC): def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.presence_router = hs.get_presence_router() self.state = hs.get_state_handler() self.is_mine_id = hs.is_mine_id @@ -1348,7 +1349,10 @@ class PresenceHandler(BasePresenceHandler): self._event_pos, room_max_stream_ordering, ) - max_pos, deltas = await self.store.get_current_state_deltas( + ( + max_pos, + deltas, + ) = await self._storage_controllers.state.get_current_state_deltas( self._event_pos, room_max_stream_ordering ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 05bb1e0225..338204287f 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -87,6 +87,7 @@ class LoginDict(TypedDict): class RegistrationHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.clock = hs.get_clock() self.hs = hs self.auth = hs.get_auth() @@ -528,7 +529,7 @@ class RegistrationHandler: if requires_invite: # If the server is in the room, check if the room is public. - state = await self.store.get_filtered_current_state_ids( + state = await self._storage_controllers.state.get_current_state_ids( room_id, StateFilter.from_types([(EventTypes.JoinRules, "")]) ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index e1341dd9bb..e2b0e519d4 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -107,6 +107,7 @@ class EventContext: class RoomCreationHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.auth = hs.get_auth() self.clock = hs.get_clock() self.hs = hs @@ -480,8 +481,10 @@ class RoomCreationHandler: if room_type == RoomTypes.SPACE: types_to_copy.append((EventTypes.SpaceChild, None)) - old_room_state_ids = await self.store.get_filtered_current_state_ids( - old_room_id, StateFilter.from_types(types_to_copy) + old_room_state_ids = ( + await self._storage_controllers.state.get_current_state_ids( + old_room_id, StateFilter.from_types(types_to_copy) + ) ) # map from event_id to BaseEvent old_room_state_events = await self.store.get_events(old_room_state_ids.values()) @@ -558,8 +561,10 @@ class RoomCreationHandler: ) # Transfer membership events - old_room_member_state_ids = await self.store.get_filtered_current_state_ids( - old_room_id, StateFilter.from_types([(EventTypes.Member, None)]) + old_room_member_state_ids = ( + await self._storage_controllers.state.get_current_state_ids( + old_room_id, StateFilter.from_types([(EventTypes.Member, None)]) + ) ) # map from event_id to BaseEvent diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index f3577b5d5a..183d4ae3c4 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -50,6 +50,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) class RoomListHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.hs = hs self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.response_cache: ResponseCache[ @@ -274,7 +275,7 @@ class RoomListHandler: if aliases: result["aliases"] = aliases - current_state_ids = await self.store.get_current_state_ids( + current_state_ids = await self._storage_controllers.state.get_current_state_ids( room_id, on_invalidate=cache_context.invalidate ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 00662dc961..70c674ff8e 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -68,6 +68,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.auth = hs.get_auth() self.state_handler = hs.get_state_handler() self.config = hs.config @@ -994,7 +995,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # If the host is in the room, but not one of the authorised hosts # for restricted join rules, a remote join must be used. room_version = await self.store.get_room_version(room_id) - current_state_ids = await self.store.get_current_state_ids(room_id) + current_state_ids = await self._storage_controllers.state.get_current_state_ids( + room_id + ) # If restricted join rules are not being used, a local join can always # be used. diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 75aee6a111..13098f56ed 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -90,6 +90,7 @@ class RoomSummaryHandler: def __init__(self, hs: "HomeServer"): self._event_auth_handler = hs.get_event_auth_handler() self._store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self._event_serializer = hs.get_event_client_serializer() self._server_name = hs.hostname self._federation_client = hs.get_federation_client() @@ -537,7 +538,7 @@ class RoomSummaryHandler: Returns: True if the room is accessible to the requesting user or server. """ - state_ids = await self._store.get_current_state_ids(room_id) + state_ids = await self._storage_controllers.state.get_current_state_ids(room_id) # If there's no state for the room, it isn't known. if not state_ids: @@ -702,7 +703,9 @@ class RoomSummaryHandler: # there should always be an entry assert stats is not None, "unable to retrieve stats for %s" % (room_id,) - current_state_ids = await self._store.get_current_state_ids(room_id) + current_state_ids = await self._storage_controllers.state.get_current_state_ids( + room_id + ) create_event = await self._store.get_event( current_state_ids[(EventTypes.Create, "")] ) @@ -760,7 +763,9 @@ class RoomSummaryHandler: """ # look for child rooms/spaces. - current_state_ids = await self._store.get_current_state_ids(room_id) + current_state_ids = await self._storage_controllers.state.get_current_state_ids( + room_id + ) events = await self._store.get_events_as_list( [ diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 436cd971ce..f45e06eb0e 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -40,6 +40,7 @@ class StatsHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.state = hs.get_state_handler() self.server_name = hs.hostname self.clock = hs.get_clock() @@ -105,7 +106,10 @@ class StatsHandler: logger.debug( "Processing room stats %s->%s", self.pos, room_max_stream_ordering ) - max_pos, deltas = await self.store.get_current_state_deltas( + ( + max_pos, + deltas, + ) = await self._storage_controllers.state.get_current_state_deltas( self.pos, room_max_stream_ordering ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index a1d41358d9..b4ead79f97 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -506,8 +506,10 @@ class SyncHandler: # ensure that we always include current state in the timeline current_state_ids: FrozenSet[str] = frozenset() if any(e.is_state() for e in recents): - current_state_ids_map = await self.store.get_current_state_ids( - room_id + current_state_ids_map = ( + await self._state_storage_controller.get_current_state_ids( + room_id + ) ) current_state_ids = frozenset(current_state_ids_map.values()) @@ -574,8 +576,11 @@ class SyncHandler: # ensure that we always include current state in the timeline current_state_ids = frozenset() if any(e.is_state() for e in loaded_recents): - current_state_ids_map = await self.store.get_current_state_ids( - room_id + # FIXME(faster_joins): We use the partial state here as + # we don't want to block `/sync` on finishing a lazy join. + # Is this the correct way of doing it? + current_state_ids_map = ( + await self.store.get_partial_current_state_ids(room_id) ) current_state_ids = frozenset(current_state_ids_map.values()) diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 74f7fdfe6c..8c3c52e1ca 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -56,6 +56,7 @@ class UserDirectoryHandler(StateDeltasHandler): super().__init__(hs) self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.server_name = hs.hostname self.clock = hs.get_clock() self.notifier = hs.get_notifier() @@ -174,7 +175,10 @@ class UserDirectoryHandler(StateDeltasHandler): logger.debug( "Processing user stats %s->%s", self.pos, room_max_stream_ordering ) - max_pos, deltas = await self.store.get_current_state_deltas( + ( + max_pos, + deltas, + ) = await self._storage_controllers.state.get_current_state_deltas( self.pos, room_max_stream_ordering ) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index b7451fc870..a8ad575fcd 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -194,6 +194,7 @@ class ModuleApi: self._store: Union[ DataStore, "GenericWorkerSlavedStore" ] = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self._auth = hs.get_auth() self._auth_handler = auth_handler self._server_name = hs.hostname @@ -911,7 +912,7 @@ class ModuleApi: The filtered state events in the room. """ state_ids = yield defer.ensureDeferred( - self._store.get_filtered_current_state_ids( + self._storage_controllers.state.get_current_state_ids( room_id=room_id, state_filter=StateFilter.from_types(types) ) ) @@ -1289,20 +1290,16 @@ class ModuleApi: # regardless of their state key ] """ + state_filter = None if event_filter: # If a filter was provided, turn it into a StateFilter and retrieve a filtered # view of the state. state_filter = StateFilter.from_types(event_filter) - state_ids = await self._store.get_filtered_current_state_ids( - room_id, - state_filter, - ) - else: - # If no filter was provided, get the whole state. We could also reuse the call - # to get_filtered_current_state_ids above, with `state_filter = StateFilter.all()`, - # but get_filtered_current_state_ids isn't cached and `get_current_state_ids` - # is, so using the latter when we can is better for perf. - state_ids = await self._store.get_current_state_ids(room_id) + + state_ids = await self._storage_controllers.state.get_current_state_ids( + room_id, + state_filter, + ) state_events = await self._store.get_events(state_ids.values()) diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 63aefd07f5..015c19b2d9 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -255,7 +255,9 @@ class Mailer: user_display_name = user_id async def _fetch_room_state(room_id: str) -> None: - room_state = await self.store.get_current_state_ids(room_id) + room_state = await self._state_storage_controller.get_current_state_ids( + room_id + ) state_by_room[room_id] = room_state # Run at most 3 of these at once: sync does 10 at a time but email diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 356d6f74d7..1cacd1a4f0 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -418,6 +418,7 @@ class RoomStateRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() @@ -430,7 +431,7 @@ class RoomStateRestServlet(RestServlet): if not ret: raise NotFoundError("Room not found") - event_ids = await self.store.get_current_state_ids(room_id) + event_ids = await self._storage_controllers.state.get_current_state_ids(room_id) events = await self.store.get_events(event_ids.values()) now = self.clock.time_msec() room_state = self._event_serializer.serialize_events(events.values(), now) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 8df80664a2..57bd74700e 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -77,7 +77,7 @@ class SQLBaseStore(metaclass=ABCMeta): # Purge other caches based on room state. self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) - self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,)) + self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,)) def _attempt_to_invalidate_cache( self, cache_name: str, key: Optional[Collection[Any]] diff --git a/synapse/storage/controllers/__init__.py b/synapse/storage/controllers/__init__.py index 992261d07b..55649719f6 100644 --- a/synapse/storage/controllers/__init__.py +++ b/synapse/storage/controllers/__init__.py @@ -18,7 +18,7 @@ from synapse.storage.controllers.persist_events import ( EventsPersistenceStorageController, ) from synapse.storage.controllers.purge_events import PurgeEventsStorageController -from synapse.storage.controllers.state import StateGroupStorageController +from synapse.storage.controllers.state import StateStorageController from synapse.storage.databases import Databases from synapse.storage.databases.main import DataStore @@ -39,7 +39,7 @@ class StorageControllers: self.main = stores.main self.purge_events = PurgeEventsStorageController(hs, stores) - self.state = StateGroupStorageController(hs, stores) + self.state = StateStorageController(hs, stores) self.persistence = None if stores.persist_events: diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index ef8c135b12..4caaa81808 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -994,7 +994,7 @@ class EventsPersistenceStorageController: Assumes that we are only persisting events for one room at a time. """ - existing_state = await self.main_store.get_current_state_ids(room_id) + existing_state = await self.main_store.get_partial_current_state_ids(room_id) to_delete = [key for key in existing_state if key not in current_state] @@ -1083,7 +1083,7 @@ class EventsPersistenceStorageController: # The server will leave the room, so we go and find out which remote # users will still be joined when we leave. if current_state is None: - current_state = await self.main_store.get_current_state_ids(room_id) + current_state = await self.main_store.get_partial_current_state_ids(room_id) current_state = dict(current_state) for key in delta.to_delete: current_state.pop(key, None) diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 0f09953086..9952b00493 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -14,7 +14,9 @@ import logging from typing import ( TYPE_CHECKING, + Any, Awaitable, + Callable, Collection, Dict, Iterable, @@ -24,9 +26,13 @@ from typing import ( Tuple, ) +from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.storage.state import StateFilter -from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker +from synapse.storage.util.partial_state_events_tracker import ( + PartialCurrentStateTracker, + PartialStateEventsTracker, +) from synapse.types import MutableStateMap, StateMap if TYPE_CHECKING: @@ -36,17 +42,27 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class StateGroupStorageController: - """High level interface to fetching state for event.""" +class StateStorageController: + """High level interface to fetching state for an event, or the current state + in a room. + """ def __init__(self, hs: "HomeServer", stores: "Databases"): self._is_mine_id = hs.is_mine_id self.stores = stores self._partial_state_events_tracker = PartialStateEventsTracker(stores.main) + self._partial_state_room_tracker = PartialCurrentStateTracker(stores.main) def notify_event_un_partial_stated(self, event_id: str) -> None: self._partial_state_events_tracker.notify_un_partial_stated(event_id) + def notify_room_un_partial_stated(self, room_id: str) -> None: + """Notify that the room no longer has any partial state. + + Must be called after `DataStore.clear_partial_state_room` + """ + self._partial_state_room_tracker.notify_un_partial_stated(room_id) + async def get_state_group_delta( self, state_group: int ) -> Tuple[Optional[int], Optional[StateMap[str]]]: @@ -349,3 +365,93 @@ class StateGroupStorageController: return await self.stores.state.store_state_group( event_id, room_id, prev_group, delta_ids, current_state_ids ) + + async def get_current_state_ids( + self, + room_id: str, + state_filter: Optional[StateFilter] = None, + on_invalidate: Optional[Callable[[], None]] = None, + ) -> StateMap[str]: + """Get the current state event ids for a room based on the + current_state_events table. + + If a state filter is given (that is not `StateFilter.all()`) the query + result is *not* cached. + + Args: + room_id: The room to get the state IDs of. state_filter: The state + filter used to fetch state from the + database. + on_invalidate: Callback for when the `get_current_state_ids` cache + for the room gets invalidated. + + Returns: + The current state of the room. + """ + if not state_filter or state_filter.must_await_full_state(self._is_mine_id): + await self._partial_state_room_tracker.await_full_state(room_id) + + if state_filter and not state_filter.is_full(): + return await self.stores.main.get_partial_filtered_current_state_ids( + room_id, state_filter + ) + else: + return await self.stores.main.get_partial_current_state_ids( + room_id, on_invalidate=on_invalidate + ) + + async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: + """Get canonical alias for room, if any + + Args: + room_id: The room ID + + Returns: + The canonical alias, if any + """ + + state = await self.get_current_state_ids( + room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) + ) + + event_id = state.get((EventTypes.CanonicalAlias, "")) + if not event_id: + return None + + event = await self.stores.main.get_event(event_id, allow_none=True) + if not event: + return None + + return event.content.get("canonical_alias") + + async def get_current_state_deltas( + self, prev_stream_id: int, max_stream_id: int + ) -> Tuple[int, List[Dict[str, Any]]]: + """Fetch a list of room state changes since the given stream id + + Each entry in the result contains the following fields: + - stream_id (int) + - room_id (str) + - type (str): event type + - state_key (str): + - event_id (str|None): new event_id for this state key. None if the + state has been deleted. + - prev_event_id (str|None): previous event_id for this state key. None + if it's new state. + + Args: + prev_stream_id: point to get changes since (exclusive) + max_stream_id: the point that we know has been correctly persisted + - ie, an upper limit to return changes from. + + Returns: + A tuple consisting of: + - the stream id which these results go up to + - list of current_state_delta_stream rows. If it is empty, we are + up to date. + """ + # FIXME(faster_joins): what do we do here? + + return await self.stores.main.get_partial_current_state_deltas( + prev_stream_id, max_stream_id + ) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index cfd8ce1624..68d4fc2e64 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1139,6 +1139,24 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): keyvalues={"room_id": room_id}, ) + async def is_partial_state_room(self, room_id: str) -> bool: + """Checks if this room has partial state. + + Returns true if this is a "partial-state" room, which means that the state + at events in the room, and `current_state_events`, may not yet be + complete. + """ + + entry = await self.db_pool.simple_select_one_onecol( + table="partial_state_rooms", + keyvalues={"room_id": room_id}, + retcol="room_id", + allow_none=True, + desc="is_partial_state_room", + ) + + return entry is not None + class _BackgroundUpdates: REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 3f2be3854b..bdd00273cd 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -242,7 +242,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Raises: NotFoundError if the room is unknown """ - state_ids = await self.get_current_state_ids(room_id) + state_ids = await self.get_partial_current_state_ids(room_id) if not state_ids: raise NotFoundError(f"Current state for room {room_id} is empty") @@ -258,10 +258,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return create_event @cached(max_entries=100000, iterable=True) - async def get_current_state_ids(self, room_id: str) -> StateMap[str]: + async def get_partial_current_state_ids(self, room_id: str) -> StateMap[str]: """Get the current state event ids for a room based on the current_state_events table. + This may be the partial state if we're lazy joining the room. + Args: room_id: The room to get the state IDs of. @@ -280,17 +282,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn} return await self.db_pool.runInteraction( - "get_current_state_ids", _get_current_state_ids_txn + "get_partial_current_state_ids", _get_current_state_ids_txn ) # FIXME: how should this be cached? - async def get_filtered_current_state_ids( + async def get_partial_filtered_current_state_ids( self, room_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[str]: """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result of doing a fresh state resolution as per state_handler.get_current_state + This may be the partial state if we're lazy joining the room. + Args: room_id state_filter: The state filter used to fetch state @@ -306,7 +310,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): if not where_clause: # We delegate to the cached version - return await self.get_current_state_ids(room_id) + return await self.get_partial_current_state_ids(room_id) def _get_filtered_current_state_ids_txn( txn: LoggingTransaction, @@ -334,30 +338,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn ) - async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: - """Get canonical alias for room, if any - - Args: - room_id: The room ID - - Returns: - The canonical alias, if any - """ - - state = await self.get_filtered_current_state_ids( - room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) - ) - - event_id = state.get((EventTypes.CanonicalAlias, "")) - if not event_id: - return None - - event = await self.get_event(event_id, allow_none=True) - if not event: - return None - - return event.content.get("canonical_alias") - @cached(max_entries=50000) async def _get_state_group_for_event(self, event_id: str) -> Optional[int]: return await self.db_pool.simple_select_one_onecol( diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index 188afec332..445213e12a 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -27,7 +27,7 @@ class StateDeltasStore(SQLBaseStore): # attribute. TODO: can we get static analysis to enforce this? _curr_state_delta_stream_cache: StreamChangeCache - async def get_current_state_deltas( + async def get_partial_current_state_deltas( self, prev_stream_id: int, max_stream_id: int ) -> Tuple[int, List[Dict[str, Any]]]: """Fetch a list of room state changes since the given stream id @@ -42,6 +42,8 @@ class StateDeltasStore(SQLBaseStore): - prev_event_id (str|None): previous event_id for this state key. None if it's new state. + This may be the partial state if we're lazy joining the room. + Args: prev_stream_id: point to get changes since (exclusive) max_stream_id: the point that we know has been correctly persisted diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 2282242e9d..ddb25b5cea 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -441,7 +441,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): (EventTypes.RoomHistoryVisibility, ""), ) - current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined] + # Getting the partial state is fine, as we're not looking at membership + # events. + current_state_ids = await self.get_partial_filtered_current_state_ids( # type: ignore[attr-defined] room_id, StateFilter.from_types(types_to_filter) ) diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py index a61a951ef0..211437cfaa 100644 --- a/synapse/storage/util/partial_state_events_tracker.py +++ b/synapse/storage/util/partial_state_events_tracker.py @@ -21,6 +21,7 @@ from twisted.internet.defer import Deferred from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.room import RoomWorkerStore from synapse.util import unwrapFirstError logger = logging.getLogger(__name__) @@ -118,3 +119,62 @@ class PartialStateEventsTracker: observer_set.discard(observer) if not observer_set: del self._observers[event_id] + + +class PartialCurrentStateTracker: + """Keeps track of which rooms have partial state, after partial-state joins""" + + def __init__(self, store: RoomWorkerStore): + self._store = store + + # a map from room id to a set of Deferreds which are waiting for that room to be + # un-partial-stated. + self._observers: Dict[str, Set[Deferred[None]]] = defaultdict(set) + + def notify_un_partial_stated(self, room_id: str) -> None: + """Notify that we now have full current state for a given room + + Unblocks any callers to await_full_state() for that room. + + Args: + room_id: the room that now has full current state. + """ + observers = self._observers.pop(room_id, None) + if not observers: + return + logger.info( + "Notifying %i things waiting for un-partial-stating of room %s", + len(observers), + room_id, + ) + with PreserveLoggingContext(): + for o in observers: + o.callback(None) + + async def await_full_state(self, room_id: str) -> None: + # We add the deferred immediately so that the DB call to check for + # partial state doesn't race when we unpartial the room. + d: Deferred[None] = Deferred() + self._observers.setdefault(room_id, set()).add(d) + + try: + # Check if the room has partial current state or not. + has_partial_state = await self._store.is_partial_state_room(room_id) + if not has_partial_state: + return + + logger.info( + "Awaiting un-partial-stating of room %s", + room_id, + ) + + await make_deferred_yieldable(d) + + logger.info("Room has un-partial-stated") + finally: + # Remove the added observer, and remove the room entry if its empty. + ds = self._observers.get(room_id) + if ds is not None: + ds.discard(d) + if not ds: + self._observers.pop(room_id, None) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 500c9ccfbc..e0eda545b9 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -237,7 +237,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): ) current_state = self.get_success( self.store.get_events_as_list( - (self.get_success(self.store.get_current_state_ids(room_id))).values() + ( + self.get_success(self.store.get_partial_current_state_ids(room_id)) + ).values() ) ) @@ -512,7 +514,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): self.get_success(d) # sanity-check: the room should show that the new user is a member - r = self.get_success(self.store.get_current_state_ids(room_id)) + r = self.get_success(self.store.get_partial_current_state_ids(room_id)) self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id) return join_event diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 1d5b2492c0..1a36c25c41 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -91,7 +91,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join") ) - initial_state_map = self.get_success(main_store.get_current_state_ids(room_id)) + initial_state_map = self.get_success( + main_store.get_partial_current_state_ids(room_id) + ) auth_event_ids = [ initial_state_map[("m.room.create", "")], diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 057256cecd..14a0ee4922 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -146,7 +146,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - self.datastore.get_current_state_deltas = Mock(return_value=(0, None)) + self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) self.datastore.get_to_device_stream_token = lambda: 0 self.datastore.get_new_device_msgs_for_remote = ( diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py index a21cbe9fa8..98c1039d33 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py @@ -249,7 +249,9 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): new_space_id = channel.json_body["replacement_room"] - state_ids = self.get_success(self.store.get_current_state_ids(new_space_id)) + state_ids = self.get_success( + self.store.get_partial_current_state_ids(new_space_id) + ) # Ensure the new room is still a space. create_event = self.get_success( @@ -284,7 +286,9 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): new_room_id = channel.json_body["replacement_room"] - state_ids = self.get_success(self.store.get_current_state_ids(new_room_id)) + state_ids = self.get_success( + self.store.get_partial_current_state_ids(new_room_id) + ) # Ensure the new room is the same type as the old room. create_event = self.get_success( diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py index 303e190b6c..cae14151c0 100644 --- a/tests/storage/util/test_partial_state_events_tracker.py +++ b/tests/storage/util/test_partial_state_events_tracker.py @@ -17,8 +17,12 @@ from unittest import mock from twisted.internet.defer import CancelledError, ensureDeferred -from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker +from synapse.storage.util.partial_state_events_tracker import ( + PartialCurrentStateTracker, + PartialStateEventsTracker, +) +from tests.test_utils import make_awaitable from tests.unittest import TestCase @@ -115,3 +119,56 @@ class PartialStateEventsTrackerTestCase(TestCase): self.tracker.notify_un_partial_stated("event1") self.successResultOf(d2) + + +class PartialCurrentStateTrackerTestCase(TestCase): + def setUp(self) -> None: + self.mock_store = mock.Mock(spec_set=["is_partial_state_room"]) + + self.tracker = PartialCurrentStateTracker(self.mock_store) + + def test_does_not_block_for_full_state_rooms(self): + self.mock_store.is_partial_state_room.return_value = make_awaitable(False) + + self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) + + def test_blocks_for_partial_room_state(self): + self.mock_store.is_partial_state_room.return_value = make_awaitable(True) + + d = ensureDeferred(self.tracker.await_full_state("room_id")) + + # there should be no result yet + self.assertNoResult(d) + + # notifying that the room has been de-partial-stated should unblock + self.tracker.notify_un_partial_stated("room_id") + self.successResultOf(d) + + def test_un_partial_state_race(self): + # We should correctly handle race between awaiting the state and us + # un-partialling the state + async def is_partial_state_room(events): + self.tracker.notify_un_partial_stated("room_id") + return True + + self.mock_store.is_partial_state_room.side_effect = is_partial_state_room + + self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) + + def test_cancellation(self): + self.mock_store.is_partial_state_room.return_value = make_awaitable(True) + + d1 = ensureDeferred(self.tracker.await_full_state("room_id")) + self.assertNoResult(d1) + + d2 = ensureDeferred(self.tracker.await_full_state("room_id")) + self.assertNoResult(d2) + + d1.cancel() + self.assertFailure(d1, CancelledError) + + # d2 should still be waiting! + self.assertNoResult(d2) + + self.tracker.notify_un_partial_stated("room_id") + self.successResultOf(d2)