Consolidate logic to check for deactivated users. (#15634)
This moves the deactivated user check to the method which all login types call. Additionally updates the application service tests to be more realistic by removing invalid tests and fixing server names.
This commit is contained in:
parent
1df0221bda
commit
7c9b91790c
|
@ -0,0 +1 @@
|
|||
Fix a long-standing bug where deactivated users were able to login in uncommon situations.
|
|
@ -46,6 +46,9 @@ instead.
|
|||
|
||||
If the authentication is unsuccessful, the module must return `None`.
|
||||
|
||||
Note that the user is not automatically registered, the `register_user(..)` method of
|
||||
the [module API](writing_a_module.html) can be used to lazily create users.
|
||||
|
||||
If multiple modules register an auth checker for the same login type but with different
|
||||
fields, Synapse will refuse to start.
|
||||
|
||||
|
|
|
@ -86,6 +86,7 @@ class ApplicationService:
|
|||
url.rstrip("/") if isinstance(url, str) else None
|
||||
) # url must not end with a slash
|
||||
self.hs_token = hs_token
|
||||
# The full Matrix ID for this application service's sender.
|
||||
self.sender = sender
|
||||
self.namespaces = self._check_namespaces(namespaces)
|
||||
self.id = id
|
||||
|
@ -212,7 +213,7 @@ class ApplicationService:
|
|||
True if the application service is interested in the user, False if not.
|
||||
"""
|
||||
return (
|
||||
# User is the appservice's sender_localpart user
|
||||
# User is the appservice's configured sender_localpart user
|
||||
user_id == self.sender
|
||||
# User is in the appservice's user namespace
|
||||
or self.is_user_in_namespace(user_id)
|
||||
|
|
|
@ -52,7 +52,6 @@ from synapse.api.errors import (
|
|||
NotFoundError,
|
||||
StoreError,
|
||||
SynapseError,
|
||||
UserDeactivatedError,
|
||||
)
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.handlers.ui_auth import (
|
||||
|
@ -1419,12 +1418,6 @@ class AuthHandler:
|
|||
return None
|
||||
(user_id, password_hash) = lookupres
|
||||
|
||||
# If the password hash is None, the account has likely been deactivated
|
||||
if not password_hash:
|
||||
deactivated = await self.store.get_user_deactivated_status(user_id)
|
||||
if deactivated:
|
||||
raise UserDeactivatedError("This account has been deactivated")
|
||||
|
||||
result = await self.validate_hash(password, password_hash)
|
||||
if not result:
|
||||
logger.warning("Failed password login for user %s", user_id)
|
||||
|
@ -1749,8 +1742,11 @@ class AuthHandler:
|
|||
registered.
|
||||
auth_provider_session_id: The session ID from the SSO IdP received during login.
|
||||
"""
|
||||
# If the account has been deactivated, do not proceed with the login
|
||||
# flow.
|
||||
# If the account has been deactivated, do not proceed with the login.
|
||||
#
|
||||
# This gets checked again when the token is submitted but this lets us
|
||||
# provide an HTML error page to the user (instead of issuing a token and
|
||||
# having it error later).
|
||||
deactivated = await self.store.get_user_deactivated_status(registered_user_id)
|
||||
if deactivated:
|
||||
respond_with_html(request, 403, self._sso_account_deactivated_template)
|
||||
|
|
|
@ -16,7 +16,7 @@ from typing import TYPE_CHECKING
|
|||
from authlib.jose import JsonWebToken, JWTClaims
|
||||
from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError
|
||||
|
||||
from synapse.api.errors import Codes, LoginError, StoreError, UserDeactivatedError
|
||||
from synapse.api.errors import Codes, LoginError
|
||||
from synapse.types import JsonDict, UserID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -26,7 +26,6 @@ if TYPE_CHECKING:
|
|||
class JwtHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self._main_store = hs.get_datastores().main
|
||||
|
||||
self.jwt_secret = hs.config.jwt.jwt_secret
|
||||
self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim
|
||||
|
@ -34,7 +33,7 @@ class JwtHandler:
|
|||
self.jwt_issuer = hs.config.jwt.jwt_issuer
|
||||
self.jwt_audiences = hs.config.jwt.jwt_audiences
|
||||
|
||||
async def validate_login(self, login_submission: JsonDict) -> str:
|
||||
def validate_login(self, login_submission: JsonDict) -> str:
|
||||
"""
|
||||
Authenticates the user for the /login API
|
||||
|
||||
|
@ -103,16 +102,4 @@ class JwtHandler:
|
|||
if user is None:
|
||||
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
|
||||
|
||||
user_id = UserID(user, self.hs.hostname).to_string()
|
||||
|
||||
# If the account has been deactivated, do not proceed with the login
|
||||
# flow.
|
||||
try:
|
||||
deactivated = await self._main_store.get_user_deactivated_status(user_id)
|
||||
except StoreError:
|
||||
# JWT lazily creates users, so they may not exist in the database yet.
|
||||
deactivated = False
|
||||
if deactivated:
|
||||
raise UserDeactivatedError("This account has been deactivated")
|
||||
|
||||
return user_id
|
||||
return UserID(user, self.hs.hostname).to_string()
|
||||
|
|
|
@ -35,6 +35,7 @@ from synapse.api.errors import (
|
|||
LoginError,
|
||||
NotApprovedError,
|
||||
SynapseError,
|
||||
UserDeactivatedError,
|
||||
)
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.api.urls import CLIENT_API_PREFIX
|
||||
|
@ -84,6 +85,7 @@ class LoginRestServlet(RestServlet):
|
|||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self._main_store = hs.get_datastores().main
|
||||
|
||||
# JWT configuration variables.
|
||||
self.jwt_enabled = hs.config.jwt.jwt_enabled
|
||||
|
@ -112,13 +114,13 @@ class LoginRestServlet(RestServlet):
|
|||
|
||||
self._well_known_builder = WellKnownBuilder(hs)
|
||||
self._address_ratelimiter = Ratelimiter(
|
||||
store=hs.get_datastores().main,
|
||||
store=self._main_store,
|
||||
clock=hs.get_clock(),
|
||||
rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second,
|
||||
burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count,
|
||||
)
|
||||
self._account_ratelimiter = Ratelimiter(
|
||||
store=hs.get_datastores().main,
|
||||
store=self._main_store,
|
||||
clock=hs.get_clock(),
|
||||
rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second,
|
||||
burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count,
|
||||
|
@ -280,6 +282,9 @@ class LoginRestServlet(RestServlet):
|
|||
login_submission,
|
||||
ratelimit=appservice.is_rate_limited(),
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
# The user represented by an appservice's configured sender_localpart
|
||||
# is not actually created in Synapse.
|
||||
should_check_deactivated=qualified_user_id != appservice.sender,
|
||||
)
|
||||
|
||||
async def _do_other_login(
|
||||
|
@ -326,6 +331,7 @@ class LoginRestServlet(RestServlet):
|
|||
auth_provider_id: Optional[str] = None,
|
||||
should_issue_refresh_token: bool = False,
|
||||
auth_provider_session_id: Optional[str] = None,
|
||||
should_check_deactivated: bool = True,
|
||||
) -> LoginResponse:
|
||||
"""Called when we've successfully authed the user and now need to
|
||||
actually login them in (e.g. create devices). This gets called on
|
||||
|
@ -345,6 +351,11 @@ class LoginRestServlet(RestServlet):
|
|||
should_issue_refresh_token: True if this login should issue
|
||||
a refresh token alongside the access token.
|
||||
auth_provider_session_id: The session ID got during login from the SSO IdP.
|
||||
should_check_deactivated: True if the user should be checked for
|
||||
deactivation status before logging in.
|
||||
|
||||
This exists purely for appservice's configured sender_localpart
|
||||
which doesn't have an associated user in the database.
|
||||
|
||||
Returns:
|
||||
Dictionary of account information after successful login.
|
||||
|
@ -364,6 +375,12 @@ class LoginRestServlet(RestServlet):
|
|||
)
|
||||
user_id = canonical_uid
|
||||
|
||||
# If the account has been deactivated, do not proceed with the login.
|
||||
if should_check_deactivated:
|
||||
deactivated = await self._main_store.get_user_deactivated_status(user_id)
|
||||
if deactivated:
|
||||
raise UserDeactivatedError("This account has been deactivated")
|
||||
|
||||
device_id = login_submission.get("device_id")
|
||||
|
||||
# If device_id is present, check that device_id is not longer than a reasonable 512 characters
|
||||
|
@ -458,7 +475,7 @@ class LoginRestServlet(RestServlet):
|
|||
Returns:
|
||||
The body of the JSON response.
|
||||
"""
|
||||
user_id = await self.hs.get_jwt_handler().validate_login(login_submission)
|
||||
user_id = self.hs.get_jwt_handler().validate_login(login_submission)
|
||||
return await self._complete_login(
|
||||
user_id,
|
||||
login_submission,
|
||||
|
|
|
@ -18,13 +18,17 @@ from http import HTTPStatus
|
|||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from unittest.mock import Mock
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
import synapse
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.handlers.account import AccountHandler
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.rest.client import account, devices, login, logout, register
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeChannel
|
||||
|
@ -162,10 +166,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
CALLBACK_USERNAME = "get_username_for_registration"
|
||||
CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
|
||||
|
||||
def setUp(self) -> None:
|
||||
def prepare(
|
||||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||
) -> None:
|
||||
# we use a global mock device, so make sure we are starting with a clean slate
|
||||
mock_password_provider.reset_mock()
|
||||
super().setUp()
|
||||
|
||||
# The mock password provider doesn't register the users, so ensure they
|
||||
# are registered first.
|
||||
self.register_user("u", "not-the-tested-password")
|
||||
self.register_user("user", "not-the-tested-password")
|
||||
|
||||
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
|
||||
def test_password_only_auth_progiver_login_legacy(self) -> None:
|
||||
|
@ -185,22 +195,12 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
mock_password_provider.reset_mock()
|
||||
|
||||
# login with mxid should work too
|
||||
channel = self._send_password_login("@u:bz", "p")
|
||||
channel = self._send_password_login("@u:test", "p")
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual("@u:bz", channel.json_body["user_id"])
|
||||
mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
|
||||
self.assertEqual("@u:test", channel.json_body["user_id"])
|
||||
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
|
||||
mock_password_provider.reset_mock()
|
||||
|
||||
# try a weird username / pass. Honestly it's unclear what we *expect* to happen
|
||||
# in these cases, but at least we can guard against the API changing
|
||||
# unexpectedly
|
||||
channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
|
||||
mock_password_provider.check_password.assert_called_once_with(
|
||||
"@ USER🙂NAME :test", " pASS😢word "
|
||||
)
|
||||
|
||||
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
|
||||
def test_password_only_auth_provider_ui_auth_legacy(self) -> None:
|
||||
self.password_only_auth_provider_ui_auth_test_body()
|
||||
|
@ -208,10 +208,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
def password_only_auth_provider_ui_auth_test_body(self) -> None:
|
||||
"""UI Auth should delegate correctly to the password provider"""
|
||||
|
||||
# create the user, otherwise access doesn't work
|
||||
module_api = self.hs.get_module_api()
|
||||
self.get_success(module_api.register_user("u"))
|
||||
|
||||
# log in twice, to get two devices
|
||||
mock_password_provider.check_password.return_value = make_awaitable(True)
|
||||
tok1 = self.login("u", "p")
|
||||
|
@ -401,29 +397,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
mock_password_provider.check_auth.assert_not_called()
|
||||
|
||||
mock_password_provider.check_auth.return_value = make_awaitable(
|
||||
("@user:bz", None)
|
||||
("@user:test", None)
|
||||
)
|
||||
channel = self._send_login("test.login_type", "u", test_field="y")
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual("@user:bz", channel.json_body["user_id"])
|
||||
self.assertEqual("@user:test", channel.json_body["user_id"])
|
||||
mock_password_provider.check_auth.assert_called_once_with(
|
||||
"u", "test.login_type", {"test_field": "y"}
|
||||
)
|
||||
mock_password_provider.reset_mock()
|
||||
|
||||
# try a weird username. Again, it's unclear what we *expect* to happen
|
||||
# in these cases, but at least we can guard against the API changing
|
||||
# unexpectedly
|
||||
mock_password_provider.check_auth.return_value = make_awaitable(
|
||||
("@ MALFORMED! :bz", None)
|
||||
)
|
||||
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
|
||||
mock_password_provider.check_auth.assert_called_once_with(
|
||||
" USER🙂NAME ", "test.login_type", {"test_field": " abc "}
|
||||
)
|
||||
|
||||
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
|
||||
def test_custom_auth_provider_ui_auth_legacy(self) -> None:
|
||||
self.custom_auth_provider_ui_auth_test_body()
|
||||
|
@ -465,7 +448,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
|
||||
# right params, but authing as the wrong user
|
||||
mock_password_provider.check_auth.return_value = make_awaitable(
|
||||
("@user:bz", None)
|
||||
("@user:test", None)
|
||||
)
|
||||
body["auth"]["test_field"] = "foo"
|
||||
channel = self._delete_device(tok1, "dev2", body)
|
||||
|
@ -498,11 +481,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
callback = Mock(return_value=make_awaitable(None))
|
||||
|
||||
mock_password_provider.check_auth.return_value = make_awaitable(
|
||||
("@user:bz", callback)
|
||||
("@user:test", callback)
|
||||
)
|
||||
channel = self._send_login("test.login_type", "u", test_field="y")
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual("@user:bz", channel.json_body["user_id"])
|
||||
self.assertEqual("@user:test", channel.json_body["user_id"])
|
||||
mock_password_provider.check_auth.assert_called_once_with(
|
||||
"u", "test.login_type", {"test_field": "y"}
|
||||
)
|
||||
|
@ -512,7 +495,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
call_args, call_kwargs = callback.call_args
|
||||
# should be one positional arg
|
||||
self.assertEqual(len(call_args), 1)
|
||||
self.assertEqual(call_args[0]["user_id"], "@user:bz")
|
||||
self.assertEqual(call_args[0]["user_id"], "@user:test")
|
||||
for p in ["user_id", "access_token", "device_id", "home_server"]:
|
||||
self.assertIn(p, call_args[0])
|
||||
|
||||
|
|
Loading…
Reference in New Issue