diff --git a/changelog.d/12399.misc b/changelog.d/12399.misc new file mode 100644 index 0000000000..cd2e09626d --- /dev/null +++ b/changelog.d/12399.misc @@ -0,0 +1 @@ +Preparation for faster-room-join work: Implement a tracking mechanism to allow functions to wait for full room state to arrive. diff --git a/docker/complement/conf/homeserver.yaml b/docker/complement/conf/homeserver.yaml index c9d6a31240..174f87f52e 100644 --- a/docker/complement/conf/homeserver.yaml +++ b/docker/complement/conf/homeserver.yaml @@ -103,8 +103,10 @@ experimental_features: spaces_enabled: true # Enable history backfilling support msc2716_enabled: true - # server-side support for partial state in /send_join + # server-side support for partial state in /send_join responses msc3706_enabled: true + # client-side support for partial state in /send_join responses + faster_joins: true # Enable jump to date endpoint msc3030_enabled: true diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index d34e9f3554..e0feba05fa 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -64,4 +64,4 @@ docker build -t $COMPLEMENT_BASE_IMAGE -f "docker/complement/$COMPLEMENT_DOCKERF # Run the tests! echo "Images built; running complement" cd "$COMPLEMENT_DIR" -go test -v -tags synapse_blacklist,msc2716,msc3030 -count=1 "$@" ./tests/... +go test -v -tags synapse_blacklist,msc2716,msc3030,faster_joins -count=1 "$@" ./tests/... diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 32bf02818c..693b544286 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -515,6 +515,7 @@ class FederationEventHandler: ) return await self._store.update_state_for_partial_state_event(event, context) + self._state_store.notify_event_un_partial_stated(event.event_id) async def backfill( self, dest: str, room_id: str, limit: int, extremities: Collection[str] diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 60876204bd..6d6e146ff1 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1974,7 +1974,15 @@ class EventsWorkerStore(SQLBaseStore): async def get_partial_state_events( self, event_ids: Collection[str] ) -> Dict[str, bool]: - """Checks which of the given events have partial state""" + """Checks which of the given events have partial state + + Args: + event_ids: the events we want to check for partial state. + + Returns: + a dict mapping from event id to partial-stateness. We return True for + any of the events which are unknown (or are outliers). + """ result = await self.db_pool.simple_select_many_batch( table="partial_state_events", column="event_id", diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 7a1b013fa3..e653841fe5 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -396,6 +396,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) # TODO(faster_joins): need to do something about workers here + txn.call_after(self.is_partial_state_event.invalidate, (event.event_id,)) txn.call_after( self._get_state_group_for_event.prefill, (event.event_id,), diff --git a/synapse/storage/state.py b/synapse/storage/state.py index cda194e8c8..d1d5859214 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -31,6 +31,7 @@ from frozendict import frozendict from synapse.api.constants import EventTypes from synapse.events import EventBase +from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker from synapse.types import MutableStateMap, StateKey, StateMap if TYPE_CHECKING: @@ -542,6 +543,10 @@ class StateGroupStorage: def __init__(self, hs: "HomeServer", stores: "Databases"): self.stores = stores + self._partial_state_events_tracker = PartialStateEventsTracker(stores.main) + + def notify_event_un_partial_stated(self, event_id: str) -> None: + self._partial_state_events_tracker.notify_un_partial_stated(event_id) async def get_state_group_delta( self, state_group: int @@ -579,7 +584,7 @@ class StateGroupStorage: if not event_ids: return {} - event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups(groups) @@ -668,7 +673,7 @@ class StateGroupStorage: RuntimeError if we don't have a state group for one or more of the events (ie they are outliers or unknown) """ - event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups( @@ -709,7 +714,7 @@ class StateGroupStorage: RuntimeError if we don't have a state group for one or more of the events (ie they are outliers or unknown) """ - event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups( @@ -785,6 +790,23 @@ class StateGroupStorage: groups, state_filter or StateFilter.all() ) + async def _get_state_group_for_events( + self, + event_ids: Collection[str], + await_full_state: bool = True, + ) -> Mapping[str, int]: + """Returns mapping event_id -> state_group + + Args: + event_ids: events to get state groups for + await_full_state: if true, will block if we do not yet have complete + state at this event. + """ + if await_full_state: + await self._partial_state_events_tracker.await_full_state(event_ids) + + return await self.stores.main._get_state_group_for_events(event_ids) + async def store_state_group( self, event_id: str, diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py new file mode 100644 index 0000000000..a61a951ef0 --- /dev/null +++ b/synapse/storage/util/partial_state_events_tracker.py @@ -0,0 +1,120 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import defaultdict +from typing import Collection, Dict, Set + +from twisted.internet import defer +from twisted.internet.defer import Deferred + +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.util import unwrapFirstError + +logger = logging.getLogger(__name__) + + +class PartialStateEventsTracker: + """Keeps track of which events have partial state, after a partial-state join""" + + def __init__(self, store: EventsWorkerStore): + self._store = store + # a map from event id to a set of Deferreds which are waiting for that event to be + # un-partial-stated. + self._observers: Dict[str, Set[Deferred[None]]] = defaultdict(set) + + def notify_un_partial_stated(self, event_id: str) -> None: + """Notify that we now have full state for a given event + + Called by the state-resynchronization loop whenever we resynchronize the state + for a particular event. Unblocks any callers to await_full_state() for that + event. + + Args: + event_id: the event that now has full state. + """ + observers = self._observers.pop(event_id, None) + if not observers: + return + logger.info( + "Notifying %i things waiting for un-partial-stating of event %s", + len(observers), + event_id, + ) + with PreserveLoggingContext(): + for o in observers: + o.callback(None) + + async def await_full_state(self, event_ids: Collection[str]) -> None: + """Wait for all the given events to have full state. + + Args: + event_ids: the list of event ids that we want full state for + """ + # first try the happy path: if there are no partial-state events, we can return + # quickly + partial_state_event_ids = [ + ev + for ev, p in (await self._store.get_partial_state_events(event_ids)).items() + if p + ] + + if not partial_state_event_ids: + return + + logger.info( + "Awaiting un-partial-stating of events %s", + partial_state_event_ids, + stack_info=True, + ) + + # create an observer for each lazy-joined event + observers: Dict[str, Deferred[None]] = { + event_id: Deferred() for event_id in partial_state_event_ids + } + for event_id, observer in observers.items(): + self._observers[event_id].add(observer) + + try: + # some of them may have been un-lazy-joined between us checking the db and + # registering the observer, in which case we'd wait forever for the + # notification. Call back the observers now. + for event_id, partial in ( + await self._store.get_partial_state_events(observers.keys()) + ).items(): + # there may have been a call to notify_un_partial_stated during the + # db query, so the observers may already have been called. + if not partial and not observers[event_id].called: + observers[event_id].callback(None) + + await make_deferred_yieldable( + defer.gatherResults( + observers.values(), + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) + logger.info("Events %s all un-partial-stated", observers.keys()) + finally: + # remove any observers we created. This should happen when the notification + # is received, but that might not happen for two reasons: + # (a) we're bailing out early on an exception (including us being + # cancelled during the await) + # (b) the event got de-lazy-joined before we set up the observer. + for event_id, observer in observers.items(): + observer_set = self._observers.get(event_id) + if observer_set: + observer_set.discard(observer) + if not observer_set: + del self._observers[event_id] diff --git a/tests/storage/util/__init__.py b/tests/storage/util/__init__.py new file mode 100644 index 0000000000..3a5f22c022 --- /dev/null +++ b/tests/storage/util/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py new file mode 100644 index 0000000000..303e190b6c --- /dev/null +++ b/tests/storage/util/test_partial_state_events_tracker.py @@ -0,0 +1,117 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict +from unittest import mock + +from twisted.internet.defer import CancelledError, ensureDeferred + +from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker + +from tests.unittest import TestCase + + +class PartialStateEventsTrackerTestCase(TestCase): + def setUp(self) -> None: + # the results to be returned by the mocked get_partial_state_events + self._events_dict: Dict[str, bool] = {} + + async def get_partial_state_events(events): + return {e: self._events_dict[e] for e in events} + + self.mock_store = mock.Mock(spec_set=["get_partial_state_events"]) + self.mock_store.get_partial_state_events.side_effect = get_partial_state_events + + self.tracker = PartialStateEventsTracker(self.mock_store) + + def test_does_not_block_for_full_state_events(self): + self._events_dict = {"event1": False, "event2": False} + + self.successResultOf( + ensureDeferred(self.tracker.await_full_state(["event1", "event2"])) + ) + + self.mock_store.get_partial_state_events.assert_called_once_with( + ["event1", "event2"] + ) + + def test_blocks_for_partial_state_events(self): + self._events_dict = {"event1": True, "event2": False} + + d = ensureDeferred(self.tracker.await_full_state(["event1", "event2"])) + + # there should be no result yet + self.assertNoResult(d) + + # notifying that the event has been de-partial-stated should unblock + self.tracker.notify_un_partial_stated("event1") + self.successResultOf(d) + + def test_un_partial_state_race(self): + # if the event is un-partial-stated between the initial check and the + # registration of the listener, it should not block. + self._events_dict = {"event1": True, "event2": False} + + async def get_partial_state_events(events): + res = {e: self._events_dict[e] for e in events} + # change the result for next time + self._events_dict = {"event1": False, "event2": False} + return res + + self.mock_store.get_partial_state_events.side_effect = get_partial_state_events + + self.successResultOf( + ensureDeferred(self.tracker.await_full_state(["event1", "event2"])) + ) + + def test_un_partial_state_during_get_partial_state_events(self): + # we should correctly handle a call to notify_un_partial_stated during the + # second call to get_partial_state_events. + + self._events_dict = {"event1": True, "event2": False} + + async def get_partial_state_events1(events): + self.mock_store.get_partial_state_events.side_effect = ( + get_partial_state_events2 + ) + return {e: self._events_dict[e] for e in events} + + async def get_partial_state_events2(events): + self.tracker.notify_un_partial_stated("event1") + self._events_dict["event1"] = False + return {e: self._events_dict[e] for e in events} + + self.mock_store.get_partial_state_events.side_effect = get_partial_state_events1 + + self.successResultOf( + ensureDeferred(self.tracker.await_full_state(["event1", "event2"])) + ) + + def test_cancellation(self): + self._events_dict = {"event1": True, "event2": False} + + d1 = ensureDeferred(self.tracker.await_full_state(["event1", "event2"])) + self.assertNoResult(d1) + + d2 = ensureDeferred(self.tracker.await_full_state(["event1"])) + self.assertNoResult(d2) + + d1.cancel() + self.assertFailure(d1, CancelledError) + + # d2 should still be waiting! + self.assertNoResult(d2) + + self.tracker.notify_un_partial_stated("event1") + self.successResultOf(d2)