Add `Retry-After` to M_LIMIT_EXCEEDED error responses (#16136)

Implements MSC4041 behind an experimental configuration flag.
This commit is contained in:
Will Hunt 2023-08-24 15:40:26 +01:00 committed by GitHub
parent e3333bacff
commit 0538e3e2db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 73 additions and 7 deletions

View File

@ -0,0 +1 @@
Return a `Retry-After` with `M_LIMIT_EXCEEDED` error responses.

View File

@ -16,6 +16,7 @@
"""Contains exceptions and error codes.""" """Contains exceptions and error codes."""
import logging import logging
import math
import typing import typing
from enum import Enum from enum import Enum
from http import HTTPStatus from http import HTTPStatus
@ -503,6 +504,8 @@ class InvalidCaptchaError(SynapseError):
class LimitExceededError(SynapseError): class LimitExceededError(SynapseError):
"""A client has sent too many requests and is being throttled.""" """A client has sent too many requests and is being throttled."""
include_retry_after_header = False
def __init__( def __init__(
self, self,
code: int = 429, code: int = 429,
@ -510,7 +513,12 @@ class LimitExceededError(SynapseError):
retry_after_ms: Optional[int] = None, retry_after_ms: Optional[int] = None,
errcode: str = Codes.LIMIT_EXCEEDED, errcode: str = Codes.LIMIT_EXCEEDED,
): ):
super().__init__(code, msg, errcode) headers = (
{"Retry-After": str(math.ceil(retry_after_ms / 1000))}
if self.include_retry_after_header and retry_after_ms is not None
else None
)
super().__init__(code, msg, errcode, headers=headers)
self.retry_after_ms = retry_after_ms self.retry_after_ms = retry_after_ms
def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":

View File

@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Any, Optional
import attr import attr
import attr.validators import attr.validators
from synapse.api.errors import LimitExceededError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.config._base import Config, RootConfig from synapse.config._base import Config, RootConfig
@ -406,3 +407,11 @@ class ExperimentalConfig(Config):
self.msc4010_push_rules_account_data = experimental.get( self.msc4010_push_rules_account_data = experimental.get(
"msc4010_push_rules_account_data", False "msc4010_push_rules_account_data", False
) )
# MSC4041: Use HTTP header Retry-After to enable library-assisted retry handling
#
# This is a bit hacky, but the most reasonable way to *alway* include the
# headers.
LimitExceededError.include_retry_after_header = experimental.get(
"msc4041_enabled", False
)

36
tests/api/test_errors.py Normal file
View File

@ -0,0 +1,36 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.errors import LimitExceededError
from tests import unittest
class ErrorsTestCase(unittest.TestCase):
# Create a sub-class to avoid mutating the class-level property.
class LimitExceededErrorHeaders(LimitExceededError):
include_retry_after_header = True
def test_limit_exceeded_header(self) -> None:
err = ErrorsTestCase.LimitExceededErrorHeaders(retry_after_ms=100)
self.assertEqual(err.error_dict(None).get("retry_after_ms"), 100)
assert err.headers is not None
self.assertEqual(err.headers.get("Retry-After"), "1")
def test_limit_exceeded_rounding(self) -> None:
err = ErrorsTestCase.LimitExceededErrorHeaders(retry_after_ms=3001)
self.assertEqual(err.error_dict(None).get("retry_after_ms"), 3001)
assert err.headers is not None
self.assertEqual(err.headers.get("Retry-After"), "4")

View File

@ -169,7 +169,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# which sets these values to 10000, but as we're overriding the entire # which sets these values to 10000, but as we're overriding the entire
# rc_login dict here, we need to set this manually as well # rc_login dict here, we need to set this manually as well
"account": {"per_second": 10000, "burst_count": 10000}, "account": {"per_second": 10000, "burst_count": 10000},
} },
"experimental_features": {"msc4041_enabled": True},
} }
) )
def test_POST_ratelimiting_per_address(self) -> None: def test_POST_ratelimiting_per_address(self) -> None:
@ -189,12 +190,15 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
if i == 5: if i == 5:
self.assertEqual(channel.code, 429, msg=channel.result) self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"]) retry_after_ms = int(channel.json_body["retry_after_ms"])
retry_header = channel.headers.getRawHeaders("Retry-After")
else: else:
self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min. # than 1min.
self.assertTrue(retry_after_ms < 6000) self.assertLess(retry_after_ms, 6000)
assert retry_header
self.assertLessEqual(int(retry_header[0]), 6)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0) self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
@ -217,7 +221,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# which sets these values to 10000, but as we're overriding the entire # which sets these values to 10000, but as we're overriding the entire
# rc_login dict here, we need to set this manually as well # rc_login dict here, we need to set this manually as well
"address": {"per_second": 10000, "burst_count": 10000}, "address": {"per_second": 10000, "burst_count": 10000},
} },
"experimental_features": {"msc4041_enabled": True},
} }
) )
def test_POST_ratelimiting_per_account(self) -> None: def test_POST_ratelimiting_per_account(self) -> None:
@ -234,12 +239,15 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
if i == 5: if i == 5:
self.assertEqual(channel.code, 429, msg=channel.result) self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"]) retry_after_ms = int(channel.json_body["retry_after_ms"])
retry_header = channel.headers.getRawHeaders("Retry-After")
else: else:
self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min. # than 1min.
self.assertTrue(retry_after_ms < 6000) self.assertLess(retry_after_ms, 6000)
assert retry_header
self.assertLessEqual(int(retry_header[0]), 6)
self.reactor.advance(retry_after_ms / 1000.0) self.reactor.advance(retry_after_ms / 1000.0)
@ -262,7 +270,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# rc_login dict here, we need to set this manually as well # rc_login dict here, we need to set this manually as well
"address": {"per_second": 10000, "burst_count": 10000}, "address": {"per_second": 10000, "burst_count": 10000},
"failed_attempts": {"per_second": 0.17, "burst_count": 5}, "failed_attempts": {"per_second": 0.17, "burst_count": 5},
} },
"experimental_features": {"msc4041_enabled": True},
} }
) )
def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: def test_POST_ratelimiting_per_account_failed_attempts(self) -> None:
@ -279,12 +288,15 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
if i == 5: if i == 5:
self.assertEqual(channel.code, 429, msg=channel.result) self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"]) retry_after_ms = int(channel.json_body["retry_after_ms"])
retry_header = channel.headers.getRawHeaders("Retry-After")
else: else:
self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min. # than 1min.
self.assertTrue(retry_after_ms < 6000) self.assertLess(retry_after_ms, 6000)
assert retry_header
self.assertLessEqual(int(retry_header[0]), 6)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0) self.reactor.advance(retry_after_ms / 1000.0 + 1.0)