Add login spam checker API (#15838)
This commit is contained in:
parent
52d8131e87
commit
25c55a9d22
|
@ -0,0 +1 @@
|
|||
Add spam checker module API for logins.
|
|
@ -348,6 +348,42 @@ callback returns `False`, Synapse falls through to the next one. The value of th
|
|||
callback that does not return `False` will be used. If this happens, Synapse will not call
|
||||
any of the subsequent implementations of this callback.
|
||||
|
||||
|
||||
### `check_login_for_spam`
|
||||
|
||||
_First introduced in Synapse v1.87.0_
|
||||
|
||||
```python
|
||||
async def check_login_for_spam(
|
||||
user_id: str,
|
||||
device_id: Optional[str],
|
||||
initial_display_name: Optional[str],
|
||||
request_info: Collection[Tuple[Optional[str], str]],
|
||||
auth_provider_id: Optional[str] = None,
|
||||
) -> Union["synapse.module_api.NOT_SPAM", "synapse.module_api.errors.Codes"]
|
||||
```
|
||||
|
||||
Called when a user logs in.
|
||||
|
||||
The arguments passed to this callback are:
|
||||
|
||||
* `user_id`: The user ID the user is logging in with
|
||||
* `device_id`: The device ID the user is re-logging into.
|
||||
* `initial_display_name`: The device display name, if any.
|
||||
* `request_info`: A collection of tuples, which first item is a user agent, and which
|
||||
second item is an IP address. These user agents and IP addresses are the ones that were
|
||||
used during the login process.
|
||||
* `auth_provider_id`: The identifier of the SSO authentication provider, if any.
|
||||
|
||||
If multiple modules implement this callback, they will be considered in order. If a
|
||||
callback returns `synapse.module_api.NOT_SPAM`, Synapse falls through to the next one.
|
||||
The value of the first callback that does not return `synapse.module_api.NOT_SPAM` will
|
||||
be used. If this happens, Synapse will not call any of the subsequent implementations of
|
||||
this callback.
|
||||
|
||||
*Note:* This will not be called when a user registers.
|
||||
|
||||
|
||||
## Example
|
||||
|
||||
The example below is a module that implements the spam checker callback
|
||||
|
|
|
@ -521,6 +521,11 @@ class SynapseRequest(Request):
|
|||
else:
|
||||
return self.getClientAddress().host
|
||||
|
||||
def request_info(self) -> "RequestInfo":
|
||||
h = self.getHeader(b"User-Agent")
|
||||
user_agent = h.decode("ascii", "replace") if h else None
|
||||
return RequestInfo(user_agent=user_agent, ip=self.get_client_ip_if_available())
|
||||
|
||||
|
||||
class XForwardedForRequest(SynapseRequest):
|
||||
"""Request object which honours proxy headers
|
||||
|
@ -661,3 +666,9 @@ class SynapseSite(Site):
|
|||
|
||||
def log(self, request: SynapseRequest) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
||||
class RequestInfo:
|
||||
user_agent: Optional[str]
|
||||
ip: str
|
||||
|
|
|
@ -80,6 +80,7 @@ from synapse.module_api.callbacks.account_validity_callbacks import (
|
|||
)
|
||||
from synapse.module_api.callbacks.spamchecker_callbacks import (
|
||||
CHECK_EVENT_FOR_SPAM_CALLBACK,
|
||||
CHECK_LOGIN_FOR_SPAM_CALLBACK,
|
||||
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK,
|
||||
CHECK_REGISTRATION_FOR_SPAM_CALLBACK,
|
||||
CHECK_USERNAME_FOR_SPAM_CALLBACK,
|
||||
|
@ -302,6 +303,7 @@ class ModuleApi:
|
|||
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
|
||||
] = None,
|
||||
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
|
||||
check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None,
|
||||
) -> None:
|
||||
"""Registers callbacks for spam checking capabilities.
|
||||
|
||||
|
@ -319,6 +321,7 @@ class ModuleApi:
|
|||
check_username_for_spam=check_username_for_spam,
|
||||
check_registration_for_spam=check_registration_for_spam,
|
||||
check_media_file_for_spam=check_media_file_for_spam,
|
||||
check_login_for_spam=check_login_for_spam,
|
||||
)
|
||||
|
||||
def register_account_validity_callbacks(
|
||||
|
|
|
@ -196,6 +196,26 @@ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[
|
|||
]
|
||||
],
|
||||
]
|
||||
CHECK_LOGIN_FOR_SPAM_CALLBACK = Callable[
|
||||
[
|
||||
str,
|
||||
Optional[str],
|
||||
Optional[str],
|
||||
Collection[Tuple[Optional[str], str]],
|
||||
Optional[str],
|
||||
],
|
||||
Awaitable[
|
||||
Union[
|
||||
Literal["NOT_SPAM"],
|
||||
Codes,
|
||||
# Highly experimental, not officially part of the spamchecker API, may
|
||||
# disappear without warning depending on the results of ongoing
|
||||
# experiments.
|
||||
# Use this to return additional information as part of an error.
|
||||
Tuple[Codes, JsonDict],
|
||||
]
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
|
||||
|
@ -315,6 +335,7 @@ class SpamCheckerModuleApiCallbacks:
|
|||
self._check_media_file_for_spam_callbacks: List[
|
||||
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK
|
||||
] = []
|
||||
self._check_login_for_spam_callbacks: List[CHECK_LOGIN_FOR_SPAM_CALLBACK] = []
|
||||
|
||||
def register_callbacks(
|
||||
self,
|
||||
|
@ -335,6 +356,7 @@ class SpamCheckerModuleApiCallbacks:
|
|||
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
|
||||
] = None,
|
||||
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
|
||||
check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None,
|
||||
) -> None:
|
||||
"""Register callbacks from module for each hook."""
|
||||
if check_event_for_spam is not None:
|
||||
|
@ -378,6 +400,9 @@ class SpamCheckerModuleApiCallbacks:
|
|||
if check_media_file_for_spam is not None:
|
||||
self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam)
|
||||
|
||||
if check_login_for_spam is not None:
|
||||
self._check_login_for_spam_callbacks.append(check_login_for_spam)
|
||||
|
||||
@trace
|
||||
async def check_event_for_spam(
|
||||
self, event: "synapse.events.EventBase"
|
||||
|
@ -819,3 +844,58 @@ class SpamCheckerModuleApiCallbacks:
|
|||
return synapse.api.errors.Codes.FORBIDDEN, {}
|
||||
|
||||
return self.NOT_SPAM
|
||||
|
||||
async def check_login_for_spam(
|
||||
self,
|
||||
user_id: str,
|
||||
device_id: Optional[str],
|
||||
initial_display_name: Optional[str],
|
||||
request_info: Collection[Tuple[Optional[str], str]],
|
||||
auth_provider_id: Optional[str] = None,
|
||||
) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
|
||||
"""Checks if we should allow the given registration request.
|
||||
|
||||
Args:
|
||||
user_id: The request user ID
|
||||
request_info: List of tuples of user agent and IP that
|
||||
were used during the registration process.
|
||||
auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml",
|
||||
"cas". If any. Note this does not include users registered
|
||||
via a password provider.
|
||||
|
||||
Returns:
|
||||
Enum for how the request should be handled
|
||||
"""
|
||||
|
||||
for callback in self._check_login_for_spam_callbacks:
|
||||
with Measure(
|
||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
|
||||
):
|
||||
res = await delay_cancellation(
|
||||
callback(
|
||||
user_id,
|
||||
device_id,
|
||||
initial_display_name,
|
||||
request_info,
|
||||
auth_provider_id,
|
||||
)
|
||||
)
|
||||
# Normalize return values to `Codes` or `"NOT_SPAM"`.
|
||||
if res is self.NOT_SPAM:
|
||||
continue
|
||||
elif isinstance(res, synapse.api.errors.Codes):
|
||||
return res, {}
|
||||
elif (
|
||||
isinstance(res, tuple)
|
||||
and len(res) == 2
|
||||
and isinstance(res[0], synapse.api.errors.Codes)
|
||||
and isinstance(res[1], dict)
|
||||
):
|
||||
return res
|
||||
else:
|
||||
logger.warning(
|
||||
"Module returned invalid value, rejecting login as spam"
|
||||
)
|
||||
return synapse.api.errors.Codes.FORBIDDEN, {}
|
||||
|
||||
return self.NOT_SPAM
|
||||
|
|
|
@ -50,7 +50,7 @@ from synapse.http.servlet import (
|
|||
parse_json_object_from_request,
|
||||
parse_string,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.http.site import RequestInfo, SynapseRequest
|
||||
from synapse.rest.client._base import client_patterns
|
||||
from synapse.rest.well_known import WellKnownBuilder
|
||||
from synapse.types import JsonDict, UserID
|
||||
|
@ -114,6 +114,7 @@ class LoginRestServlet(RestServlet):
|
|||
self.auth_handler = self.hs.get_auth_handler()
|
||||
self.registration_handler = hs.get_registration_handler()
|
||||
self._sso_handler = hs.get_sso_handler()
|
||||
self._spam_checker = hs.get_module_api_callbacks().spam_checker
|
||||
|
||||
self._well_known_builder = WellKnownBuilder(hs)
|
||||
self._address_ratelimiter = Ratelimiter(
|
||||
|
@ -197,6 +198,8 @@ class LoginRestServlet(RestServlet):
|
|||
self._refresh_tokens_enabled and client_requested_refresh_token
|
||||
)
|
||||
|
||||
request_info = request.request_info()
|
||||
|
||||
try:
|
||||
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
@ -216,6 +219,7 @@ class LoginRestServlet(RestServlet):
|
|||
login_submission,
|
||||
appservice,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
request_info=request_info,
|
||||
)
|
||||
elif (
|
||||
self.jwt_enabled
|
||||
|
@ -227,6 +231,7 @@ class LoginRestServlet(RestServlet):
|
|||
result = await self._do_jwt_login(
|
||||
login_submission,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
request_info=request_info,
|
||||
)
|
||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||
await self._address_ratelimiter.ratelimit(
|
||||
|
@ -235,6 +240,7 @@ class LoginRestServlet(RestServlet):
|
|||
result = await self._do_token_login(
|
||||
login_submission,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
request_info=request_info,
|
||||
)
|
||||
else:
|
||||
await self._address_ratelimiter.ratelimit(
|
||||
|
@ -243,6 +249,7 @@ class LoginRestServlet(RestServlet):
|
|||
result = await self._do_other_login(
|
||||
login_submission,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
request_info=request_info,
|
||||
)
|
||||
except KeyError:
|
||||
raise SynapseError(400, "Missing JSON keys.")
|
||||
|
@ -265,6 +272,8 @@ class LoginRestServlet(RestServlet):
|
|||
login_submission: JsonDict,
|
||||
appservice: ApplicationService,
|
||||
should_issue_refresh_token: bool = False,
|
||||
*,
|
||||
request_info: RequestInfo,
|
||||
) -> LoginResponse:
|
||||
identifier = login_submission.get("identifier")
|
||||
logger.info("Got appservice login request with identifier: %r", identifier)
|
||||
|
@ -300,10 +309,15 @@ class LoginRestServlet(RestServlet):
|
|||
# The user represented by an appservice's configured sender_localpart
|
||||
# is not actually created in Synapse.
|
||||
should_check_deactivated=qualified_user_id != appservice.sender,
|
||||
request_info=request_info,
|
||||
)
|
||||
|
||||
async def _do_other_login(
|
||||
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
|
||||
self,
|
||||
login_submission: JsonDict,
|
||||
should_issue_refresh_token: bool = False,
|
||||
*,
|
||||
request_info: RequestInfo,
|
||||
) -> LoginResponse:
|
||||
"""Handle non-token/saml/jwt logins
|
||||
|
||||
|
@ -333,6 +347,7 @@ class LoginRestServlet(RestServlet):
|
|||
login_submission,
|
||||
callback,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
request_info=request_info,
|
||||
)
|
||||
return result
|
||||
|
||||
|
@ -347,6 +362,8 @@ class LoginRestServlet(RestServlet):
|
|||
should_issue_refresh_token: bool = False,
|
||||
auth_provider_session_id: Optional[str] = None,
|
||||
should_check_deactivated: bool = True,
|
||||
*,
|
||||
request_info: RequestInfo,
|
||||
) -> LoginResponse:
|
||||
"""Called when we've successfully authed the user and now need to
|
||||
actually login them in (e.g. create devices). This gets called on
|
||||
|
@ -371,6 +388,7 @@ class LoginRestServlet(RestServlet):
|
|||
|
||||
This exists purely for appservice's configured sender_localpart
|
||||
which doesn't have an associated user in the database.
|
||||
request_info: The user agent/IP address of the user.
|
||||
|
||||
Returns:
|
||||
Dictionary of account information after successful login.
|
||||
|
@ -417,6 +435,22 @@ class LoginRestServlet(RestServlet):
|
|||
)
|
||||
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
spam_check = await self._spam_checker.check_login_for_spam(
|
||||
user_id,
|
||||
device_id=device_id,
|
||||
initial_display_name=initial_display_name,
|
||||
request_info=[(request_info.user_agent, request_info.ip)],
|
||||
auth_provider_id=auth_provider_id,
|
||||
)
|
||||
if spam_check != self._spam_checker.NOT_SPAM:
|
||||
logger.info("Blocking login due to spam checker")
|
||||
raise SynapseError(
|
||||
403,
|
||||
msg="Login was blocked by the server",
|
||||
errcode=spam_check[0],
|
||||
additional_fields=spam_check[1],
|
||||
)
|
||||
|
||||
(
|
||||
device_id,
|
||||
access_token,
|
||||
|
@ -451,7 +485,11 @@ class LoginRestServlet(RestServlet):
|
|||
return result
|
||||
|
||||
async def _do_token_login(
|
||||
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
|
||||
self,
|
||||
login_submission: JsonDict,
|
||||
should_issue_refresh_token: bool = False,
|
||||
*,
|
||||
request_info: RequestInfo,
|
||||
) -> LoginResponse:
|
||||
"""
|
||||
Handle token login.
|
||||
|
@ -474,10 +512,15 @@ class LoginRestServlet(RestServlet):
|
|||
auth_provider_id=res.auth_provider_id,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
auth_provider_session_id=res.auth_provider_session_id,
|
||||
request_info=request_info,
|
||||
)
|
||||
|
||||
async def _do_jwt_login(
|
||||
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
|
||||
self,
|
||||
login_submission: JsonDict,
|
||||
should_issue_refresh_token: bool = False,
|
||||
*,
|
||||
request_info: RequestInfo,
|
||||
) -> LoginResponse:
|
||||
"""
|
||||
Handle the custom JWT login.
|
||||
|
@ -496,6 +539,7 @@ class LoginRestServlet(RestServlet):
|
|||
login_submission,
|
||||
create_non_existent_users=True,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
request_info=request_info,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -13,11 +13,12 @@
|
|||
# limitations under the License.
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Collection, Dict, List, Optional, Tuple, Union
|
||||
from unittest.mock import Mock
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import pymacaroons
|
||||
from typing_extensions import Literal
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
from twisted.web.resource import Resource
|
||||
|
@ -26,11 +27,12 @@ import synapse.rest.admin
|
|||
from synapse.api.constants import ApprovalNoticeMedium, LoginType
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.rest.client import devices, login, logout, register
|
||||
from synapse.rest.client.account import WhoamiRestServlet
|
||||
from synapse.rest.synapse.client import build_synapse_client_resource_tree
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import create_requester
|
||||
from synapse.types import JsonDict, create_requester
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
|
@ -88,6 +90,56 @@ ADDITIONAL_LOGIN_FLOWS = [
|
|||
]
|
||||
|
||||
|
||||
class TestSpamChecker:
|
||||
def __init__(self, config: None, api: ModuleApi):
|
||||
api.register_spam_checker_callbacks(
|
||||
check_login_for_spam=self.check_login_for_spam,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_config(config: JsonDict) -> None:
|
||||
return None
|
||||
|
||||
async def check_login_for_spam(
|
||||
self,
|
||||
user_id: str,
|
||||
device_id: Optional[str],
|
||||
initial_display_name: Optional[str],
|
||||
request_info: Collection[Tuple[Optional[str], str]],
|
||||
auth_provider_id: Optional[str] = None,
|
||||
) -> Union[
|
||||
Literal["NOT_SPAM"],
|
||||
Tuple["synapse.module_api.errors.Codes", JsonDict],
|
||||
]:
|
||||
return "NOT_SPAM"
|
||||
|
||||
|
||||
class DenyAllSpamChecker:
|
||||
def __init__(self, config: None, api: ModuleApi):
|
||||
api.register_spam_checker_callbacks(
|
||||
check_login_for_spam=self.check_login_for_spam,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_config(config: JsonDict) -> None:
|
||||
return None
|
||||
|
||||
async def check_login_for_spam(
|
||||
self,
|
||||
user_id: str,
|
||||
device_id: Optional[str],
|
||||
initial_display_name: Optional[str],
|
||||
request_info: Collection[Tuple[Optional[str], str]],
|
||||
auth_provider_id: Optional[str] = None,
|
||||
) -> Union[
|
||||
Literal["NOT_SPAM"],
|
||||
Tuple["synapse.module_api.errors.Codes", JsonDict],
|
||||
]:
|
||||
# Return an odd set of values to ensure that they get correctly passed
|
||||
# to the client.
|
||||
return Codes.LIMIT_EXCEEDED, {"extra": "value"}
|
||||
|
||||
|
||||
class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
|
@ -469,6 +521,58 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|||
],
|
||||
)
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"modules": [
|
||||
{
|
||||
"module": TestSpamChecker.__module__
|
||||
+ "."
|
||||
+ TestSpamChecker.__qualname__
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
def test_spam_checker_allow(self) -> None:
|
||||
"""Check that that adding a spam checker doesn't break login."""
|
||||
self.register_user("kermit", "monkey")
|
||||
|
||||
body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/r0/login",
|
||||
body,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"modules": [
|
||||
{
|
||||
"module": DenyAllSpamChecker.__module__
|
||||
+ "."
|
||||
+ DenyAllSpamChecker.__qualname__
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
def test_spam_checker_deny(self) -> None:
|
||||
"""Check that login"""
|
||||
|
||||
self.register_user("kermit", "monkey")
|
||||
|
||||
body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/r0/login",
|
||||
body,
|
||||
)
|
||||
self.assertEqual(channel.code, 403, channel.result)
|
||||
self.assertDictContainsSubset(
|
||||
{"errcode": Codes.LIMIT_EXCEEDED, "extra": "value"}, channel.json_body
|
||||
)
|
||||
|
||||
|
||||
@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
|
||||
class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
|
|
Loading…
Reference in New Issue