mirror of
https://github.com/matrix-org/synapse.git
synced 2025-01-12 19:27:55 +00:00
Decouple synapse.api.auth_blocking.AuthBlocking
from synapse.api.auth.Auth
. (#13021)
This commit is contained in:
parent
a164a46038
commit
92103cb2c8
1
changelog.d/13021.misc
Normal file
1
changelog.d/13021.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Decouple `synapse.api.auth_blocking.AuthBlocking` from `synapse.api.auth.Auth`.
|
@ -20,7 +20,6 @@ from netaddr import IPAddress
|
|||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
|
||||||
from synapse import event_auth
|
from synapse import event_auth
|
||||||
from synapse.api.auth_blocking import AuthBlocking
|
|
||||||
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
|
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError,
|
AuthError,
|
||||||
@ -67,8 +66,6 @@ class Auth:
|
|||||||
10000, "token_cache"
|
10000, "token_cache"
|
||||||
)
|
)
|
||||||
|
|
||||||
self._auth_blocking = AuthBlocking(self.hs)
|
|
||||||
|
|
||||||
self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
|
self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
|
||||||
self._track_puppeted_user_ips = hs.config.api.track_puppeted_user_ips
|
self._track_puppeted_user_ips = hs.config.api.track_puppeted_user_ips
|
||||||
self._macaroon_secret_key = hs.config.key.macaroon_secret_key
|
self._macaroon_secret_key = hs.config.key.macaroon_secret_key
|
||||||
@ -711,14 +708,3 @@ class Auth:
|
|||||||
"User %s not in room %s, and room previews are disabled"
|
"User %s not in room %s, and room previews are disabled"
|
||||||
% (user_id, room_id),
|
% (user_id, room_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def check_auth_blocking(
|
|
||||||
self,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
threepid: Optional[dict] = None,
|
|
||||||
user_type: Optional[str] = None,
|
|
||||||
requester: Optional[Requester] = None,
|
|
||||||
) -> None:
|
|
||||||
await self._auth_blocking.check_auth_blocking(
|
|
||||||
user_id=user_id, threepid=threepid, user_type=user_type, requester=requester
|
|
||||||
)
|
|
||||||
|
@ -199,6 +199,7 @@ class AuthHandler:
|
|||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
self.auth_blocking = hs.get_auth_blocking()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
|
self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
|
||||||
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
|
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
|
||||||
@ -985,7 +986,7 @@ class AuthHandler:
|
|||||||
not is_appservice_ghost
|
not is_appservice_ghost
|
||||||
or self.hs.config.appservice.track_appservice_user_ips
|
or self.hs.config.appservice.track_appservice_user_ips
|
||||||
):
|
):
|
||||||
await self.auth.check_auth_blocking(user_id)
|
await self.auth_blocking.check_auth_blocking(user_id)
|
||||||
|
|
||||||
access_token = self.generate_access_token(target_user_id_obj)
|
access_token = self.generate_access_token(target_user_id_obj)
|
||||||
await self.store.add_access_token_to_user(
|
await self.store.add_access_token_to_user(
|
||||||
@ -1439,7 +1440,7 @@ class AuthHandler:
|
|||||||
except Exception:
|
except Exception:
|
||||||
raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
|
raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
await self.auth.check_auth_blocking(res.user_id)
|
await self.auth_blocking.check_auth_blocking(res.user_id)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
async def delete_access_token(self, access_token: str) -> None:
|
async def delete_access_token(self, access_token: str) -> None:
|
||||||
|
@ -444,7 +444,7 @@ _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY = 7 * 24 * 60 * 60 * 1000
|
|||||||
class EventCreationHandler:
|
class EventCreationHandler:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth_blocking = hs.get_auth_blocking()
|
||||||
self._event_auth_handler = hs.get_event_auth_handler()
|
self._event_auth_handler = hs.get_event_auth_handler()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self._storage_controllers = hs.get_storage_controllers()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
@ -605,7 +605,7 @@ class EventCreationHandler:
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of created event, Context
|
Tuple of created event, Context
|
||||||
"""
|
"""
|
||||||
await self.auth.check_auth_blocking(requester=requester)
|
await self.auth_blocking.check_auth_blocking(requester=requester)
|
||||||
|
|
||||||
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
|
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
|
||||||
room_version_id = event_dict["content"]["room_version"]
|
room_version_id = event_dict["content"]["room_version"]
|
||||||
|
@ -91,6 +91,7 @@ class RegistrationHandler:
|
|||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
self.auth_blocking = hs.get_auth_blocking()
|
||||||
self._auth_handler = hs.get_auth_handler()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
self.profile_handler = hs.get_profile_handler()
|
self.profile_handler = hs.get_profile_handler()
|
||||||
self.user_directory_handler = hs.get_user_directory_handler()
|
self.user_directory_handler = hs.get_user_directory_handler()
|
||||||
@ -276,7 +277,7 @@ class RegistrationHandler:
|
|||||||
|
|
||||||
# do not check_auth_blocking if the call is coming through the Admin API
|
# do not check_auth_blocking if the call is coming through the Admin API
|
||||||
if not by_admin:
|
if not by_admin:
|
||||||
await self.auth.check_auth_blocking(threepid=threepid)
|
await self.auth_blocking.check_auth_blocking(threepid=threepid)
|
||||||
|
|
||||||
if localpart is not None:
|
if localpart is not None:
|
||||||
await self.check_username(localpart, guest_access_token=guest_access_token)
|
await self.check_username(localpart, guest_access_token=guest_access_token)
|
||||||
|
@ -110,6 +110,7 @@ class RoomCreationHandler:
|
|||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self._storage_controllers = hs.get_storage_controllers()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
self.auth_blocking = hs.get_auth_blocking()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.spam_checker = hs.get_spam_checker()
|
self.spam_checker = hs.get_spam_checker()
|
||||||
@ -706,7 +707,7 @@ class RoomCreationHandler:
|
|||||||
"""
|
"""
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
await self.auth.check_auth_blocking(requester=requester)
|
await self.auth_blocking.check_auth_blocking(requester=requester)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self._server_notices_mxid is not None
|
self._server_notices_mxid is not None
|
||||||
|
@ -237,7 +237,7 @@ class SyncHandler:
|
|||||||
self.event_sources = hs.get_event_sources()
|
self.event_sources = hs.get_event_sources()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.auth = hs.get_auth()
|
self.auth_blocking = hs.get_auth_blocking()
|
||||||
self._storage_controllers = hs.get_storage_controllers()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self._state_storage_controller = self._storage_controllers.state
|
self._state_storage_controller = self._storage_controllers.state
|
||||||
|
|
||||||
@ -280,7 +280,7 @@ class SyncHandler:
|
|||||||
# not been exceeded (if not part of the group by this point, almost certain
|
# not been exceeded (if not part of the group by this point, almost certain
|
||||||
# auth_blocking will occur)
|
# auth_blocking will occur)
|
||||||
user_id = sync_config.user.to_string()
|
user_id = sync_config.user.to_string()
|
||||||
await self.auth.check_auth_blocking(requester=requester)
|
await self.auth_blocking.check_auth_blocking(requester=requester)
|
||||||
|
|
||||||
res = await self.response_cache.wrap(
|
res = await self.response_cache.wrap(
|
||||||
sync_config.request_key,
|
sync_config.request_key,
|
||||||
|
@ -29,6 +29,7 @@ from twisted.web.iweb import IPolicyForHTTPS
|
|||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
from synapse.api.auth import Auth
|
from synapse.api.auth import Auth
|
||||||
|
from synapse.api.auth_blocking import AuthBlocking
|
||||||
from synapse.api.filtering import Filtering
|
from synapse.api.filtering import Filtering
|
||||||
from synapse.api.ratelimiting import Ratelimiter, RequestRatelimiter
|
from synapse.api.ratelimiting import Ratelimiter, RequestRatelimiter
|
||||||
from synapse.appservice.api import ApplicationServiceApi
|
from synapse.appservice.api import ApplicationServiceApi
|
||||||
@ -379,6 +380,10 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||||||
def get_auth(self) -> Auth:
|
def get_auth(self) -> Auth:
|
||||||
return Auth(self)
|
return Auth(self)
|
||||||
|
|
||||||
|
@cache_in_self
|
||||||
|
def get_auth_blocking(self) -> AuthBlocking:
|
||||||
|
return AuthBlocking(self)
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_http_client_context_factory(self) -> IPolicyForHTTPS:
|
def get_http_client_context_factory(self) -> IPolicyForHTTPS:
|
||||||
if self.config.tls.use_insecure_ssl_client_just_for_testing_do_not_use:
|
if self.config.tls.use_insecure_ssl_client_just_for_testing_do_not_use:
|
||||||
|
@ -37,7 +37,7 @@ class ResourceLimitsServerNotices:
|
|||||||
self._server_notices_manager = hs.get_server_notices_manager()
|
self._server_notices_manager = hs.get_server_notices_manager()
|
||||||
self._store = hs.get_datastores().main
|
self._store = hs.get_datastores().main
|
||||||
self._storage_controllers = hs.get_storage_controllers()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self._auth = hs.get_auth()
|
self._auth_blocking = hs.get_auth_blocking()
|
||||||
self._config = hs.config
|
self._config = hs.config
|
||||||
self._resouce_limited = False
|
self._resouce_limited = False
|
||||||
self._account_data_handler = hs.get_account_data_handler()
|
self._account_data_handler = hs.get_account_data_handler()
|
||||||
@ -91,7 +91,7 @@ class ResourceLimitsServerNotices:
|
|||||||
# Normally should always pass in user_id to check_auth_blocking
|
# Normally should always pass in user_id to check_auth_blocking
|
||||||
# if you have it, but in this case are checking what would happen
|
# if you have it, but in this case are checking what would happen
|
||||||
# to other users if they were to arrive.
|
# to other users if they were to arrive.
|
||||||
await self._auth.check_auth_blocking()
|
await self._auth_blocking.check_auth_blocking()
|
||||||
except ResourceLimitError as e:
|
except ResourceLimitError as e:
|
||||||
limit_msg = e.msg
|
limit_msg = e.msg
|
||||||
limit_type = e.limit_type
|
limit_type = e.limit_type
|
||||||
|
@ -19,6 +19,7 @@ import pymacaroons
|
|||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.auth import Auth
|
from synapse.api.auth import Auth
|
||||||
|
from synapse.api.auth_blocking import AuthBlocking
|
||||||
from synapse.api.constants import UserTypes
|
from synapse.api.constants import UserTypes
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError,
|
AuthError,
|
||||||
@ -49,7 +50,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
# AuthBlocking reads from the hs' config on initialization. We need to
|
# AuthBlocking reads from the hs' config on initialization. We need to
|
||||||
# modify its config instead of the hs'
|
# modify its config instead of the hs'
|
||||||
self.auth_blocking = self.auth._auth_blocking
|
self.auth_blocking = AuthBlocking(hs)
|
||||||
|
|
||||||
self.test_user = "@foo:bar"
|
self.test_user = "@foo:bar"
|
||||||
self.test_token = b"_test_token_"
|
self.test_token = b"_test_token_"
|
||||||
@ -362,20 +363,22 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
|||||||
small_number_of_users = 1
|
small_number_of_users = 1
|
||||||
|
|
||||||
# Ensure no error thrown
|
# Ensure no error thrown
|
||||||
self.get_success(self.auth.check_auth_blocking())
|
self.get_success(self.auth_blocking.check_auth_blocking())
|
||||||
|
|
||||||
self.auth_blocking._limit_usage_by_mau = True
|
self.auth_blocking._limit_usage_by_mau = True
|
||||||
|
|
||||||
self.store.get_monthly_active_count = simple_async_mock(lots_of_users)
|
self.store.get_monthly_active_count = simple_async_mock(lots_of_users)
|
||||||
|
|
||||||
e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
|
e = self.get_failure(
|
||||||
|
self.auth_blocking.check_auth_blocking(), ResourceLimitError
|
||||||
|
)
|
||||||
self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
|
self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
|
||||||
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||||
self.assertEqual(e.value.code, 403)
|
self.assertEqual(e.value.code, 403)
|
||||||
|
|
||||||
# Ensure does not throw an error
|
# Ensure does not throw an error
|
||||||
self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
|
self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
|
||||||
self.get_success(self.auth.check_auth_blocking())
|
self.get_success(self.auth_blocking.check_auth_blocking())
|
||||||
|
|
||||||
def test_blocking_mau__depending_on_user_type(self):
|
def test_blocking_mau__depending_on_user_type(self):
|
||||||
self.auth_blocking._max_mau_value = 50
|
self.auth_blocking._max_mau_value = 50
|
||||||
@ -383,15 +386,18 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.store.get_monthly_active_count = simple_async_mock(100)
|
self.store.get_monthly_active_count = simple_async_mock(100)
|
||||||
# Support users allowed
|
# Support users allowed
|
||||||
self.get_success(self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT))
|
self.get_success(
|
||||||
|
self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT)
|
||||||
|
)
|
||||||
self.store.get_monthly_active_count = simple_async_mock(100)
|
self.store.get_monthly_active_count = simple_async_mock(100)
|
||||||
# Bots not allowed
|
# Bots not allowed
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
self.auth.check_auth_blocking(user_type=UserTypes.BOT), ResourceLimitError
|
self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT),
|
||||||
|
ResourceLimitError,
|
||||||
)
|
)
|
||||||
self.store.get_monthly_active_count = simple_async_mock(100)
|
self.store.get_monthly_active_count = simple_async_mock(100)
|
||||||
# Real users not allowed
|
# Real users not allowed
|
||||||
self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
|
self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
|
||||||
|
|
||||||
def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self):
|
def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self):
|
||||||
self.auth_blocking._max_mau_value = 50
|
self.auth_blocking._max_mau_value = 50
|
||||||
@ -419,7 +425,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
|||||||
app_service=appservice,
|
app_service=appservice,
|
||||||
authenticated_entity="@appservice:server",
|
authenticated_entity="@appservice:server",
|
||||||
)
|
)
|
||||||
self.get_success(self.auth.check_auth_blocking(requester=requester))
|
self.get_success(self.auth_blocking.check_auth_blocking(requester=requester))
|
||||||
|
|
||||||
def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self):
|
def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self):
|
||||||
self.auth_blocking._max_mau_value = 50
|
self.auth_blocking._max_mau_value = 50
|
||||||
@ -448,7 +454,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
|||||||
authenticated_entity="@appservice:server",
|
authenticated_entity="@appservice:server",
|
||||||
)
|
)
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
self.auth.check_auth_blocking(requester=requester), ResourceLimitError
|
self.auth_blocking.check_auth_blocking(requester=requester),
|
||||||
|
ResourceLimitError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_reserved_threepid(self):
|
def test_reserved_threepid(self):
|
||||||
@ -459,18 +466,21 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
|||||||
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
|
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
|
||||||
self.auth_blocking._mau_limits_reserved_threepids = [threepid]
|
self.auth_blocking._mau_limits_reserved_threepids = [threepid]
|
||||||
|
|
||||||
self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
|
self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
|
||||||
|
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
self.auth.check_auth_blocking(threepid=unknown_threepid), ResourceLimitError
|
self.auth_blocking.check_auth_blocking(threepid=unknown_threepid),
|
||||||
|
ResourceLimitError,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(self.auth.check_auth_blocking(threepid=threepid))
|
self.get_success(self.auth_blocking.check_auth_blocking(threepid=threepid))
|
||||||
|
|
||||||
def test_hs_disabled(self):
|
def test_hs_disabled(self):
|
||||||
self.auth_blocking._hs_disabled = True
|
self.auth_blocking._hs_disabled = True
|
||||||
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
|
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
|
||||||
e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
|
e = self.get_failure(
|
||||||
|
self.auth_blocking.check_auth_blocking(), ResourceLimitError
|
||||||
|
)
|
||||||
self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
|
self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
|
||||||
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||||
self.assertEqual(e.value.code, 403)
|
self.assertEqual(e.value.code, 403)
|
||||||
@ -485,7 +495,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.auth_blocking._hs_disabled = True
|
self.auth_blocking._hs_disabled = True
|
||||||
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
|
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
|
||||||
e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
|
e = self.get_failure(
|
||||||
|
self.auth_blocking.check_auth_blocking(), ResourceLimitError
|
||||||
|
)
|
||||||
self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
|
self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
|
||||||
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||||
self.assertEqual(e.value.code, 403)
|
self.assertEqual(e.value.code, 403)
|
||||||
@ -495,4 +507,4 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
|||||||
user = "@user:server"
|
user = "@user:server"
|
||||||
self.auth_blocking._server_notices_mxid = user
|
self.auth_blocking._server_notices_mxid = user
|
||||||
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
|
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
|
||||||
self.get_success(self.auth.check_auth_blocking(user))
|
self.get_success(self.auth_blocking.check_auth_blocking(user))
|
||||||
|
@ -38,7 +38,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
|||||||
# MAU tests
|
# MAU tests
|
||||||
# AuthBlocking reads from the hs' config on initialization. We need to
|
# AuthBlocking reads from the hs' config on initialization. We need to
|
||||||
# modify its config instead of the hs'
|
# modify its config instead of the hs'
|
||||||
self.auth_blocking = hs.get_auth()._auth_blocking
|
self.auth_blocking = hs.get_auth_blocking()
|
||||||
self.auth_blocking._max_mau_value = 50
|
self.auth_blocking._max_mau_value = 50
|
||||||
|
|
||||||
self.small_number_of_users = 1
|
self.small_number_of_users = 1
|
||||||
|
@ -699,7 +699,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||||||
"""
|
"""
|
||||||
if localpart is None:
|
if localpart is None:
|
||||||
raise SynapseError(400, "Request must include user id")
|
raise SynapseError(400, "Request must include user id")
|
||||||
await self.hs.get_auth().check_auth_blocking()
|
await self.hs.get_auth_blocking().check_auth_blocking()
|
||||||
need_register = True
|
need_register = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -45,7 +45,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
# AuthBlocking reads from the hs' config on initialization. We need to
|
# AuthBlocking reads from the hs' config on initialization. We need to
|
||||||
# modify its config instead of the hs'
|
# modify its config instead of the hs'
|
||||||
self.auth_blocking = self.hs.get_auth()._auth_blocking
|
self.auth_blocking = self.hs.get_auth_blocking()
|
||||||
|
|
||||||
def test_wait_for_sync_for_user_auth_blocking(self):
|
def test_wait_for_sync_for_user_auth_blocking(self):
|
||||||
user_id1 = "@user1:test"
|
user_id1 = "@user1:test"
|
||||||
|
@ -96,7 +96,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||||||
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
|
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
|
||||||
"""Test when user has blocked notice, but should have it removed"""
|
"""Test when user has blocked notice, but should have it removed"""
|
||||||
|
|
||||||
self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
|
self._rlsn._auth_blocking.check_auth_blocking = Mock(
|
||||||
|
return_value=make_awaitable(None)
|
||||||
|
)
|
||||||
mock_event = Mock(
|
mock_event = Mock(
|
||||||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
||||||
)
|
)
|
||||||
@ -112,7 +114,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||||||
"""
|
"""
|
||||||
Test when user has blocked notice, but notice ought to be there (NOOP)
|
Test when user has blocked notice, but notice ought to be there (NOOP)
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth.check_auth_blocking = Mock(
|
self._rlsn._auth_blocking.check_auth_blocking = Mock(
|
||||||
return_value=make_awaitable(None),
|
return_value=make_awaitable(None),
|
||||||
side_effect=ResourceLimitError(403, "foo"),
|
side_effect=ResourceLimitError(403, "foo"),
|
||||||
)
|
)
|
||||||
@ -132,7 +134,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||||||
"""
|
"""
|
||||||
Test when user does not have blocked notice, but should have one
|
Test when user does not have blocked notice, but should have one
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth.check_auth_blocking = Mock(
|
self._rlsn._auth_blocking.check_auth_blocking = Mock(
|
||||||
return_value=make_awaitable(None),
|
return_value=make_awaitable(None),
|
||||||
side_effect=ResourceLimitError(403, "foo"),
|
side_effect=ResourceLimitError(403, "foo"),
|
||||||
)
|
)
|
||||||
@ -145,7 +147,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||||||
"""
|
"""
|
||||||
Test when user does not have blocked notice, nor should they (NOOP)
|
Test when user does not have blocked notice, nor should they (NOOP)
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
|
self._rlsn._auth_blocking.check_auth_blocking = Mock(
|
||||||
|
return_value=make_awaitable(None)
|
||||||
|
)
|
||||||
|
|
||||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||||
|
|
||||||
@ -156,7 +160,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||||||
Test when user is not part of the MAU cohort - this should not ever
|
Test when user is not part of the MAU cohort - this should not ever
|
||||||
happen - but ...
|
happen - but ...
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
|
self._rlsn._auth_blocking.check_auth_blocking = Mock(
|
||||||
|
return_value=make_awaitable(None)
|
||||||
|
)
|
||||||
self._rlsn._store.user_last_seen_monthly_active = Mock(
|
self._rlsn._store.user_last_seen_monthly_active = Mock(
|
||||||
return_value=make_awaitable(None)
|
return_value=make_awaitable(None)
|
||||||
)
|
)
|
||||||
@ -170,7 +176,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||||||
Test that when server is over MAU limit and alerting is suppressed, then
|
Test that when server is over MAU limit and alerting is suppressed, then
|
||||||
an alert message is not sent into the room
|
an alert message is not sent into the room
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth.check_auth_blocking = Mock(
|
self._rlsn._auth_blocking.check_auth_blocking = Mock(
|
||||||
return_value=make_awaitable(None),
|
return_value=make_awaitable(None),
|
||||||
side_effect=ResourceLimitError(
|
side_effect=ResourceLimitError(
|
||||||
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
|
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
|
||||||
@ -185,7 +191,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||||||
"""
|
"""
|
||||||
Test that when a server is disabled, that MAU limit alerting is ignored.
|
Test that when a server is disabled, that MAU limit alerting is ignored.
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth.check_auth_blocking = Mock(
|
self._rlsn._auth_blocking.check_auth_blocking = Mock(
|
||||||
return_value=make_awaitable(None),
|
return_value=make_awaitable(None),
|
||||||
side_effect=ResourceLimitError(
|
side_effect=ResourceLimitError(
|
||||||
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
|
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
|
||||||
@ -202,7 +208,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||||||
When the room is already in a blocked state, test that when alerting
|
When the room is already in a blocked state, test that when alerting
|
||||||
is suppressed that the room is returned to an unblocked state.
|
is suppressed that the room is returned to an unblocked state.
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth.check_auth_blocking = Mock(
|
self._rlsn._auth_blocking.check_auth_blocking = Mock(
|
||||||
return_value=make_awaitable(None),
|
return_value=make_awaitable(None),
|
||||||
side_effect=ResourceLimitError(
|
side_effect=ResourceLimitError(
|
||||||
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
|
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
|
||||||
|
Loading…
Reference in New Issue
Block a user