diff --git a/changelog.d/14812.bugfix b/changelog.d/14812.bugfix new file mode 100644 index 0000000000..94e0d70cbc --- /dev/null +++ b/changelog.d/14812.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where Synapse would exhaust the stack when processing many federation requests where the remote homeserver has disconencted early. diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 3cb1e7e375..be696c304b 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -310,6 +310,7 @@ class UsernameAvailabilityRestServlet(RestServlet): self.hs = hs self.registration_handler = hs.get_registration_handler() self.ratelimiter = FederationRateLimiter( + hs.get_reactor(), hs.get_clock(), FederationRatelimitSettings( # Time window of 2s diff --git a/synapse/server.py b/synapse/server.py index f4ab94c4f3..c8752baa5a 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -768,6 +768,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_federation_ratelimiter(self) -> FederationRateLimiter: return FederationRateLimiter( + self.get_reactor(), self.get_clock(), config=self.config.ratelimiting.rc_federation, metrics_name="federation_servlets", diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 2aceb1a47f..bd72947bfe 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -34,6 +34,7 @@ from prometheus_client.core import Counter from typing_extensions import ContextManager from twisted.internet import defer +from twisted.internet.interfaces import IReactorTime from synapse.api.errors import LimitExceededError from synapse.config.ratelimiting import FederationRatelimitSettings @@ -146,12 +147,14 @@ class FederationRateLimiter: def __init__( self, + reactor: IReactorTime, clock: Clock, config: FederationRatelimitSettings, metrics_name: Optional[str] = None, ): """ Args: + reactor clock config metrics_name: The name of the rate limiter so we can differentiate it @@ -163,7 +166,7 @@ class FederationRateLimiter: def new_limiter() -> "_PerHostRatelimiter": return _PerHostRatelimiter( - clock=clock, config=config, metrics_name=metrics_name + reactor=reactor, clock=clock, config=config, metrics_name=metrics_name ) self.ratelimiters: DefaultDict[ @@ -194,12 +197,14 @@ class FederationRateLimiter: class _PerHostRatelimiter: def __init__( self, + reactor: IReactorTime, clock: Clock, config: FederationRatelimitSettings, metrics_name: Optional[str] = None, ): """ Args: + reactor clock config metrics_name: The name of the rate limiter so we can differentiate it @@ -207,6 +212,7 @@ class _PerHostRatelimiter: for this rate limiter. from the rest in the metrics """ + self.reactor = reactor self.clock = clock self.metrics_name = metrics_name @@ -364,12 +370,22 @@ class _PerHostRatelimiter: def _on_exit(self, request_id: object) -> None: logger.debug("Ratelimit(%s) [%s]: Processed req", self.host, id(request_id)) - self.current_processing.discard(request_id) - try: - # start processing the next item on the queue. - _, deferred = self.ready_request_queue.popitem(last=False) - with PreserveLoggingContext(): - deferred.callback(None) - except KeyError: - pass + # When requests complete synchronously, we will recursively start the next + # request in the queue. To avoid stack exhaustion, we defer starting the next + # request until the next reactor tick. + + def start_next_request() -> None: + # We only remove the completed request from the list when we're about to + # start the next one, otherwise we can allow extra requests through. + self.current_processing.discard(request_id) + try: + # start processing the next item on the queue. + _, deferred = self.ready_request_queue.popitem(last=False) + + with PreserveLoggingContext(): + deferred.callback(None) + except KeyError: + pass + + self.reactor.callLater(0.0, start_next_request) diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py index 5b327b390e..2f3ea15b96 100644 --- a/tests/util/test_ratelimitutils.py +++ b/tests/util/test_ratelimitutils.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Optional +from twisted.internet import defer from twisted.internet.defer import Deferred from synapse.config.homeserver import HomeServerConfig @@ -29,7 +30,7 @@ class FederationRateLimiterTestCase(TestCase): """A simple test with the default values""" reactor, clock = get_clock() rc_config = build_rc_config() - ratelimiter = FederationRateLimiter(clock, rc_config) + ratelimiter = FederationRateLimiter(reactor, clock, rc_config) with ratelimiter.ratelimit("testhost") as d1: # shouldn't block @@ -39,7 +40,7 @@ class FederationRateLimiterTestCase(TestCase): """Test what happens when we hit the concurrent limit""" reactor, clock = get_clock() rc_config = build_rc_config({"rc_federation": {"concurrent": 2}}) - ratelimiter = FederationRateLimiter(clock, rc_config) + ratelimiter = FederationRateLimiter(reactor, clock, rc_config) with ratelimiter.ratelimit("testhost") as d1: # shouldn't block @@ -57,6 +58,7 @@ class FederationRateLimiterTestCase(TestCase): # ... until we complete an earlier request cm2.__exit__(None, None, None) + reactor.advance(0.0) self.successResultOf(d3) def test_sleep_limit(self) -> None: @@ -65,7 +67,7 @@ class FederationRateLimiterTestCase(TestCase): rc_config = build_rc_config( {"rc_federation": {"sleep_limit": 2, "sleep_delay": 500}} ) - ratelimiter = FederationRateLimiter(clock, rc_config) + ratelimiter = FederationRateLimiter(reactor, clock, rc_config) with ratelimiter.ratelimit("testhost") as d1: # shouldn't block @@ -81,6 +83,43 @@ class FederationRateLimiterTestCase(TestCase): sleep_time = _await_resolution(reactor, d3) self.assertAlmostEqual(sleep_time, 500, places=3) + def test_lots_of_queued_things(self) -> None: + """Tests lots of synchronous things queued up behind a slow thing. + + The stack should *not* explode when the slow thing completes. + """ + reactor, clock = get_clock() + rc_config = build_rc_config( + { + "rc_federation": { + "sleep_limit": 1000000000, # never sleep + "reject_limit": 1000000000, # never reject requests + "concurrent": 1, + } + } + ) + ratelimiter = FederationRateLimiter(reactor, clock, rc_config) + + with ratelimiter.ratelimit("testhost") as d: + # shouldn't block + self.successResultOf(d) + + async def task() -> None: + with ratelimiter.ratelimit("testhost") as d: + await d + + for _ in range(1, 100): + defer.ensureDeferred(task()) + + last_task = defer.ensureDeferred(task()) + + # Upon exiting the context manager, all the synchronous things will resume. + # If a stack overflow occurs, the final task will not complete. + + # Wait for all the things to complete. + reactor.advance(0.0) + self.successResultOf(last_task) + def _await_resolution(reactor: ThreadedMemoryReactorClock, d: Deferred) -> float: """advance the clock until the deferred completes.