Add missing type hints to tests.handlers. (#14680)

And do not allow untyped defs in tests.handlers.
This commit is contained in:
Patrick Cloke 2022-12-16 06:53:01 -05:00 committed by GitHub
parent 54c012c5a8
commit 652d1669c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 527 additions and 378 deletions

1
changelog.d/14680.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints.

View File

@ -95,10 +95,7 @@ disallow_untyped_defs = True
[mypy-tests.federation.transport.test_client] [mypy-tests.federation.transport.test_client]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-tests.handlers.test_sso] [mypy-tests.handlers.*]
disallow_untyped_defs = True
[mypy-tests.handlers.test_user_directory]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-tests.metrics.test_background_process_metrics] [mypy-tests.metrics.test_background_process_metrics]

View File

@ -2031,7 +2031,7 @@ class PasswordAuthProvider:
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = [] self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
# Mapping from login type to login parameters # Mapping from login type to login parameters
self._supported_login_types: Dict[str, Iterable[str]] = {} self._supported_login_types: Dict[str, Tuple[str, ...]] = {}
# Mapping from login type to auth checker callbacks # Mapping from login type to auth checker callbacks
self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {} self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}

View File

@ -31,7 +31,7 @@ from synapse.appservice import (
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.rest.client import login, receipts, register, room, sendtodevice from synapse.rest.client import login, receipts, register, room, sendtodevice
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import RoomStreamToken from synapse.types import JsonDict, RoomStreamToken
from synapse.util import Clock from synapse.util import Clock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -44,7 +44,7 @@ from tests.utils import MockClock
class AppServiceHandlerTestCase(unittest.TestCase): class AppServiceHandlerTestCase(unittest.TestCase):
"""Tests the ApplicationServicesHandler.""" """Tests the ApplicationServicesHandler."""
def setUp(self): def setUp(self) -> None:
self.mock_store = Mock() self.mock_store = Mock()
self.mock_as_api = Mock() self.mock_as_api = Mock()
self.mock_scheduler = Mock() self.mock_scheduler = Mock()
@ -61,7 +61,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.handler = ApplicationServicesHandler(hs) self.handler = ApplicationServicesHandler(hs)
self.event_source = hs.get_event_sources() self.event_source = hs.get_event_sources()
def test_notify_interested_services(self): def test_notify_interested_services(self) -> None:
interested_service = self._mkservice(is_interested_in_event=True) interested_service = self._mkservice(is_interested_in_event=True)
services = [ services = [
self._mkservice(is_interested_in_event=False), self._mkservice(is_interested_in_event=False),
@ -90,7 +90,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
interested_service, events=[event] interested_service, events=[event]
) )
def test_query_user_exists_unknown_user(self): def test_query_user_exists_unknown_user(self) -> None:
user_id = "@someone:anywhere" user_id = "@someone:anywhere"
services = [self._mkservice(is_interested_in_event=True)] services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True services[0].is_interested_in_user.return_value = True
@ -107,7 +107,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
def test_query_user_exists_known_user(self): def test_query_user_exists_known_user(self) -> None:
user_id = "@someone:anywhere" user_id = "@someone:anywhere"
services = [self._mkservice(is_interested_in_event=True)] services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True services[0].is_interested_in_user.return_value = True
@ -127,7 +127,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
"query_user called when it shouldn't have been.", "query_user called when it shouldn't have been.",
) )
def test_query_room_alias_exists(self): def test_query_room_alias_exists(self) -> None:
room_alias_str = "#foo:bar" room_alias_str = "#foo:bar"
room_alias = Mock() room_alias = Mock()
room_alias.to_string.return_value = room_alias_str room_alias.to_string.return_value = room_alias_str
@ -157,7 +157,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.assertEqual(result.room_id, room_id) self.assertEqual(result.room_id, room_id)
self.assertEqual(result.servers, servers) self.assertEqual(result.servers, servers)
def test_get_3pe_protocols_no_appservices(self): def test_get_3pe_protocols_no_appservices(self) -> None:
self.mock_store.get_app_services.return_value = [] self.mock_store.get_app_services.return_value = []
response = self.successResultOf( response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol"))
@ -165,7 +165,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_as_api.get_3pe_protocol.assert_not_called() self.mock_as_api.get_3pe_protocol.assert_not_called()
self.assertEqual(response, {}) self.assertEqual(response, {})
def test_get_3pe_protocols_no_protocols(self): def test_get_3pe_protocols_no_protocols(self) -> None:
service = self._mkservice(False, []) service = self._mkservice(False, [])
self.mock_store.get_app_services.return_value = [service] self.mock_store.get_app_services.return_value = [service]
response = self.successResultOf( response = self.successResultOf(
@ -174,7 +174,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_as_api.get_3pe_protocol.assert_not_called() self.mock_as_api.get_3pe_protocol.assert_not_called()
self.assertEqual(response, {}) self.assertEqual(response, {})
def test_get_3pe_protocols_protocol_no_response(self): def test_get_3pe_protocols_protocol_no_response(self) -> None:
service = self._mkservice(False, ["my-protocol"]) service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service] self.mock_store.get_app_services.return_value = [service]
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None) self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None)
@ -186,7 +186,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
) )
self.assertEqual(response, {}) self.assertEqual(response, {})
def test_get_3pe_protocols_select_one_protocol(self): def test_get_3pe_protocols_select_one_protocol(self) -> None:
service = self._mkservice(False, ["my-protocol"]) service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service] self.mock_store.get_app_services.return_value = [service]
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
@ -202,7 +202,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} response, {"my-protocol": {"x-protocol-data": 42, "instances": []}}
) )
def test_get_3pe_protocols_one_protocol(self): def test_get_3pe_protocols_one_protocol(self) -> None:
service = self._mkservice(False, ["my-protocol"]) service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service] self.mock_store.get_app_services.return_value = [service]
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
@ -218,7 +218,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} response, {"my-protocol": {"x-protocol-data": 42, "instances": []}}
) )
def test_get_3pe_protocols_multiple_protocol(self): def test_get_3pe_protocols_multiple_protocol(self) -> None:
service_one = self._mkservice(False, ["my-protocol"]) service_one = self._mkservice(False, ["my-protocol"])
service_two = self._mkservice(False, ["other-protocol"]) service_two = self._mkservice(False, ["other-protocol"])
self.mock_store.get_app_services.return_value = [service_one, service_two] self.mock_store.get_app_services.return_value = [service_one, service_two]
@ -237,11 +237,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
}, },
) )
def test_get_3pe_protocols_multiple_info(self): def test_get_3pe_protocols_multiple_info(self) -> None:
service_one = self._mkservice(False, ["my-protocol"]) service_one = self._mkservice(False, ["my-protocol"])
service_two = self._mkservice(False, ["my-protocol"]) service_two = self._mkservice(False, ["my-protocol"])
async def get_3pe_protocol(service, unusedProtocol): async def get_3pe_protocol(
service: ApplicationService, protocol: str
) -> Optional[JsonDict]:
if service == service_one: if service == service_one:
return { return {
"x-protocol-data": 42, "x-protocol-data": 42,
@ -276,7 +278,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
}, },
) )
def test_notify_interested_services_ephemeral(self): def test_notify_interested_services_ephemeral(self) -> None:
""" """
Test sending ephemeral events to the appservice handler are scheduled Test sending ephemeral events to the appservice handler are scheduled
to be pushed out to interested appservices, and that the stream ID is to be pushed out to interested appservices, and that the stream ID is
@ -306,7 +308,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
580, 580,
) )
def test_notify_interested_services_ephemeral_out_of_order(self): def test_notify_interested_services_ephemeral_out_of_order(self) -> None:
""" """
Test sending out of order ephemeral events to the appservice handler Test sending out of order ephemeral events to the appservice handler
are ignored. are ignored.
@ -390,7 +392,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
receipts.register_servlets, receipts.register_servlets,
] ]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.hs = hs self.hs = hs
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that # Mock the ApplicationServiceScheduler's _TransactionController's send method so that
# we can track any outgoing ephemeral events # we can track any outgoing ephemeral events
@ -417,7 +419,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
"exclusive_as_user", "password", self.exclusive_as_user_device_id "exclusive_as_user", "password", self.exclusive_as_user_device_id
) )
def _notify_interested_services(self): def _notify_interested_services(self) -> None:
# This is normally set in `notify_interested_services` but we need to call the # This is normally set in `notify_interested_services` but we need to call the
# internal async version so the reactor gets pushed to completion. # internal async version so the reactor gets pushed to completion.
self.hs.get_application_service_handler().current_max += 1 self.hs.get_application_service_handler().current_max += 1
@ -443,7 +445,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
) )
def test_match_interesting_room_members( def test_match_interesting_room_members(
self, interesting_user: str, should_notify: bool self, interesting_user: str, should_notify: bool
): ) -> None:
""" """
Test to make sure that a interesting user (local or remote) in the room is Test to make sure that a interesting user (local or remote) in the room is
notified as expected when someone else in the room sends a message. notified as expected when someone else in the room sends a message.
@ -512,7 +514,9 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
else: else:
self.send_mock.assert_not_called() self.send_mock.assert_not_called()
def test_application_services_receive_events_sent_by_interesting_local_user(self): def test_application_services_receive_events_sent_by_interesting_local_user(
self,
) -> None:
""" """
Test to make sure that a messages sent from a local user can be interesting and Test to make sure that a messages sent from a local user can be interesting and
picked up by the appservice. picked up by the appservice.
@ -568,7 +572,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
self.assertEqual(events[0]["type"], "m.room.message") self.assertEqual(events[0]["type"], "m.room.message")
self.assertEqual(events[0]["sender"], alice) self.assertEqual(events[0]["sender"], alice)
def test_sending_read_receipt_batches_to_application_services(self): def test_sending_read_receipt_batches_to_application_services(self) -> None:
"""Tests that a large batch of read receipts are sent correctly to """Tests that a large batch of read receipts are sent correctly to
interested application services. interested application services.
""" """
@ -644,7 +648,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
@unittest.override_config( @unittest.override_config(
{"experimental_features": {"msc2409_to_device_messages_enabled": True}} {"experimental_features": {"msc2409_to_device_messages_enabled": True}}
) )
def test_application_services_receive_local_to_device(self): def test_application_services_receive_local_to_device(self) -> None:
""" """
Test that when a user sends a to-device message to another user Test that when a user sends a to-device message to another user
that is an application service's user namespace, the that is an application service's user namespace, the
@ -722,7 +726,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
@unittest.override_config( @unittest.override_config(
{"experimental_features": {"msc2409_to_device_messages_enabled": True}} {"experimental_features": {"msc2409_to_device_messages_enabled": True}}
) )
def test_application_services_receive_bursts_of_to_device(self): def test_application_services_receive_bursts_of_to_device(self) -> None:
""" """
Test that when a user sends >100 to-device messages at once, any Test that when a user sends >100 to-device messages at once, any
interested AS's will receive them in separate transactions. interested AS's will receive them in separate transactions.
@ -913,7 +917,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
experimental_feature_enabled: bool, experimental_feature_enabled: bool,
as_supports_txn_extensions: bool, as_supports_txn_extensions: bool,
as_should_receive_device_list_updates: bool, as_should_receive_device_list_updates: bool,
): ) -> None:
""" """
Tests that an application service receives notice of changed device Tests that an application service receives notice of changed device
lists for a user, when a user changes their device lists. lists for a user, when a user changes their device lists.
@ -1070,7 +1074,7 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
and a room for the users to talk in. and a room for the users to talk in.
""" """
async def preparation(): async def preparation() -> None:
await self._add_otks_for_device(self._sender_user, self._sender_device, 42) await self._add_otks_for_device(self._sender_user, self._sender_device, 42)
await self._add_fallback_key_for_device( await self._add_fallback_key_for_device(
self._sender_user, self._sender_device, used=True self._sender_user, self._sender_device, used=True

View File

@ -199,7 +199,7 @@ class CasHandlerTestCase(HomeserverTestCase):
) )
def _mock_request(): def _mock_request() -> Mock:
"""Returns a mock which will stand in as a SynapseRequest""" """Returns a mock which will stand in as a SynapseRequest"""
mock = Mock( mock = Mock(
spec=[ spec=[

View File

@ -20,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.api.errors import synapse.api.errors
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.rest.client import directory, login, room from synapse.rest.client import directory, login, room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias, create_requester from synapse.types import JsonDict, RoomAlias, create_requester
@ -201,7 +202,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
self.test_user_tok = self.login("user", "pass") self.test_user_tok = self.login("user", "pass")
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
def _create_alias(self, user) -> None: def _create_alias(self, user: str) -> None:
# Create a new alias to this room. # Create a new alias to this room.
self.get_success( self.get_success(
self.store.create_room_alias_association( self.store.create_room_alias_association(
@ -324,7 +325,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
) )
return room_alias return room_alias
def _set_canonical_alias(self, content) -> None: def _set_canonical_alias(self, content: JsonDict) -> None:
"""Configure the canonical alias state on the room.""" """Configure the canonical alias state on the room."""
self.helper.send_state( self.helper.send_state(
self.room_id, self.room_id,
@ -333,13 +334,15 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
tok=self.admin_user_tok, tok=self.admin_user_tok,
) )
def _get_canonical_alias(self): def _get_canonical_alias(self) -> EventBase:
"""Get the canonical alias state of the room.""" """Get the canonical alias state of the room."""
return self.get_success( result = self.get_success(
self._storage_controllers.state.get_current_state_event( self._storage_controllers.state.get_current_state_event(
self.room_id, EventTypes.CanonicalAlias, "" self.room_id, EventTypes.CanonicalAlias, ""
) )
) )
assert result is not None
return result
def test_remove_alias(self) -> None: def test_remove_alias(self) -> None:
"""Removing an alias that is the canonical alias should remove it there too.""" """Removing an alias that is the canonical alias should remove it there too."""
@ -349,8 +352,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
) )
data = self._get_canonical_alias() data = self._get_canonical_alias()
self.assertEqual(data["content"]["alias"], self.test_alias) self.assertEqual(data.content["alias"], self.test_alias)
self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) self.assertEqual(data.content["alt_aliases"], [self.test_alias])
# Finally, delete the alias. # Finally, delete the alias.
self.get_success( self.get_success(
@ -360,8 +363,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
) )
data = self._get_canonical_alias() data = self._get_canonical_alias()
self.assertNotIn("alias", data["content"]) self.assertNotIn("alias", data.content)
self.assertNotIn("alt_aliases", data["content"]) self.assertNotIn("alt_aliases", data.content)
def test_remove_other_alias(self) -> None: def test_remove_other_alias(self) -> None:
"""Removing an alias listed as in alt_aliases should remove it there too.""" """Removing an alias listed as in alt_aliases should remove it there too."""
@ -378,9 +381,9 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
) )
data = self._get_canonical_alias() data = self._get_canonical_alias()
self.assertEqual(data["content"]["alias"], self.test_alias) self.assertEqual(data.content["alias"], self.test_alias)
self.assertEqual( self.assertEqual(
data["content"]["alt_aliases"], [self.test_alias, other_test_alias] data.content["alt_aliases"], [self.test_alias, other_test_alias]
) )
# Delete the second alias. # Delete the second alias.
@ -391,8 +394,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
) )
data = self._get_canonical_alias() data = self._get_canonical_alias()
self.assertEqual(data["content"]["alias"], self.test_alias) self.assertEqual(data.content["alias"], self.test_alias)
self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) self.assertEqual(data.content["alt_aliases"], [self.test_alias])
class TestCreateAliasACL(unittest.HomeserverTestCase): class TestCreateAliasACL(unittest.HomeserverTestCase):

View File

@ -17,7 +17,11 @@
import copy import copy
from unittest import mock from unittest import mock
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest from tests import unittest
@ -39,14 +43,14 @@ room_keys = {
class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(replication_layer=mock.Mock()) return self.setup_test_homeserver(replication_layer=mock.Mock())
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_e2e_room_keys_handler() self.handler = hs.get_e2e_room_keys_handler()
self.local_user = "@boris:" + hs.hostname self.local_user = "@boris:" + hs.hostname
def test_get_missing_current_version_info(self): def test_get_missing_current_version_info(self) -> None:
"""Check that we get a 404 if we ask for info about the current version """Check that we get a 404 if we ask for info about the current version
if there is no version. if there is no version.
""" """
@ -56,7 +60,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code res = e.value.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
def test_get_missing_version_info(self): def test_get_missing_version_info(self) -> None:
"""Check that we get a 404 if we ask for info about a specific version """Check that we get a 404 if we ask for info about a specific version
if it doesn't exist. if it doesn't exist.
""" """
@ -67,9 +71,9 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code res = e.value.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
def test_create_version(self): def test_create_version(self) -> None:
"""Check that we can create and then retrieve versions.""" """Check that we can create and then retrieve versions."""
res = self.get_success( version = self.get_success(
self.handler.create_version( self.handler.create_version(
self.local_user, self.local_user,
{ {
@ -78,7 +82,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
}, },
) )
) )
self.assertEqual(res, "1") self.assertEqual(version, "1")
# check we can retrieve it as the current version # check we can retrieve it as the current version
res = self.get_success(self.handler.get_version_info(self.local_user)) res = self.get_success(self.handler.get_version_info(self.local_user))
@ -110,7 +114,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
) )
# upload a new one... # upload a new one...
res = self.get_success( version = self.get_success(
self.handler.create_version( self.handler.create_version(
self.local_user, self.local_user,
{ {
@ -119,7 +123,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
}, },
) )
) )
self.assertEqual(res, "2") self.assertEqual(version, "2")
# check we can retrieve it as the current version # check we can retrieve it as the current version
res = self.get_success(self.handler.get_version_info(self.local_user)) res = self.get_success(self.handler.get_version_info(self.local_user))
@ -134,7 +138,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
}, },
) )
def test_update_version(self): def test_update_version(self) -> None:
"""Check that we can update versions.""" """Check that we can update versions."""
version = self.get_success( version = self.get_success(
self.handler.create_version( self.handler.create_version(
@ -173,7 +177,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
}, },
) )
def test_update_missing_version(self): def test_update_missing_version(self) -> None:
"""Check that we get a 404 on updating nonexistent versions""" """Check that we get a 404 on updating nonexistent versions"""
e = self.get_failure( e = self.get_failure(
self.handler.update_version( self.handler.update_version(
@ -190,7 +194,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code res = e.value.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
def test_update_omitted_version(self): def test_update_omitted_version(self) -> None:
"""Check that the update succeeds if the version is missing from the body""" """Check that the update succeeds if the version is missing from the body"""
version = self.get_success( version = self.get_success(
self.handler.create_version( self.handler.create_version(
@ -227,7 +231,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
}, },
) )
def test_update_bad_version(self): def test_update_bad_version(self) -> None:
"""Check that we get a 400 if the version in the body doesn't match""" """Check that we get a 400 if the version in the body doesn't match"""
version = self.get_success( version = self.get_success(
self.handler.create_version( self.handler.create_version(
@ -255,7 +259,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code res = e.value.code
self.assertEqual(res, 400) self.assertEqual(res, 400)
def test_delete_missing_version(self): def test_delete_missing_version(self) -> None:
"""Check that we get a 404 on deleting nonexistent versions""" """Check that we get a 404 on deleting nonexistent versions"""
e = self.get_failure( e = self.get_failure(
self.handler.delete_version(self.local_user, "1"), SynapseError self.handler.delete_version(self.local_user, "1"), SynapseError
@ -263,15 +267,15 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code res = e.value.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
def test_delete_missing_current_version(self): def test_delete_missing_current_version(self) -> None:
"""Check that we get a 404 on deleting nonexistent current version""" """Check that we get a 404 on deleting nonexistent current version"""
e = self.get_failure(self.handler.delete_version(self.local_user), SynapseError) e = self.get_failure(self.handler.delete_version(self.local_user), SynapseError)
res = e.value.code res = e.value.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
def test_delete_version(self): def test_delete_version(self) -> None:
"""Check that we can create and then delete versions.""" """Check that we can create and then delete versions."""
res = self.get_success( version = self.get_success(
self.handler.create_version( self.handler.create_version(
self.local_user, self.local_user,
{ {
@ -280,7 +284,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
}, },
) )
) )
self.assertEqual(res, "1") self.assertEqual(version, "1")
# check we can delete it # check we can delete it
self.get_success(self.handler.delete_version(self.local_user, "1")) self.get_success(self.handler.delete_version(self.local_user, "1"))
@ -292,7 +296,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code res = e.value.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
def test_get_missing_backup(self): def test_get_missing_backup(self) -> None:
"""Check that we get a 404 on querying missing backup""" """Check that we get a 404 on querying missing backup"""
e = self.get_failure( e = self.get_failure(
self.handler.get_room_keys(self.local_user, "bogus_version"), SynapseError self.handler.get_room_keys(self.local_user, "bogus_version"), SynapseError
@ -300,7 +304,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code res = e.value.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
def test_get_missing_room_keys(self): def test_get_missing_room_keys(self) -> None:
"""Check we get an empty response from an empty backup""" """Check we get an empty response from an empty backup"""
version = self.get_success( version = self.get_success(
self.handler.create_version( self.handler.create_version(
@ -319,7 +323,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
# TODO: test the locking semantics when uploading room_keys, # TODO: test the locking semantics when uploading room_keys,
# although this is probably best done in sytest # although this is probably best done in sytest
def test_upload_room_keys_no_versions(self): def test_upload_room_keys_no_versions(self) -> None:
"""Check that we get a 404 on uploading keys when no versions are defined""" """Check that we get a 404 on uploading keys when no versions are defined"""
e = self.get_failure( e = self.get_failure(
self.handler.upload_room_keys(self.local_user, "no_version", room_keys), self.handler.upload_room_keys(self.local_user, "no_version", room_keys),
@ -328,7 +332,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code res = e.value.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
def test_upload_room_keys_bogus_version(self): def test_upload_room_keys_bogus_version(self) -> None:
"""Check that we get a 404 on uploading keys when an nonexistent version """Check that we get a 404 on uploading keys when an nonexistent version
is specified is specified
""" """
@ -350,7 +354,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code res = e.value.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
def test_upload_room_keys_wrong_version(self): def test_upload_room_keys_wrong_version(self) -> None:
"""Check that we get a 403 on uploading keys for an old version""" """Check that we get a 403 on uploading keys for an old version"""
version = self.get_success( version = self.get_success(
self.handler.create_version( self.handler.create_version(
@ -380,7 +384,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code res = e.value.code
self.assertEqual(res, 403) self.assertEqual(res, 403)
def test_upload_room_keys_insert(self): def test_upload_room_keys_insert(self) -> None:
"""Check that we can insert and retrieve keys for a session""" """Check that we can insert and retrieve keys for a session"""
version = self.get_success( version = self.get_success(
self.handler.create_version( self.handler.create_version(
@ -416,7 +420,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
) )
self.assertDictEqual(res, room_keys) self.assertDictEqual(res, room_keys)
def test_upload_room_keys_merge(self): def test_upload_room_keys_merge(self) -> None:
"""Check that we can upload a new room_key for an existing session and """Check that we can upload a new room_key for an existing session and
have it correctly merged""" have it correctly merged"""
version = self.get_success( version = self.get_success(
@ -449,9 +453,11 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
self.handler.upload_room_keys(self.local_user, version, new_room_keys) self.handler.upload_room_keys(self.local_user, version, new_room_keys)
) )
res = self.get_success(self.handler.get_room_keys(self.local_user, version)) res_keys = self.get_success(
self.handler.get_room_keys(self.local_user, version)
)
self.assertEqual( self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
"SSBBTSBBIEZJU0gK", "SSBBTSBBIEZJU0gK",
) )
@ -465,9 +471,12 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
self.handler.upload_room_keys(self.local_user, version, new_room_keys) self.handler.upload_room_keys(self.local_user, version, new_room_keys)
) )
res = self.get_success(self.handler.get_room_keys(self.local_user, version)) res_keys = self.get_success(
self.handler.get_room_keys(self.local_user, version)
)
self.assertEqual( self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
"new",
) )
# the etag should NOT be equal now, since the key changed # the etag should NOT be equal now, since the key changed
@ -483,9 +492,12 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
self.handler.upload_room_keys(self.local_user, version, new_room_keys) self.handler.upload_room_keys(self.local_user, version, new_room_keys)
) )
res = self.get_success(self.handler.get_room_keys(self.local_user, version)) res_keys = self.get_success(
self.handler.get_room_keys(self.local_user, version)
)
self.assertEqual( self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
"new",
) )
# the etag should be the same since the session did not change # the etag should be the same since the session did not change
@ -494,7 +506,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
# TODO: check edge cases as well as the common variations here # TODO: check edge cases as well as the common variations here
def test_delete_room_keys(self): def test_delete_room_keys(self) -> None:
"""Check that we can insert and delete keys for a session""" """Check that we can insert and delete keys for a session"""
version = self.get_success( version = self.get_success(
self.handler.create_version( self.handler.create_version(

View File

@ -439,7 +439,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
user_id = self.register_user("kermit", "test") user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test") tok = self.login("kermit", "test")
def create_invite(): def create_invite() -> EventBase:
room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
room_version = self.get_success(self.store.get_room_version(room_id)) room_version = self.get_success(self.store.get_room_version(room_id))
return event_from_pdu_json( return event_from_pdu_json(

View File

@ -14,6 +14,8 @@
from typing import Optional from typing import Optional
from unittest import mock from unittest import mock
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import AuthError, StoreError from synapse.api.errors import AuthError, StoreError
from synapse.api.room_versions import RoomVersion from synapse.api.room_versions import RoomVersion
from synapse.event_auth import ( from synapse.event_auth import (
@ -26,8 +28,10 @@ from synapse.federation.transport.client import StateRequestResponse
from synapse.logging.context import LoggingContext from synapse.logging.context import LoggingContext
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import event_injection, make_awaitable from tests.test_utils import event_injection, make_awaitable
@ -40,7 +44,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
room.register_servlets, room.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# mock out the federation transport client # mock out the federation transport client
self.mock_federation_transport_client = mock.Mock( self.mock_federation_transport_client = mock.Mock(
spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"] spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"]
@ -165,7 +169,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
) )
else: else:
async def get_event(destination: str, event_id: str, timeout=None): async def get_event(
destination: str, event_id: str, timeout: Optional[int] = None
) -> JsonDict:
self.assertEqual(destination, self.OTHER_SERVER_NAME) self.assertEqual(destination, self.OTHER_SERVER_NAME)
self.assertEqual(event_id, prev_event.event_id) self.assertEqual(event_id, prev_event.event_id)
return {"pdus": [prev_event.get_pdu_json()]} return {"pdus": [prev_event.get_pdu_json()]}

View File

@ -14,12 +14,16 @@
import logging import logging
from typing import Tuple from typing import Tuple
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import create_requester from synapse.types import create_requester
from synapse.util import Clock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from tests import unittest from tests import unittest
@ -35,7 +39,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
room.register_servlets, room.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = self.hs.get_event_creation_handler() self.handler = self.hs.get_event_creation_handler()
self._persist_event_storage_controller = ( self._persist_event_storage_controller = (
self.hs.get_storage_controllers().persistence self.hs.get_storage_controllers().persistence
@ -94,7 +98,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
) )
) )
def test_duplicated_txn_id(self): def test_duplicated_txn_id(self) -> None:
"""Test that attempting to handle/persist an event with a transaction ID """Test that attempting to handle/persist an event with a transaction ID
that has already been persisted correctly returns the old event and does that has already been persisted correctly returns the old event and does
*not* produce duplicate messages. *not* produce duplicate messages.
@ -161,7 +165,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# rather than the new one. # rather than the new one.
self.assertEqual(ret_event1.event_id, ret_event4.event_id) self.assertEqual(ret_event1.event_id, ret_event4.event_id)
def test_duplicated_txn_id_one_call(self): def test_duplicated_txn_id_one_call(self) -> None:
"""Test that we correctly handle duplicates that we try and persist at """Test that we correctly handle duplicates that we try and persist at
the same time. the same time.
""" """
@ -185,7 +189,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(events), 2) self.assertEqual(len(events), 2)
self.assertEqual(events[0].event_id, events[1].event_id) self.assertEqual(events[0].event_id, events[1].event_id)
def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events(self): def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events(
self,
) -> None:
"""When we set allow_no_prev_events=True, should be able to create a """When we set allow_no_prev_events=True, should be able to create a
event without any prev_events (only auth_events). event without any prev_events (only auth_events).
""" """
@ -214,7 +220,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def test_when_empty_prev_events_not_allowed_reject_event_with_empty_prev_events( def test_when_empty_prev_events_not_allowed_reject_event_with_empty_prev_events(
self, self,
): ) -> None:
"""When we set allow_no_prev_events=False, shouldn't be able to create a """When we set allow_no_prev_events=False, shouldn't be able to create a
event without any prev_events even if it has auth_events. Expect an event without any prev_events even if it has auth_events. Expect an
exception to be raised. exception to be raised.
@ -245,7 +251,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def test_when_empty_prev_events_allowed_reject_event_with_empty_prev_events_and_auth_events( def test_when_empty_prev_events_allowed_reject_event_with_empty_prev_events_and_auth_events(
self, self,
): ) -> None:
"""When we set allow_no_prev_events=True, should be able to create a """When we set allow_no_prev_events=True, should be able to create a
event without any prev_events or auth_events. Expect an exception to be event without any prev_events or auth_events. Expect an exception to be
raised. raised.
@ -277,12 +283,12 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
room.register_servlets, room.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("tester", "foobar") self.user_id = self.register_user("tester", "foobar")
self.access_token = self.login("tester", "foobar") self.access_token = self.login("tester", "foobar")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
def test_allow_server_acl(self): def test_allow_server_acl(self) -> None:
"""Test that sending an ACL that blocks everyone but ourselves works.""" """Test that sending an ACL that blocks everyone but ourselves works."""
self.helper.send_state( self.helper.send_state(
@ -293,7 +299,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
expect_code=200, expect_code=200,
) )
def test_deny_server_acl_block_outselves(self): def test_deny_server_acl_block_outselves(self) -> None:
"""Test that sending an ACL that blocks ourselves does not work.""" """Test that sending an ACL that blocks ourselves does not work."""
self.helper.send_state( self.helper.send_state(
self.room_id, self.room_id,
@ -303,7 +309,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
expect_code=400, expect_code=400,
) )
def test_deny_redact_server_acl(self): def test_deny_redact_server_acl(self) -> None:
"""Test that attempting to redact an ACL is blocked.""" """Test that attempting to redact an ACL is blocked."""
body = self.helper.send_state( body = self.helper.send_state(

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
from typing import Any, Dict, Tuple from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple
from unittest.mock import ANY, Mock, patch from unittest.mock import ANY, Mock, patch
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
@ -23,7 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.handlers.sso import MappingException from synapse.handlers.sso import MappingException
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import UserID from synapse.types import JsonDict, UserID
from synapse.util import Clock from synapse.util import Clock
from synapse.util.macaroons import get_value_from_macaroon from synapse.util.macaroons import get_value_from_macaroon
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -34,6 +34,10 @@ from tests.unittest import HomeserverTestCase, override_config
try: try:
import authlib # noqa: F401 import authlib # noqa: F401
from authlib.oidc.core import UserInfo
from authlib.oidc.discovery import OpenIDProviderMetadata
from synapse.handlers.oidc import Token, UserAttributeDict
HAS_OIDC = True HAS_OIDC = True
except ImportError: except ImportError:
@ -70,29 +74,37 @@ EXPLICIT_ENDPOINT_CONFIG = {
class TestMappingProvider: class TestMappingProvider:
@staticmethod @staticmethod
def parse_config(config): def parse_config(config: JsonDict) -> None:
return return None
def __init__(self, config): def __init__(self, config: None):
pass pass
def get_remote_user_id(self, userinfo): def get_remote_user_id(self, userinfo: "UserInfo") -> str:
return userinfo["sub"] return userinfo["sub"]
async def map_user_attributes(self, userinfo, token): async def map_user_attributes(
return {"localpart": userinfo["username"], "display_name": None} self, userinfo: "UserInfo", token: "Token"
) -> "UserAttributeDict":
# This is testing not providing the full map.
return {"localpart": userinfo["username"], "display_name": None} # type: ignore[typeddict-item]
# Do not include get_extra_attributes to test backwards compatibility paths. # Do not include get_extra_attributes to test backwards compatibility paths.
class TestMappingProviderExtra(TestMappingProvider): class TestMappingProviderExtra(TestMappingProvider):
async def get_extra_attributes(self, userinfo, token): async def get_extra_attributes(
self, userinfo: "UserInfo", token: "Token"
) -> JsonDict:
return {"phone": userinfo["phone"]} return {"phone": userinfo["phone"]}
class TestMappingProviderFailures(TestMappingProvider): class TestMappingProviderFailures(TestMappingProvider):
async def map_user_attributes(self, userinfo, token, failures): # Superclass is testing the legacy interface for map_user_attributes.
return { async def map_user_attributes( # type: ignore[override]
self, userinfo: "UserInfo", token: "Token", failures: int
) -> "UserAttributeDict":
return { # type: ignore[typeddict-item]
"localpart": userinfo["username"] + (str(failures) if failures else ""), "localpart": userinfo["username"] + (str(failures) if failures else ""),
"display_name": None, "display_name": None,
} }
@ -161,13 +173,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.hs_patcher.stop() self.hs_patcher.stop()
return super().tearDown() return super().tearDown()
def reset_mocks(self): def reset_mocks(self) -> None:
"""Reset all the Mocks.""" """Reset all the Mocks."""
self.fake_server.reset_mocks() self.fake_server.reset_mocks()
self.render_error.reset_mock() self.render_error.reset_mock()
self.complete_sso_login.reset_mock() self.complete_sso_login.reset_mock()
def metadata_edit(self, values): def metadata_edit(self, values: dict) -> ContextManager[Mock]:
"""Modify the result that will be returned by the well-known query""" """Modify the result that will be returned by the well-known query"""
metadata = self.fake_server.get_metadata() metadata = self.fake_server.get_metadata()
@ -196,7 +208,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
session = self._generate_oidc_session_token(state, nonce, client_redirect_url) session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
return _build_callback_request(code, state, session), grant return _build_callback_request(code, state, session), grant
def assertRenderedError(self, error, error_description=None): def assertRenderedError(
self, error: str, error_description: Optional[str] = None
) -> Tuple[Any, ...]:
self.render_error.assert_called_once() self.render_error.assert_called_once()
args = self.render_error.call_args[0] args = self.render_error.call_args[0]
self.assertEqual(args[1], error) self.assertEqual(args[1], error)
@ -273,8 +287,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""Provider metadatas are extensively validated.""" """Provider metadatas are extensively validated."""
h = self.provider h = self.provider
def force_load_metadata(): def force_load_metadata() -> Awaitable[None]:
async def force_load(): async def force_load() -> "OpenIDProviderMetadata":
return await h.load_metadata(force=True) return await h.load_metadata(force=True)
return get_awaitable_result(force_load()) return get_awaitable_result(force_load())
@ -1198,7 +1212,7 @@ def _build_callback_request(
state: str, state: str,
session: str, session: str,
ip_address: str = "10.0.0.1", ip_address: str = "10.0.0.1",
): ) -> Mock:
"""Builds a fake SynapseRequest to mock the browser callback """Builds a fake SynapseRequest to mock the browser callback
Returns a Mock object which looks like the SynapseRequest we get from a browser Returns a Mock object which looks like the SynapseRequest we get from a browser

View File

@ -15,12 +15,13 @@
"""Tests for the password_auth_provider interface""" """Tests for the password_auth_provider interface"""
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Type, Union from typing import Any, Dict, List, Optional, Type, Union
from unittest.mock import Mock from unittest.mock import Mock
import synapse import synapse
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.handlers.account import AccountHandler
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
from synapse.rest.client import account, devices, login, logout, register from synapse.rest.client import account, devices, login, logout, register
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
@ -44,13 +45,13 @@ class LegacyPasswordOnlyAuthProvider:
"""A legacy password_provider which only implements `check_password`.""" """A legacy password_provider which only implements `check_password`."""
@staticmethod @staticmethod
def parse_config(self): def parse_config(config: JsonDict) -> None:
pass pass
def __init__(self, config, account_handler): def __init__(self, config: None, account_handler: AccountHandler):
pass pass
def check_password(self, *args): def check_password(self, *args: str) -> Mock:
return mock_password_provider.check_password(*args) return mock_password_provider.check_password(*args)
@ -58,16 +59,16 @@ class LegacyCustomAuthProvider:
"""A legacy password_provider which implements a custom login type.""" """A legacy password_provider which implements a custom login type."""
@staticmethod @staticmethod
def parse_config(self): def parse_config(config: JsonDict) -> None:
pass pass
def __init__(self, config, account_handler): def __init__(self, config: None, account_handler: AccountHandler):
pass pass
def get_supported_login_types(self): def get_supported_login_types(self) -> Dict[str, List[str]]:
return {"test.login_type": ["test_field"]} return {"test.login_type": ["test_field"]}
def check_auth(self, *args): def check_auth(self, *args: str) -> Mock:
return mock_password_provider.check_auth(*args) return mock_password_provider.check_auth(*args)
@ -75,15 +76,15 @@ class CustomAuthProvider:
"""A module which registers password_auth_provider callbacks for a custom login type.""" """A module which registers password_auth_provider callbacks for a custom login type."""
@staticmethod @staticmethod
def parse_config(self): def parse_config(config: JsonDict) -> None:
pass pass
def __init__(self, config, api: ModuleApi): def __init__(self, config: None, api: ModuleApi):
api.register_password_auth_provider_callbacks( api.register_password_auth_provider_callbacks(
auth_checkers={("test.login_type", ("test_field",)): self.check_auth} auth_checkers={("test.login_type", ("test_field",)): self.check_auth}
) )
def check_auth(self, *args): def check_auth(self, *args: Any) -> Mock:
return mock_password_provider.check_auth(*args) return mock_password_provider.check_auth(*args)
@ -92,16 +93,16 @@ class LegacyPasswordCustomAuthProvider:
as a custom type.""" as a custom type."""
@staticmethod @staticmethod
def parse_config(self): def parse_config(config: JsonDict) -> None:
pass pass
def __init__(self, config, account_handler): def __init__(self, config: None, account_handler: AccountHandler):
pass pass
def get_supported_login_types(self): def get_supported_login_types(self) -> Dict[str, List[str]]:
return {"m.login.password": ["password"], "test.login_type": ["test_field"]} return {"m.login.password": ["password"], "test.login_type": ["test_field"]}
def check_auth(self, *args): def check_auth(self, *args: str) -> Mock:
return mock_password_provider.check_auth(*args) return mock_password_provider.check_auth(*args)
@ -110,10 +111,10 @@ class PasswordCustomAuthProvider:
as well as a password login""" as well as a password login"""
@staticmethod @staticmethod
def parse_config(self): def parse_config(config: JsonDict) -> None:
pass pass
def __init__(self, config, api: ModuleApi): def __init__(self, config: None, api: ModuleApi):
api.register_password_auth_provider_callbacks( api.register_password_auth_provider_callbacks(
auth_checkers={ auth_checkers={
("test.login_type", ("test_field",)): self.check_auth, ("test.login_type", ("test_field",)): self.check_auth,
@ -121,10 +122,10 @@ class PasswordCustomAuthProvider:
} }
) )
def check_auth(self, *args): def check_auth(self, *args: Any) -> Mock:
return mock_password_provider.check_auth(*args) return mock_password_provider.check_auth(*args)
def check_pass(self, *args): def check_pass(self, *args: str) -> Mock:
return mock_password_provider.check_password(*args) return mock_password_provider.check_password(*args)
@ -161,16 +162,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
CALLBACK_USERNAME = "get_username_for_registration" CALLBACK_USERNAME = "get_username_for_registration"
CALLBACK_DISPLAYNAME = "get_displayname_for_registration" CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
def setUp(self): def setUp(self) -> None:
# we use a global mock device, so make sure we are starting with a clean slate # we use a global mock device, so make sure we are starting with a clean slate
mock_password_provider.reset_mock() mock_password_provider.reset_mock()
super().setUp() super().setUp()
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_progiver_login_legacy(self): def test_password_only_auth_progiver_login_legacy(self) -> None:
self.password_only_auth_provider_login_test_body() self.password_only_auth_provider_login_test_body()
def password_only_auth_provider_login_test_body(self): def password_only_auth_provider_login_test_body(self) -> None:
# login flows should only have m.login.password # login flows should only have m.login.password
flows = self._get_login_flows() flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
@ -201,10 +202,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
) )
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_provider_ui_auth_legacy(self): def test_password_only_auth_provider_ui_auth_legacy(self) -> None:
self.password_only_auth_provider_ui_auth_test_body() self.password_only_auth_provider_ui_auth_test_body()
def password_only_auth_provider_ui_auth_test_body(self): def password_only_auth_provider_ui_auth_test_body(self) -> None:
"""UI Auth should delegate correctly to the password provider""" """UI Auth should delegate correctly to the password provider"""
# create the user, otherwise access doesn't work # create the user, otherwise access doesn't work
@ -238,10 +239,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_password.assert_called_once_with("@u:test", "p") mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_local_user_fallback_login_legacy(self): def test_local_user_fallback_login_legacy(self) -> None:
self.local_user_fallback_login_test_body() self.local_user_fallback_login_test_body()
def local_user_fallback_login_test_body(self): def local_user_fallback_login_test_body(self) -> None:
"""rejected login should fall back to local db""" """rejected login should fall back to local db"""
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
@ -255,10 +256,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual("@localuser:test", channel.json_body["user_id"]) self.assertEqual("@localuser:test", channel.json_body["user_id"])
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_local_user_fallback_ui_auth_legacy(self): def test_local_user_fallback_ui_auth_legacy(self) -> None:
self.local_user_fallback_ui_auth_test_body() self.local_user_fallback_ui_auth_test_body()
def local_user_fallback_ui_auth_test_body(self): def local_user_fallback_ui_auth_test_body(self) -> None:
"""rejected login should fall back to local db""" """rejected login should fall back to local db"""
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
@ -298,10 +299,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"localdb_enabled": False}, "password_config": {"localdb_enabled": False},
} }
) )
def test_no_local_user_fallback_login_legacy(self): def test_no_local_user_fallback_login_legacy(self) -> None:
self.no_local_user_fallback_login_test_body() self.no_local_user_fallback_login_test_body()
def no_local_user_fallback_login_test_body(self): def no_local_user_fallback_login_test_body(self) -> None:
"""localdb_enabled can block login with the local password""" """localdb_enabled can block login with the local password"""
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
@ -320,10 +321,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"localdb_enabled": False}, "password_config": {"localdb_enabled": False},
} }
) )
def test_no_local_user_fallback_ui_auth_legacy(self): def test_no_local_user_fallback_ui_auth_legacy(self) -> None:
self.no_local_user_fallback_ui_auth_test_body() self.no_local_user_fallback_ui_auth_test_body()
def no_local_user_fallback_ui_auth_test_body(self): def no_local_user_fallback_ui_auth_test_body(self) -> None:
"""localdb_enabled can block ui auth with the local password""" """localdb_enabled can block ui auth with the local password"""
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
@ -361,10 +362,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False}, "password_config": {"enabled": False},
} }
) )
def test_password_auth_disabled_legacy(self): def test_password_auth_disabled_legacy(self) -> None:
self.password_auth_disabled_test_body() self.password_auth_disabled_test_body()
def password_auth_disabled_test_body(self): def password_auth_disabled_test_body(self) -> None:
"""password auth doesn't work if it's disabled across the board""" """password auth doesn't work if it's disabled across the board"""
# login flows should be empty # login flows should be empty
flows = self._get_login_flows() flows = self._get_login_flows()
@ -376,14 +377,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_password.assert_not_called() mock_password_provider.check_password.assert_not_called()
@override_config(legacy_providers_config(LegacyCustomAuthProvider)) @override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_login_legacy(self): def test_custom_auth_provider_login_legacy(self) -> None:
self.custom_auth_provider_login_test_body() self.custom_auth_provider_login_test_body()
@override_config(providers_config(CustomAuthProvider)) @override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_login(self): def test_custom_auth_provider_login(self) -> None:
self.custom_auth_provider_login_test_body() self.custom_auth_provider_login_test_body()
def custom_auth_provider_login_test_body(self): def custom_auth_provider_login_test_body(self) -> None:
# login flows should have the custom flow and m.login.password, since we # login flows should have the custom flow and m.login.password, since we
# haven't disabled local password lookup. # haven't disabled local password lookup.
# (password must come first, because reasons) # (password must come first, because reasons)
@ -424,14 +425,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
) )
@override_config(legacy_providers_config(LegacyCustomAuthProvider)) @override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_ui_auth_legacy(self): def test_custom_auth_provider_ui_auth_legacy(self) -> None:
self.custom_auth_provider_ui_auth_test_body() self.custom_auth_provider_ui_auth_test_body()
@override_config(providers_config(CustomAuthProvider)) @override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_ui_auth(self): def test_custom_auth_provider_ui_auth(self) -> None:
self.custom_auth_provider_ui_auth_test_body() self.custom_auth_provider_ui_auth_test_body()
def custom_auth_provider_ui_auth_test_body(self): def custom_auth_provider_ui_auth_test_body(self) -> None:
# register the user and log in twice, to get two devices # register the user and log in twice, to get two devices
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
tok1 = self.login("localuser", "localpass") tok1 = self.login("localuser", "localpass")
@ -486,14 +487,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
) )
@override_config(legacy_providers_config(LegacyCustomAuthProvider)) @override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_callback_legacy(self): def test_custom_auth_provider_callback_legacy(self) -> None:
self.custom_auth_provider_callback_test_body() self.custom_auth_provider_callback_test_body()
@override_config(providers_config(CustomAuthProvider)) @override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_callback(self): def test_custom_auth_provider_callback(self) -> None:
self.custom_auth_provider_callback_test_body() self.custom_auth_provider_callback_test_body()
def custom_auth_provider_callback_test_body(self): def custom_auth_provider_callback_test_body(self) -> None:
callback = Mock(return_value=make_awaitable(None)) callback = Mock(return_value=make_awaitable(None))
mock_password_provider.check_auth.return_value = make_awaitable( mock_password_provider.check_auth.return_value = make_awaitable(
@ -521,16 +522,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False}, "password_config": {"enabled": False},
} }
) )
def test_custom_auth_password_disabled_legacy(self): def test_custom_auth_password_disabled_legacy(self) -> None:
self.custom_auth_password_disabled_test_body() self.custom_auth_password_disabled_test_body()
@override_config( @override_config(
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}} {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
) )
def test_custom_auth_password_disabled(self): def test_custom_auth_password_disabled(self) -> None:
self.custom_auth_password_disabled_test_body() self.custom_auth_password_disabled_test_body()
def custom_auth_password_disabled_test_body(self): def custom_auth_password_disabled_test_body(self) -> None:
"""Test login with a custom auth provider where password login is disabled""" """Test login with a custom auth provider where password login is disabled"""
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
@ -548,7 +549,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False, "localdb_enabled": False}, "password_config": {"enabled": False, "localdb_enabled": False},
} }
) )
def test_custom_auth_password_disabled_localdb_enabled_legacy(self): def test_custom_auth_password_disabled_localdb_enabled_legacy(self) -> None:
self.custom_auth_password_disabled_localdb_enabled_test_body() self.custom_auth_password_disabled_localdb_enabled_test_body()
@override_config( @override_config(
@ -557,10 +558,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False, "localdb_enabled": False}, "password_config": {"enabled": False, "localdb_enabled": False},
} }
) )
def test_custom_auth_password_disabled_localdb_enabled(self): def test_custom_auth_password_disabled_localdb_enabled(self) -> None:
self.custom_auth_password_disabled_localdb_enabled_test_body() self.custom_auth_password_disabled_localdb_enabled_test_body()
def custom_auth_password_disabled_localdb_enabled_test_body(self): def custom_auth_password_disabled_localdb_enabled_test_body(self) -> None:
"""Check the localdb_enabled == enabled == False """Check the localdb_enabled == enabled == False
Regression test for https://github.com/matrix-org/synapse/issues/8914: check Regression test for https://github.com/matrix-org/synapse/issues/8914: check
@ -583,7 +584,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False}, "password_config": {"enabled": False},
} }
) )
def test_password_custom_auth_password_disabled_login_legacy(self): def test_password_custom_auth_password_disabled_login_legacy(self) -> None:
self.password_custom_auth_password_disabled_login_test_body() self.password_custom_auth_password_disabled_login_test_body()
@override_config( @override_config(
@ -592,10 +593,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False}, "password_config": {"enabled": False},
} }
) )
def test_password_custom_auth_password_disabled_login(self): def test_password_custom_auth_password_disabled_login(self) -> None:
self.password_custom_auth_password_disabled_login_test_body() self.password_custom_auth_password_disabled_login_test_body()
def password_custom_auth_password_disabled_login_test_body(self): def password_custom_auth_password_disabled_login_test_body(self) -> None:
"""log in with a custom auth provider which implements password, but password """log in with a custom auth provider which implements password, but password
login is disabled""" login is disabled"""
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
@ -615,7 +616,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False}, "password_config": {"enabled": False},
} }
) )
def test_password_custom_auth_password_disabled_ui_auth_legacy(self): def test_password_custom_auth_password_disabled_ui_auth_legacy(self) -> None:
self.password_custom_auth_password_disabled_ui_auth_test_body() self.password_custom_auth_password_disabled_ui_auth_test_body()
@override_config( @override_config(
@ -624,10 +625,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False}, "password_config": {"enabled": False},
} }
) )
def test_password_custom_auth_password_disabled_ui_auth(self): def test_password_custom_auth_password_disabled_ui_auth(self) -> None:
self.password_custom_auth_password_disabled_ui_auth_test_body() self.password_custom_auth_password_disabled_ui_auth_test_body()
def password_custom_auth_password_disabled_ui_auth_test_body(self): def password_custom_auth_password_disabled_ui_auth_test_body(self) -> None:
"""UI Auth with a custom auth provider which implements password, but password """UI Auth with a custom auth provider which implements password, but password
login is disabled""" login is disabled"""
# register the user and log in twice via the test login type to get two devices, # register the user and log in twice via the test login type to get two devices,
@ -689,7 +690,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"localdb_enabled": False}, "password_config": {"localdb_enabled": False},
} }
) )
def test_custom_auth_no_local_user_fallback_legacy(self): def test_custom_auth_no_local_user_fallback_legacy(self) -> None:
self.custom_auth_no_local_user_fallback_test_body() self.custom_auth_no_local_user_fallback_test_body()
@override_config( @override_config(
@ -698,10 +699,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"localdb_enabled": False}, "password_config": {"localdb_enabled": False},
} }
) )
def test_custom_auth_no_local_user_fallback(self): def test_custom_auth_no_local_user_fallback(self) -> None:
self.custom_auth_no_local_user_fallback_test_body() self.custom_auth_no_local_user_fallback_test_body()
def custom_auth_no_local_user_fallback_test_body(self): def custom_auth_no_local_user_fallback_test_body(self) -> None:
"""Test login with a custom auth provider where the local db is disabled""" """Test login with a custom auth provider where the local db is disabled"""
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
@ -713,14 +714,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
channel = self._send_password_login("localuser", "localpass") channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
def test_on_logged_out(self): def test_on_logged_out(self) -> None:
"""Tests that the on_logged_out callback is called when the user logs out.""" """Tests that the on_logged_out callback is called when the user logs out."""
self.register_user("rin", "password") self.register_user("rin", "password")
tok = self.login("rin", "password") tok = self.login("rin", "password")
self.called = False self.called = False
async def on_logged_out(user_id, device_id, access_token): async def on_logged_out(
user_id: str, device_id: Optional[str], access_token: str
) -> None:
self.called = True self.called = True
on_logged_out = Mock(side_effect=on_logged_out) on_logged_out = Mock(side_effect=on_logged_out)
@ -738,7 +741,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
on_logged_out.assert_called_once() on_logged_out.assert_called_once()
self.assertTrue(self.called) self.assertTrue(self.called)
def test_username(self): def test_username(self) -> None:
"""Tests that the get_username_for_registration callback can define the username """Tests that the get_username_for_registration callback can define the username
of a user when registering. of a user when registering.
""" """
@ -763,7 +766,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mxid = channel.json_body["user_id"] mxid = channel.json_body["user_id"]
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo") self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
def test_username_uia(self): def test_username_uia(self) -> None:
"""Tests that the get_username_for_registration callback is only called at the """Tests that the get_username_for_registration callback is only called at the
end of the UIA flow. end of the UIA flow.
""" """
@ -782,7 +785,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# Set some email configuration so the test doesn't fail because of its absence. # Set some email configuration so the test doesn't fail because of its absence.
@override_config({"email": {"notif_from": "noreply@test"}}) @override_config({"email": {"notif_from": "noreply@test"}})
def test_3pid_allowed(self): def test_3pid_allowed(self) -> None:
"""Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse """Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse
to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind
the 3PID. Also checks that the module is passed a boolean indicating whether the the 3PID. Also checks that the module is passed a boolean indicating whether the
@ -791,7 +794,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self._test_3pid_allowed("rin", False) self._test_3pid_allowed("rin", False)
self._test_3pid_allowed("kitay", True) self._test_3pid_allowed("kitay", True)
def test_displayname(self): def test_displayname(self) -> None:
"""Tests that the get_displayname_for_registration callback can define the """Tests that the get_displayname_for_registration callback can define the
display name of a user when registering. display name of a user when registering.
""" """
@ -820,7 +823,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(display_name, username + "-foo") self.assertEqual(display_name, username + "-foo")
def test_displayname_uia(self): def test_displayname_uia(self) -> None:
"""Tests that the get_displayname_for_registration callback is only called at the """Tests that the get_displayname_for_registration callback is only called at the
end of the UIA flow. end of the UIA flow.
""" """
@ -841,7 +844,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# Check that the callback has been called. # Check that the callback has been called.
m.assert_called_once() m.assert_called_once()
def _test_3pid_allowed(self, username: str, registration: bool): def _test_3pid_allowed(self, username: str, registration: bool) -> None:
"""Tests that the "is_3pid_allowed" module callback is called correctly, using """Tests that the "is_3pid_allowed" module callback is called correctly, using
either /register or /account URLs depending on the arguments. either /register or /account URLs depending on the arguments.
@ -907,7 +910,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
client is trying to register. client is trying to register.
""" """
async def callback(uia_results, params): async def callback(uia_results: JsonDict, params: JsonDict) -> str:
self.assertIn(LoginType.DUMMY, uia_results) self.assertIn(LoginType.DUMMY, uia_results)
username = params["username"] username = params["username"]
return username + "-foo" return username + "-foo"
@ -950,12 +953,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
def _send_password_login(self, user: str, password: str) -> FakeChannel: def _send_password_login(self, user: str, password: str) -> FakeChannel:
return self._send_login(type="m.login.password", user=user, password=password) return self._send_login(type="m.login.password", user=user, password=password)
def _send_login(self, type, user, **params) -> FakeChannel: def _send_login(self, type: str, user: str, **extra_params: str) -> FakeChannel:
params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type}) params = {"identifier": {"type": "m.id.user", "user": user}, "type": type}
params.update(extra_params)
channel = self.make_request("POST", "/_matrix/client/r0/login", params) channel = self.make_request("POST", "/_matrix/client/r0/login", params)
return channel return channel
def _start_delete_device_session(self, access_token, device_id) -> str: def _start_delete_device_session(self, access_token: str, device_id: str) -> str:
"""Make an initial delete device request, and return the UI Auth session ID""" """Make an initial delete device request, and return the UI Auth session ID"""
channel = self._delete_device(access_token, device_id) channel = self._delete_device(access_token, device_id)
self.assertEqual(channel.code, 401) self.assertEqual(channel.code, 401)

View File

@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional from typing import Optional, cast
from unittest.mock import Mock, call from unittest.mock import Mock, call
from parameterized import parameterized from parameterized import parameterized
from signedjson.key import generate_signing_key from signedjson.key import generate_signing_key
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.presence import UserPresenceState from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@ -35,7 +37,9 @@ from synapse.handlers.presence import (
) )
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import room from synapse.rest.client import room
from synapse.types import UserID, get_domain_from_id from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
@ -44,10 +48,12 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase
class PresenceUpdateTestCase(unittest.HomeserverTestCase): class PresenceUpdateTestCase(unittest.HomeserverTestCase):
servlets = [admin.register_servlets] servlets = [admin.register_servlets]
def prepare(self, reactor, clock, homeserver): def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.store = homeserver.get_datastores().main self.store = homeserver.get_datastores().main
def test_offline_to_online(self): def test_offline_to_online(self) -> None:
wheel_timer = Mock() wheel_timer = Mock()
user_id = "@foo:bar" user_id = "@foo:bar"
now = 5000000 now = 5000000
@ -85,7 +91,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True, any_order=True,
) )
def test_online_to_online(self): def test_online_to_online(self) -> None:
wheel_timer = Mock() wheel_timer = Mock()
user_id = "@foo:bar" user_id = "@foo:bar"
now = 5000000 now = 5000000
@ -128,7 +134,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True, any_order=True,
) )
def test_online_to_online_last_active_noop(self): def test_online_to_online_last_active_noop(self) -> None:
wheel_timer = Mock() wheel_timer = Mock()
user_id = "@foo:bar" user_id = "@foo:bar"
now = 5000000 now = 5000000
@ -173,7 +179,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True, any_order=True,
) )
def test_online_to_online_last_active(self): def test_online_to_online_last_active(self) -> None:
wheel_timer = Mock() wheel_timer = Mock()
user_id = "@foo:bar" user_id = "@foo:bar"
now = 5000000 now = 5000000
@ -210,7 +216,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True, any_order=True,
) )
def test_remote_ping_timer(self): def test_remote_ping_timer(self) -> None:
wheel_timer = Mock() wheel_timer = Mock()
user_id = "@foo:bar" user_id = "@foo:bar"
now = 5000000 now = 5000000
@ -244,7 +250,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True, any_order=True,
) )
def test_online_to_offline(self): def test_online_to_offline(self) -> None:
wheel_timer = Mock() wheel_timer = Mock()
user_id = "@foo:bar" user_id = "@foo:bar"
now = 5000000 now = 5000000
@ -266,7 +272,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
self.assertEqual(wheel_timer.insert.call_count, 0) self.assertEqual(wheel_timer.insert.call_count, 0)
def test_online_to_idle(self): def test_online_to_idle(self) -> None:
wheel_timer = Mock() wheel_timer = Mock()
user_id = "@foo:bar" user_id = "@foo:bar"
now = 5000000 now = 5000000
@ -300,7 +306,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True, any_order=True,
) )
def test_persisting_presence_updates(self): def test_persisting_presence_updates(self) -> None:
"""Tests that the latest presence state for each user is persisted correctly""" """Tests that the latest presence state for each user is persisted correctly"""
# Create some test users and presence states for them # Create some test users and presence states for them
presence_states = [] presence_states = []
@ -322,7 +328,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.update_presence(presence_states)) self.get_success(self.store.update_presence(presence_states))
# Check that each update is present in the database # Check that each update is present in the database
db_presence_states = self.get_success( db_presence_states_raw = self.get_success(
self.store.get_all_presence_updates( self.store.get_all_presence_updates(
instance_name="master", instance_name="master",
last_id=0, last_id=0,
@ -332,7 +338,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
) )
# Extract presence update user ID and state information into lists of tuples # Extract presence update user ID and state information into lists of tuples
db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]] db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states_raw[0]]
presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states] presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states]
# Compare what we put into the storage with what we got out. # Compare what we put into the storage with what we got out.
@ -343,7 +349,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
class PresenceTimeoutTestCase(unittest.TestCase): class PresenceTimeoutTestCase(unittest.TestCase):
"""Tests different timers and that the timer does not change `status_msg` of user.""" """Tests different timers and that the timer does not change `status_msg` of user."""
def test_idle_timer(self): def test_idle_timer(self) -> None:
user_id = "@foo:bar" user_id = "@foo:bar"
status_msg = "I'm here!" status_msg = "I'm here!"
now = 5000000 now = 5000000
@ -363,7 +369,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.UNAVAILABLE) self.assertEqual(new_state.state, PresenceState.UNAVAILABLE)
self.assertEqual(new_state.status_msg, status_msg) self.assertEqual(new_state.status_msg, status_msg)
def test_busy_no_idle(self): def test_busy_no_idle(self) -> None:
""" """
Tests that a user setting their presence to busy but idling doesn't turn their Tests that a user setting their presence to busy but idling doesn't turn their
presence state into unavailable. presence state into unavailable.
@ -387,7 +393,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.BUSY) self.assertEqual(new_state.state, PresenceState.BUSY)
self.assertEqual(new_state.status_msg, status_msg) self.assertEqual(new_state.status_msg, status_msg)
def test_sync_timeout(self): def test_sync_timeout(self) -> None:
user_id = "@foo:bar" user_id = "@foo:bar"
status_msg = "I'm here!" status_msg = "I'm here!"
now = 5000000 now = 5000000
@ -407,7 +413,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.OFFLINE) self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg) self.assertEqual(new_state.status_msg, status_msg)
def test_sync_online(self): def test_sync_online(self) -> None:
user_id = "@foo:bar" user_id = "@foo:bar"
status_msg = "I'm here!" status_msg = "I'm here!"
now = 5000000 now = 5000000
@ -429,7 +435,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.ONLINE) self.assertEqual(new_state.state, PresenceState.ONLINE)
self.assertEqual(new_state.status_msg, status_msg) self.assertEqual(new_state.status_msg, status_msg)
def test_federation_ping(self): def test_federation_ping(self) -> None:
user_id = "@foo:bar" user_id = "@foo:bar"
status_msg = "I'm here!" status_msg = "I'm here!"
now = 5000000 now = 5000000
@ -448,7 +454,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertIsNotNone(new_state) self.assertIsNotNone(new_state)
self.assertEqual(state, new_state) self.assertEqual(state, new_state)
def test_no_timeout(self): def test_no_timeout(self) -> None:
user_id = "@foo:bar" user_id = "@foo:bar"
now = 5000000 now = 5000000
@ -464,7 +470,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertIsNone(new_state) self.assertIsNone(new_state)
def test_federation_timeout(self): def test_federation_timeout(self) -> None:
user_id = "@foo:bar" user_id = "@foo:bar"
status_msg = "I'm here!" status_msg = "I'm here!"
now = 5000000 now = 5000000
@ -487,7 +493,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.OFFLINE) self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg) self.assertEqual(new_state.status_msg, status_msg)
def test_last_active(self): def test_last_active(self) -> None:
user_id = "@foo:bar" user_id = "@foo:bar"
status_msg = "I'm here!" status_msg = "I'm here!"
now = 5000000 now = 5000000
@ -508,15 +514,15 @@ class PresenceTimeoutTestCase(unittest.TestCase):
class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
def test_external_process_timeout(self): def test_external_process_timeout(self) -> None:
"""Test that if an external process doesn't update the records for a while """Test that if an external process doesn't update the records for a while
we time out their syncing users presence. we time out their syncing users presence.
""" """
process_id = 1 process_id = "1"
user_id = "@test:server" user_id = "@test:server"
# Notify handler that a user is now syncing. # Notify handler that a user is now syncing.
@ -544,7 +550,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
) )
self.assertEqual(state.state, PresenceState.OFFLINE) self.assertEqual(state.state, PresenceState.OFFLINE)
def test_user_goes_offline_by_timeout_status_msg_remain(self): def test_user_goes_offline_by_timeout_status_msg_remain(self) -> None:
"""Test that if a user doesn't update the records for a while """Test that if a user doesn't update the records for a while
users presence goes `OFFLINE` because of timeout and `status_msg` remains. users presence goes `OFFLINE` because of timeout and `status_msg` remains.
""" """
@ -576,7 +582,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(state.state, PresenceState.OFFLINE) self.assertEqual(state.state, PresenceState.OFFLINE)
self.assertEqual(state.status_msg, status_msg) self.assertEqual(state.status_msg, status_msg)
def test_user_goes_offline_manually_with_no_status_msg(self): def test_user_goes_offline_manually_with_no_status_msg(self) -> None:
"""Test that if a user change presence manually to `OFFLINE` """Test that if a user change presence manually to `OFFLINE`
and no status is set, that `status_msg` is `None`. and no status is set, that `status_msg` is `None`.
""" """
@ -601,7 +607,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(state.state, PresenceState.OFFLINE) self.assertEqual(state.state, PresenceState.OFFLINE)
self.assertEqual(state.status_msg, None) self.assertEqual(state.status_msg, None)
def test_user_goes_offline_manually_with_status_msg(self): def test_user_goes_offline_manually_with_status_msg(self) -> None:
"""Test that if a user change presence manually to `OFFLINE` """Test that if a user change presence manually to `OFFLINE`
and a status is set, that `status_msg` appears. and a status is set, that `status_msg` appears.
""" """
@ -618,7 +624,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
user_id, PresenceState.OFFLINE, "And now here." user_id, PresenceState.OFFLINE, "And now here."
) )
def test_user_reset_online_with_no_status(self): def test_user_reset_online_with_no_status(self) -> None:
"""Test that if a user set again the presence manually """Test that if a user set again the presence manually
and no status is set, that `status_msg` is `None`. and no status is set, that `status_msg` is `None`.
""" """
@ -644,7 +650,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(state.state, PresenceState.ONLINE) self.assertEqual(state.state, PresenceState.ONLINE)
self.assertEqual(state.status_msg, None) self.assertEqual(state.status_msg, None)
def test_set_presence_with_status_msg_none(self): def test_set_presence_with_status_msg_none(self) -> None:
"""Test that if a user set again the presence manually """Test that if a user set again the presence manually
and status is `None`, that `status_msg` is `None`. and status is `None`, that `status_msg` is `None`.
""" """
@ -659,7 +665,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# Mark user as online and `status_msg = None` # Mark user as online and `status_msg = None`
self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None) self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)
def test_set_presence_from_syncing_not_set(self): def test_set_presence_from_syncing_not_set(self) -> None:
"""Test that presence is not set by syncing if affect_presence is false""" """Test that presence is not set by syncing if affect_presence is false"""
user_id = "@test:server" user_id = "@test:server"
status_msg = "I'm here!" status_msg = "I'm here!"
@ -680,7 +686,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# and status message should still be the same # and status message should still be the same
self.assertEqual(state.status_msg, status_msg) self.assertEqual(state.status_msg, status_msg)
def test_set_presence_from_syncing_is_set(self): def test_set_presence_from_syncing_is_set(self) -> None:
"""Test that presence is set by syncing if affect_presence is true""" """Test that presence is set by syncing if affect_presence is true"""
user_id = "@test:server" user_id = "@test:server"
status_msg = "I'm here!" status_msg = "I'm here!"
@ -699,7 +705,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# we should now be online # we should now be online
self.assertEqual(state.state, PresenceState.ONLINE) self.assertEqual(state.state, PresenceState.ONLINE)
def test_set_presence_from_syncing_keeps_status(self): def test_set_presence_from_syncing_keeps_status(self) -> None:
"""Test that presence set by syncing retains status message""" """Test that presence set by syncing retains status message"""
user_id = "@test:server" user_id = "@test:server"
status_msg = "I'm here!" status_msg = "I'm here!"
@ -726,7 +732,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
}, },
} }
) )
def test_set_presence_from_syncing_keeps_busy(self, test_with_workers: bool): def test_set_presence_from_syncing_keeps_busy(
self, test_with_workers: bool
) -> None:
"""Test that presence set by syncing doesn't affect busy status """Test that presence set by syncing doesn't affect busy status
Args: Args:
@ -767,7 +775,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
def _set_presencestate_with_status_msg( def _set_presencestate_with_status_msg(
self, user_id: str, state: str, status_msg: Optional[str] self, user_id: str, state: str, status_msg: Optional[str]
): ) -> None:
"""Set a PresenceState and status_msg and check the result. """Set a PresenceState and status_msg and check the result.
Args: Args:
@ -790,14 +798,14 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.instance_name = hs.get_instance_name() self.instance_name = hs.get_instance_name()
self.queue = self.presence_handler.get_federation_queue() self.queue = self.presence_handler.get_federation_queue()
def test_send_and_get(self): def test_send_and_get(self) -> None:
state1 = UserPresenceState.default("@user1:test") state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test") state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test") state3 = UserPresenceState.default("@user3:test")
@ -834,7 +842,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
self.assertFalse(limited) self.assertFalse(limited)
self.assertCountEqual(rows, []) self.assertCountEqual(rows, [])
def test_send_and_get_split(self): def test_send_and_get_split(self) -> None:
state1 = UserPresenceState.default("@user1:test") state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test") state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test") state3 = UserPresenceState.default("@user3:test")
@ -877,7 +885,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
self.assertCountEqual(rows, expected_rows) self.assertCountEqual(rows, expected_rows)
def test_clear_queue_all(self): def test_clear_queue_all(self) -> None:
state1 = UserPresenceState.default("@user1:test") state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test") state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test") state3 = UserPresenceState.default("@user3:test")
@ -921,7 +929,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
self.assertCountEqual(rows, expected_rows) self.assertCountEqual(rows, expected_rows)
def test_partially_clear_queue(self): def test_partially_clear_queue(self) -> None:
state1 = UserPresenceState.default("@user1:test") state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test") state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test") state3 = UserPresenceState.default("@user3:test")
@ -982,7 +990,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
servlets = [room.register_servlets] servlets = [room.register_servlets]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
"server", "server",
federation_http_client=None, federation_http_client=None,
@ -990,14 +998,14 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
) )
return hs return hs
def default_config(self): def default_config(self) -> JsonDict:
config = super().default_config() config = super().default_config()
# Enable federation sending on the main process. # Enable federation sending on the main process.
config["federation_sender_instances"] = None config["federation_sender_instances"] = None
return config return config
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.federation_sender = hs.get_federation_sender() self.federation_sender = cast(Mock, hs.get_federation_sender())
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
self.federation_event_handler = hs.get_federation_event_handler() self.federation_event_handler = hs.get_federation_event_handler()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
@ -1013,7 +1021,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
# random key to use. # random key to use.
self.random_signing_key = generate_signing_key("ver") self.random_signing_key = generate_signing_key("ver")
def test_remote_joins(self): def test_remote_joins(self) -> None:
# We advance time to something that isn't 0, as we use 0 as a special # We advance time to something that isn't 0, as we use 0 as a special
# value. # value.
self.reactor.advance(1000000000000) self.reactor.advance(1000000000000)
@ -1061,7 +1069,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
destinations={"server3"}, states=[expected_state] destinations={"server3"}, states=[expected_state]
) )
def test_remote_gets_presence_when_local_user_joins(self): def test_remote_gets_presence_when_local_user_joins(self) -> None:
# We advance time to something that isn't 0, as we use 0 as a special # We advance time to something that isn't 0, as we use 0 as a special
# value. # value.
self.reactor.advance(1000000000000) self.reactor.advance(1000000000000)
@ -1110,7 +1118,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
destinations={"server2", "server3"}, states=[expected_state] destinations={"server2", "server3"}, states=[expected_state]
) )
def _add_new_user(self, room_id, user_id): def _add_new_user(self, room_id: str, user_id: str) -> None:
"""Add new user to the room by creating an event and poking the federation API.""" """Add new user to the room by creating an event and poking the federation API."""
hostname = get_domain_from_id(user_id) hostname = get_domain_from_id(user_id)

View File

@ -332,7 +332,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
@unittest.override_config( @unittest.override_config(
{"server_name": "test:8888", "allowed_avatar_mimetypes": ["image/png"]} {"server_name": "test:8888", "allowed_avatar_mimetypes": ["image/png"]}
) )
def test_avatar_constraint_on_local_server_with_port(self): def test_avatar_constraint_on_local_server_with_port(self) -> None:
"""Test that avatar metadata is correctly fetched when the media is on a local """Test that avatar metadata is correctly fetched when the media is on a local
server and the server has an explicit port. server and the server has an explicit port.
@ -376,7 +376,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.check_avatar_size_and_mime_type(remote_mxc)) self.get_success(self.handler.check_avatar_size_and_mime_type(remote_mxc))
) )
def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]): def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None:
"""Stores metadata about files in the database. """Stores metadata about files in the database.
Args: Args:

View File

@ -15,14 +15,18 @@
from copy import deepcopy from copy import deepcopy
from typing import List from typing import List
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes, ReceiptTypes from synapse.api.constants import EduTypes, ReceiptTypes
from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest from tests import unittest
class ReceiptsTestCase(unittest.HomeserverTestCase): class ReceiptsTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.event_source = hs.get_event_sources().sources.receipt self.event_source = hs.get_event_sources().sources.receipt
def test_filters_out_private_receipt(self) -> None: def test_filters_out_private_receipt(self) -> None:

View File

@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Collection, List, Optional, Tuple
from unittest.mock import Mock from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import ( from synapse.api.errors import (
@ -22,8 +25,18 @@ from synapse.api.errors import (
ResourceLimitError, ResourceLimitError,
SynapseError, SynapseError,
) )
from synapse.module_api import ModuleApi
from synapse.server import HomeServer
from synapse.spam_checker_api import RegistrationBehaviour from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, RoomID, UserID, create_requester from synapse.types import (
JsonDict,
Requester,
RoomAlias,
RoomID,
UserID,
create_requester,
)
from synapse.util import Clock
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
from tests.unittest import override_config from tests.unittest import override_config
@ -33,94 +46,98 @@ from .. import unittest
class TestSpamChecker: class TestSpamChecker:
def __init__(self, config, api): def __init__(self, config: None, api: ModuleApi):
api.register_spam_checker_callbacks( api.register_spam_checker_callbacks(
check_registration_for_spam=self.check_registration_for_spam, check_registration_for_spam=self.check_registration_for_spam,
) )
@staticmethod @staticmethod
def parse_config(config): def parse_config(config: JsonDict) -> None:
return config return None
async def check_registration_for_spam( async def check_registration_for_spam(
self, self,
email_threepid, email_threepid: Optional[dict],
username, username: Optional[str],
request_info, request_info: Collection[Tuple[str, str]],
auth_provider_id, auth_provider_id: Optional[str],
): ) -> RegistrationBehaviour:
pass pass
class DenyAll(TestSpamChecker): class DenyAll(TestSpamChecker):
async def check_registration_for_spam( async def check_registration_for_spam(
self, self,
email_threepid, email_threepid: Optional[dict],
username, username: Optional[str],
request_info, request_info: Collection[Tuple[str, str]],
auth_provider_id, auth_provider_id: Optional[str],
): ) -> RegistrationBehaviour:
return RegistrationBehaviour.DENY return RegistrationBehaviour.DENY
class BanAll(TestSpamChecker): class BanAll(TestSpamChecker):
async def check_registration_for_spam( async def check_registration_for_spam(
self, self,
email_threepid, email_threepid: Optional[dict],
username, username: Optional[str],
request_info, request_info: Collection[Tuple[str, str]],
auth_provider_id, auth_provider_id: Optional[str],
): ) -> RegistrationBehaviour:
return RegistrationBehaviour.SHADOW_BAN return RegistrationBehaviour.SHADOW_BAN
class BanBadIdPUser(TestSpamChecker): class BanBadIdPUser(TestSpamChecker):
async def check_registration_for_spam( async def check_registration_for_spam(
self, email_threepid, username, request_info, auth_provider_id=None self,
): email_threepid: Optional[dict],
username: Optional[str],
request_info: Collection[Tuple[str, str]],
auth_provider_id: Optional[str] = None,
) -> RegistrationBehaviour:
# Reject any user coming from CAS and whose username contains profanity # Reject any user coming from CAS and whose username contains profanity
if auth_provider_id == "cas" and "flimflob" in username: if auth_provider_id == "cas" and username and "flimflob" in username:
return RegistrationBehaviour.DENY return RegistrationBehaviour.DENY
return RegistrationBehaviour.ALLOW return RegistrationBehaviour.ALLOW
class TestLegacyRegistrationSpamChecker: class TestLegacyRegistrationSpamChecker:
def __init__(self, config, api): def __init__(self, config: None, api: ModuleApi):
pass pass
async def check_registration_for_spam( async def check_registration_for_spam(
self, self,
email_threepid, email_threepid: Optional[dict],
username, username: Optional[str],
request_info, request_info: Collection[Tuple[str, str]],
): ) -> RegistrationBehaviour:
pass pass
class LegacyAllowAll(TestLegacyRegistrationSpamChecker): class LegacyAllowAll(TestLegacyRegistrationSpamChecker):
async def check_registration_for_spam( async def check_registration_for_spam(
self, self,
email_threepid, email_threepid: Optional[dict],
username, username: Optional[str],
request_info, request_info: Collection[Tuple[str, str]],
): ) -> RegistrationBehaviour:
return RegistrationBehaviour.ALLOW return RegistrationBehaviour.ALLOW
class LegacyDenyAll(TestLegacyRegistrationSpamChecker): class LegacyDenyAll(TestLegacyRegistrationSpamChecker):
async def check_registration_for_spam( async def check_registration_for_spam(
self, self,
email_threepid, email_threepid: Optional[dict],
username, username: Optional[str],
request_info, request_info: Collection[Tuple[str, str]],
): ) -> RegistrationBehaviour:
return RegistrationBehaviour.DENY return RegistrationBehaviour.DENY
class RegistrationTestCase(unittest.HomeserverTestCase): class RegistrationTestCase(unittest.HomeserverTestCase):
"""Tests the RegistrationHandler.""" """Tests the RegistrationHandler."""
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs_config = self.default_config() hs_config = self.default_config()
# some of the tests rely on us having a user consent version # some of the tests rely on us having a user consent version
@ -145,7 +162,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
return hs return hs
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = self.hs.get_registration_handler() self.handler = self.hs.get_registration_handler()
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main
self.lots_of_users = 100 self.lots_of_users = 100
@ -153,7 +170,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.requester = create_requester("@requester:test") self.requester = create_requester("@requester:test")
def test_user_is_created_and_logged_in_if_doesnt_exist(self): def test_user_is_created_and_logged_in_if_doesnt_exist(self) -> None:
frank = UserID.from_string("@frank:test") frank = UserID.from_string("@frank:test")
user_id = frank.to_string() user_id = frank.to_string()
requester = create_requester(user_id) requester = create_requester(user_id)
@ -164,7 +181,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertIsInstance(result_token, str) self.assertIsInstance(result_token, str)
self.assertGreater(len(result_token), 20) self.assertGreater(len(result_token), 20)
def test_if_user_exists(self): def test_if_user_exists(self) -> None:
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
frank = UserID.from_string("@frank:test") frank = UserID.from_string("@frank:test")
self.get_success( self.get_success(
@ -180,12 +197,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertTrue(result_token is not None) self.assertTrue(result_token is not None)
@override_config({"limit_usage_by_mau": False}) @override_config({"limit_usage_by_mau": False})
def test_mau_limits_when_disabled(self): def test_mau_limits_when_disabled(self) -> None:
# Ensure does not throw exception # Ensure does not throw exception
self.get_success(self.get_or_create_user(self.requester, "a", "display_name")) self.get_success(self.get_or_create_user(self.requester, "a", "display_name"))
@override_config({"limit_usage_by_mau": True}) @override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_not_blocked(self): def test_get_or_create_user_mau_not_blocked(self) -> None:
self.store.count_monthly_users = Mock( self.store.count_monthly_users = Mock(
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
) )
@ -193,7 +210,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.get_success(self.get_or_create_user(self.requester, "c", "User")) self.get_success(self.get_or_create_user(self.requester, "c", "User"))
@override_config({"limit_usage_by_mau": True}) @override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_blocked(self): def test_get_or_create_user_mau_blocked(self) -> None:
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.lots_of_users) return_value=make_awaitable(self.lots_of_users)
) )
@ -211,7 +228,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
) )
@override_config({"limit_usage_by_mau": True}) @override_config({"limit_usage_by_mau": True})
def test_register_mau_blocked(self): def test_register_mau_blocked(self) -> None:
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.lots_of_users) return_value=make_awaitable(self.lots_of_users)
) )
@ -229,7 +246,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config( @override_config(
{"auto_join_rooms": ["#room:test"], "auto_join_rooms_for_guests": False} {"auto_join_rooms": ["#room:test"], "auto_join_rooms_for_guests": False}
) )
def test_auto_join_rooms_for_guests(self): def test_auto_join_rooms_for_guests(self) -> None:
user_id = self.get_success( user_id = self.get_success(
self.handler.register_user(localpart="jeff", make_guest=True), self.handler.register_user(localpart="jeff", make_guest=True),
) )
@ -237,7 +254,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(rooms), 0) self.assertEqual(len(rooms), 0)
@override_config({"auto_join_rooms": ["#room:test"]}) @override_config({"auto_join_rooms": ["#room:test"]})
def test_auto_create_auto_join_rooms(self): def test_auto_create_auto_join_rooms(self) -> None:
room_alias_str = "#room:test" room_alias_str = "#room:test"
user_id = self.get_success(self.handler.register_user(localpart="jeff")) user_id = self.get_success(self.handler.register_user(localpart="jeff"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@ -249,7 +266,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(rooms), 1) self.assertEqual(len(rooms), 1)
@override_config({"auto_join_rooms": []}) @override_config({"auto_join_rooms": []})
def test_auto_create_auto_join_rooms_with_no_rooms(self): def test_auto_create_auto_join_rooms_with_no_rooms(self) -> None:
frank = UserID.from_string("@frank:test") frank = UserID.from_string("@frank:test")
user_id = self.get_success(self.handler.register_user(frank.localpart)) user_id = self.get_success(self.handler.register_user(frank.localpart))
self.assertEqual(user_id, frank.to_string()) self.assertEqual(user_id, frank.to_string())
@ -257,7 +274,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(rooms), 0) self.assertEqual(len(rooms), 0)
@override_config({"auto_join_rooms": ["#room:another"]}) @override_config({"auto_join_rooms": ["#room:another"]})
def test_auto_create_auto_join_where_room_is_another_domain(self): def test_auto_create_auto_join_where_room_is_another_domain(self) -> None:
frank = UserID.from_string("@frank:test") frank = UserID.from_string("@frank:test")
user_id = self.get_success(self.handler.register_user(frank.localpart)) user_id = self.get_success(self.handler.register_user(frank.localpart))
self.assertEqual(user_id, frank.to_string()) self.assertEqual(user_id, frank.to_string())
@ -267,13 +284,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config( @override_config(
{"auto_join_rooms": ["#room:test"], "autocreate_auto_join_rooms": False} {"auto_join_rooms": ["#room:test"], "autocreate_auto_join_rooms": False}
) )
def test_auto_create_auto_join_where_auto_create_is_false(self): def test_auto_create_auto_join_where_auto_create_is_false(self) -> None:
user_id = self.get_success(self.handler.register_user(localpart="jeff")) user_id = self.get_success(self.handler.register_user(localpart="jeff"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0) self.assertEqual(len(rooms), 0)
@override_config({"auto_join_rooms": ["#room:test"]}) @override_config({"auto_join_rooms": ["#room:test"]})
def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self): def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self) -> None:
room_alias_str = "#room:test" room_alias_str = "#room:test"
self.store.is_real_user = Mock(return_value=make_awaitable(False)) self.store.is_real_user = Mock(return_value=make_awaitable(False))
user_id = self.get_success(self.handler.register_user(localpart="support")) user_id = self.get_success(self.handler.register_user(localpart="support"))
@ -284,7 +301,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.get_failure(directory_handler.get_association(room_alias), SynapseError) self.get_failure(directory_handler.get_association(room_alias), SynapseError)
@override_config({"auto_join_rooms": ["#room:test"]}) @override_config({"auto_join_rooms": ["#room:test"]})
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self): def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None:
room_alias_str = "#room:test" room_alias_str = "#room:test"
self.store.count_real_users = Mock(return_value=make_awaitable(1)) self.store.count_real_users = Mock(return_value=make_awaitable(1))
@ -299,7 +316,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(rooms), 1) self.assertEqual(len(rooms), 1)
@override_config({"auto_join_rooms": ["#room:test"]}) @override_config({"auto_join_rooms": ["#room:test"]})
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(self): def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(
self,
) -> None:
self.store.count_real_users = Mock(return_value=make_awaitable(2)) self.store.count_real_users = Mock(return_value=make_awaitable(2))
self.store.is_real_user = Mock(return_value=make_awaitable(True)) self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real")) user_id = self.get_success(self.handler.register_user(localpart="real"))
@ -312,7 +331,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"autocreate_auto_join_rooms_federated": False, "autocreate_auto_join_rooms_federated": False,
} }
) )
def test_auto_create_auto_join_rooms_federated(self): def test_auto_create_auto_join_rooms_federated(self) -> None:
""" """
Auto-created rooms that are private require an invite to go to the user Auto-created rooms that are private require an invite to go to the user
(instead of directly joining it). (instead of directly joining it).
@ -339,7 +358,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config( @override_config(
{"auto_join_rooms": ["#room:test"], "auto_join_mxid_localpart": "support"} {"auto_join_rooms": ["#room:test"], "auto_join_mxid_localpart": "support"}
) )
def test_auto_join_mxid_localpart(self): def test_auto_join_mxid_localpart(self) -> None:
""" """
Ensure the user still needs up in the room created by a different user. Ensure the user still needs up in the room created by a different user.
""" """
@ -376,7 +395,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"auto_join_mxid_localpart": "support", "auto_join_mxid_localpart": "support",
} }
) )
def test_auto_create_auto_join_room_preset(self): def test_auto_create_auto_join_room_preset(self) -> None:
""" """
Auto-created rooms that are private require an invite to go to the user Auto-created rooms that are private require an invite to go to the user
(instead of directly joining it). (instead of directly joining it).
@ -416,7 +435,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"auto_join_mxid_localpart": "support", "auto_join_mxid_localpart": "support",
} }
) )
def test_auto_create_auto_join_room_preset_guest(self): def test_auto_create_auto_join_room_preset_guest(self) -> None:
""" """
Auto-created rooms that are private require an invite to go to the user Auto-created rooms that are private require an invite to go to the user
(instead of directly joining it). (instead of directly joining it).
@ -454,7 +473,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"auto_join_mxid_localpart": "support", "auto_join_mxid_localpart": "support",
} }
) )
def test_auto_create_auto_join_room_preset_invalid_permissions(self): def test_auto_create_auto_join_room_preset_invalid_permissions(self) -> None:
""" """
Auto-created rooms that are private require an invite, check that Auto-created rooms that are private require an invite, check that
registration doesn't completely break if the inviter doesn't have proper registration doesn't completely break if the inviter doesn't have proper
@ -525,7 +544,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"auto_join_rooms": ["#room:test"], "auto_join_rooms": ["#room:test"],
}, },
) )
def test_auto_create_auto_join_where_no_consent(self): def test_auto_create_auto_join_where_no_consent(self) -> None:
"""Test to ensure that the first user is not auto-joined to a room if """Test to ensure that the first user is not auto-joined to a room if
they have not given general consent. they have not given general consent.
""" """
@ -550,19 +569,19 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 1) self.assertEqual(len(rooms), 1)
def test_register_support_user(self): def test_register_support_user(self) -> None:
user_id = self.get_success( user_id = self.get_success(
self.handler.register_user(localpart="user", user_type=UserTypes.SUPPORT) self.handler.register_user(localpart="user", user_type=UserTypes.SUPPORT)
) )
d = self.store.is_support_user(user_id) d = self.store.is_support_user(user_id)
self.assertTrue(self.get_success(d)) self.assertTrue(self.get_success(d))
def test_register_not_support_user(self): def test_register_not_support_user(self) -> None:
user_id = self.get_success(self.handler.register_user(localpart="user")) user_id = self.get_success(self.handler.register_user(localpart="user"))
d = self.store.is_support_user(user_id) d = self.store.is_support_user(user_id)
self.assertFalse(self.get_success(d)) self.assertFalse(self.get_success(d))
def test_invalid_user_id_length(self): def test_invalid_user_id_length(self) -> None:
invalid_user_id = "x" * 256 invalid_user_id = "x" * 256
self.get_failure( self.get_failure(
self.handler.register_user(localpart=invalid_user_id), SynapseError self.handler.register_user(localpart=invalid_user_id), SynapseError
@ -577,7 +596,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
] ]
} }
) )
def test_spam_checker_deny(self): def test_spam_checker_deny(self) -> None:
"""A spam checker can deny registration, which results in an error.""" """A spam checker can deny registration, which results in an error."""
self.get_failure(self.handler.register_user(localpart="user"), SynapseError) self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
@ -590,7 +609,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
] ]
} }
) )
def test_spam_checker_legacy_allow(self): def test_spam_checker_legacy_allow(self) -> None:
"""Tests that a legacy spam checker implementing the legacy 3-arg version of the """Tests that a legacy spam checker implementing the legacy 3-arg version of the
check_registration_for_spam callback is correctly called. check_registration_for_spam callback is correctly called.
@ -610,7 +629,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
] ]
} }
) )
def test_spam_checker_legacy_deny(self): def test_spam_checker_legacy_deny(self) -> None:
"""Tests that a legacy spam checker implementing the legacy 3-arg version of the """Tests that a legacy spam checker implementing the legacy 3-arg version of the
check_registration_for_spam callback is correctly called. check_registration_for_spam callback is correctly called.
@ -630,7 +649,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
] ]
} }
) )
def test_spam_checker_shadow_ban(self): def test_spam_checker_shadow_ban(self) -> None:
"""A spam checker can choose to shadow-ban a user, which allows registration to succeed.""" """A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
user_id = self.get_success(self.handler.register_user(localpart="user")) user_id = self.get_success(self.handler.register_user(localpart="user"))
@ -660,7 +679,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
] ]
} }
) )
def test_spam_checker_receives_sso_type(self): def test_spam_checker_receives_sso_type(self) -> None:
"""Test rejecting registration based on SSO type""" """Test rejecting registration based on SSO type"""
f = self.get_failure( f = self.get_failure(
self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"), self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"),
@ -678,8 +697,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
) )
async def get_or_create_user( async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None self,
): requester: Requester,
localpart: str,
displayname: Optional[str],
password_hash: Optional[str] = None,
) -> Tuple[str, str]:
"""Creates a new user if the user does not exist, """Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one. else revokes all previous access tokens and generates a new one.
@ -734,13 +757,15 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
class RemoteAutoJoinTestCase(unittest.HomeserverTestCase): class RemoteAutoJoinTestCase(unittest.HomeserverTestCase):
"""Tests auto-join on remote rooms.""" """Tests auto-join on remote rooms."""
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.room_id = "!roomid:remotetest" self.room_id = "!roomid:remotetest"
async def update_membership(*args, **kwargs): async def update_membership(*args: Any, **kwargs: Any) -> None:
pass pass
async def lookup_room_alias(*args, **kwargs): async def lookup_room_alias(
*args: Any, **kwargs: Any
) -> Tuple[RoomID, List[str]]:
return RoomID.from_string(self.room_id), ["remotetest"] return RoomID.from_string(self.room_id), ["remotetest"]
self.room_member_handler = Mock(spec=["update_membership", "lookup_room_alias"]) self.room_member_handler = Mock(spec=["update_membership", "lookup_room_alias"])
@ -750,12 +775,12 @@ class RemoteAutoJoinTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(room_member_handler=self.room_member_handler) hs = self.setup_test_homeserver(room_member_handler=self.room_member_handler)
return hs return hs
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = self.hs.get_registration_handler() self.handler = self.hs.get_registration_handler()
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main
@override_config({"auto_join_rooms": ["#room:remotetest"]}) @override_config({"auto_join_rooms": ["#room:remotetest"]})
def test_auto_create_auto_join_remote_room(self): def test_auto_create_auto_join_remote_room(self) -> None:
"""Tests that we don't attempt to create remote rooms, and that we don't attempt """Tests that we don't attempt to create remote rooms, and that we don't attempt
to invite ourselves to rooms we're not in.""" to invite ourselves to rooms we're not in."""

View File

@ -14,7 +14,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase):
] ]
@override_config({"encryption_enabled_by_default_for_room_type": "all"}) @override_config({"encryption_enabled_by_default_for_room_type": "all"})
def test_encrypted_by_default_config_option_all(self): def test_encrypted_by_default_config_option_all(self) -> None:
"""Tests that invite-only and non-invite-only rooms have encryption enabled by """Tests that invite-only and non-invite-only rooms have encryption enabled by
default when the config option encryption_enabled_by_default_for_room_type is "all". default when the config option encryption_enabled_by_default_for_room_type is "all".
""" """
@ -45,7 +45,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase):
self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
@override_config({"encryption_enabled_by_default_for_room_type": "invite"}) @override_config({"encryption_enabled_by_default_for_room_type": "invite"})
def test_encrypted_by_default_config_option_invite(self): def test_encrypted_by_default_config_option_invite(self) -> None:
"""Tests that only new, invite-only rooms have encryption enabled by default when """Tests that only new, invite-only rooms have encryption enabled by default when
the config option encryption_enabled_by_default_for_room_type is "invite". the config option encryption_enabled_by_default_for_room_type is "invite".
""" """
@ -76,7 +76,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase):
) )
@override_config({"encryption_enabled_by_default_for_room_type": "off"}) @override_config({"encryption_enabled_by_default_for_room_type": "off"})
def test_encrypted_by_default_config_option_off(self): def test_encrypted_by_default_config_option_off(self) -> None:
"""Tests that neither new invite-only nor non-invite-only rooms have encryption """Tests that neither new invite-only nor non-invite-only rooms have encryption
enabled by default when the config option enabled by default when the config option
encryption_enabled_by_default_for_room_type is "off". encryption_enabled_by_default_for_room_type is "off".

View File

@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from unittest import mock from unittest import mock
from twisted.internet.defer import ensureDeferred from twisted.internet.defer import ensureDeferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import ( from synapse.api.constants import (
EventContentFields, EventContentFields,
@ -34,11 +35,14 @@ from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from tests import unittest from tests import unittest
def _create_event(room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0): def _create_event(
room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0
) -> mock.Mock:
result = mock.Mock(name=room_id) result = mock.Mock(name=room_id)
result.room_id = room_id result.room_id = room_id
result.content = {} result.content = {}
@ -48,40 +52,40 @@ def _create_event(room_id: str, order: Optional[Any] = None, origin_server_ts: i
return result return result
def _order(*events): def _order(*events: mock.Mock) -> List[mock.Mock]:
return sorted(events, key=_child_events_comparison_key) return sorted(events, key=_child_events_comparison_key)
class TestSpaceSummarySort(unittest.TestCase): class TestSpaceSummarySort(unittest.TestCase):
def test_no_order_last(self): def test_no_order_last(self) -> None:
"""An event with no ordering is placed behind those with an ordering.""" """An event with no ordering is placed behind those with an ordering."""
ev1 = _create_event("!abc:test") ev1 = _create_event("!abc:test")
ev2 = _create_event("!xyz:test", "xyz") ev2 = _create_event("!xyz:test", "xyz")
self.assertEqual([ev2, ev1], _order(ev1, ev2)) self.assertEqual([ev2, ev1], _order(ev1, ev2))
def test_order(self): def test_order(self) -> None:
"""The ordering should be used.""" """The ordering should be used."""
ev1 = _create_event("!abc:test", "xyz") ev1 = _create_event("!abc:test", "xyz")
ev2 = _create_event("!xyz:test", "abc") ev2 = _create_event("!xyz:test", "abc")
self.assertEqual([ev2, ev1], _order(ev1, ev2)) self.assertEqual([ev2, ev1], _order(ev1, ev2))
def test_order_origin_server_ts(self): def test_order_origin_server_ts(self) -> None:
"""Origin server is a tie-breaker for ordering.""" """Origin server is a tie-breaker for ordering."""
ev1 = _create_event("!abc:test", origin_server_ts=10) ev1 = _create_event("!abc:test", origin_server_ts=10)
ev2 = _create_event("!xyz:test", origin_server_ts=30) ev2 = _create_event("!xyz:test", origin_server_ts=30)
self.assertEqual([ev1, ev2], _order(ev1, ev2)) self.assertEqual([ev1, ev2], _order(ev1, ev2))
def test_order_room_id(self): def test_order_room_id(self) -> None:
"""Room ID is a final tie-breaker for ordering.""" """Room ID is a final tie-breaker for ordering."""
ev1 = _create_event("!abc:test") ev1 = _create_event("!abc:test")
ev2 = _create_event("!xyz:test") ev2 = _create_event("!xyz:test")
self.assertEqual([ev1, ev2], _order(ev1, ev2)) self.assertEqual([ev1, ev2], _order(ev1, ev2))
def test_invalid_ordering_type(self): def test_invalid_ordering_type(self) -> None:
"""Invalid orderings are considered the same as missing.""" """Invalid orderings are considered the same as missing."""
ev1 = _create_event("!abc:test", 1) ev1 = _create_event("!abc:test", 1)
ev2 = _create_event("!xyz:test", "xyz") ev2 = _create_event("!xyz:test", "xyz")
@ -97,7 +101,7 @@ class TestSpaceSummarySort(unittest.TestCase):
ev1 = _create_event("!abc:test", True) ev1 = _create_event("!abc:test", True)
self.assertEqual([ev2, ev1], _order(ev1, ev2)) self.assertEqual([ev2, ev1], _order(ev1, ev2))
def test_invalid_ordering_value(self): def test_invalid_ordering_value(self) -> None:
"""Invalid orderings are considered the same as missing.""" """Invalid orderings are considered the same as missing."""
ev1 = _create_event("!abc:test", "foo\n") ev1 = _create_event("!abc:test", "foo\n")
ev2 = _create_event("!xyz:test", "xyz") ev2 = _create_event("!xyz:test", "xyz")
@ -115,7 +119,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, hs: HomeServer): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.hs = hs self.hs = hs
self.handler = self.hs.get_room_summary_handler() self.handler = self.hs.get_room_summary_handler()
@ -223,7 +227,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6) fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6)
) )
def test_simple_space(self): def test_simple_space(self) -> None:
"""Test a simple space with a single room.""" """Test a simple space with a single room."""
# The result should have the space and the room in it, along with a link # The result should have the space and the room in it, along with a link
# from space -> room. # from space -> room.
@ -234,7 +238,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
) )
self._assert_hierarchy(result, expected) self._assert_hierarchy(result, expected)
def test_large_space(self): def test_large_space(self) -> None:
"""Test a space with a large number of rooms.""" """Test a space with a large number of rooms."""
rooms = [self.room] rooms = [self.room]
# Make at least 51 rooms that are part of the space. # Make at least 51 rooms that are part of the space.
@ -260,7 +264,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
result["rooms"] += result2["rooms"] result["rooms"] += result2["rooms"]
self._assert_hierarchy(result, expected) self._assert_hierarchy(result, expected)
def test_visibility(self): def test_visibility(self) -> None:
"""A user not in a space cannot inspect it.""" """A user not in a space cannot inspect it."""
user2 = self.register_user("user2", "pass") user2 = self.register_user("user2", "pass")
token2 = self.login("user2", "pass") token2 = self.login("user2", "pass")
@ -380,7 +384,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._assert_hierarchy(result2, [(self.space, [self.room])]) self._assert_hierarchy(result2, [(self.space, [self.room])])
def _create_room_with_join_rule( def _create_room_with_join_rule(
self, join_rule: str, room_version: Optional[str] = None, **extra_content self, join_rule: str, room_version: Optional[str] = None, **extra_content: Any
) -> str: ) -> str:
"""Create a room with the given join rule and add it to the space.""" """Create a room with the given join rule and add it to the space."""
room_id = self.helper.create_room_as( room_id = self.helper.create_room_as(
@ -403,7 +407,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._add_child(self.space, room_id, self.token) self._add_child(self.space, room_id, self.token)
return room_id return room_id
def test_filtering(self): def test_filtering(self) -> None:
""" """
Rooms should be properly filtered to only include rooms the user has access to. Rooms should be properly filtered to only include rooms the user has access to.
""" """
@ -476,7 +480,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
) )
self._assert_hierarchy(result, expected) self._assert_hierarchy(result, expected)
def test_complex_space(self): def test_complex_space(self) -> None:
""" """
Create a "complex" space to see how it handles things like loops and subspaces. Create a "complex" space to see how it handles things like loops and subspaces.
""" """
@ -516,7 +520,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
) )
self._assert_hierarchy(result, expected) self._assert_hierarchy(result, expected)
def test_pagination(self): def test_pagination(self) -> None:
"""Test simple pagination works.""" """Test simple pagination works."""
room_ids = [] room_ids = []
for i in range(1, 10): for i in range(1, 10):
@ -553,7 +557,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._assert_hierarchy(result, expected) self._assert_hierarchy(result, expected)
self.assertNotIn("next_batch", result) self.assertNotIn("next_batch", result)
def test_invalid_pagination_token(self): def test_invalid_pagination_token(self) -> None:
"""An invalid pagination token, or changing other parameters, shoudl be rejected.""" """An invalid pagination token, or changing other parameters, shoudl be rejected."""
room_ids = [] room_ids = []
for i in range(1, 10): for i in range(1, 10):
@ -604,7 +608,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
SynapseError, SynapseError,
) )
def test_max_depth(self): def test_max_depth(self) -> None:
"""Create a deep tree to test the max depth against.""" """Create a deep tree to test the max depth against."""
spaces = [self.space] spaces = [self.space]
rooms = [self.room] rooms = [self.room]
@ -659,7 +663,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
] ]
self._assert_hierarchy(result, expected) self._assert_hierarchy(result, expected)
def test_unknown_room_version(self): def test_unknown_room_version(self) -> None:
""" """
If a room with an unknown room version is encountered it should not cause If a room with an unknown room version is encountered it should not cause
the entire summary to skip. the entire summary to skip.
@ -685,7 +689,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
) )
self._assert_hierarchy(result, expected) self._assert_hierarchy(result, expected)
def test_fed_complex(self): def test_fed_complex(self) -> None:
""" """
Return data over federation and ensure that it is handled properly. Return data over federation and ensure that it is handled properly.
""" """
@ -722,7 +726,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
"world_readable": True, "world_readable": True,
} }
async def summarize_remote_room_hierarchy(_self, room, suggested_only): async def summarize_remote_room_hierarchy(
_self: Any, room: Any, suggested_only: bool
) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
return requested_room_entry, {subroom: child_room}, set() return requested_room_entry, {subroom: child_room}, set()
# Add a room to the space which is on another server. # Add a room to the space which is on another server.
@ -744,7 +750,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
) )
self._assert_hierarchy(result, expected) self._assert_hierarchy(result, expected)
def test_fed_filtering(self): def test_fed_filtering(self) -> None:
""" """
Rooms returned over federation should be properly filtered to only include Rooms returned over federation should be properly filtered to only include
rooms the user has access to. rooms the user has access to.
@ -853,7 +859,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
], ],
) )
async def summarize_remote_room_hierarchy(_self, room, suggested_only): async def summarize_remote_room_hierarchy(
_self: Any, room: Any, suggested_only: bool
) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
return subspace_room_entry, dict(children_rooms), set() return subspace_room_entry, dict(children_rooms), set()
# Add a room to the space which is on another server. # Add a room to the space which is on another server.
@ -892,7 +900,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
) )
self._assert_hierarchy(result, expected) self._assert_hierarchy(result, expected)
def test_fed_invited(self): def test_fed_invited(self) -> None:
""" """
A room which the user was invited to should be included in the response. A room which the user was invited to should be included in the response.
@ -915,7 +923,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
}, },
) )
async def summarize_remote_room_hierarchy(_self, room, suggested_only): async def summarize_remote_room_hierarchy(
_self: Any, room: Any, suggested_only: bool
) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
return fed_room_entry, {}, set() return fed_room_entry, {}, set()
# Add a room to the space which is on another server. # Add a room to the space which is on another server.
@ -936,7 +946,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
) )
self._assert_hierarchy(result, expected) self._assert_hierarchy(result, expected)
def test_fed_caching(self): def test_fed_caching(self) -> None:
""" """
Federation `/hierarchy` responses should be cached. Federation `/hierarchy` responses should be cached.
""" """
@ -1023,7 +1033,7 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, hs: HomeServer): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.hs = hs self.hs = hs
self.handler = self.hs.get_room_summary_handler() self.handler = self.hs.get_room_summary_handler()
@ -1040,12 +1050,12 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
tok=self.token, tok=self.token,
) )
def test_own_room(self): def test_own_room(self) -> None:
"""Test a simple room created by the requester.""" """Test a simple room created by the requester."""
result = self.get_success(self.handler.get_room_summary(self.user, self.room)) result = self.get_success(self.handler.get_room_summary(self.user, self.room))
self.assertEqual(result.get("room_id"), self.room) self.assertEqual(result.get("room_id"), self.room)
def test_visibility(self): def test_visibility(self) -> None:
"""A user not in a private room cannot get its summary.""" """A user not in a private room cannot get its summary."""
user2 = self.register_user("user2", "pass") user2 = self.register_user("user2", "pass")
token2 = self.login("user2", "pass") token2 = self.login("user2", "pass")
@ -1093,7 +1103,7 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
result = self.get_success(self.handler.get_room_summary(user2, self.room)) result = self.get_success(self.handler.get_room_summary(user2, self.room))
self.assertEqual(result.get("room_id"), self.room) self.assertEqual(result.get("room_id"), self.room)
def test_fed(self): def test_fed(self) -> None:
""" """
Return data over federation and ensure that it is handled properly. Return data over federation and ensure that it is handled properly.
""" """
@ -1105,7 +1115,9 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
{"room_id": fed_room, "world_readable": True}, {"room_id": fed_room, "world_readable": True},
) )
async def summarize_remote_room_hierarchy(_self, room, suggested_only): async def summarize_remote_room_hierarchy(
_self: Any, room: Any, suggested_only: bool
) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
return requested_room_entry, {}, set() return requested_room_entry, {}, set()
with mock.patch( with mock.patch(

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Set, Tuple
from unittest.mock import Mock from unittest.mock import Mock
import attr import attr
@ -20,7 +20,9 @@ import attr
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import RedirectException from synapse.api.errors import RedirectException
from synapse.module_api import ModuleApi
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
from tests.test_utils import simple_async_mock from tests.test_utils import simple_async_mock
@ -29,6 +31,7 @@ from tests.unittest import HomeserverTestCase, override_config
# Check if we have the dependencies to run the tests. # Check if we have the dependencies to run the tests.
try: try:
import saml2.config import saml2.config
import saml2.response
from saml2.sigver import SigverError from saml2.sigver import SigverError
has_saml2 = True has_saml2 = True
@ -56,31 +59,39 @@ class FakeAuthnResponse:
class TestMappingProvider: class TestMappingProvider:
def __init__(self, config, module): def __init__(self, config: None, module: ModuleApi):
pass pass
@staticmethod @staticmethod
def parse_config(config): def parse_config(config: JsonDict) -> None:
return return None
@staticmethod @staticmethod
def get_saml_attributes(config): def get_saml_attributes(config: None) -> Tuple[Set[str], Set[str]]:
return {"uid"}, {"displayName"} return {"uid"}, {"displayName"}
def get_remote_user_id(self, saml_response, client_redirect_url): def get_remote_user_id(
self, saml_response: "saml2.response.AuthnResponse", client_redirect_url: str
) -> str:
return saml_response.ava["uid"] return saml_response.ava["uid"]
def saml_response_to_user_attributes( def saml_response_to_user_attributes(
self, saml_response, failures, client_redirect_url self,
): saml_response: "saml2.response.AuthnResponse",
failures: int,
client_redirect_url: str,
) -> dict:
localpart = saml_response.ava["username"] + (str(failures) if failures else "") localpart = saml_response.ava["username"] + (str(failures) if failures else "")
return {"mxid_localpart": localpart, "displayname": None} return {"mxid_localpart": localpart, "displayname": None}
class TestRedirectMappingProvider(TestMappingProvider): class TestRedirectMappingProvider(TestMappingProvider):
def saml_response_to_user_attributes( def saml_response_to_user_attributes(
self, saml_response, failures, client_redirect_url self,
): saml_response: "saml2.response.AuthnResponse",
failures: int,
client_redirect_url: str,
) -> dict:
raise RedirectException(b"https://custom-saml-redirect/") raise RedirectException(b"https://custom-saml-redirect/")
@ -347,7 +358,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
) )
def _mock_request(): def _mock_request() -> Mock:
"""Returns a mock which will stand in as a SynapseRequest""" """Returns a mock which will stand in as a SynapseRequest"""
mock = Mock( mock = Mock(
spec=[ spec=[

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import List, Tuple from typing import Callable, List, Tuple
from zope.interface import implementer from zope.interface import implementer
@ -28,20 +28,27 @@ from tests.unittest import HomeserverTestCase, override_config
@implementer(interfaces.IMessageDelivery) @implementer(interfaces.IMessageDelivery)
class _DummyMessageDelivery: class _DummyMessageDelivery:
def __init__(self): def __init__(self) -> None:
# (recipient, message) tuples # (recipient, message) tuples
self.messages: List[Tuple[smtp.Address, bytes]] = [] self.messages: List[Tuple[smtp.Address, bytes]] = []
def receivedHeader(self, helo, origin, recipients): def receivedHeader(
self,
helo: Tuple[bytes, bytes],
origin: smtp.Address,
recipients: List[smtp.User],
) -> None:
return None return None
def validateFrom(self, helo, origin): def validateFrom(
self, helo: Tuple[bytes, bytes], origin: smtp.Address
) -> smtp.Address:
return origin return origin
def record_message(self, recipient: smtp.Address, message: bytes): def record_message(self, recipient: smtp.Address, message: bytes) -> None:
self.messages.append((recipient, message)) self.messages.append((recipient, message))
def validateTo(self, user: smtp.User): def validateTo(self, user: smtp.User) -> Callable[[], interfaces.IMessageSMTP]:
return lambda: _DummyMessage(self, user) return lambda: _DummyMessage(self, user)
@ -56,20 +63,20 @@ class _DummyMessage:
self._user = user self._user = user
self._buffer: List[bytes] = [] self._buffer: List[bytes] = []
def lineReceived(self, line): def lineReceived(self, line: bytes) -> None:
self._buffer.append(line) self._buffer.append(line)
def eomReceived(self): def eomReceived(self) -> "defer.Deferred[bytes]":
message = b"\n".join(self._buffer) + b"\n" message = b"\n".join(self._buffer) + b"\n"
self._delivery.record_message(self._user.dest, message) self._delivery.record_message(self._user.dest, message)
return defer.succeed(b"saved") return defer.succeed(b"saved")
def connectionLost(self): def connectionLost(self) -> None:
pass pass
class SendEmailHandlerTestCase(HomeserverTestCase): class SendEmailHandlerTestCase(HomeserverTestCase):
def test_send_email(self): def test_send_email(self) -> None:
"""Happy-path test that we can send email to a non-TLS server.""" """Happy-path test that we can send email to a non-TLS server."""
h = self.hs.get_send_email_handler() h = self.hs.get_send_email_handler()
d = ensureDeferred( d = ensureDeferred(
@ -119,7 +126,7 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
}, },
} }
) )
def test_send_email_force_tls(self): def test_send_email_force_tls(self) -> None:
"""Happy-path test that we can send email to an Implicit TLS server.""" """Happy-path test that we can send email to an Implicit TLS server."""
h = self.hs.get_send_email_handler() h = self.hs.get_send_email_handler()
d = ensureDeferred( d = ensureDeferred(

View File

@ -12,9 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, List, Optional
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.storage.databases.main import stats from synapse.storage.databases.main import stats
from synapse.util import Clock
from tests import unittest from tests import unittest
@ -32,11 +38,11 @@ class StatsRoomTests(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.handler = self.hs.get_stats_handler() self.handler = self.hs.get_stats_handler()
def _add_background_updates(self): def _add_background_updates(self) -> None:
""" """
Add the background updates we need to run. Add the background updates we need to run.
""" """
@ -63,12 +69,14 @@ class StatsRoomTests(unittest.HomeserverTestCase):
) )
) )
async def get_all_room_state(self): async def get_all_room_state(self) -> List[Dict[str, Any]]:
return await self.store.db_pool.simple_select_list( return await self.store.db_pool.simple_select_list(
"room_stats_state", None, retcols=("name", "topic", "canonical_alias") "room_stats_state", None, retcols=("name", "topic", "canonical_alias")
) )
def _get_current_stats(self, stats_type, stat_id): def _get_current_stats(
self, stats_type: str, stat_id: str
) -> Optional[Dict[str, Any]]:
table, id_col = stats.TYPE_TO_TABLE[stats_type] table, id_col = stats.TYPE_TO_TABLE[stats_type]
cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type])
@ -82,13 +90,13 @@ class StatsRoomTests(unittest.HomeserverTestCase):
) )
) )
def _perform_background_initial_update(self): def _perform_background_initial_update(self) -> None:
# Do the initial population of the stats via the background update # Do the initial population of the stats via the background update
self._add_background_updates() self._add_background_updates()
self.wait_for_background_updates() self.wait_for_background_updates()
def test_initial_room(self): def test_initial_room(self) -> None:
""" """
The background updates will build the table from scratch. The background updates will build the table from scratch.
""" """
@ -125,7 +133,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.assertEqual(len(r), 1) self.assertEqual(len(r), 1)
self.assertEqual(r[0]["topic"], "foo") self.assertEqual(r[0]["topic"], "foo")
def test_create_user(self): def test_create_user(self) -> None:
""" """
When we create a user, it should have statistics already ready. When we create a user, it should have statistics already ready.
""" """
@ -134,12 +142,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
u1stats = self._get_current_stats("user", u1) u1stats = self._get_current_stats("user", u1)
self.assertIsNotNone(u1stats) assert u1stats is not None
# not in any rooms by default # not in any rooms by default
self.assertEqual(u1stats["joined_rooms"], 0) self.assertEqual(u1stats["joined_rooms"], 0)
def test_create_room(self): def test_create_room(self) -> None:
""" """
When we create a room, it should have statistics already ready. When we create a room, it should have statistics already ready.
""" """
@ -153,8 +161,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r2 = self.helper.create_room_as(u1, tok=u1token, is_public=False) r2 = self.helper.create_room_as(u1, tok=u1token, is_public=False)
r2stats = self._get_current_stats("room", r2) r2stats = self._get_current_stats("room", r2)
self.assertIsNotNone(r1stats) assert r1stats is not None
self.assertIsNotNone(r2stats) assert r2stats is not None
self.assertEqual( self.assertEqual(
r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM
@ -171,7 +179,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.assertEqual(r2stats["invited_members"], 0) self.assertEqual(r2stats["invited_members"], 0)
self.assertEqual(r2stats["banned_members"], 0) self.assertEqual(r2stats["banned_members"], 0)
def test_updating_profile_information_does_not_increase_joined_members_count(self): def test_updating_profile_information_does_not_increase_joined_members_count(
self,
) -> None:
""" """
Check that the joined_members count does not increase when a user changes their Check that the joined_members count does not increase when a user changes their
profile information (which is done by sending another join membership event into profile information (which is done by sending another join membership event into
@ -186,6 +196,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Get the current room stats # Get the current room stats
r1stats_ante = self._get_current_stats("room", r1) r1stats_ante = self._get_current_stats("room", r1)
assert r1stats_ante is not None
# Send a profile update into the room # Send a profile update into the room
new_profile = {"displayname": "bob"} new_profile = {"displayname": "bob"}
@ -195,6 +206,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Get the new room stats # Get the new room stats
r1stats_post = self._get_current_stats("room", r1) r1stats_post = self._get_current_stats("room", r1)
assert r1stats_post is not None
# Ensure that the user count did not changed # Ensure that the user count did not changed
self.assertEqual(r1stats_post["joined_members"], r1stats_ante["joined_members"]) self.assertEqual(r1stats_post["joined_members"], r1stats_ante["joined_members"])
@ -202,7 +214,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["local_users_in_room"], r1stats_ante["local_users_in_room"] r1stats_post["local_users_in_room"], r1stats_ante["local_users_in_room"]
) )
def test_send_state_event_nonoverwriting(self): def test_send_state_event_nonoverwriting(self) -> None:
""" """
When we send a non-overwriting state event, it increments current_state_events When we send a non-overwriting state event, it increments current_state_events
""" """
@ -218,19 +230,21 @@ class StatsRoomTests(unittest.HomeserverTestCase):
) )
r1stats_ante = self._get_current_stats("room", r1) r1stats_ante = self._get_current_stats("room", r1)
assert r1stats_ante is not None
self.helper.send_state( self.helper.send_state(
r1, "cat.hissing", {"value": False}, tok=u1token, state_key="moggy" r1, "cat.hissing", {"value": False}, tok=u1token, state_key="moggy"
) )
r1stats_post = self._get_current_stats("room", r1) r1stats_post = self._get_current_stats("room", r1)
assert r1stats_post is not None
self.assertEqual( self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
1, 1,
) )
def test_join_first_time(self): def test_join_first_time(self) -> None:
""" """
When a user joins a room for the first time, current_state_events and When a user joins a room for the first time, current_state_events and
joined_members should increase by exactly 1. joined_members should increase by exactly 1.
@ -246,10 +260,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
u2token = self.login("u2", "pass") u2token = self.login("u2", "pass")
r1stats_ante = self._get_current_stats("room", r1) r1stats_ante = self._get_current_stats("room", r1)
assert r1stats_ante is not None
self.helper.join(r1, u2, tok=u2token) self.helper.join(r1, u2, tok=u2token)
r1stats_post = self._get_current_stats("room", r1) r1stats_post = self._get_current_stats("room", r1)
assert r1stats_post is not None
self.assertEqual( self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@ -259,7 +275,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["joined_members"] - r1stats_ante["joined_members"], 1 r1stats_post["joined_members"] - r1stats_ante["joined_members"], 1
) )
def test_join_after_leave(self): def test_join_after_leave(self) -> None:
""" """
When a user joins a room after being previously left, When a user joins a room after being previously left,
joined_members should increase by exactly 1. joined_members should increase by exactly 1.
@ -280,10 +296,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.helper.leave(r1, u2, tok=u2token) self.helper.leave(r1, u2, tok=u2token)
r1stats_ante = self._get_current_stats("room", r1) r1stats_ante = self._get_current_stats("room", r1)
assert r1stats_ante is not None
self.helper.join(r1, u2, tok=u2token) self.helper.join(r1, u2, tok=u2token)
r1stats_post = self._get_current_stats("room", r1) r1stats_post = self._get_current_stats("room", r1)
assert r1stats_post is not None
self.assertEqual( self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@ -296,7 +314,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["left_members"] - r1stats_ante["left_members"], -1 r1stats_post["left_members"] - r1stats_ante["left_members"], -1
) )
def test_invited(self): def test_invited(self) -> None:
""" """
When a user invites another user, current_state_events and When a user invites another user, current_state_events and
invited_members should increase by exactly 1. invited_members should increase by exactly 1.
@ -311,10 +329,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
u2 = self.register_user("u2", "pass") u2 = self.register_user("u2", "pass")
r1stats_ante = self._get_current_stats("room", r1) r1stats_ante = self._get_current_stats("room", r1)
assert r1stats_ante is not None
self.helper.invite(r1, u1, u2, tok=u1token) self.helper.invite(r1, u1, u2, tok=u1token)
r1stats_post = self._get_current_stats("room", r1) r1stats_post = self._get_current_stats("room", r1)
assert r1stats_post is not None
self.assertEqual( self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@ -324,7 +344,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["invited_members"] - r1stats_ante["invited_members"], +1 r1stats_post["invited_members"] - r1stats_ante["invited_members"], +1
) )
def test_join_after_invite(self): def test_join_after_invite(self) -> None:
""" """
When a user joins a room after being invited and When a user joins a room after being invited and
joined_members should increase by exactly 1. joined_members should increase by exactly 1.
@ -344,10 +364,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.helper.invite(r1, u1, u2, tok=u1token) self.helper.invite(r1, u1, u2, tok=u1token)
r1stats_ante = self._get_current_stats("room", r1) r1stats_ante = self._get_current_stats("room", r1)
assert r1stats_ante is not None
self.helper.join(r1, u2, tok=u2token) self.helper.join(r1, u2, tok=u2token)
r1stats_post = self._get_current_stats("room", r1) r1stats_post = self._get_current_stats("room", r1)
assert r1stats_post is not None
self.assertEqual( self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@ -360,7 +382,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["invited_members"] - r1stats_ante["invited_members"], -1 r1stats_post["invited_members"] - r1stats_ante["invited_members"], -1
) )
def test_left(self): def test_left(self) -> None:
""" """
When a user leaves a room after joining and When a user leaves a room after joining and
left_members should increase by exactly 1. left_members should increase by exactly 1.
@ -380,10 +402,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.helper.join(r1, u2, tok=u2token) self.helper.join(r1, u2, tok=u2token)
r1stats_ante = self._get_current_stats("room", r1) r1stats_ante = self._get_current_stats("room", r1)
assert r1stats_ante is not None
self.helper.leave(r1, u2, tok=u2token) self.helper.leave(r1, u2, tok=u2token)
r1stats_post = self._get_current_stats("room", r1) r1stats_post = self._get_current_stats("room", r1)
assert r1stats_post is not None
self.assertEqual( self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@ -396,7 +420,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1 r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1
) )
def test_banned(self): def test_banned(self) -> None:
""" """
When a user is banned from a room after joining and When a user is banned from a room after joining and
left_members should increase by exactly 1. left_members should increase by exactly 1.
@ -416,10 +440,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.helper.join(r1, u2, tok=u2token) self.helper.join(r1, u2, tok=u2token)
r1stats_ante = self._get_current_stats("room", r1) r1stats_ante = self._get_current_stats("room", r1)
assert r1stats_ante is not None
self.helper.change_membership(r1, u1, u2, "ban", tok=u1token) self.helper.change_membership(r1, u1, u2, "ban", tok=u1token)
r1stats_post = self._get_current_stats("room", r1) r1stats_post = self._get_current_stats("room", r1)
assert r1stats_post is not None
self.assertEqual( self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@ -432,7 +458,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1 r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1
) )
def test_initial_background_update(self): def test_initial_background_update(self) -> None:
""" """
Test that statistics can be generated by the initial background update Test that statistics can be generated by the initial background update
handler. handler.
@ -462,6 +488,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats = self._get_current_stats("room", r1) r1stats = self._get_current_stats("room", r1)
u1stats = self._get_current_stats("user", u1) u1stats = self._get_current_stats("user", u1)
assert r1stats is not None
assert u1stats is not None
self.assertEqual(r1stats["joined_members"], 1) self.assertEqual(r1stats["joined_members"], 1)
self.assertEqual( self.assertEqual(
r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM
@ -469,7 +498,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.assertEqual(u1stats["joined_rooms"], 1) self.assertEqual(u1stats["joined_rooms"], 1)
def test_incomplete_stats(self): def test_incomplete_stats(self) -> None:
""" """
This tests that we track incomplete statistics. This tests that we track incomplete statistics.
@ -533,8 +562,11 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.wait_for_background_updates() self.wait_for_background_updates()
r1stats_complete = self._get_current_stats("room", r1) r1stats_complete = self._get_current_stats("room", r1)
assert r1stats_complete is not None
u1stats_complete = self._get_current_stats("user", u1) u1stats_complete = self._get_current_stats("user", u1)
assert u1stats_complete is not None
u2stats_complete = self._get_current_stats("user", u2) u2stats_complete = self._get_current_stats("user", u2)
assert u2stats_complete is not None
# now we make our assertions # now we make our assertions

View File

@ -14,6 +14,8 @@
from typing import Optional from typing import Optional
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, patch
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, JoinRules from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, ResourceLimitError from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import Filtering from synapse.api.filtering import Filtering
@ -23,6 +25,7 @@ from synapse.rest import admin
from synapse.rest.client import knock, login, room from synapse.rest.client import knock, login, room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from synapse.util import Clock
import tests.unittest import tests.unittest
import tests.utils import tests.utils
@ -39,7 +42,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
room.register_servlets, room.register_servlets,
] ]
def prepare(self, reactor, clock, hs: HomeServer): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.sync_handler = self.hs.get_sync_handler() self.sync_handler = self.hs.get_sync_handler()
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main
@ -47,7 +50,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# modify its config instead of the hs' # modify its config instead of the hs'
self.auth_blocking = self.hs.get_auth_blocking() self.auth_blocking = self.hs.get_auth_blocking()
def test_wait_for_sync_for_user_auth_blocking(self): def test_wait_for_sync_for_user_auth_blocking(self) -> None:
user_id1 = "@user1:test" user_id1 = "@user1:test"
user_id2 = "@user2:test" user_id2 = "@user2:test"
sync_config = generate_sync_config(user_id1) sync_config = generate_sync_config(user_id1)
@ -82,7 +85,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
) )
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_unknown_room_version(self): def test_unknown_room_version(self) -> None:
""" """
A room with an unknown room version should not break sync (and should be excluded). A room with an unknown room version should not break sync (and should be excluded).
""" """
@ -186,7 +189,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.assertNotIn(invite_room, [r.room_id for r in result.invited]) self.assertNotIn(invite_room, [r.room_id for r in result.invited])
self.assertNotIn(knock_room, [r.room_id for r in result.knocked]) self.assertNotIn(knock_room, [r.room_id for r in result.knocked])
def test_ban_wins_race_with_join(self): def test_ban_wins_race_with_join(self) -> None:
"""Rooms shouldn't appear under "joined" if a join loses a race to a ban. """Rooms shouldn't appear under "joined" if a join loses a race to a ban.
A complicated edge case. Imagine the following scenario: A complicated edge case. Imagine the following scenario: