Add missing type hints to tests.handlers. (#14680)
And do not allow untyped defs in tests.handlers.
This commit is contained in:
parent
54c012c5a8
commit
652d1669c5
|
@ -0,0 +1 @@
|
||||||
|
Add missing type hints.
|
5
mypy.ini
5
mypy.ini
|
@ -95,10 +95,7 @@ disallow_untyped_defs = True
|
||||||
[mypy-tests.federation.transport.test_client]
|
[mypy-tests.federation.transport.test_client]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-tests.handlers.test_sso]
|
[mypy-tests.handlers.*]
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-tests.handlers.test_user_directory]
|
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-tests.metrics.test_background_process_metrics]
|
[mypy-tests.metrics.test_background_process_metrics]
|
||||||
|
|
|
@ -2031,7 +2031,7 @@ class PasswordAuthProvider:
|
||||||
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
|
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
|
||||||
|
|
||||||
# Mapping from login type to login parameters
|
# Mapping from login type to login parameters
|
||||||
self._supported_login_types: Dict[str, Iterable[str]] = {}
|
self._supported_login_types: Dict[str, Tuple[str, ...]] = {}
|
||||||
|
|
||||||
# Mapping from login type to auth checker callbacks
|
# Mapping from login type to auth checker callbacks
|
||||||
self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
|
self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
|
||||||
|
|
|
@ -31,7 +31,7 @@ from synapse.appservice import (
|
||||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||||
from synapse.rest.client import login, receipts, register, room, sendtodevice
|
from synapse.rest.client import login, receipts, register, room, sendtodevice
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import RoomStreamToken
|
from synapse.types import JsonDict, RoomStreamToken
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ from tests.utils import MockClock
|
||||||
class AppServiceHandlerTestCase(unittest.TestCase):
|
class AppServiceHandlerTestCase(unittest.TestCase):
|
||||||
"""Tests the ApplicationServicesHandler."""
|
"""Tests the ApplicationServicesHandler."""
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self) -> None:
|
||||||
self.mock_store = Mock()
|
self.mock_store = Mock()
|
||||||
self.mock_as_api = Mock()
|
self.mock_as_api = Mock()
|
||||||
self.mock_scheduler = Mock()
|
self.mock_scheduler = Mock()
|
||||||
|
@ -61,7 +61,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||||
self.handler = ApplicationServicesHandler(hs)
|
self.handler = ApplicationServicesHandler(hs)
|
||||||
self.event_source = hs.get_event_sources()
|
self.event_source = hs.get_event_sources()
|
||||||
|
|
||||||
def test_notify_interested_services(self):
|
def test_notify_interested_services(self) -> None:
|
||||||
interested_service = self._mkservice(is_interested_in_event=True)
|
interested_service = self._mkservice(is_interested_in_event=True)
|
||||||
services = [
|
services = [
|
||||||
self._mkservice(is_interested_in_event=False),
|
self._mkservice(is_interested_in_event=False),
|
||||||
|
@ -90,7 +90,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||||
interested_service, events=[event]
|
interested_service, events=[event]
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_query_user_exists_unknown_user(self):
|
def test_query_user_exists_unknown_user(self) -> None:
|
||||||
user_id = "@someone:anywhere"
|
user_id = "@someone:anywhere"
|
||||||
services = [self._mkservice(is_interested_in_event=True)]
|
services = [self._mkservice(is_interested_in_event=True)]
|
||||||
services[0].is_interested_in_user.return_value = True
|
services[0].is_interested_in_user.return_value = True
|
||||||
|
@ -107,7 +107,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
|
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
|
||||||
|
|
||||||
def test_query_user_exists_known_user(self):
|
def test_query_user_exists_known_user(self) -> None:
|
||||||
user_id = "@someone:anywhere"
|
user_id = "@someone:anywhere"
|
||||||
services = [self._mkservice(is_interested_in_event=True)]
|
services = [self._mkservice(is_interested_in_event=True)]
|
||||||
services[0].is_interested_in_user.return_value = True
|
services[0].is_interested_in_user.return_value = True
|
||||||
|
@ -127,7 +127,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||||
"query_user called when it shouldn't have been.",
|
"query_user called when it shouldn't have been.",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_query_room_alias_exists(self):
|
def test_query_room_alias_exists(self) -> None:
|
||||||
room_alias_str = "#foo:bar"
|
room_alias_str = "#foo:bar"
|
||||||
room_alias = Mock()
|
room_alias = Mock()
|
||||||
room_alias.to_string.return_value = room_alias_str
|
room_alias.to_string.return_value = room_alias_str
|
||||||
|
@ -157,7 +157,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||||
self.assertEqual(result.room_id, room_id)
|
self.assertEqual(result.room_id, room_id)
|
||||||
self.assertEqual(result.servers, servers)
|
self.assertEqual(result.servers, servers)
|
||||||
|
|
||||||
def test_get_3pe_protocols_no_appservices(self):
|
def test_get_3pe_protocols_no_appservices(self) -> None:
|
||||||
self.mock_store.get_app_services.return_value = []
|
self.mock_store.get_app_services.return_value = []
|
||||||
response = self.successResultOf(
|
response = self.successResultOf(
|
||||||
defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol"))
|
defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol"))
|
||||||
|
@ -165,7 +165,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||||
self.mock_as_api.get_3pe_protocol.assert_not_called()
|
self.mock_as_api.get_3pe_protocol.assert_not_called()
|
||||||
self.assertEqual(response, {})
|
self.assertEqual(response, {})
|
||||||
|
|
||||||
def test_get_3pe_protocols_no_protocols(self):
|
def test_get_3pe_protocols_no_protocols(self) -> None:
|
||||||
service = self._mkservice(False, [])
|
service = self._mkservice(False, [])
|
||||||
self.mock_store.get_app_services.return_value = [service]
|
self.mock_store.get_app_services.return_value = [service]
|
||||||
response = self.successResultOf(
|
response = self.successResultOf(
|
||||||
|
@ -174,7 +174,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||||
self.mock_as_api.get_3pe_protocol.assert_not_called()
|
self.mock_as_api.get_3pe_protocol.assert_not_called()
|
||||||
self.assertEqual(response, {})
|
self.assertEqual(response, {})
|
||||||
|
|
||||||
def test_get_3pe_protocols_protocol_no_response(self):
|
def test_get_3pe_protocols_protocol_no_response(self) -> None:
|
||||||
service = self._mkservice(False, ["my-protocol"])
|
service = self._mkservice(False, ["my-protocol"])
|
||||||
self.mock_store.get_app_services.return_value = [service]
|
self.mock_store.get_app_services.return_value = [service]
|
||||||
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None)
|
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None)
|
||||||
|
@ -186,7 +186,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(response, {})
|
self.assertEqual(response, {})
|
||||||
|
|
||||||
def test_get_3pe_protocols_select_one_protocol(self):
|
def test_get_3pe_protocols_select_one_protocol(self) -> None:
|
||||||
service = self._mkservice(False, ["my-protocol"])
|
service = self._mkservice(False, ["my-protocol"])
|
||||||
self.mock_store.get_app_services.return_value = [service]
|
self.mock_store.get_app_services.return_value = [service]
|
||||||
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
|
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
|
||||||
|
@ -202,7 +202,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||||
response, {"my-protocol": {"x-protocol-data": 42, "instances": []}}
|
response, {"my-protocol": {"x-protocol-data": 42, "instances": []}}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_3pe_protocols_one_protocol(self):
|
def test_get_3pe_protocols_one_protocol(self) -> None:
|
||||||
service = self._mkservice(False, ["my-protocol"])
|
service = self._mkservice(False, ["my-protocol"])
|
||||||
self.mock_store.get_app_services.return_value = [service]
|
self.mock_store.get_app_services.return_value = [service]
|
||||||
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
|
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
|
||||||
|
@ -218,7 +218,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||||
response, {"my-protocol": {"x-protocol-data": 42, "instances": []}}
|
response, {"my-protocol": {"x-protocol-data": 42, "instances": []}}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_3pe_protocols_multiple_protocol(self):
|
def test_get_3pe_protocols_multiple_protocol(self) -> None:
|
||||||
service_one = self._mkservice(False, ["my-protocol"])
|
service_one = self._mkservice(False, ["my-protocol"])
|
||||||
service_two = self._mkservice(False, ["other-protocol"])
|
service_two = self._mkservice(False, ["other-protocol"])
|
||||||
self.mock_store.get_app_services.return_value = [service_one, service_two]
|
self.mock_store.get_app_services.return_value = [service_one, service_two]
|
||||||
|
@ -237,11 +237,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_3pe_protocols_multiple_info(self):
|
def test_get_3pe_protocols_multiple_info(self) -> None:
|
||||||
service_one = self._mkservice(False, ["my-protocol"])
|
service_one = self._mkservice(False, ["my-protocol"])
|
||||||
service_two = self._mkservice(False, ["my-protocol"])
|
service_two = self._mkservice(False, ["my-protocol"])
|
||||||
|
|
||||||
async def get_3pe_protocol(service, unusedProtocol):
|
async def get_3pe_protocol(
|
||||||
|
service: ApplicationService, protocol: str
|
||||||
|
) -> Optional[JsonDict]:
|
||||||
if service == service_one:
|
if service == service_one:
|
||||||
return {
|
return {
|
||||||
"x-protocol-data": 42,
|
"x-protocol-data": 42,
|
||||||
|
@ -276,7 +278,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_notify_interested_services_ephemeral(self):
|
def test_notify_interested_services_ephemeral(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test sending ephemeral events to the appservice handler are scheduled
|
Test sending ephemeral events to the appservice handler are scheduled
|
||||||
to be pushed out to interested appservices, and that the stream ID is
|
to be pushed out to interested appservices, and that the stream ID is
|
||||||
|
@ -306,7 +308,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||||
580,
|
580,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_notify_interested_services_ephemeral_out_of_order(self):
|
def test_notify_interested_services_ephemeral_out_of_order(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test sending out of order ephemeral events to the appservice handler
|
Test sending out of order ephemeral events to the appservice handler
|
||||||
are ignored.
|
are ignored.
|
||||||
|
@ -390,7 +392,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
||||||
receipts.register_servlets,
|
receipts.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
|
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
|
||||||
# we can track any outgoing ephemeral events
|
# we can track any outgoing ephemeral events
|
||||||
|
@ -417,7 +419,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
||||||
"exclusive_as_user", "password", self.exclusive_as_user_device_id
|
"exclusive_as_user", "password", self.exclusive_as_user_device_id
|
||||||
)
|
)
|
||||||
|
|
||||||
def _notify_interested_services(self):
|
def _notify_interested_services(self) -> None:
|
||||||
# This is normally set in `notify_interested_services` but we need to call the
|
# This is normally set in `notify_interested_services` but we need to call the
|
||||||
# internal async version so the reactor gets pushed to completion.
|
# internal async version so the reactor gets pushed to completion.
|
||||||
self.hs.get_application_service_handler().current_max += 1
|
self.hs.get_application_service_handler().current_max += 1
|
||||||
|
@ -443,7 +445,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
def test_match_interesting_room_members(
|
def test_match_interesting_room_members(
|
||||||
self, interesting_user: str, should_notify: bool
|
self, interesting_user: str, should_notify: bool
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Test to make sure that a interesting user (local or remote) in the room is
|
Test to make sure that a interesting user (local or remote) in the room is
|
||||||
notified as expected when someone else in the room sends a message.
|
notified as expected when someone else in the room sends a message.
|
||||||
|
@ -512,7 +514,9 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
||||||
else:
|
else:
|
||||||
self.send_mock.assert_not_called()
|
self.send_mock.assert_not_called()
|
||||||
|
|
||||||
def test_application_services_receive_events_sent_by_interesting_local_user(self):
|
def test_application_services_receive_events_sent_by_interesting_local_user(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Test to make sure that a messages sent from a local user can be interesting and
|
Test to make sure that a messages sent from a local user can be interesting and
|
||||||
picked up by the appservice.
|
picked up by the appservice.
|
||||||
|
@ -568,7 +572,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(events[0]["type"], "m.room.message")
|
self.assertEqual(events[0]["type"], "m.room.message")
|
||||||
self.assertEqual(events[0]["sender"], alice)
|
self.assertEqual(events[0]["sender"], alice)
|
||||||
|
|
||||||
def test_sending_read_receipt_batches_to_application_services(self):
|
def test_sending_read_receipt_batches_to_application_services(self) -> None:
|
||||||
"""Tests that a large batch of read receipts are sent correctly to
|
"""Tests that a large batch of read receipts are sent correctly to
|
||||||
interested application services.
|
interested application services.
|
||||||
"""
|
"""
|
||||||
|
@ -644,7 +648,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
||||||
@unittest.override_config(
|
@unittest.override_config(
|
||||||
{"experimental_features": {"msc2409_to_device_messages_enabled": True}}
|
{"experimental_features": {"msc2409_to_device_messages_enabled": True}}
|
||||||
)
|
)
|
||||||
def test_application_services_receive_local_to_device(self):
|
def test_application_services_receive_local_to_device(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test that when a user sends a to-device message to another user
|
Test that when a user sends a to-device message to another user
|
||||||
that is an application service's user namespace, the
|
that is an application service's user namespace, the
|
||||||
|
@ -722,7 +726,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
||||||
@unittest.override_config(
|
@unittest.override_config(
|
||||||
{"experimental_features": {"msc2409_to_device_messages_enabled": True}}
|
{"experimental_features": {"msc2409_to_device_messages_enabled": True}}
|
||||||
)
|
)
|
||||||
def test_application_services_receive_bursts_of_to_device(self):
|
def test_application_services_receive_bursts_of_to_device(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test that when a user sends >100 to-device messages at once, any
|
Test that when a user sends >100 to-device messages at once, any
|
||||||
interested AS's will receive them in separate transactions.
|
interested AS's will receive them in separate transactions.
|
||||||
|
@ -913,7 +917,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
|
||||||
experimental_feature_enabled: bool,
|
experimental_feature_enabled: bool,
|
||||||
as_supports_txn_extensions: bool,
|
as_supports_txn_extensions: bool,
|
||||||
as_should_receive_device_list_updates: bool,
|
as_should_receive_device_list_updates: bool,
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Tests that an application service receives notice of changed device
|
Tests that an application service receives notice of changed device
|
||||||
lists for a user, when a user changes their device lists.
|
lists for a user, when a user changes their device lists.
|
||||||
|
@ -1070,7 +1074,7 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
|
||||||
and a room for the users to talk in.
|
and a room for the users to talk in.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def preparation():
|
async def preparation() -> None:
|
||||||
await self._add_otks_for_device(self._sender_user, self._sender_device, 42)
|
await self._add_otks_for_device(self._sender_user, self._sender_device, 42)
|
||||||
await self._add_fallback_key_for_device(
|
await self._add_fallback_key_for_device(
|
||||||
self._sender_user, self._sender_device, used=True
|
self._sender_user, self._sender_device, used=True
|
||||||
|
|
|
@ -199,7 +199,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _mock_request():
|
def _mock_request() -> Mock:
|
||||||
"""Returns a mock which will stand in as a SynapseRequest"""
|
"""Returns a mock which will stand in as a SynapseRequest"""
|
||||||
mock = Mock(
|
mock = Mock(
|
||||||
spec=[
|
spec=[
|
||||||
|
|
|
@ -20,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor
|
||||||
import synapse.api.errors
|
import synapse.api.errors
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
|
from synapse.events import EventBase
|
||||||
from synapse.rest.client import directory, login, room
|
from synapse.rest.client import directory, login, room
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict, RoomAlias, create_requester
|
from synapse.types import JsonDict, RoomAlias, create_requester
|
||||||
|
@ -201,7 +202,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||||
self.test_user_tok = self.login("user", "pass")
|
self.test_user_tok = self.login("user", "pass")
|
||||||
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
|
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
|
||||||
|
|
||||||
def _create_alias(self, user) -> None:
|
def _create_alias(self, user: str) -> None:
|
||||||
# Create a new alias to this room.
|
# Create a new alias to this room.
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.create_room_alias_association(
|
self.store.create_room_alias_association(
|
||||||
|
@ -324,7 +325,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
return room_alias
|
return room_alias
|
||||||
|
|
||||||
def _set_canonical_alias(self, content) -> None:
|
def _set_canonical_alias(self, content: JsonDict) -> None:
|
||||||
"""Configure the canonical alias state on the room."""
|
"""Configure the canonical alias state on the room."""
|
||||||
self.helper.send_state(
|
self.helper.send_state(
|
||||||
self.room_id,
|
self.room_id,
|
||||||
|
@ -333,13 +334,15 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||||
tok=self.admin_user_tok,
|
tok=self.admin_user_tok,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_canonical_alias(self):
|
def _get_canonical_alias(self) -> EventBase:
|
||||||
"""Get the canonical alias state of the room."""
|
"""Get the canonical alias state of the room."""
|
||||||
return self.get_success(
|
result = self.get_success(
|
||||||
self._storage_controllers.state.get_current_state_event(
|
self._storage_controllers.state.get_current_state_event(
|
||||||
self.room_id, EventTypes.CanonicalAlias, ""
|
self.room_id, EventTypes.CanonicalAlias, ""
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
assert result is not None
|
||||||
|
return result
|
||||||
|
|
||||||
def test_remove_alias(self) -> None:
|
def test_remove_alias(self) -> None:
|
||||||
"""Removing an alias that is the canonical alias should remove it there too."""
|
"""Removing an alias that is the canonical alias should remove it there too."""
|
||||||
|
@ -349,8 +352,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
data = self._get_canonical_alias()
|
data = self._get_canonical_alias()
|
||||||
self.assertEqual(data["content"]["alias"], self.test_alias)
|
self.assertEqual(data.content["alias"], self.test_alias)
|
||||||
self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])
|
self.assertEqual(data.content["alt_aliases"], [self.test_alias])
|
||||||
|
|
||||||
# Finally, delete the alias.
|
# Finally, delete the alias.
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -360,8 +363,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
data = self._get_canonical_alias()
|
data = self._get_canonical_alias()
|
||||||
self.assertNotIn("alias", data["content"])
|
self.assertNotIn("alias", data.content)
|
||||||
self.assertNotIn("alt_aliases", data["content"])
|
self.assertNotIn("alt_aliases", data.content)
|
||||||
|
|
||||||
def test_remove_other_alias(self) -> None:
|
def test_remove_other_alias(self) -> None:
|
||||||
"""Removing an alias listed as in alt_aliases should remove it there too."""
|
"""Removing an alias listed as in alt_aliases should remove it there too."""
|
||||||
|
@ -378,9 +381,9 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
data = self._get_canonical_alias()
|
data = self._get_canonical_alias()
|
||||||
self.assertEqual(data["content"]["alias"], self.test_alias)
|
self.assertEqual(data.content["alias"], self.test_alias)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
data["content"]["alt_aliases"], [self.test_alias, other_test_alias]
|
data.content["alt_aliases"], [self.test_alias, other_test_alias]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Delete the second alias.
|
# Delete the second alias.
|
||||||
|
@ -391,8 +394,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
data = self._get_canonical_alias()
|
data = self._get_canonical_alias()
|
||||||
self.assertEqual(data["content"]["alias"], self.test_alias)
|
self.assertEqual(data.content["alias"], self.test_alias)
|
||||||
self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])
|
self.assertEqual(data.content["alt_aliases"], [self.test_alias])
|
||||||
|
|
||||||
|
|
||||||
class TestCreateAliasACL(unittest.HomeserverTestCase):
|
class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||||
|
|
|
@ -17,7 +17,11 @@
|
||||||
import copy
|
import copy
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
@ -39,14 +43,14 @@ room_keys = {
|
||||||
|
|
||||||
|
|
||||||
class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
return self.setup_test_homeserver(replication_layer=mock.Mock())
|
return self.setup_test_homeserver(replication_layer=mock.Mock())
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.handler = hs.get_e2e_room_keys_handler()
|
self.handler = hs.get_e2e_room_keys_handler()
|
||||||
self.local_user = "@boris:" + hs.hostname
|
self.local_user = "@boris:" + hs.hostname
|
||||||
|
|
||||||
def test_get_missing_current_version_info(self):
|
def test_get_missing_current_version_info(self) -> None:
|
||||||
"""Check that we get a 404 if we ask for info about the current version
|
"""Check that we get a 404 if we ask for info about the current version
|
||||||
if there is no version.
|
if there is no version.
|
||||||
"""
|
"""
|
||||||
|
@ -56,7 +60,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
res = e.value.code
|
res = e.value.code
|
||||||
self.assertEqual(res, 404)
|
self.assertEqual(res, 404)
|
||||||
|
|
||||||
def test_get_missing_version_info(self):
|
def test_get_missing_version_info(self) -> None:
|
||||||
"""Check that we get a 404 if we ask for info about a specific version
|
"""Check that we get a 404 if we ask for info about a specific version
|
||||||
if it doesn't exist.
|
if it doesn't exist.
|
||||||
"""
|
"""
|
||||||
|
@ -67,9 +71,9 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
res = e.value.code
|
res = e.value.code
|
||||||
self.assertEqual(res, 404)
|
self.assertEqual(res, 404)
|
||||||
|
|
||||||
def test_create_version(self):
|
def test_create_version(self) -> None:
|
||||||
"""Check that we can create and then retrieve versions."""
|
"""Check that we can create and then retrieve versions."""
|
||||||
res = self.get_success(
|
version = self.get_success(
|
||||||
self.handler.create_version(
|
self.handler.create_version(
|
||||||
self.local_user,
|
self.local_user,
|
||||||
{
|
{
|
||||||
|
@ -78,7 +82,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(res, "1")
|
self.assertEqual(version, "1")
|
||||||
|
|
||||||
# check we can retrieve it as the current version
|
# check we can retrieve it as the current version
|
||||||
res = self.get_success(self.handler.get_version_info(self.local_user))
|
res = self.get_success(self.handler.get_version_info(self.local_user))
|
||||||
|
@ -110,7 +114,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# upload a new one...
|
# upload a new one...
|
||||||
res = self.get_success(
|
version = self.get_success(
|
||||||
self.handler.create_version(
|
self.handler.create_version(
|
||||||
self.local_user,
|
self.local_user,
|
||||||
{
|
{
|
||||||
|
@ -119,7 +123,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(res, "2")
|
self.assertEqual(version, "2")
|
||||||
|
|
||||||
# check we can retrieve it as the current version
|
# check we can retrieve it as the current version
|
||||||
res = self.get_success(self.handler.get_version_info(self.local_user))
|
res = self.get_success(self.handler.get_version_info(self.local_user))
|
||||||
|
@ -134,7 +138,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_update_version(self):
|
def test_update_version(self) -> None:
|
||||||
"""Check that we can update versions."""
|
"""Check that we can update versions."""
|
||||||
version = self.get_success(
|
version = self.get_success(
|
||||||
self.handler.create_version(
|
self.handler.create_version(
|
||||||
|
@ -173,7 +177,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_update_missing_version(self):
|
def test_update_missing_version(self) -> None:
|
||||||
"""Check that we get a 404 on updating nonexistent versions"""
|
"""Check that we get a 404 on updating nonexistent versions"""
|
||||||
e = self.get_failure(
|
e = self.get_failure(
|
||||||
self.handler.update_version(
|
self.handler.update_version(
|
||||||
|
@ -190,7 +194,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
res = e.value.code
|
res = e.value.code
|
||||||
self.assertEqual(res, 404)
|
self.assertEqual(res, 404)
|
||||||
|
|
||||||
def test_update_omitted_version(self):
|
def test_update_omitted_version(self) -> None:
|
||||||
"""Check that the update succeeds if the version is missing from the body"""
|
"""Check that the update succeeds if the version is missing from the body"""
|
||||||
version = self.get_success(
|
version = self.get_success(
|
||||||
self.handler.create_version(
|
self.handler.create_version(
|
||||||
|
@ -227,7 +231,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_update_bad_version(self):
|
def test_update_bad_version(self) -> None:
|
||||||
"""Check that we get a 400 if the version in the body doesn't match"""
|
"""Check that we get a 400 if the version in the body doesn't match"""
|
||||||
version = self.get_success(
|
version = self.get_success(
|
||||||
self.handler.create_version(
|
self.handler.create_version(
|
||||||
|
@ -255,7 +259,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
res = e.value.code
|
res = e.value.code
|
||||||
self.assertEqual(res, 400)
|
self.assertEqual(res, 400)
|
||||||
|
|
||||||
def test_delete_missing_version(self):
|
def test_delete_missing_version(self) -> None:
|
||||||
"""Check that we get a 404 on deleting nonexistent versions"""
|
"""Check that we get a 404 on deleting nonexistent versions"""
|
||||||
e = self.get_failure(
|
e = self.get_failure(
|
||||||
self.handler.delete_version(self.local_user, "1"), SynapseError
|
self.handler.delete_version(self.local_user, "1"), SynapseError
|
||||||
|
@ -263,15 +267,15 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
res = e.value.code
|
res = e.value.code
|
||||||
self.assertEqual(res, 404)
|
self.assertEqual(res, 404)
|
||||||
|
|
||||||
def test_delete_missing_current_version(self):
|
def test_delete_missing_current_version(self) -> None:
|
||||||
"""Check that we get a 404 on deleting nonexistent current version"""
|
"""Check that we get a 404 on deleting nonexistent current version"""
|
||||||
e = self.get_failure(self.handler.delete_version(self.local_user), SynapseError)
|
e = self.get_failure(self.handler.delete_version(self.local_user), SynapseError)
|
||||||
res = e.value.code
|
res = e.value.code
|
||||||
self.assertEqual(res, 404)
|
self.assertEqual(res, 404)
|
||||||
|
|
||||||
def test_delete_version(self):
|
def test_delete_version(self) -> None:
|
||||||
"""Check that we can create and then delete versions."""
|
"""Check that we can create and then delete versions."""
|
||||||
res = self.get_success(
|
version = self.get_success(
|
||||||
self.handler.create_version(
|
self.handler.create_version(
|
||||||
self.local_user,
|
self.local_user,
|
||||||
{
|
{
|
||||||
|
@ -280,7 +284,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(res, "1")
|
self.assertEqual(version, "1")
|
||||||
|
|
||||||
# check we can delete it
|
# check we can delete it
|
||||||
self.get_success(self.handler.delete_version(self.local_user, "1"))
|
self.get_success(self.handler.delete_version(self.local_user, "1"))
|
||||||
|
@ -292,7 +296,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
res = e.value.code
|
res = e.value.code
|
||||||
self.assertEqual(res, 404)
|
self.assertEqual(res, 404)
|
||||||
|
|
||||||
def test_get_missing_backup(self):
|
def test_get_missing_backup(self) -> None:
|
||||||
"""Check that we get a 404 on querying missing backup"""
|
"""Check that we get a 404 on querying missing backup"""
|
||||||
e = self.get_failure(
|
e = self.get_failure(
|
||||||
self.handler.get_room_keys(self.local_user, "bogus_version"), SynapseError
|
self.handler.get_room_keys(self.local_user, "bogus_version"), SynapseError
|
||||||
|
@ -300,7 +304,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
res = e.value.code
|
res = e.value.code
|
||||||
self.assertEqual(res, 404)
|
self.assertEqual(res, 404)
|
||||||
|
|
||||||
def test_get_missing_room_keys(self):
|
def test_get_missing_room_keys(self) -> None:
|
||||||
"""Check we get an empty response from an empty backup"""
|
"""Check we get an empty response from an empty backup"""
|
||||||
version = self.get_success(
|
version = self.get_success(
|
||||||
self.handler.create_version(
|
self.handler.create_version(
|
||||||
|
@ -319,7 +323,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
# TODO: test the locking semantics when uploading room_keys,
|
# TODO: test the locking semantics when uploading room_keys,
|
||||||
# although this is probably best done in sytest
|
# although this is probably best done in sytest
|
||||||
|
|
||||||
def test_upload_room_keys_no_versions(self):
|
def test_upload_room_keys_no_versions(self) -> None:
|
||||||
"""Check that we get a 404 on uploading keys when no versions are defined"""
|
"""Check that we get a 404 on uploading keys when no versions are defined"""
|
||||||
e = self.get_failure(
|
e = self.get_failure(
|
||||||
self.handler.upload_room_keys(self.local_user, "no_version", room_keys),
|
self.handler.upload_room_keys(self.local_user, "no_version", room_keys),
|
||||||
|
@ -328,7 +332,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
res = e.value.code
|
res = e.value.code
|
||||||
self.assertEqual(res, 404)
|
self.assertEqual(res, 404)
|
||||||
|
|
||||||
def test_upload_room_keys_bogus_version(self):
|
def test_upload_room_keys_bogus_version(self) -> None:
|
||||||
"""Check that we get a 404 on uploading keys when an nonexistent version
|
"""Check that we get a 404 on uploading keys when an nonexistent version
|
||||||
is specified
|
is specified
|
||||||
"""
|
"""
|
||||||
|
@ -350,7 +354,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
res = e.value.code
|
res = e.value.code
|
||||||
self.assertEqual(res, 404)
|
self.assertEqual(res, 404)
|
||||||
|
|
||||||
def test_upload_room_keys_wrong_version(self):
|
def test_upload_room_keys_wrong_version(self) -> None:
|
||||||
"""Check that we get a 403 on uploading keys for an old version"""
|
"""Check that we get a 403 on uploading keys for an old version"""
|
||||||
version = self.get_success(
|
version = self.get_success(
|
||||||
self.handler.create_version(
|
self.handler.create_version(
|
||||||
|
@ -380,7 +384,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
res = e.value.code
|
res = e.value.code
|
||||||
self.assertEqual(res, 403)
|
self.assertEqual(res, 403)
|
||||||
|
|
||||||
def test_upload_room_keys_insert(self):
|
def test_upload_room_keys_insert(self) -> None:
|
||||||
"""Check that we can insert and retrieve keys for a session"""
|
"""Check that we can insert and retrieve keys for a session"""
|
||||||
version = self.get_success(
|
version = self.get_success(
|
||||||
self.handler.create_version(
|
self.handler.create_version(
|
||||||
|
@ -416,7 +420,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertDictEqual(res, room_keys)
|
self.assertDictEqual(res, room_keys)
|
||||||
|
|
||||||
def test_upload_room_keys_merge(self):
|
def test_upload_room_keys_merge(self) -> None:
|
||||||
"""Check that we can upload a new room_key for an existing session and
|
"""Check that we can upload a new room_key for an existing session and
|
||||||
have it correctly merged"""
|
have it correctly merged"""
|
||||||
version = self.get_success(
|
version = self.get_success(
|
||||||
|
@ -449,9 +453,11 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
|
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
|
||||||
)
|
)
|
||||||
|
|
||||||
res = self.get_success(self.handler.get_room_keys(self.local_user, version))
|
res_keys = self.get_success(
|
||||||
|
self.handler.get_room_keys(self.local_user, version)
|
||||||
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
|
res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
|
||||||
"SSBBTSBBIEZJU0gK",
|
"SSBBTSBBIEZJU0gK",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -465,9 +471,12 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
|
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
|
||||||
)
|
)
|
||||||
|
|
||||||
res = self.get_success(self.handler.get_room_keys(self.local_user, version))
|
res_keys = self.get_success(
|
||||||
|
self.handler.get_room_keys(self.local_user, version)
|
||||||
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
|
res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
|
||||||
|
"new",
|
||||||
)
|
)
|
||||||
|
|
||||||
# the etag should NOT be equal now, since the key changed
|
# the etag should NOT be equal now, since the key changed
|
||||||
|
@ -483,9 +492,12 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
|
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
|
||||||
)
|
)
|
||||||
|
|
||||||
res = self.get_success(self.handler.get_room_keys(self.local_user, version))
|
res_keys = self.get_success(
|
||||||
|
self.handler.get_room_keys(self.local_user, version)
|
||||||
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
|
res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
|
||||||
|
"new",
|
||||||
)
|
)
|
||||||
|
|
||||||
# the etag should be the same since the session did not change
|
# the etag should be the same since the session did not change
|
||||||
|
@ -494,7 +506,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# TODO: check edge cases as well as the common variations here
|
# TODO: check edge cases as well as the common variations here
|
||||||
|
|
||||||
def test_delete_room_keys(self):
|
def test_delete_room_keys(self) -> None:
|
||||||
"""Check that we can insert and delete keys for a session"""
|
"""Check that we can insert and delete keys for a session"""
|
||||||
version = self.get_success(
|
version = self.get_success(
|
||||||
self.handler.create_version(
|
self.handler.create_version(
|
||||||
|
|
|
@ -439,7 +439,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
||||||
user_id = self.register_user("kermit", "test")
|
user_id = self.register_user("kermit", "test")
|
||||||
tok = self.login("kermit", "test")
|
tok = self.login("kermit", "test")
|
||||||
|
|
||||||
def create_invite():
|
def create_invite() -> EventBase:
|
||||||
room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
|
room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
|
||||||
room_version = self.get_success(self.store.get_room_version(room_id))
|
room_version = self.get_success(self.store.get_room_version(room_id))
|
||||||
return event_from_pdu_json(
|
return event_from_pdu_json(
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.errors import AuthError, StoreError
|
from synapse.api.errors import AuthError, StoreError
|
||||||
from synapse.api.room_versions import RoomVersion
|
from synapse.api.room_versions import RoomVersion
|
||||||
from synapse.event_auth import (
|
from synapse.event_auth import (
|
||||||
|
@ -26,8 +28,10 @@ from synapse.federation.transport.client import StateRequestResponse
|
||||||
from synapse.logging.context import LoggingContext
|
from synapse.logging.context import LoggingContext
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login, room
|
from synapse.rest.client import login, room
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
|
from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import event_injection, make_awaitable
|
from tests.test_utils import event_injection, make_awaitable
|
||||||
|
@ -40,7 +44,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
|
||||||
room.register_servlets,
|
room.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
# mock out the federation transport client
|
# mock out the federation transport client
|
||||||
self.mock_federation_transport_client = mock.Mock(
|
self.mock_federation_transport_client = mock.Mock(
|
||||||
spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"]
|
spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"]
|
||||||
|
@ -165,7 +169,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
async def get_event(destination: str, event_id: str, timeout=None):
|
async def get_event(
|
||||||
|
destination: str, event_id: str, timeout: Optional[int] = None
|
||||||
|
) -> JsonDict:
|
||||||
self.assertEqual(destination, self.OTHER_SERVER_NAME)
|
self.assertEqual(destination, self.OTHER_SERVER_NAME)
|
||||||
self.assertEqual(event_id, prev_event.event_id)
|
self.assertEqual(event_id, prev_event.event_id)
|
||||||
return {"pdus": [prev_event.get_pdu_json()]}
|
return {"pdus": [prev_event.get_pdu_json()]}
|
||||||
|
|
|
@ -14,12 +14,16 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login, room
|
from synapse.rest.client import login, room
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.types import create_requester
|
from synapse.types import create_requester
|
||||||
|
from synapse.util import Clock
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
@ -35,7 +39,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||||
room.register_servlets,
|
room.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.handler = self.hs.get_event_creation_handler()
|
self.handler = self.hs.get_event_creation_handler()
|
||||||
self._persist_event_storage_controller = (
|
self._persist_event_storage_controller = (
|
||||||
self.hs.get_storage_controllers().persistence
|
self.hs.get_storage_controllers().persistence
|
||||||
|
@ -94,7 +98,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_duplicated_txn_id(self):
|
def test_duplicated_txn_id(self) -> None:
|
||||||
"""Test that attempting to handle/persist an event with a transaction ID
|
"""Test that attempting to handle/persist an event with a transaction ID
|
||||||
that has already been persisted correctly returns the old event and does
|
that has already been persisted correctly returns the old event and does
|
||||||
*not* produce duplicate messages.
|
*not* produce duplicate messages.
|
||||||
|
@ -161,7 +165,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||||
# rather than the new one.
|
# rather than the new one.
|
||||||
self.assertEqual(ret_event1.event_id, ret_event4.event_id)
|
self.assertEqual(ret_event1.event_id, ret_event4.event_id)
|
||||||
|
|
||||||
def test_duplicated_txn_id_one_call(self):
|
def test_duplicated_txn_id_one_call(self) -> None:
|
||||||
"""Test that we correctly handle duplicates that we try and persist at
|
"""Test that we correctly handle duplicates that we try and persist at
|
||||||
the same time.
|
the same time.
|
||||||
"""
|
"""
|
||||||
|
@ -185,7 +189,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(len(events), 2)
|
self.assertEqual(len(events), 2)
|
||||||
self.assertEqual(events[0].event_id, events[1].event_id)
|
self.assertEqual(events[0].event_id, events[1].event_id)
|
||||||
|
|
||||||
def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events(self):
|
def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
"""When we set allow_no_prev_events=True, should be able to create a
|
"""When we set allow_no_prev_events=True, should be able to create a
|
||||||
event without any prev_events (only auth_events).
|
event without any prev_events (only auth_events).
|
||||||
"""
|
"""
|
||||||
|
@ -214,7 +220,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def test_when_empty_prev_events_not_allowed_reject_event_with_empty_prev_events(
|
def test_when_empty_prev_events_not_allowed_reject_event_with_empty_prev_events(
|
||||||
self,
|
self,
|
||||||
):
|
) -> None:
|
||||||
"""When we set allow_no_prev_events=False, shouldn't be able to create a
|
"""When we set allow_no_prev_events=False, shouldn't be able to create a
|
||||||
event without any prev_events even if it has auth_events. Expect an
|
event without any prev_events even if it has auth_events. Expect an
|
||||||
exception to be raised.
|
exception to be raised.
|
||||||
|
@ -245,7 +251,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def test_when_empty_prev_events_allowed_reject_event_with_empty_prev_events_and_auth_events(
|
def test_when_empty_prev_events_allowed_reject_event_with_empty_prev_events_and_auth_events(
|
||||||
self,
|
self,
|
||||||
):
|
) -> None:
|
||||||
"""When we set allow_no_prev_events=True, should be able to create a
|
"""When we set allow_no_prev_events=True, should be able to create a
|
||||||
event without any prev_events or auth_events. Expect an exception to be
|
event without any prev_events or auth_events. Expect an exception to be
|
||||||
raised.
|
raised.
|
||||||
|
@ -277,12 +283,12 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
|
||||||
room.register_servlets,
|
room.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.user_id = self.register_user("tester", "foobar")
|
self.user_id = self.register_user("tester", "foobar")
|
||||||
self.access_token = self.login("tester", "foobar")
|
self.access_token = self.login("tester", "foobar")
|
||||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
|
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
|
||||||
|
|
||||||
def test_allow_server_acl(self):
|
def test_allow_server_acl(self) -> None:
|
||||||
"""Test that sending an ACL that blocks everyone but ourselves works."""
|
"""Test that sending an ACL that blocks everyone but ourselves works."""
|
||||||
|
|
||||||
self.helper.send_state(
|
self.helper.send_state(
|
||||||
|
@ -293,7 +299,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
|
||||||
expect_code=200,
|
expect_code=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_deny_server_acl_block_outselves(self):
|
def test_deny_server_acl_block_outselves(self) -> None:
|
||||||
"""Test that sending an ACL that blocks ourselves does not work."""
|
"""Test that sending an ACL that blocks ourselves does not work."""
|
||||||
self.helper.send_state(
|
self.helper.send_state(
|
||||||
self.room_id,
|
self.room_id,
|
||||||
|
@ -303,7 +309,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
|
||||||
expect_code=400,
|
expect_code=400,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_deny_redact_server_acl(self):
|
def test_deny_redact_server_acl(self) -> None:
|
||||||
"""Test that attempting to redact an ACL is blocked."""
|
"""Test that attempting to redact an ACL is blocked."""
|
||||||
|
|
||||||
body = self.helper.send_state(
|
body = self.helper.send_state(
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Tuple
|
from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple
|
||||||
from unittest.mock import ANY, Mock, patch
|
from unittest.mock import ANY, Mock, patch
|
||||||
from urllib.parse import parse_qs, urlparse
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
|
||||||
from synapse.handlers.sso import MappingException
|
from synapse.handlers.sso import MappingException
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import UserID
|
from synapse.types import JsonDict, UserID
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
from synapse.util.macaroons import get_value_from_macaroon
|
from synapse.util.macaroons import get_value_from_macaroon
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
@ -34,6 +34,10 @@ from tests.unittest import HomeserverTestCase, override_config
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import authlib # noqa: F401
|
import authlib # noqa: F401
|
||||||
|
from authlib.oidc.core import UserInfo
|
||||||
|
from authlib.oidc.discovery import OpenIDProviderMetadata
|
||||||
|
|
||||||
|
from synapse.handlers.oidc import Token, UserAttributeDict
|
||||||
|
|
||||||
HAS_OIDC = True
|
HAS_OIDC = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -70,29 +74,37 @@ EXPLICIT_ENDPOINT_CONFIG = {
|
||||||
|
|
||||||
class TestMappingProvider:
|
class TestMappingProvider:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_config(config):
|
def parse_config(config: JsonDict) -> None:
|
||||||
return
|
return None
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config: None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_remote_user_id(self, userinfo):
|
def get_remote_user_id(self, userinfo: "UserInfo") -> str:
|
||||||
return userinfo["sub"]
|
return userinfo["sub"]
|
||||||
|
|
||||||
async def map_user_attributes(self, userinfo, token):
|
async def map_user_attributes(
|
||||||
return {"localpart": userinfo["username"], "display_name": None}
|
self, userinfo: "UserInfo", token: "Token"
|
||||||
|
) -> "UserAttributeDict":
|
||||||
|
# This is testing not providing the full map.
|
||||||
|
return {"localpart": userinfo["username"], "display_name": None} # type: ignore[typeddict-item]
|
||||||
|
|
||||||
# Do not include get_extra_attributes to test backwards compatibility paths.
|
# Do not include get_extra_attributes to test backwards compatibility paths.
|
||||||
|
|
||||||
|
|
||||||
class TestMappingProviderExtra(TestMappingProvider):
|
class TestMappingProviderExtra(TestMappingProvider):
|
||||||
async def get_extra_attributes(self, userinfo, token):
|
async def get_extra_attributes(
|
||||||
|
self, userinfo: "UserInfo", token: "Token"
|
||||||
|
) -> JsonDict:
|
||||||
return {"phone": userinfo["phone"]}
|
return {"phone": userinfo["phone"]}
|
||||||
|
|
||||||
|
|
||||||
class TestMappingProviderFailures(TestMappingProvider):
|
class TestMappingProviderFailures(TestMappingProvider):
|
||||||
async def map_user_attributes(self, userinfo, token, failures):
|
# Superclass is testing the legacy interface for map_user_attributes.
|
||||||
return {
|
async def map_user_attributes( # type: ignore[override]
|
||||||
|
self, userinfo: "UserInfo", token: "Token", failures: int
|
||||||
|
) -> "UserAttributeDict":
|
||||||
|
return { # type: ignore[typeddict-item]
|
||||||
"localpart": userinfo["username"] + (str(failures) if failures else ""),
|
"localpart": userinfo["username"] + (str(failures) if failures else ""),
|
||||||
"display_name": None,
|
"display_name": None,
|
||||||
}
|
}
|
||||||
|
@ -161,13 +173,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.hs_patcher.stop()
|
self.hs_patcher.stop()
|
||||||
return super().tearDown()
|
return super().tearDown()
|
||||||
|
|
||||||
def reset_mocks(self):
|
def reset_mocks(self) -> None:
|
||||||
"""Reset all the Mocks."""
|
"""Reset all the Mocks."""
|
||||||
self.fake_server.reset_mocks()
|
self.fake_server.reset_mocks()
|
||||||
self.render_error.reset_mock()
|
self.render_error.reset_mock()
|
||||||
self.complete_sso_login.reset_mock()
|
self.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
def metadata_edit(self, values):
|
def metadata_edit(self, values: dict) -> ContextManager[Mock]:
|
||||||
"""Modify the result that will be returned by the well-known query"""
|
"""Modify the result that will be returned by the well-known query"""
|
||||||
|
|
||||||
metadata = self.fake_server.get_metadata()
|
metadata = self.fake_server.get_metadata()
|
||||||
|
@ -196,7 +208,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
|
session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
|
||||||
return _build_callback_request(code, state, session), grant
|
return _build_callback_request(code, state, session), grant
|
||||||
|
|
||||||
def assertRenderedError(self, error, error_description=None):
|
def assertRenderedError(
|
||||||
|
self, error: str, error_description: Optional[str] = None
|
||||||
|
) -> Tuple[Any, ...]:
|
||||||
self.render_error.assert_called_once()
|
self.render_error.assert_called_once()
|
||||||
args = self.render_error.call_args[0]
|
args = self.render_error.call_args[0]
|
||||||
self.assertEqual(args[1], error)
|
self.assertEqual(args[1], error)
|
||||||
|
@ -273,8 +287,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
"""Provider metadatas are extensively validated."""
|
"""Provider metadatas are extensively validated."""
|
||||||
h = self.provider
|
h = self.provider
|
||||||
|
|
||||||
def force_load_metadata():
|
def force_load_metadata() -> Awaitable[None]:
|
||||||
async def force_load():
|
async def force_load() -> "OpenIDProviderMetadata":
|
||||||
return await h.load_metadata(force=True)
|
return await h.load_metadata(force=True)
|
||||||
|
|
||||||
return get_awaitable_result(force_load())
|
return get_awaitable_result(force_load())
|
||||||
|
@ -1198,7 +1212,7 @@ def _build_callback_request(
|
||||||
state: str,
|
state: str,
|
||||||
session: str,
|
session: str,
|
||||||
ip_address: str = "10.0.0.1",
|
ip_address: str = "10.0.0.1",
|
||||||
):
|
) -> Mock:
|
||||||
"""Builds a fake SynapseRequest to mock the browser callback
|
"""Builds a fake SynapseRequest to mock the browser callback
|
||||||
|
|
||||||
Returns a Mock object which looks like the SynapseRequest we get from a browser
|
Returns a Mock object which looks like the SynapseRequest we get from a browser
|
||||||
|
|
|
@ -15,12 +15,13 @@
|
||||||
"""Tests for the password_auth_provider interface"""
|
"""Tests for the password_auth_provider interface"""
|
||||||
|
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Any, Type, Union
|
from typing import Any, Dict, List, Optional, Type, Union
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import synapse
|
import synapse
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import Codes
|
from synapse.api.errors import Codes
|
||||||
|
from synapse.handlers.account import AccountHandler
|
||||||
from synapse.module_api import ModuleApi
|
from synapse.module_api import ModuleApi
|
||||||
from synapse.rest.client import account, devices, login, logout, register
|
from synapse.rest.client import account, devices, login, logout, register
|
||||||
from synapse.types import JsonDict, UserID
|
from synapse.types import JsonDict, UserID
|
||||||
|
@ -44,13 +45,13 @@ class LegacyPasswordOnlyAuthProvider:
|
||||||
"""A legacy password_provider which only implements `check_password`."""
|
"""A legacy password_provider which only implements `check_password`."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_config(self):
|
def parse_config(config: JsonDict) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __init__(self, config, account_handler):
|
def __init__(self, config: None, account_handler: AccountHandler):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def check_password(self, *args):
|
def check_password(self, *args: str) -> Mock:
|
||||||
return mock_password_provider.check_password(*args)
|
return mock_password_provider.check_password(*args)
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,16 +59,16 @@ class LegacyCustomAuthProvider:
|
||||||
"""A legacy password_provider which implements a custom login type."""
|
"""A legacy password_provider which implements a custom login type."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_config(self):
|
def parse_config(config: JsonDict) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __init__(self, config, account_handler):
|
def __init__(self, config: None, account_handler: AccountHandler):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_supported_login_types(self):
|
def get_supported_login_types(self) -> Dict[str, List[str]]:
|
||||||
return {"test.login_type": ["test_field"]}
|
return {"test.login_type": ["test_field"]}
|
||||||
|
|
||||||
def check_auth(self, *args):
|
def check_auth(self, *args: str) -> Mock:
|
||||||
return mock_password_provider.check_auth(*args)
|
return mock_password_provider.check_auth(*args)
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,15 +76,15 @@ class CustomAuthProvider:
|
||||||
"""A module which registers password_auth_provider callbacks for a custom login type."""
|
"""A module which registers password_auth_provider callbacks for a custom login type."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_config(self):
|
def parse_config(config: JsonDict) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __init__(self, config, api: ModuleApi):
|
def __init__(self, config: None, api: ModuleApi):
|
||||||
api.register_password_auth_provider_callbacks(
|
api.register_password_auth_provider_callbacks(
|
||||||
auth_checkers={("test.login_type", ("test_field",)): self.check_auth}
|
auth_checkers={("test.login_type", ("test_field",)): self.check_auth}
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_auth(self, *args):
|
def check_auth(self, *args: Any) -> Mock:
|
||||||
return mock_password_provider.check_auth(*args)
|
return mock_password_provider.check_auth(*args)
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,16 +93,16 @@ class LegacyPasswordCustomAuthProvider:
|
||||||
as a custom type."""
|
as a custom type."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_config(self):
|
def parse_config(config: JsonDict) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __init__(self, config, account_handler):
|
def __init__(self, config: None, account_handler: AccountHandler):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_supported_login_types(self):
|
def get_supported_login_types(self) -> Dict[str, List[str]]:
|
||||||
return {"m.login.password": ["password"], "test.login_type": ["test_field"]}
|
return {"m.login.password": ["password"], "test.login_type": ["test_field"]}
|
||||||
|
|
||||||
def check_auth(self, *args):
|
def check_auth(self, *args: str) -> Mock:
|
||||||
return mock_password_provider.check_auth(*args)
|
return mock_password_provider.check_auth(*args)
|
||||||
|
|
||||||
|
|
||||||
|
@ -110,10 +111,10 @@ class PasswordCustomAuthProvider:
|
||||||
as well as a password login"""
|
as well as a password login"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_config(self):
|
def parse_config(config: JsonDict) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __init__(self, config, api: ModuleApi):
|
def __init__(self, config: None, api: ModuleApi):
|
||||||
api.register_password_auth_provider_callbacks(
|
api.register_password_auth_provider_callbacks(
|
||||||
auth_checkers={
|
auth_checkers={
|
||||||
("test.login_type", ("test_field",)): self.check_auth,
|
("test.login_type", ("test_field",)): self.check_auth,
|
||||||
|
@ -121,10 +122,10 @@ class PasswordCustomAuthProvider:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_auth(self, *args):
|
def check_auth(self, *args: Any) -> Mock:
|
||||||
return mock_password_provider.check_auth(*args)
|
return mock_password_provider.check_auth(*args)
|
||||||
|
|
||||||
def check_pass(self, *args):
|
def check_pass(self, *args: str) -> Mock:
|
||||||
return mock_password_provider.check_password(*args)
|
return mock_password_provider.check_password(*args)
|
||||||
|
|
||||||
|
|
||||||
|
@ -161,16 +162,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
CALLBACK_USERNAME = "get_username_for_registration"
|
CALLBACK_USERNAME = "get_username_for_registration"
|
||||||
CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
|
CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self) -> None:
|
||||||
# we use a global mock device, so make sure we are starting with a clean slate
|
# we use a global mock device, so make sure we are starting with a clean slate
|
||||||
mock_password_provider.reset_mock()
|
mock_password_provider.reset_mock()
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|
||||||
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
|
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
|
||||||
def test_password_only_auth_progiver_login_legacy(self):
|
def test_password_only_auth_progiver_login_legacy(self) -> None:
|
||||||
self.password_only_auth_provider_login_test_body()
|
self.password_only_auth_provider_login_test_body()
|
||||||
|
|
||||||
def password_only_auth_provider_login_test_body(self):
|
def password_only_auth_provider_login_test_body(self) -> None:
|
||||||
# login flows should only have m.login.password
|
# login flows should only have m.login.password
|
||||||
flows = self._get_login_flows()
|
flows = self._get_login_flows()
|
||||||
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
|
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
|
||||||
|
@ -201,10 +202,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
|
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
|
||||||
def test_password_only_auth_provider_ui_auth_legacy(self):
|
def test_password_only_auth_provider_ui_auth_legacy(self) -> None:
|
||||||
self.password_only_auth_provider_ui_auth_test_body()
|
self.password_only_auth_provider_ui_auth_test_body()
|
||||||
|
|
||||||
def password_only_auth_provider_ui_auth_test_body(self):
|
def password_only_auth_provider_ui_auth_test_body(self) -> None:
|
||||||
"""UI Auth should delegate correctly to the password provider"""
|
"""UI Auth should delegate correctly to the password provider"""
|
||||||
|
|
||||||
# create the user, otherwise access doesn't work
|
# create the user, otherwise access doesn't work
|
||||||
|
@ -238,10 +239,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
|
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
|
||||||
|
|
||||||
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
|
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
|
||||||
def test_local_user_fallback_login_legacy(self):
|
def test_local_user_fallback_login_legacy(self) -> None:
|
||||||
self.local_user_fallback_login_test_body()
|
self.local_user_fallback_login_test_body()
|
||||||
|
|
||||||
def local_user_fallback_login_test_body(self):
|
def local_user_fallback_login_test_body(self) -> None:
|
||||||
"""rejected login should fall back to local db"""
|
"""rejected login should fall back to local db"""
|
||||||
self.register_user("localuser", "localpass")
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
|
@ -255,10 +256,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual("@localuser:test", channel.json_body["user_id"])
|
self.assertEqual("@localuser:test", channel.json_body["user_id"])
|
||||||
|
|
||||||
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
|
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
|
||||||
def test_local_user_fallback_ui_auth_legacy(self):
|
def test_local_user_fallback_ui_auth_legacy(self) -> None:
|
||||||
self.local_user_fallback_ui_auth_test_body()
|
self.local_user_fallback_ui_auth_test_body()
|
||||||
|
|
||||||
def local_user_fallback_ui_auth_test_body(self):
|
def local_user_fallback_ui_auth_test_body(self) -> None:
|
||||||
"""rejected login should fall back to local db"""
|
"""rejected login should fall back to local db"""
|
||||||
self.register_user("localuser", "localpass")
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
|
@ -298,10 +299,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
"password_config": {"localdb_enabled": False},
|
"password_config": {"localdb_enabled": False},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_no_local_user_fallback_login_legacy(self):
|
def test_no_local_user_fallback_login_legacy(self) -> None:
|
||||||
self.no_local_user_fallback_login_test_body()
|
self.no_local_user_fallback_login_test_body()
|
||||||
|
|
||||||
def no_local_user_fallback_login_test_body(self):
|
def no_local_user_fallback_login_test_body(self) -> None:
|
||||||
"""localdb_enabled can block login with the local password"""
|
"""localdb_enabled can block login with the local password"""
|
||||||
self.register_user("localuser", "localpass")
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
|
@ -320,10 +321,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
"password_config": {"localdb_enabled": False},
|
"password_config": {"localdb_enabled": False},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_no_local_user_fallback_ui_auth_legacy(self):
|
def test_no_local_user_fallback_ui_auth_legacy(self) -> None:
|
||||||
self.no_local_user_fallback_ui_auth_test_body()
|
self.no_local_user_fallback_ui_auth_test_body()
|
||||||
|
|
||||||
def no_local_user_fallback_ui_auth_test_body(self):
|
def no_local_user_fallback_ui_auth_test_body(self) -> None:
|
||||||
"""localdb_enabled can block ui auth with the local password"""
|
"""localdb_enabled can block ui auth with the local password"""
|
||||||
self.register_user("localuser", "localpass")
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
|
@ -361,10 +362,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
"password_config": {"enabled": False},
|
"password_config": {"enabled": False},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_password_auth_disabled_legacy(self):
|
def test_password_auth_disabled_legacy(self) -> None:
|
||||||
self.password_auth_disabled_test_body()
|
self.password_auth_disabled_test_body()
|
||||||
|
|
||||||
def password_auth_disabled_test_body(self):
|
def password_auth_disabled_test_body(self) -> None:
|
||||||
"""password auth doesn't work if it's disabled across the board"""
|
"""password auth doesn't work if it's disabled across the board"""
|
||||||
# login flows should be empty
|
# login flows should be empty
|
||||||
flows = self._get_login_flows()
|
flows = self._get_login_flows()
|
||||||
|
@ -376,14 +377,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
mock_password_provider.check_password.assert_not_called()
|
mock_password_provider.check_password.assert_not_called()
|
||||||
|
|
||||||
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
|
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
|
||||||
def test_custom_auth_provider_login_legacy(self):
|
def test_custom_auth_provider_login_legacy(self) -> None:
|
||||||
self.custom_auth_provider_login_test_body()
|
self.custom_auth_provider_login_test_body()
|
||||||
|
|
||||||
@override_config(providers_config(CustomAuthProvider))
|
@override_config(providers_config(CustomAuthProvider))
|
||||||
def test_custom_auth_provider_login(self):
|
def test_custom_auth_provider_login(self) -> None:
|
||||||
self.custom_auth_provider_login_test_body()
|
self.custom_auth_provider_login_test_body()
|
||||||
|
|
||||||
def custom_auth_provider_login_test_body(self):
|
def custom_auth_provider_login_test_body(self) -> None:
|
||||||
# login flows should have the custom flow and m.login.password, since we
|
# login flows should have the custom flow and m.login.password, since we
|
||||||
# haven't disabled local password lookup.
|
# haven't disabled local password lookup.
|
||||||
# (password must come first, because reasons)
|
# (password must come first, because reasons)
|
||||||
|
@ -424,14 +425,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
|
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
|
||||||
def test_custom_auth_provider_ui_auth_legacy(self):
|
def test_custom_auth_provider_ui_auth_legacy(self) -> None:
|
||||||
self.custom_auth_provider_ui_auth_test_body()
|
self.custom_auth_provider_ui_auth_test_body()
|
||||||
|
|
||||||
@override_config(providers_config(CustomAuthProvider))
|
@override_config(providers_config(CustomAuthProvider))
|
||||||
def test_custom_auth_provider_ui_auth(self):
|
def test_custom_auth_provider_ui_auth(self) -> None:
|
||||||
self.custom_auth_provider_ui_auth_test_body()
|
self.custom_auth_provider_ui_auth_test_body()
|
||||||
|
|
||||||
def custom_auth_provider_ui_auth_test_body(self):
|
def custom_auth_provider_ui_auth_test_body(self) -> None:
|
||||||
# register the user and log in twice, to get two devices
|
# register the user and log in twice, to get two devices
|
||||||
self.register_user("localuser", "localpass")
|
self.register_user("localuser", "localpass")
|
||||||
tok1 = self.login("localuser", "localpass")
|
tok1 = self.login("localuser", "localpass")
|
||||||
|
@ -486,14 +487,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
|
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
|
||||||
def test_custom_auth_provider_callback_legacy(self):
|
def test_custom_auth_provider_callback_legacy(self) -> None:
|
||||||
self.custom_auth_provider_callback_test_body()
|
self.custom_auth_provider_callback_test_body()
|
||||||
|
|
||||||
@override_config(providers_config(CustomAuthProvider))
|
@override_config(providers_config(CustomAuthProvider))
|
||||||
def test_custom_auth_provider_callback(self):
|
def test_custom_auth_provider_callback(self) -> None:
|
||||||
self.custom_auth_provider_callback_test_body()
|
self.custom_auth_provider_callback_test_body()
|
||||||
|
|
||||||
def custom_auth_provider_callback_test_body(self):
|
def custom_auth_provider_callback_test_body(self) -> None:
|
||||||
callback = Mock(return_value=make_awaitable(None))
|
callback = Mock(return_value=make_awaitable(None))
|
||||||
|
|
||||||
mock_password_provider.check_auth.return_value = make_awaitable(
|
mock_password_provider.check_auth.return_value = make_awaitable(
|
||||||
|
@ -521,16 +522,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
"password_config": {"enabled": False},
|
"password_config": {"enabled": False},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_custom_auth_password_disabled_legacy(self):
|
def test_custom_auth_password_disabled_legacy(self) -> None:
|
||||||
self.custom_auth_password_disabled_test_body()
|
self.custom_auth_password_disabled_test_body()
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
|
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
|
||||||
)
|
)
|
||||||
def test_custom_auth_password_disabled(self):
|
def test_custom_auth_password_disabled(self) -> None:
|
||||||
self.custom_auth_password_disabled_test_body()
|
self.custom_auth_password_disabled_test_body()
|
||||||
|
|
||||||
def custom_auth_password_disabled_test_body(self):
|
def custom_auth_password_disabled_test_body(self) -> None:
|
||||||
"""Test login with a custom auth provider where password login is disabled"""
|
"""Test login with a custom auth provider where password login is disabled"""
|
||||||
self.register_user("localuser", "localpass")
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
|
@ -548,7 +549,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
"password_config": {"enabled": False, "localdb_enabled": False},
|
"password_config": {"enabled": False, "localdb_enabled": False},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_custom_auth_password_disabled_localdb_enabled_legacy(self):
|
def test_custom_auth_password_disabled_localdb_enabled_legacy(self) -> None:
|
||||||
self.custom_auth_password_disabled_localdb_enabled_test_body()
|
self.custom_auth_password_disabled_localdb_enabled_test_body()
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
|
@ -557,10 +558,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
"password_config": {"enabled": False, "localdb_enabled": False},
|
"password_config": {"enabled": False, "localdb_enabled": False},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_custom_auth_password_disabled_localdb_enabled(self):
|
def test_custom_auth_password_disabled_localdb_enabled(self) -> None:
|
||||||
self.custom_auth_password_disabled_localdb_enabled_test_body()
|
self.custom_auth_password_disabled_localdb_enabled_test_body()
|
||||||
|
|
||||||
def custom_auth_password_disabled_localdb_enabled_test_body(self):
|
def custom_auth_password_disabled_localdb_enabled_test_body(self) -> None:
|
||||||
"""Check the localdb_enabled == enabled == False
|
"""Check the localdb_enabled == enabled == False
|
||||||
|
|
||||||
Regression test for https://github.com/matrix-org/synapse/issues/8914: check
|
Regression test for https://github.com/matrix-org/synapse/issues/8914: check
|
||||||
|
@ -583,7 +584,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
"password_config": {"enabled": False},
|
"password_config": {"enabled": False},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_password_custom_auth_password_disabled_login_legacy(self):
|
def test_password_custom_auth_password_disabled_login_legacy(self) -> None:
|
||||||
self.password_custom_auth_password_disabled_login_test_body()
|
self.password_custom_auth_password_disabled_login_test_body()
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
|
@ -592,10 +593,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
"password_config": {"enabled": False},
|
"password_config": {"enabled": False},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_password_custom_auth_password_disabled_login(self):
|
def test_password_custom_auth_password_disabled_login(self) -> None:
|
||||||
self.password_custom_auth_password_disabled_login_test_body()
|
self.password_custom_auth_password_disabled_login_test_body()
|
||||||
|
|
||||||
def password_custom_auth_password_disabled_login_test_body(self):
|
def password_custom_auth_password_disabled_login_test_body(self) -> None:
|
||||||
"""log in with a custom auth provider which implements password, but password
|
"""log in with a custom auth provider which implements password, but password
|
||||||
login is disabled"""
|
login is disabled"""
|
||||||
self.register_user("localuser", "localpass")
|
self.register_user("localuser", "localpass")
|
||||||
|
@ -615,7 +616,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
"password_config": {"enabled": False},
|
"password_config": {"enabled": False},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_password_custom_auth_password_disabled_ui_auth_legacy(self):
|
def test_password_custom_auth_password_disabled_ui_auth_legacy(self) -> None:
|
||||||
self.password_custom_auth_password_disabled_ui_auth_test_body()
|
self.password_custom_auth_password_disabled_ui_auth_test_body()
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
|
@ -624,10 +625,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
"password_config": {"enabled": False},
|
"password_config": {"enabled": False},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_password_custom_auth_password_disabled_ui_auth(self):
|
def test_password_custom_auth_password_disabled_ui_auth(self) -> None:
|
||||||
self.password_custom_auth_password_disabled_ui_auth_test_body()
|
self.password_custom_auth_password_disabled_ui_auth_test_body()
|
||||||
|
|
||||||
def password_custom_auth_password_disabled_ui_auth_test_body(self):
|
def password_custom_auth_password_disabled_ui_auth_test_body(self) -> None:
|
||||||
"""UI Auth with a custom auth provider which implements password, but password
|
"""UI Auth with a custom auth provider which implements password, but password
|
||||||
login is disabled"""
|
login is disabled"""
|
||||||
# register the user and log in twice via the test login type to get two devices,
|
# register the user and log in twice via the test login type to get two devices,
|
||||||
|
@ -689,7 +690,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
"password_config": {"localdb_enabled": False},
|
"password_config": {"localdb_enabled": False},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_custom_auth_no_local_user_fallback_legacy(self):
|
def test_custom_auth_no_local_user_fallback_legacy(self) -> None:
|
||||||
self.custom_auth_no_local_user_fallback_test_body()
|
self.custom_auth_no_local_user_fallback_test_body()
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
|
@ -698,10 +699,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
"password_config": {"localdb_enabled": False},
|
"password_config": {"localdb_enabled": False},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_custom_auth_no_local_user_fallback(self):
|
def test_custom_auth_no_local_user_fallback(self) -> None:
|
||||||
self.custom_auth_no_local_user_fallback_test_body()
|
self.custom_auth_no_local_user_fallback_test_body()
|
||||||
|
|
||||||
def custom_auth_no_local_user_fallback_test_body(self):
|
def custom_auth_no_local_user_fallback_test_body(self) -> None:
|
||||||
"""Test login with a custom auth provider where the local db is disabled"""
|
"""Test login with a custom auth provider where the local db is disabled"""
|
||||||
self.register_user("localuser", "localpass")
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
|
@ -713,14 +714,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
channel = self._send_password_login("localuser", "localpass")
|
channel = self._send_password_login("localuser", "localpass")
|
||||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||||
|
|
||||||
def test_on_logged_out(self):
|
def test_on_logged_out(self) -> None:
|
||||||
"""Tests that the on_logged_out callback is called when the user logs out."""
|
"""Tests that the on_logged_out callback is called when the user logs out."""
|
||||||
self.register_user("rin", "password")
|
self.register_user("rin", "password")
|
||||||
tok = self.login("rin", "password")
|
tok = self.login("rin", "password")
|
||||||
|
|
||||||
self.called = False
|
self.called = False
|
||||||
|
|
||||||
async def on_logged_out(user_id, device_id, access_token):
|
async def on_logged_out(
|
||||||
|
user_id: str, device_id: Optional[str], access_token: str
|
||||||
|
) -> None:
|
||||||
self.called = True
|
self.called = True
|
||||||
|
|
||||||
on_logged_out = Mock(side_effect=on_logged_out)
|
on_logged_out = Mock(side_effect=on_logged_out)
|
||||||
|
@ -738,7 +741,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
on_logged_out.assert_called_once()
|
on_logged_out.assert_called_once()
|
||||||
self.assertTrue(self.called)
|
self.assertTrue(self.called)
|
||||||
|
|
||||||
def test_username(self):
|
def test_username(self) -> None:
|
||||||
"""Tests that the get_username_for_registration callback can define the username
|
"""Tests that the get_username_for_registration callback can define the username
|
||||||
of a user when registering.
|
of a user when registering.
|
||||||
"""
|
"""
|
||||||
|
@ -763,7 +766,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
mxid = channel.json_body["user_id"]
|
mxid = channel.json_body["user_id"]
|
||||||
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
|
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
|
||||||
|
|
||||||
def test_username_uia(self):
|
def test_username_uia(self) -> None:
|
||||||
"""Tests that the get_username_for_registration callback is only called at the
|
"""Tests that the get_username_for_registration callback is only called at the
|
||||||
end of the UIA flow.
|
end of the UIA flow.
|
||||||
"""
|
"""
|
||||||
|
@ -782,7 +785,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Set some email configuration so the test doesn't fail because of its absence.
|
# Set some email configuration so the test doesn't fail because of its absence.
|
||||||
@override_config({"email": {"notif_from": "noreply@test"}})
|
@override_config({"email": {"notif_from": "noreply@test"}})
|
||||||
def test_3pid_allowed(self):
|
def test_3pid_allowed(self) -> None:
|
||||||
"""Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse
|
"""Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse
|
||||||
to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind
|
to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind
|
||||||
the 3PID. Also checks that the module is passed a boolean indicating whether the
|
the 3PID. Also checks that the module is passed a boolean indicating whether the
|
||||||
|
@ -791,7 +794,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
self._test_3pid_allowed("rin", False)
|
self._test_3pid_allowed("rin", False)
|
||||||
self._test_3pid_allowed("kitay", True)
|
self._test_3pid_allowed("kitay", True)
|
||||||
|
|
||||||
def test_displayname(self):
|
def test_displayname(self) -> None:
|
||||||
"""Tests that the get_displayname_for_registration callback can define the
|
"""Tests that the get_displayname_for_registration callback can define the
|
||||||
display name of a user when registering.
|
display name of a user when registering.
|
||||||
"""
|
"""
|
||||||
|
@ -820,7 +823,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual(display_name, username + "-foo")
|
self.assertEqual(display_name, username + "-foo")
|
||||||
|
|
||||||
def test_displayname_uia(self):
|
def test_displayname_uia(self) -> None:
|
||||||
"""Tests that the get_displayname_for_registration callback is only called at the
|
"""Tests that the get_displayname_for_registration callback is only called at the
|
||||||
end of the UIA flow.
|
end of the UIA flow.
|
||||||
"""
|
"""
|
||||||
|
@ -841,7 +844,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
# Check that the callback has been called.
|
# Check that the callback has been called.
|
||||||
m.assert_called_once()
|
m.assert_called_once()
|
||||||
|
|
||||||
def _test_3pid_allowed(self, username: str, registration: bool):
|
def _test_3pid_allowed(self, username: str, registration: bool) -> None:
|
||||||
"""Tests that the "is_3pid_allowed" module callback is called correctly, using
|
"""Tests that the "is_3pid_allowed" module callback is called correctly, using
|
||||||
either /register or /account URLs depending on the arguments.
|
either /register or /account URLs depending on the arguments.
|
||||||
|
|
||||||
|
@ -907,7 +910,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
client is trying to register.
|
client is trying to register.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def callback(uia_results, params):
|
async def callback(uia_results: JsonDict, params: JsonDict) -> str:
|
||||||
self.assertIn(LoginType.DUMMY, uia_results)
|
self.assertIn(LoginType.DUMMY, uia_results)
|
||||||
username = params["username"]
|
username = params["username"]
|
||||||
return username + "-foo"
|
return username + "-foo"
|
||||||
|
@ -950,12 +953,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
def _send_password_login(self, user: str, password: str) -> FakeChannel:
|
def _send_password_login(self, user: str, password: str) -> FakeChannel:
|
||||||
return self._send_login(type="m.login.password", user=user, password=password)
|
return self._send_login(type="m.login.password", user=user, password=password)
|
||||||
|
|
||||||
def _send_login(self, type, user, **params) -> FakeChannel:
|
def _send_login(self, type: str, user: str, **extra_params: str) -> FakeChannel:
|
||||||
params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
|
params = {"identifier": {"type": "m.id.user", "user": user}, "type": type}
|
||||||
|
params.update(extra_params)
|
||||||
channel = self.make_request("POST", "/_matrix/client/r0/login", params)
|
channel = self.make_request("POST", "/_matrix/client/r0/login", params)
|
||||||
return channel
|
return channel
|
||||||
|
|
||||||
def _start_delete_device_session(self, access_token, device_id) -> str:
|
def _start_delete_device_session(self, access_token: str, device_id: str) -> str:
|
||||||
"""Make an initial delete device request, and return the UI Auth session ID"""
|
"""Make an initial delete device request, and return the UI Auth session ID"""
|
||||||
channel = self._delete_device(access_token, device_id)
|
channel = self._delete_device(access_token, device_id)
|
||||||
self.assertEqual(channel.code, 401)
|
self.assertEqual(channel.code, 401)
|
||||||
|
|
|
@ -12,12 +12,14 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional, cast
|
||||||
from unittest.mock import Mock, call
|
from unittest.mock import Mock, call
|
||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
from signedjson.key import generate_signing_key
|
from signedjson.key import generate_signing_key
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership, PresenceState
|
from synapse.api.constants import EventTypes, Membership, PresenceState
|
||||||
from synapse.api.presence import UserPresenceState
|
from synapse.api.presence import UserPresenceState
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||||
|
@ -35,7 +37,9 @@ from synapse.handlers.presence import (
|
||||||
)
|
)
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import room
|
from synapse.rest.client import room
|
||||||
from synapse.types import UserID, get_domain_from_id
|
from synapse.server import HomeServer
|
||||||
|
from synapse.types import JsonDict, UserID, get_domain_from_id
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||||
|
@ -44,10 +48,12 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||||
class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
servlets = [admin.register_servlets]
|
servlets = [admin.register_servlets]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(
|
||||||
|
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||||
|
) -> None:
|
||||||
self.store = homeserver.get_datastores().main
|
self.store = homeserver.get_datastores().main
|
||||||
|
|
||||||
def test_offline_to_online(self):
|
def test_offline_to_online(self) -> None:
|
||||||
wheel_timer = Mock()
|
wheel_timer = Mock()
|
||||||
user_id = "@foo:bar"
|
user_id = "@foo:bar"
|
||||||
now = 5000000
|
now = 5000000
|
||||||
|
@ -85,7 +91,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
any_order=True,
|
any_order=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_online_to_online(self):
|
def test_online_to_online(self) -> None:
|
||||||
wheel_timer = Mock()
|
wheel_timer = Mock()
|
||||||
user_id = "@foo:bar"
|
user_id = "@foo:bar"
|
||||||
now = 5000000
|
now = 5000000
|
||||||
|
@ -128,7 +134,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
any_order=True,
|
any_order=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_online_to_online_last_active_noop(self):
|
def test_online_to_online_last_active_noop(self) -> None:
|
||||||
wheel_timer = Mock()
|
wheel_timer = Mock()
|
||||||
user_id = "@foo:bar"
|
user_id = "@foo:bar"
|
||||||
now = 5000000
|
now = 5000000
|
||||||
|
@ -173,7 +179,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
any_order=True,
|
any_order=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_online_to_online_last_active(self):
|
def test_online_to_online_last_active(self) -> None:
|
||||||
wheel_timer = Mock()
|
wheel_timer = Mock()
|
||||||
user_id = "@foo:bar"
|
user_id = "@foo:bar"
|
||||||
now = 5000000
|
now = 5000000
|
||||||
|
@ -210,7 +216,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
any_order=True,
|
any_order=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_remote_ping_timer(self):
|
def test_remote_ping_timer(self) -> None:
|
||||||
wheel_timer = Mock()
|
wheel_timer = Mock()
|
||||||
user_id = "@foo:bar"
|
user_id = "@foo:bar"
|
||||||
now = 5000000
|
now = 5000000
|
||||||
|
@ -244,7 +250,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
any_order=True,
|
any_order=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_online_to_offline(self):
|
def test_online_to_offline(self) -> None:
|
||||||
wheel_timer = Mock()
|
wheel_timer = Mock()
|
||||||
user_id = "@foo:bar"
|
user_id = "@foo:bar"
|
||||||
now = 5000000
|
now = 5000000
|
||||||
|
@ -266,7 +272,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual(wheel_timer.insert.call_count, 0)
|
self.assertEqual(wheel_timer.insert.call_count, 0)
|
||||||
|
|
||||||
def test_online_to_idle(self):
|
def test_online_to_idle(self) -> None:
|
||||||
wheel_timer = Mock()
|
wheel_timer = Mock()
|
||||||
user_id = "@foo:bar"
|
user_id = "@foo:bar"
|
||||||
now = 5000000
|
now = 5000000
|
||||||
|
@ -300,7 +306,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
any_order=True,
|
any_order=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_persisting_presence_updates(self):
|
def test_persisting_presence_updates(self) -> None:
|
||||||
"""Tests that the latest presence state for each user is persisted correctly"""
|
"""Tests that the latest presence state for each user is persisted correctly"""
|
||||||
# Create some test users and presence states for them
|
# Create some test users and presence states for them
|
||||||
presence_states = []
|
presence_states = []
|
||||||
|
@ -322,7 +328,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
self.get_success(self.store.update_presence(presence_states))
|
self.get_success(self.store.update_presence(presence_states))
|
||||||
|
|
||||||
# Check that each update is present in the database
|
# Check that each update is present in the database
|
||||||
db_presence_states = self.get_success(
|
db_presence_states_raw = self.get_success(
|
||||||
self.store.get_all_presence_updates(
|
self.store.get_all_presence_updates(
|
||||||
instance_name="master",
|
instance_name="master",
|
||||||
last_id=0,
|
last_id=0,
|
||||||
|
@ -332,7 +338,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract presence update user ID and state information into lists of tuples
|
# Extract presence update user ID and state information into lists of tuples
|
||||||
db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]]
|
db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states_raw[0]]
|
||||||
presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states]
|
presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states]
|
||||||
|
|
||||||
# Compare what we put into the storage with what we got out.
|
# Compare what we put into the storage with what we got out.
|
||||||
|
@ -343,7 +349,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
class PresenceTimeoutTestCase(unittest.TestCase):
|
class PresenceTimeoutTestCase(unittest.TestCase):
|
||||||
"""Tests different timers and that the timer does not change `status_msg` of user."""
|
"""Tests different timers and that the timer does not change `status_msg` of user."""
|
||||||
|
|
||||||
def test_idle_timer(self):
|
def test_idle_timer(self) -> None:
|
||||||
user_id = "@foo:bar"
|
user_id = "@foo:bar"
|
||||||
status_msg = "I'm here!"
|
status_msg = "I'm here!"
|
||||||
now = 5000000
|
now = 5000000
|
||||||
|
@ -363,7 +369,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
||||||
self.assertEqual(new_state.state, PresenceState.UNAVAILABLE)
|
self.assertEqual(new_state.state, PresenceState.UNAVAILABLE)
|
||||||
self.assertEqual(new_state.status_msg, status_msg)
|
self.assertEqual(new_state.status_msg, status_msg)
|
||||||
|
|
||||||
def test_busy_no_idle(self):
|
def test_busy_no_idle(self) -> None:
|
||||||
"""
|
"""
|
||||||
Tests that a user setting their presence to busy but idling doesn't turn their
|
Tests that a user setting their presence to busy but idling doesn't turn their
|
||||||
presence state into unavailable.
|
presence state into unavailable.
|
||||||
|
@ -387,7 +393,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
||||||
self.assertEqual(new_state.state, PresenceState.BUSY)
|
self.assertEqual(new_state.state, PresenceState.BUSY)
|
||||||
self.assertEqual(new_state.status_msg, status_msg)
|
self.assertEqual(new_state.status_msg, status_msg)
|
||||||
|
|
||||||
def test_sync_timeout(self):
|
def test_sync_timeout(self) -> None:
|
||||||
user_id = "@foo:bar"
|
user_id = "@foo:bar"
|
||||||
status_msg = "I'm here!"
|
status_msg = "I'm here!"
|
||||||
now = 5000000
|
now = 5000000
|
||||||
|
@ -407,7 +413,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
||||||
self.assertEqual(new_state.state, PresenceState.OFFLINE)
|
self.assertEqual(new_state.state, PresenceState.OFFLINE)
|
||||||
self.assertEqual(new_state.status_msg, status_msg)
|
self.assertEqual(new_state.status_msg, status_msg)
|
||||||
|
|
||||||
def test_sync_online(self):
|
def test_sync_online(self) -> None:
|
||||||
user_id = "@foo:bar"
|
user_id = "@foo:bar"
|
||||||
status_msg = "I'm here!"
|
status_msg = "I'm here!"
|
||||||
now = 5000000
|
now = 5000000
|
||||||
|
@ -429,7 +435,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
||||||
self.assertEqual(new_state.state, PresenceState.ONLINE)
|
self.assertEqual(new_state.state, PresenceState.ONLINE)
|
||||||
self.assertEqual(new_state.status_msg, status_msg)
|
self.assertEqual(new_state.status_msg, status_msg)
|
||||||
|
|
||||||
def test_federation_ping(self):
|
def test_federation_ping(self) -> None:
|
||||||
user_id = "@foo:bar"
|
user_id = "@foo:bar"
|
||||||
status_msg = "I'm here!"
|
status_msg = "I'm here!"
|
||||||
now = 5000000
|
now = 5000000
|
||||||
|
@ -448,7 +454,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
||||||
self.assertIsNotNone(new_state)
|
self.assertIsNotNone(new_state)
|
||||||
self.assertEqual(state, new_state)
|
self.assertEqual(state, new_state)
|
||||||
|
|
||||||
def test_no_timeout(self):
|
def test_no_timeout(self) -> None:
|
||||||
user_id = "@foo:bar"
|
user_id = "@foo:bar"
|
||||||
now = 5000000
|
now = 5000000
|
||||||
|
|
||||||
|
@ -464,7 +470,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertIsNone(new_state)
|
self.assertIsNone(new_state)
|
||||||
|
|
||||||
def test_federation_timeout(self):
|
def test_federation_timeout(self) -> None:
|
||||||
user_id = "@foo:bar"
|
user_id = "@foo:bar"
|
||||||
status_msg = "I'm here!"
|
status_msg = "I'm here!"
|
||||||
now = 5000000
|
now = 5000000
|
||||||
|
@ -487,7 +493,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
||||||
self.assertEqual(new_state.state, PresenceState.OFFLINE)
|
self.assertEqual(new_state.state, PresenceState.OFFLINE)
|
||||||
self.assertEqual(new_state.status_msg, status_msg)
|
self.assertEqual(new_state.status_msg, status_msg)
|
||||||
|
|
||||||
def test_last_active(self):
|
def test_last_active(self) -> None:
|
||||||
user_id = "@foo:bar"
|
user_id = "@foo:bar"
|
||||||
status_msg = "I'm here!"
|
status_msg = "I'm here!"
|
||||||
now = 5000000
|
now = 5000000
|
||||||
|
@ -508,15 +514,15 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.presence_handler = hs.get_presence_handler()
|
self.presence_handler = hs.get_presence_handler()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
def test_external_process_timeout(self):
|
def test_external_process_timeout(self) -> None:
|
||||||
"""Test that if an external process doesn't update the records for a while
|
"""Test that if an external process doesn't update the records for a while
|
||||||
we time out their syncing users presence.
|
we time out their syncing users presence.
|
||||||
"""
|
"""
|
||||||
process_id = 1
|
process_id = "1"
|
||||||
user_id = "@test:server"
|
user_id = "@test:server"
|
||||||
|
|
||||||
# Notify handler that a user is now syncing.
|
# Notify handler that a user is now syncing.
|
||||||
|
@ -544,7 +550,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(state.state, PresenceState.OFFLINE)
|
self.assertEqual(state.state, PresenceState.OFFLINE)
|
||||||
|
|
||||||
def test_user_goes_offline_by_timeout_status_msg_remain(self):
|
def test_user_goes_offline_by_timeout_status_msg_remain(self) -> None:
|
||||||
"""Test that if a user doesn't update the records for a while
|
"""Test that if a user doesn't update the records for a while
|
||||||
users presence goes `OFFLINE` because of timeout and `status_msg` remains.
|
users presence goes `OFFLINE` because of timeout and `status_msg` remains.
|
||||||
"""
|
"""
|
||||||
|
@ -576,7 +582,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
self.assertEqual(state.state, PresenceState.OFFLINE)
|
self.assertEqual(state.state, PresenceState.OFFLINE)
|
||||||
self.assertEqual(state.status_msg, status_msg)
|
self.assertEqual(state.status_msg, status_msg)
|
||||||
|
|
||||||
def test_user_goes_offline_manually_with_no_status_msg(self):
|
def test_user_goes_offline_manually_with_no_status_msg(self) -> None:
|
||||||
"""Test that if a user change presence manually to `OFFLINE`
|
"""Test that if a user change presence manually to `OFFLINE`
|
||||||
and no status is set, that `status_msg` is `None`.
|
and no status is set, that `status_msg` is `None`.
|
||||||
"""
|
"""
|
||||||
|
@ -601,7 +607,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
self.assertEqual(state.state, PresenceState.OFFLINE)
|
self.assertEqual(state.state, PresenceState.OFFLINE)
|
||||||
self.assertEqual(state.status_msg, None)
|
self.assertEqual(state.status_msg, None)
|
||||||
|
|
||||||
def test_user_goes_offline_manually_with_status_msg(self):
|
def test_user_goes_offline_manually_with_status_msg(self) -> None:
|
||||||
"""Test that if a user change presence manually to `OFFLINE`
|
"""Test that if a user change presence manually to `OFFLINE`
|
||||||
and a status is set, that `status_msg` appears.
|
and a status is set, that `status_msg` appears.
|
||||||
"""
|
"""
|
||||||
|
@ -618,7 +624,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
user_id, PresenceState.OFFLINE, "And now here."
|
user_id, PresenceState.OFFLINE, "And now here."
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_user_reset_online_with_no_status(self):
|
def test_user_reset_online_with_no_status(self) -> None:
|
||||||
"""Test that if a user set again the presence manually
|
"""Test that if a user set again the presence manually
|
||||||
and no status is set, that `status_msg` is `None`.
|
and no status is set, that `status_msg` is `None`.
|
||||||
"""
|
"""
|
||||||
|
@ -644,7 +650,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
self.assertEqual(state.state, PresenceState.ONLINE)
|
self.assertEqual(state.state, PresenceState.ONLINE)
|
||||||
self.assertEqual(state.status_msg, None)
|
self.assertEqual(state.status_msg, None)
|
||||||
|
|
||||||
def test_set_presence_with_status_msg_none(self):
|
def test_set_presence_with_status_msg_none(self) -> None:
|
||||||
"""Test that if a user set again the presence manually
|
"""Test that if a user set again the presence manually
|
||||||
and status is `None`, that `status_msg` is `None`.
|
and status is `None`, that `status_msg` is `None`.
|
||||||
"""
|
"""
|
||||||
|
@ -659,7 +665,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
# Mark user as online and `status_msg = None`
|
# Mark user as online and `status_msg = None`
|
||||||
self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)
|
self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)
|
||||||
|
|
||||||
def test_set_presence_from_syncing_not_set(self):
|
def test_set_presence_from_syncing_not_set(self) -> None:
|
||||||
"""Test that presence is not set by syncing if affect_presence is false"""
|
"""Test that presence is not set by syncing if affect_presence is false"""
|
||||||
user_id = "@test:server"
|
user_id = "@test:server"
|
||||||
status_msg = "I'm here!"
|
status_msg = "I'm here!"
|
||||||
|
@ -680,7 +686,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
# and status message should still be the same
|
# and status message should still be the same
|
||||||
self.assertEqual(state.status_msg, status_msg)
|
self.assertEqual(state.status_msg, status_msg)
|
||||||
|
|
||||||
def test_set_presence_from_syncing_is_set(self):
|
def test_set_presence_from_syncing_is_set(self) -> None:
|
||||||
"""Test that presence is set by syncing if affect_presence is true"""
|
"""Test that presence is set by syncing if affect_presence is true"""
|
||||||
user_id = "@test:server"
|
user_id = "@test:server"
|
||||||
status_msg = "I'm here!"
|
status_msg = "I'm here!"
|
||||||
|
@ -699,7 +705,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
# we should now be online
|
# we should now be online
|
||||||
self.assertEqual(state.state, PresenceState.ONLINE)
|
self.assertEqual(state.state, PresenceState.ONLINE)
|
||||||
|
|
||||||
def test_set_presence_from_syncing_keeps_status(self):
|
def test_set_presence_from_syncing_keeps_status(self) -> None:
|
||||||
"""Test that presence set by syncing retains status message"""
|
"""Test that presence set by syncing retains status message"""
|
||||||
user_id = "@test:server"
|
user_id = "@test:server"
|
||||||
status_msg = "I'm here!"
|
status_msg = "I'm here!"
|
||||||
|
@ -726,7 +732,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_set_presence_from_syncing_keeps_busy(self, test_with_workers: bool):
|
def test_set_presence_from_syncing_keeps_busy(
|
||||||
|
self, test_with_workers: bool
|
||||||
|
) -> None:
|
||||||
"""Test that presence set by syncing doesn't affect busy status
|
"""Test that presence set by syncing doesn't affect busy status
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -767,7 +775,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
|
|
||||||
def _set_presencestate_with_status_msg(
|
def _set_presencestate_with_status_msg(
|
||||||
self, user_id: str, state: str, status_msg: Optional[str]
|
self, user_id: str, state: str, status_msg: Optional[str]
|
||||||
):
|
) -> None:
|
||||||
"""Set a PresenceState and status_msg and check the result.
|
"""Set a PresenceState and status_msg and check the result.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -790,14 +798,14 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
|
|
||||||
|
|
||||||
class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
|
class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.presence_handler = hs.get_presence_handler()
|
self.presence_handler = hs.get_presence_handler()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.instance_name = hs.get_instance_name()
|
self.instance_name = hs.get_instance_name()
|
||||||
|
|
||||||
self.queue = self.presence_handler.get_federation_queue()
|
self.queue = self.presence_handler.get_federation_queue()
|
||||||
|
|
||||||
def test_send_and_get(self):
|
def test_send_and_get(self) -> None:
|
||||||
state1 = UserPresenceState.default("@user1:test")
|
state1 = UserPresenceState.default("@user1:test")
|
||||||
state2 = UserPresenceState.default("@user2:test")
|
state2 = UserPresenceState.default("@user2:test")
|
||||||
state3 = UserPresenceState.default("@user3:test")
|
state3 = UserPresenceState.default("@user3:test")
|
||||||
|
@ -834,7 +842,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertFalse(limited)
|
self.assertFalse(limited)
|
||||||
self.assertCountEqual(rows, [])
|
self.assertCountEqual(rows, [])
|
||||||
|
|
||||||
def test_send_and_get_split(self):
|
def test_send_and_get_split(self) -> None:
|
||||||
state1 = UserPresenceState.default("@user1:test")
|
state1 = UserPresenceState.default("@user1:test")
|
||||||
state2 = UserPresenceState.default("@user2:test")
|
state2 = UserPresenceState.default("@user2:test")
|
||||||
state3 = UserPresenceState.default("@user3:test")
|
state3 = UserPresenceState.default("@user3:test")
|
||||||
|
@ -877,7 +885,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertCountEqual(rows, expected_rows)
|
self.assertCountEqual(rows, expected_rows)
|
||||||
|
|
||||||
def test_clear_queue_all(self):
|
def test_clear_queue_all(self) -> None:
|
||||||
state1 = UserPresenceState.default("@user1:test")
|
state1 = UserPresenceState.default("@user1:test")
|
||||||
state2 = UserPresenceState.default("@user2:test")
|
state2 = UserPresenceState.default("@user2:test")
|
||||||
state3 = UserPresenceState.default("@user3:test")
|
state3 = UserPresenceState.default("@user3:test")
|
||||||
|
@ -921,7 +929,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertCountEqual(rows, expected_rows)
|
self.assertCountEqual(rows, expected_rows)
|
||||||
|
|
||||||
def test_partially_clear_queue(self):
|
def test_partially_clear_queue(self) -> None:
|
||||||
state1 = UserPresenceState.default("@user1:test")
|
state1 = UserPresenceState.default("@user1:test")
|
||||||
state2 = UserPresenceState.default("@user2:test")
|
state2 = UserPresenceState.default("@user2:test")
|
||||||
state3 = UserPresenceState.default("@user3:test")
|
state3 = UserPresenceState.default("@user3:test")
|
||||||
|
@ -982,7 +990,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
servlets = [room.register_servlets]
|
servlets = [room.register_servlets]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
hs = self.setup_test_homeserver(
|
hs = self.setup_test_homeserver(
|
||||||
"server",
|
"server",
|
||||||
federation_http_client=None,
|
federation_http_client=None,
|
||||||
|
@ -990,14 +998,14 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self) -> JsonDict:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
# Enable federation sending on the main process.
|
# Enable federation sending on the main process.
|
||||||
config["federation_sender_instances"] = None
|
config["federation_sender_instances"] = None
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.federation_sender = hs.get_federation_sender()
|
self.federation_sender = cast(Mock, hs.get_federation_sender())
|
||||||
self.event_builder_factory = hs.get_event_builder_factory()
|
self.event_builder_factory = hs.get_event_builder_factory()
|
||||||
self.federation_event_handler = hs.get_federation_event_handler()
|
self.federation_event_handler = hs.get_federation_event_handler()
|
||||||
self.presence_handler = hs.get_presence_handler()
|
self.presence_handler = hs.get_presence_handler()
|
||||||
|
@ -1013,7 +1021,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
||||||
# random key to use.
|
# random key to use.
|
||||||
self.random_signing_key = generate_signing_key("ver")
|
self.random_signing_key = generate_signing_key("ver")
|
||||||
|
|
||||||
def test_remote_joins(self):
|
def test_remote_joins(self) -> None:
|
||||||
# We advance time to something that isn't 0, as we use 0 as a special
|
# We advance time to something that isn't 0, as we use 0 as a special
|
||||||
# value.
|
# value.
|
||||||
self.reactor.advance(1000000000000)
|
self.reactor.advance(1000000000000)
|
||||||
|
@ -1061,7 +1069,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
||||||
destinations={"server3"}, states=[expected_state]
|
destinations={"server3"}, states=[expected_state]
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_remote_gets_presence_when_local_user_joins(self):
|
def test_remote_gets_presence_when_local_user_joins(self) -> None:
|
||||||
# We advance time to something that isn't 0, as we use 0 as a special
|
# We advance time to something that isn't 0, as we use 0 as a special
|
||||||
# value.
|
# value.
|
||||||
self.reactor.advance(1000000000000)
|
self.reactor.advance(1000000000000)
|
||||||
|
@ -1110,7 +1118,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
||||||
destinations={"server2", "server3"}, states=[expected_state]
|
destinations={"server2", "server3"}, states=[expected_state]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _add_new_user(self, room_id, user_id):
|
def _add_new_user(self, room_id: str, user_id: str) -> None:
|
||||||
"""Add new user to the room by creating an event and poking the federation API."""
|
"""Add new user to the room by creating an event and poking the federation API."""
|
||||||
|
|
||||||
hostname = get_domain_from_id(user_id)
|
hostname = get_domain_from_id(user_id)
|
||||||
|
|
|
@ -332,7 +332,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
@unittest.override_config(
|
@unittest.override_config(
|
||||||
{"server_name": "test:8888", "allowed_avatar_mimetypes": ["image/png"]}
|
{"server_name": "test:8888", "allowed_avatar_mimetypes": ["image/png"]}
|
||||||
)
|
)
|
||||||
def test_avatar_constraint_on_local_server_with_port(self):
|
def test_avatar_constraint_on_local_server_with_port(self) -> None:
|
||||||
"""Test that avatar metadata is correctly fetched when the media is on a local
|
"""Test that avatar metadata is correctly fetched when the media is on a local
|
||||||
server and the server has an explicit port.
|
server and the server has an explicit port.
|
||||||
|
|
||||||
|
@ -376,7 +376,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
self.get_success(self.handler.check_avatar_size_and_mime_type(remote_mxc))
|
self.get_success(self.handler.check_avatar_size_and_mime_type(remote_mxc))
|
||||||
)
|
)
|
||||||
|
|
||||||
def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]):
|
def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None:
|
||||||
"""Stores metadata about files in the database.
|
"""Stores metadata about files in the database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -15,14 +15,18 @@
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import EduTypes, ReceiptTypes
|
from synapse.api.constants import EduTypes, ReceiptTypes
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
class ReceiptsTestCase(unittest.HomeserverTestCase):
|
class ReceiptsTestCase(unittest.HomeserverTestCase):
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.event_source = hs.get_event_sources().sources.receipt
|
self.event_source = hs.get_event_sources().sources.receipt
|
||||||
|
|
||||||
def test_filters_out_private_receipt(self) -> None:
|
def test_filters_out_private_receipt(self) -> None:
|
||||||
|
|
|
@ -12,8 +12,11 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, Collection, List, Optional, Tuple
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.auth import Auth
|
from synapse.api.auth import Auth
|
||||||
from synapse.api.constants import UserTypes
|
from synapse.api.constants import UserTypes
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
|
@ -22,8 +25,18 @@ from synapse.api.errors import (
|
||||||
ResourceLimitError,
|
ResourceLimitError,
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
|
from synapse.module_api import ModuleApi
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.spam_checker_api import RegistrationBehaviour
|
from synapse.spam_checker_api import RegistrationBehaviour
|
||||||
from synapse.types import RoomAlias, RoomID, UserID, create_requester
|
from synapse.types import (
|
||||||
|
JsonDict,
|
||||||
|
Requester,
|
||||||
|
RoomAlias,
|
||||||
|
RoomID,
|
||||||
|
UserID,
|
||||||
|
create_requester,
|
||||||
|
)
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.test_utils import make_awaitable
|
from tests.test_utils import make_awaitable
|
||||||
from tests.unittest import override_config
|
from tests.unittest import override_config
|
||||||
|
@ -33,94 +46,98 @@ from .. import unittest
|
||||||
|
|
||||||
|
|
||||||
class TestSpamChecker:
|
class TestSpamChecker:
|
||||||
def __init__(self, config, api):
|
def __init__(self, config: None, api: ModuleApi):
|
||||||
api.register_spam_checker_callbacks(
|
api.register_spam_checker_callbacks(
|
||||||
check_registration_for_spam=self.check_registration_for_spam,
|
check_registration_for_spam=self.check_registration_for_spam,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_config(config):
|
def parse_config(config: JsonDict) -> None:
|
||||||
return config
|
return None
|
||||||
|
|
||||||
async def check_registration_for_spam(
|
async def check_registration_for_spam(
|
||||||
self,
|
self,
|
||||||
email_threepid,
|
email_threepid: Optional[dict],
|
||||||
username,
|
username: Optional[str],
|
||||||
request_info,
|
request_info: Collection[Tuple[str, str]],
|
||||||
auth_provider_id,
|
auth_provider_id: Optional[str],
|
||||||
):
|
) -> RegistrationBehaviour:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DenyAll(TestSpamChecker):
|
class DenyAll(TestSpamChecker):
|
||||||
async def check_registration_for_spam(
|
async def check_registration_for_spam(
|
||||||
self,
|
self,
|
||||||
email_threepid,
|
email_threepid: Optional[dict],
|
||||||
username,
|
username: Optional[str],
|
||||||
request_info,
|
request_info: Collection[Tuple[str, str]],
|
||||||
auth_provider_id,
|
auth_provider_id: Optional[str],
|
||||||
):
|
) -> RegistrationBehaviour:
|
||||||
return RegistrationBehaviour.DENY
|
return RegistrationBehaviour.DENY
|
||||||
|
|
||||||
|
|
||||||
class BanAll(TestSpamChecker):
|
class BanAll(TestSpamChecker):
|
||||||
async def check_registration_for_spam(
|
async def check_registration_for_spam(
|
||||||
self,
|
self,
|
||||||
email_threepid,
|
email_threepid: Optional[dict],
|
||||||
username,
|
username: Optional[str],
|
||||||
request_info,
|
request_info: Collection[Tuple[str, str]],
|
||||||
auth_provider_id,
|
auth_provider_id: Optional[str],
|
||||||
):
|
) -> RegistrationBehaviour:
|
||||||
return RegistrationBehaviour.SHADOW_BAN
|
return RegistrationBehaviour.SHADOW_BAN
|
||||||
|
|
||||||
|
|
||||||
class BanBadIdPUser(TestSpamChecker):
|
class BanBadIdPUser(TestSpamChecker):
|
||||||
async def check_registration_for_spam(
|
async def check_registration_for_spam(
|
||||||
self, email_threepid, username, request_info, auth_provider_id=None
|
self,
|
||||||
):
|
email_threepid: Optional[dict],
|
||||||
|
username: Optional[str],
|
||||||
|
request_info: Collection[Tuple[str, str]],
|
||||||
|
auth_provider_id: Optional[str] = None,
|
||||||
|
) -> RegistrationBehaviour:
|
||||||
# Reject any user coming from CAS and whose username contains profanity
|
# Reject any user coming from CAS and whose username contains profanity
|
||||||
if auth_provider_id == "cas" and "flimflob" in username:
|
if auth_provider_id == "cas" and username and "flimflob" in username:
|
||||||
return RegistrationBehaviour.DENY
|
return RegistrationBehaviour.DENY
|
||||||
return RegistrationBehaviour.ALLOW
|
return RegistrationBehaviour.ALLOW
|
||||||
|
|
||||||
|
|
||||||
class TestLegacyRegistrationSpamChecker:
|
class TestLegacyRegistrationSpamChecker:
|
||||||
def __init__(self, config, api):
|
def __init__(self, config: None, api: ModuleApi):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def check_registration_for_spam(
|
async def check_registration_for_spam(
|
||||||
self,
|
self,
|
||||||
email_threepid,
|
email_threepid: Optional[dict],
|
||||||
username,
|
username: Optional[str],
|
||||||
request_info,
|
request_info: Collection[Tuple[str, str]],
|
||||||
):
|
) -> RegistrationBehaviour:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LegacyAllowAll(TestLegacyRegistrationSpamChecker):
|
class LegacyAllowAll(TestLegacyRegistrationSpamChecker):
|
||||||
async def check_registration_for_spam(
|
async def check_registration_for_spam(
|
||||||
self,
|
self,
|
||||||
email_threepid,
|
email_threepid: Optional[dict],
|
||||||
username,
|
username: Optional[str],
|
||||||
request_info,
|
request_info: Collection[Tuple[str, str]],
|
||||||
):
|
) -> RegistrationBehaviour:
|
||||||
return RegistrationBehaviour.ALLOW
|
return RegistrationBehaviour.ALLOW
|
||||||
|
|
||||||
|
|
||||||
class LegacyDenyAll(TestLegacyRegistrationSpamChecker):
|
class LegacyDenyAll(TestLegacyRegistrationSpamChecker):
|
||||||
async def check_registration_for_spam(
|
async def check_registration_for_spam(
|
||||||
self,
|
self,
|
||||||
email_threepid,
|
email_threepid: Optional[dict],
|
||||||
username,
|
username: Optional[str],
|
||||||
request_info,
|
request_info: Collection[Tuple[str, str]],
|
||||||
):
|
) -> RegistrationBehaviour:
|
||||||
return RegistrationBehaviour.DENY
|
return RegistrationBehaviour.DENY
|
||||||
|
|
||||||
|
|
||||||
class RegistrationTestCase(unittest.HomeserverTestCase):
|
class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
"""Tests the RegistrationHandler."""
|
"""Tests the RegistrationHandler."""
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
hs_config = self.default_config()
|
hs_config = self.default_config()
|
||||||
|
|
||||||
# some of the tests rely on us having a user consent version
|
# some of the tests rely on us having a user consent version
|
||||||
|
@ -145,7 +162,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.handler = self.hs.get_registration_handler()
|
self.handler = self.hs.get_registration_handler()
|
||||||
self.store = self.hs.get_datastores().main
|
self.store = self.hs.get_datastores().main
|
||||||
self.lots_of_users = 100
|
self.lots_of_users = 100
|
||||||
|
@ -153,7 +170,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.requester = create_requester("@requester:test")
|
self.requester = create_requester("@requester:test")
|
||||||
|
|
||||||
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
def test_user_is_created_and_logged_in_if_doesnt_exist(self) -> None:
|
||||||
frank = UserID.from_string("@frank:test")
|
frank = UserID.from_string("@frank:test")
|
||||||
user_id = frank.to_string()
|
user_id = frank.to_string()
|
||||||
requester = create_requester(user_id)
|
requester = create_requester(user_id)
|
||||||
|
@ -164,7 +181,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertIsInstance(result_token, str)
|
self.assertIsInstance(result_token, str)
|
||||||
self.assertGreater(len(result_token), 20)
|
self.assertGreater(len(result_token), 20)
|
||||||
|
|
||||||
def test_if_user_exists(self):
|
def test_if_user_exists(self) -> None:
|
||||||
store = self.hs.get_datastores().main
|
store = self.hs.get_datastores().main
|
||||||
frank = UserID.from_string("@frank:test")
|
frank = UserID.from_string("@frank:test")
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -180,12 +197,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertTrue(result_token is not None)
|
self.assertTrue(result_token is not None)
|
||||||
|
|
||||||
@override_config({"limit_usage_by_mau": False})
|
@override_config({"limit_usage_by_mau": False})
|
||||||
def test_mau_limits_when_disabled(self):
|
def test_mau_limits_when_disabled(self) -> None:
|
||||||
# Ensure does not throw exception
|
# Ensure does not throw exception
|
||||||
self.get_success(self.get_or_create_user(self.requester, "a", "display_name"))
|
self.get_success(self.get_or_create_user(self.requester, "a", "display_name"))
|
||||||
|
|
||||||
@override_config({"limit_usage_by_mau": True})
|
@override_config({"limit_usage_by_mau": True})
|
||||||
def test_get_or_create_user_mau_not_blocked(self):
|
def test_get_or_create_user_mau_not_blocked(self) -> None:
|
||||||
self.store.count_monthly_users = Mock(
|
self.store.count_monthly_users = Mock(
|
||||||
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
|
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
|
||||||
)
|
)
|
||||||
|
@ -193,7 +210,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
self.get_success(self.get_or_create_user(self.requester, "c", "User"))
|
self.get_success(self.get_or_create_user(self.requester, "c", "User"))
|
||||||
|
|
||||||
@override_config({"limit_usage_by_mau": True})
|
@override_config({"limit_usage_by_mau": True})
|
||||||
def test_get_or_create_user_mau_blocked(self):
|
def test_get_or_create_user_mau_blocked(self) -> None:
|
||||||
self.store.get_monthly_active_count = Mock(
|
self.store.get_monthly_active_count = Mock(
|
||||||
return_value=make_awaitable(self.lots_of_users)
|
return_value=make_awaitable(self.lots_of_users)
|
||||||
)
|
)
|
||||||
|
@ -211,7 +228,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"limit_usage_by_mau": True})
|
@override_config({"limit_usage_by_mau": True})
|
||||||
def test_register_mau_blocked(self):
|
def test_register_mau_blocked(self) -> None:
|
||||||
self.store.get_monthly_active_count = Mock(
|
self.store.get_monthly_active_count = Mock(
|
||||||
return_value=make_awaitable(self.lots_of_users)
|
return_value=make_awaitable(self.lots_of_users)
|
||||||
)
|
)
|
||||||
|
@ -229,7 +246,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
@override_config(
|
@override_config(
|
||||||
{"auto_join_rooms": ["#room:test"], "auto_join_rooms_for_guests": False}
|
{"auto_join_rooms": ["#room:test"], "auto_join_rooms_for_guests": False}
|
||||||
)
|
)
|
||||||
def test_auto_join_rooms_for_guests(self):
|
def test_auto_join_rooms_for_guests(self) -> None:
|
||||||
user_id = self.get_success(
|
user_id = self.get_success(
|
||||||
self.handler.register_user(localpart="jeff", make_guest=True),
|
self.handler.register_user(localpart="jeff", make_guest=True),
|
||||||
)
|
)
|
||||||
|
@ -237,7 +254,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(len(rooms), 0)
|
self.assertEqual(len(rooms), 0)
|
||||||
|
|
||||||
@override_config({"auto_join_rooms": ["#room:test"]})
|
@override_config({"auto_join_rooms": ["#room:test"]})
|
||||||
def test_auto_create_auto_join_rooms(self):
|
def test_auto_create_auto_join_rooms(self) -> None:
|
||||||
room_alias_str = "#room:test"
|
room_alias_str = "#room:test"
|
||||||
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
|
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
|
||||||
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
||||||
|
@ -249,7 +266,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(len(rooms), 1)
|
self.assertEqual(len(rooms), 1)
|
||||||
|
|
||||||
@override_config({"auto_join_rooms": []})
|
@override_config({"auto_join_rooms": []})
|
||||||
def test_auto_create_auto_join_rooms_with_no_rooms(self):
|
def test_auto_create_auto_join_rooms_with_no_rooms(self) -> None:
|
||||||
frank = UserID.from_string("@frank:test")
|
frank = UserID.from_string("@frank:test")
|
||||||
user_id = self.get_success(self.handler.register_user(frank.localpart))
|
user_id = self.get_success(self.handler.register_user(frank.localpart))
|
||||||
self.assertEqual(user_id, frank.to_string())
|
self.assertEqual(user_id, frank.to_string())
|
||||||
|
@ -257,7 +274,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(len(rooms), 0)
|
self.assertEqual(len(rooms), 0)
|
||||||
|
|
||||||
@override_config({"auto_join_rooms": ["#room:another"]})
|
@override_config({"auto_join_rooms": ["#room:another"]})
|
||||||
def test_auto_create_auto_join_where_room_is_another_domain(self):
|
def test_auto_create_auto_join_where_room_is_another_domain(self) -> None:
|
||||||
frank = UserID.from_string("@frank:test")
|
frank = UserID.from_string("@frank:test")
|
||||||
user_id = self.get_success(self.handler.register_user(frank.localpart))
|
user_id = self.get_success(self.handler.register_user(frank.localpart))
|
||||||
self.assertEqual(user_id, frank.to_string())
|
self.assertEqual(user_id, frank.to_string())
|
||||||
|
@ -267,13 +284,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
@override_config(
|
@override_config(
|
||||||
{"auto_join_rooms": ["#room:test"], "autocreate_auto_join_rooms": False}
|
{"auto_join_rooms": ["#room:test"], "autocreate_auto_join_rooms": False}
|
||||||
)
|
)
|
||||||
def test_auto_create_auto_join_where_auto_create_is_false(self):
|
def test_auto_create_auto_join_where_auto_create_is_false(self) -> None:
|
||||||
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
|
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
|
||||||
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
||||||
self.assertEqual(len(rooms), 0)
|
self.assertEqual(len(rooms), 0)
|
||||||
|
|
||||||
@override_config({"auto_join_rooms": ["#room:test"]})
|
@override_config({"auto_join_rooms": ["#room:test"]})
|
||||||
def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self):
|
def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self) -> None:
|
||||||
room_alias_str = "#room:test"
|
room_alias_str = "#room:test"
|
||||||
self.store.is_real_user = Mock(return_value=make_awaitable(False))
|
self.store.is_real_user = Mock(return_value=make_awaitable(False))
|
||||||
user_id = self.get_success(self.handler.register_user(localpart="support"))
|
user_id = self.get_success(self.handler.register_user(localpart="support"))
|
||||||
|
@ -284,7 +301,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
self.get_failure(directory_handler.get_association(room_alias), SynapseError)
|
self.get_failure(directory_handler.get_association(room_alias), SynapseError)
|
||||||
|
|
||||||
@override_config({"auto_join_rooms": ["#room:test"]})
|
@override_config({"auto_join_rooms": ["#room:test"]})
|
||||||
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self):
|
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None:
|
||||||
room_alias_str = "#room:test"
|
room_alias_str = "#room:test"
|
||||||
|
|
||||||
self.store.count_real_users = Mock(return_value=make_awaitable(1))
|
self.store.count_real_users = Mock(return_value=make_awaitable(1))
|
||||||
|
@ -299,7 +316,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(len(rooms), 1)
|
self.assertEqual(len(rooms), 1)
|
||||||
|
|
||||||
@override_config({"auto_join_rooms": ["#room:test"]})
|
@override_config({"auto_join_rooms": ["#room:test"]})
|
||||||
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(self):
|
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
self.store.count_real_users = Mock(return_value=make_awaitable(2))
|
self.store.count_real_users = Mock(return_value=make_awaitable(2))
|
||||||
self.store.is_real_user = Mock(return_value=make_awaitable(True))
|
self.store.is_real_user = Mock(return_value=make_awaitable(True))
|
||||||
user_id = self.get_success(self.handler.register_user(localpart="real"))
|
user_id = self.get_success(self.handler.register_user(localpart="real"))
|
||||||
|
@ -312,7 +331,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
"autocreate_auto_join_rooms_federated": False,
|
"autocreate_auto_join_rooms_federated": False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_auto_create_auto_join_rooms_federated(self):
|
def test_auto_create_auto_join_rooms_federated(self) -> None:
|
||||||
"""
|
"""
|
||||||
Auto-created rooms that are private require an invite to go to the user
|
Auto-created rooms that are private require an invite to go to the user
|
||||||
(instead of directly joining it).
|
(instead of directly joining it).
|
||||||
|
@ -339,7 +358,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
@override_config(
|
@override_config(
|
||||||
{"auto_join_rooms": ["#room:test"], "auto_join_mxid_localpart": "support"}
|
{"auto_join_rooms": ["#room:test"], "auto_join_mxid_localpart": "support"}
|
||||||
)
|
)
|
||||||
def test_auto_join_mxid_localpart(self):
|
def test_auto_join_mxid_localpart(self) -> None:
|
||||||
"""
|
"""
|
||||||
Ensure the user still needs up in the room created by a different user.
|
Ensure the user still needs up in the room created by a different user.
|
||||||
"""
|
"""
|
||||||
|
@ -376,7 +395,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
"auto_join_mxid_localpart": "support",
|
"auto_join_mxid_localpart": "support",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_auto_create_auto_join_room_preset(self):
|
def test_auto_create_auto_join_room_preset(self) -> None:
|
||||||
"""
|
"""
|
||||||
Auto-created rooms that are private require an invite to go to the user
|
Auto-created rooms that are private require an invite to go to the user
|
||||||
(instead of directly joining it).
|
(instead of directly joining it).
|
||||||
|
@ -416,7 +435,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
"auto_join_mxid_localpart": "support",
|
"auto_join_mxid_localpart": "support",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_auto_create_auto_join_room_preset_guest(self):
|
def test_auto_create_auto_join_room_preset_guest(self) -> None:
|
||||||
"""
|
"""
|
||||||
Auto-created rooms that are private require an invite to go to the user
|
Auto-created rooms that are private require an invite to go to the user
|
||||||
(instead of directly joining it).
|
(instead of directly joining it).
|
||||||
|
@ -454,7 +473,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
"auto_join_mxid_localpart": "support",
|
"auto_join_mxid_localpart": "support",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_auto_create_auto_join_room_preset_invalid_permissions(self):
|
def test_auto_create_auto_join_room_preset_invalid_permissions(self) -> None:
|
||||||
"""
|
"""
|
||||||
Auto-created rooms that are private require an invite, check that
|
Auto-created rooms that are private require an invite, check that
|
||||||
registration doesn't completely break if the inviter doesn't have proper
|
registration doesn't completely break if the inviter doesn't have proper
|
||||||
|
@ -525,7 +544,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
"auto_join_rooms": ["#room:test"],
|
"auto_join_rooms": ["#room:test"],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
def test_auto_create_auto_join_where_no_consent(self):
|
def test_auto_create_auto_join_where_no_consent(self) -> None:
|
||||||
"""Test to ensure that the first user is not auto-joined to a room if
|
"""Test to ensure that the first user is not auto-joined to a room if
|
||||||
they have not given general consent.
|
they have not given general consent.
|
||||||
"""
|
"""
|
||||||
|
@ -550,19 +569,19 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
||||||
self.assertEqual(len(rooms), 1)
|
self.assertEqual(len(rooms), 1)
|
||||||
|
|
||||||
def test_register_support_user(self):
|
def test_register_support_user(self) -> None:
|
||||||
user_id = self.get_success(
|
user_id = self.get_success(
|
||||||
self.handler.register_user(localpart="user", user_type=UserTypes.SUPPORT)
|
self.handler.register_user(localpart="user", user_type=UserTypes.SUPPORT)
|
||||||
)
|
)
|
||||||
d = self.store.is_support_user(user_id)
|
d = self.store.is_support_user(user_id)
|
||||||
self.assertTrue(self.get_success(d))
|
self.assertTrue(self.get_success(d))
|
||||||
|
|
||||||
def test_register_not_support_user(self):
|
def test_register_not_support_user(self) -> None:
|
||||||
user_id = self.get_success(self.handler.register_user(localpart="user"))
|
user_id = self.get_success(self.handler.register_user(localpart="user"))
|
||||||
d = self.store.is_support_user(user_id)
|
d = self.store.is_support_user(user_id)
|
||||||
self.assertFalse(self.get_success(d))
|
self.assertFalse(self.get_success(d))
|
||||||
|
|
||||||
def test_invalid_user_id_length(self):
|
def test_invalid_user_id_length(self) -> None:
|
||||||
invalid_user_id = "x" * 256
|
invalid_user_id = "x" * 256
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
self.handler.register_user(localpart=invalid_user_id), SynapseError
|
self.handler.register_user(localpart=invalid_user_id), SynapseError
|
||||||
|
@ -577,7 +596,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_spam_checker_deny(self):
|
def test_spam_checker_deny(self) -> None:
|
||||||
"""A spam checker can deny registration, which results in an error."""
|
"""A spam checker can deny registration, which results in an error."""
|
||||||
self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
|
self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
|
||||||
|
|
||||||
|
@ -590,7 +609,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_spam_checker_legacy_allow(self):
|
def test_spam_checker_legacy_allow(self) -> None:
|
||||||
"""Tests that a legacy spam checker implementing the legacy 3-arg version of the
|
"""Tests that a legacy spam checker implementing the legacy 3-arg version of the
|
||||||
check_registration_for_spam callback is correctly called.
|
check_registration_for_spam callback is correctly called.
|
||||||
|
|
||||||
|
@ -610,7 +629,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_spam_checker_legacy_deny(self):
|
def test_spam_checker_legacy_deny(self) -> None:
|
||||||
"""Tests that a legacy spam checker implementing the legacy 3-arg version of the
|
"""Tests that a legacy spam checker implementing the legacy 3-arg version of the
|
||||||
check_registration_for_spam callback is correctly called.
|
check_registration_for_spam callback is correctly called.
|
||||||
|
|
||||||
|
@ -630,7 +649,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_spam_checker_shadow_ban(self):
|
def test_spam_checker_shadow_ban(self) -> None:
|
||||||
"""A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
|
"""A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
|
||||||
user_id = self.get_success(self.handler.register_user(localpart="user"))
|
user_id = self.get_success(self.handler.register_user(localpart="user"))
|
||||||
|
|
||||||
|
@ -660,7 +679,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_spam_checker_receives_sso_type(self):
|
def test_spam_checker_receives_sso_type(self) -> None:
|
||||||
"""Test rejecting registration based on SSO type"""
|
"""Test rejecting registration based on SSO type"""
|
||||||
f = self.get_failure(
|
f = self.get_failure(
|
||||||
self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"),
|
self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"),
|
||||||
|
@ -678,8 +697,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_or_create_user(
|
async def get_or_create_user(
|
||||||
self, requester, localpart, displayname, password_hash=None
|
self,
|
||||||
):
|
requester: Requester,
|
||||||
|
localpart: str,
|
||||||
|
displayname: Optional[str],
|
||||||
|
password_hash: Optional[str] = None,
|
||||||
|
) -> Tuple[str, str]:
|
||||||
"""Creates a new user if the user does not exist,
|
"""Creates a new user if the user does not exist,
|
||||||
else revokes all previous access tokens and generates a new one.
|
else revokes all previous access tokens and generates a new one.
|
||||||
|
|
||||||
|
@ -734,13 +757,15 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
class RemoteAutoJoinTestCase(unittest.HomeserverTestCase):
|
class RemoteAutoJoinTestCase(unittest.HomeserverTestCase):
|
||||||
"""Tests auto-join on remote rooms."""
|
"""Tests auto-join on remote rooms."""
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
self.room_id = "!roomid:remotetest"
|
self.room_id = "!roomid:remotetest"
|
||||||
|
|
||||||
async def update_membership(*args, **kwargs):
|
async def update_membership(*args: Any, **kwargs: Any) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def lookup_room_alias(*args, **kwargs):
|
async def lookup_room_alias(
|
||||||
|
*args: Any, **kwargs: Any
|
||||||
|
) -> Tuple[RoomID, List[str]]:
|
||||||
return RoomID.from_string(self.room_id), ["remotetest"]
|
return RoomID.from_string(self.room_id), ["remotetest"]
|
||||||
|
|
||||||
self.room_member_handler = Mock(spec=["update_membership", "lookup_room_alias"])
|
self.room_member_handler = Mock(spec=["update_membership", "lookup_room_alias"])
|
||||||
|
@ -750,12 +775,12 @@ class RemoteAutoJoinTestCase(unittest.HomeserverTestCase):
|
||||||
hs = self.setup_test_homeserver(room_member_handler=self.room_member_handler)
|
hs = self.setup_test_homeserver(room_member_handler=self.room_member_handler)
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.handler = self.hs.get_registration_handler()
|
self.handler = self.hs.get_registration_handler()
|
||||||
self.store = self.hs.get_datastores().main
|
self.store = self.hs.get_datastores().main
|
||||||
|
|
||||||
@override_config({"auto_join_rooms": ["#room:remotetest"]})
|
@override_config({"auto_join_rooms": ["#room:remotetest"]})
|
||||||
def test_auto_create_auto_join_remote_room(self):
|
def test_auto_create_auto_join_remote_room(self) -> None:
|
||||||
"""Tests that we don't attempt to create remote rooms, and that we don't attempt
|
"""Tests that we don't attempt to create remote rooms, and that we don't attempt
|
||||||
to invite ourselves to rooms we're not in."""
|
to invite ourselves to rooms we're not in."""
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
@override_config({"encryption_enabled_by_default_for_room_type": "all"})
|
@override_config({"encryption_enabled_by_default_for_room_type": "all"})
|
||||||
def test_encrypted_by_default_config_option_all(self):
|
def test_encrypted_by_default_config_option_all(self) -> None:
|
||||||
"""Tests that invite-only and non-invite-only rooms have encryption enabled by
|
"""Tests that invite-only and non-invite-only rooms have encryption enabled by
|
||||||
default when the config option encryption_enabled_by_default_for_room_type is "all".
|
default when the config option encryption_enabled_by_default_for_room_type is "all".
|
||||||
"""
|
"""
|
||||||
|
@ -45,7 +45,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
|
self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
|
||||||
|
|
||||||
@override_config({"encryption_enabled_by_default_for_room_type": "invite"})
|
@override_config({"encryption_enabled_by_default_for_room_type": "invite"})
|
||||||
def test_encrypted_by_default_config_option_invite(self):
|
def test_encrypted_by_default_config_option_invite(self) -> None:
|
||||||
"""Tests that only new, invite-only rooms have encryption enabled by default when
|
"""Tests that only new, invite-only rooms have encryption enabled by default when
|
||||||
the config option encryption_enabled_by_default_for_room_type is "invite".
|
the config option encryption_enabled_by_default_for_room_type is "invite".
|
||||||
"""
|
"""
|
||||||
|
@ -76,7 +76,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"encryption_enabled_by_default_for_room_type": "off"})
|
@override_config({"encryption_enabled_by_default_for_room_type": "off"})
|
||||||
def test_encrypted_by_default_config_option_off(self):
|
def test_encrypted_by_default_config_option_off(self) -> None:
|
||||||
"""Tests that neither new invite-only nor non-invite-only rooms have encryption
|
"""Tests that neither new invite-only nor non-invite-only rooms have encryption
|
||||||
enabled by default when the config option
|
enabled by default when the config option
|
||||||
encryption_enabled_by_default_for_room_type is "off".
|
encryption_enabled_by_default_for_room_type is "off".
|
||||||
|
|
|
@ -11,10 +11,11 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Any, Iterable, List, Optional, Tuple
|
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from twisted.internet.defer import ensureDeferred
|
from twisted.internet.defer import ensureDeferred
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import (
|
from synapse.api.constants import (
|
||||||
EventContentFields,
|
EventContentFields,
|
||||||
|
@ -34,11 +35,14 @@ from synapse.rest import admin
|
||||||
from synapse.rest.client import login, room
|
from synapse.rest.client import login, room
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict, UserID, create_requester
|
from synapse.types import JsonDict, UserID, create_requester
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
def _create_event(room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0):
|
def _create_event(
|
||||||
|
room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0
|
||||||
|
) -> mock.Mock:
|
||||||
result = mock.Mock(name=room_id)
|
result = mock.Mock(name=room_id)
|
||||||
result.room_id = room_id
|
result.room_id = room_id
|
||||||
result.content = {}
|
result.content = {}
|
||||||
|
@ -48,40 +52,40 @@ def _create_event(room_id: str, order: Optional[Any] = None, origin_server_ts: i
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _order(*events):
|
def _order(*events: mock.Mock) -> List[mock.Mock]:
|
||||||
return sorted(events, key=_child_events_comparison_key)
|
return sorted(events, key=_child_events_comparison_key)
|
||||||
|
|
||||||
|
|
||||||
class TestSpaceSummarySort(unittest.TestCase):
|
class TestSpaceSummarySort(unittest.TestCase):
|
||||||
def test_no_order_last(self):
|
def test_no_order_last(self) -> None:
|
||||||
"""An event with no ordering is placed behind those with an ordering."""
|
"""An event with no ordering is placed behind those with an ordering."""
|
||||||
ev1 = _create_event("!abc:test")
|
ev1 = _create_event("!abc:test")
|
||||||
ev2 = _create_event("!xyz:test", "xyz")
|
ev2 = _create_event("!xyz:test", "xyz")
|
||||||
|
|
||||||
self.assertEqual([ev2, ev1], _order(ev1, ev2))
|
self.assertEqual([ev2, ev1], _order(ev1, ev2))
|
||||||
|
|
||||||
def test_order(self):
|
def test_order(self) -> None:
|
||||||
"""The ordering should be used."""
|
"""The ordering should be used."""
|
||||||
ev1 = _create_event("!abc:test", "xyz")
|
ev1 = _create_event("!abc:test", "xyz")
|
||||||
ev2 = _create_event("!xyz:test", "abc")
|
ev2 = _create_event("!xyz:test", "abc")
|
||||||
|
|
||||||
self.assertEqual([ev2, ev1], _order(ev1, ev2))
|
self.assertEqual([ev2, ev1], _order(ev1, ev2))
|
||||||
|
|
||||||
def test_order_origin_server_ts(self):
|
def test_order_origin_server_ts(self) -> None:
|
||||||
"""Origin server is a tie-breaker for ordering."""
|
"""Origin server is a tie-breaker for ordering."""
|
||||||
ev1 = _create_event("!abc:test", origin_server_ts=10)
|
ev1 = _create_event("!abc:test", origin_server_ts=10)
|
||||||
ev2 = _create_event("!xyz:test", origin_server_ts=30)
|
ev2 = _create_event("!xyz:test", origin_server_ts=30)
|
||||||
|
|
||||||
self.assertEqual([ev1, ev2], _order(ev1, ev2))
|
self.assertEqual([ev1, ev2], _order(ev1, ev2))
|
||||||
|
|
||||||
def test_order_room_id(self):
|
def test_order_room_id(self) -> None:
|
||||||
"""Room ID is a final tie-breaker for ordering."""
|
"""Room ID is a final tie-breaker for ordering."""
|
||||||
ev1 = _create_event("!abc:test")
|
ev1 = _create_event("!abc:test")
|
||||||
ev2 = _create_event("!xyz:test")
|
ev2 = _create_event("!xyz:test")
|
||||||
|
|
||||||
self.assertEqual([ev1, ev2], _order(ev1, ev2))
|
self.assertEqual([ev1, ev2], _order(ev1, ev2))
|
||||||
|
|
||||||
def test_invalid_ordering_type(self):
|
def test_invalid_ordering_type(self) -> None:
|
||||||
"""Invalid orderings are considered the same as missing."""
|
"""Invalid orderings are considered the same as missing."""
|
||||||
ev1 = _create_event("!abc:test", 1)
|
ev1 = _create_event("!abc:test", 1)
|
||||||
ev2 = _create_event("!xyz:test", "xyz")
|
ev2 = _create_event("!xyz:test", "xyz")
|
||||||
|
@ -97,7 +101,7 @@ class TestSpaceSummarySort(unittest.TestCase):
|
||||||
ev1 = _create_event("!abc:test", True)
|
ev1 = _create_event("!abc:test", True)
|
||||||
self.assertEqual([ev2, ev1], _order(ev1, ev2))
|
self.assertEqual([ev2, ev1], _order(ev1, ev2))
|
||||||
|
|
||||||
def test_invalid_ordering_value(self):
|
def test_invalid_ordering_value(self) -> None:
|
||||||
"""Invalid orderings are considered the same as missing."""
|
"""Invalid orderings are considered the same as missing."""
|
||||||
ev1 = _create_event("!abc:test", "foo\n")
|
ev1 = _create_event("!abc:test", "foo\n")
|
||||||
ev2 = _create_event("!xyz:test", "xyz")
|
ev2 = _create_event("!xyz:test", "xyz")
|
||||||
|
@ -115,7 +119,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs: HomeServer):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.handler = self.hs.get_room_summary_handler()
|
self.handler = self.hs.get_room_summary_handler()
|
||||||
|
|
||||||
|
@ -223,7 +227,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
||||||
fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6)
|
fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_simple_space(self):
|
def test_simple_space(self) -> None:
|
||||||
"""Test a simple space with a single room."""
|
"""Test a simple space with a single room."""
|
||||||
# The result should have the space and the room in it, along with a link
|
# The result should have the space and the room in it, along with a link
|
||||||
# from space -> room.
|
# from space -> room.
|
||||||
|
@ -234,7 +238,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self._assert_hierarchy(result, expected)
|
self._assert_hierarchy(result, expected)
|
||||||
|
|
||||||
def test_large_space(self):
|
def test_large_space(self) -> None:
|
||||||
"""Test a space with a large number of rooms."""
|
"""Test a space with a large number of rooms."""
|
||||||
rooms = [self.room]
|
rooms = [self.room]
|
||||||
# Make at least 51 rooms that are part of the space.
|
# Make at least 51 rooms that are part of the space.
|
||||||
|
@ -260,7 +264,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
||||||
result["rooms"] += result2["rooms"]
|
result["rooms"] += result2["rooms"]
|
||||||
self._assert_hierarchy(result, expected)
|
self._assert_hierarchy(result, expected)
|
||||||
|
|
||||||
def test_visibility(self):
|
def test_visibility(self) -> None:
|
||||||
"""A user not in a space cannot inspect it."""
|
"""A user not in a space cannot inspect it."""
|
||||||
user2 = self.register_user("user2", "pass")
|
user2 = self.register_user("user2", "pass")
|
||||||
token2 = self.login("user2", "pass")
|
token2 = self.login("user2", "pass")
|
||||||
|
@ -380,7 +384,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
||||||
self._assert_hierarchy(result2, [(self.space, [self.room])])
|
self._assert_hierarchy(result2, [(self.space, [self.room])])
|
||||||
|
|
||||||
def _create_room_with_join_rule(
|
def _create_room_with_join_rule(
|
||||||
self, join_rule: str, room_version: Optional[str] = None, **extra_content
|
self, join_rule: str, room_version: Optional[str] = None, **extra_content: Any
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a room with the given join rule and add it to the space."""
|
"""Create a room with the given join rule and add it to the space."""
|
||||||
room_id = self.helper.create_room_as(
|
room_id = self.helper.create_room_as(
|
||||||
|
@ -403,7 +407,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
||||||
self._add_child(self.space, room_id, self.token)
|
self._add_child(self.space, room_id, self.token)
|
||||||
return room_id
|
return room_id
|
||||||
|
|
||||||
def test_filtering(self):
|
def test_filtering(self) -> None:
|
||||||
"""
|
"""
|
||||||
Rooms should be properly filtered to only include rooms the user has access to.
|
Rooms should be properly filtered to only include rooms the user has access to.
|
||||||
"""
|
"""
|
||||||
|
@ -476,7 +480,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self._assert_hierarchy(result, expected)
|
self._assert_hierarchy(result, expected)
|
||||||
|
|
||||||
def test_complex_space(self):
|
def test_complex_space(self) -> None:
|
||||||
"""
|
"""
|
||||||
Create a "complex" space to see how it handles things like loops and subspaces.
|
Create a "complex" space to see how it handles things like loops and subspaces.
|
||||||
"""
|
"""
|
||||||
|
@ -516,7 +520,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self._assert_hierarchy(result, expected)
|
self._assert_hierarchy(result, expected)
|
||||||
|
|
||||||
def test_pagination(self):
|
def test_pagination(self) -> None:
|
||||||
"""Test simple pagination works."""
|
"""Test simple pagination works."""
|
||||||
room_ids = []
|
room_ids = []
|
||||||
for i in range(1, 10):
|
for i in range(1, 10):
|
||||||
|
@ -553,7 +557,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
||||||
self._assert_hierarchy(result, expected)
|
self._assert_hierarchy(result, expected)
|
||||||
self.assertNotIn("next_batch", result)
|
self.assertNotIn("next_batch", result)
|
||||||
|
|
||||||
def test_invalid_pagination_token(self):
|
def test_invalid_pagination_token(self) -> None:
|
||||||
"""An invalid pagination token, or changing other parameters, shoudl be rejected."""
|
"""An invalid pagination token, or changing other parameters, shoudl be rejected."""
|
||||||
room_ids = []
|
room_ids = []
|
||||||
for i in range(1, 10):
|
for i in range(1, 10):
|
||||||
|
@ -604,7 +608,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_max_depth(self):
|
def test_max_depth(self) -> None:
|
||||||
"""Create a deep tree to test the max depth against."""
|
"""Create a deep tree to test the max depth against."""
|
||||||
spaces = [self.space]
|
spaces = [self.space]
|
||||||
rooms = [self.room]
|
rooms = [self.room]
|
||||||
|
@ -659,7 +663,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
||||||
]
|
]
|
||||||
self._assert_hierarchy(result, expected)
|
self._assert_hierarchy(result, expected)
|
||||||
|
|
||||||
def test_unknown_room_version(self):
|
def test_unknown_room_version(self) -> None:
|
||||||
"""
|
"""
|
||||||
If a room with an unknown room version is encountered it should not cause
|
If a room with an unknown room version is encountered it should not cause
|
||||||
the entire summary to skip.
|
the entire summary to skip.
|
||||||
|
@ -685,7 +689,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self._assert_hierarchy(result, expected)
|
self._assert_hierarchy(result, expected)
|
||||||
|
|
||||||
def test_fed_complex(self):
|
def test_fed_complex(self) -> None:
|
||||||
"""
|
"""
|
||||||
Return data over federation and ensure that it is handled properly.
|
Return data over federation and ensure that it is handled properly.
|
||||||
"""
|
"""
|
||||||
|
@ -722,7 +726,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
||||||
"world_readable": True,
|
"world_readable": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def summarize_remote_room_hierarchy(_self, room, suggested_only):
|
async def summarize_remote_room_hierarchy(
|
||||||
|
_self: Any, room: Any, suggested_only: bool
|
||||||
|
) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
|
||||||
return requested_room_entry, {subroom: child_room}, set()
|
return requested_room_entry, {subroom: child_room}, set()
|
||||||
|
|
||||||
# Add a room to the space which is on another server.
|
# Add a room to the space which is on another server.
|
||||||
|
@ -744,7 +750,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self._assert_hierarchy(result, expected)
|
self._assert_hierarchy(result, expected)
|
||||||
|
|
||||||
def test_fed_filtering(self):
|
def test_fed_filtering(self) -> None:
|
||||||
"""
|
"""
|
||||||
Rooms returned over federation should be properly filtered to only include
|
Rooms returned over federation should be properly filtered to only include
|
||||||
rooms the user has access to.
|
rooms the user has access to.
|
||||||
|
@ -853,7 +859,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def summarize_remote_room_hierarchy(_self, room, suggested_only):
|
async def summarize_remote_room_hierarchy(
|
||||||
|
_self: Any, room: Any, suggested_only: bool
|
||||||
|
) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
|
||||||
return subspace_room_entry, dict(children_rooms), set()
|
return subspace_room_entry, dict(children_rooms), set()
|
||||||
|
|
||||||
# Add a room to the space which is on another server.
|
# Add a room to the space which is on another server.
|
||||||
|
@ -892,7 +900,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self._assert_hierarchy(result, expected)
|
self._assert_hierarchy(result, expected)
|
||||||
|
|
||||||
def test_fed_invited(self):
|
def test_fed_invited(self) -> None:
|
||||||
"""
|
"""
|
||||||
A room which the user was invited to should be included in the response.
|
A room which the user was invited to should be included in the response.
|
||||||
|
|
||||||
|
@ -915,7 +923,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def summarize_remote_room_hierarchy(_self, room, suggested_only):
|
async def summarize_remote_room_hierarchy(
|
||||||
|
_self: Any, room: Any, suggested_only: bool
|
||||||
|
) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
|
||||||
return fed_room_entry, {}, set()
|
return fed_room_entry, {}, set()
|
||||||
|
|
||||||
# Add a room to the space which is on another server.
|
# Add a room to the space which is on another server.
|
||||||
|
@ -936,7 +946,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self._assert_hierarchy(result, expected)
|
self._assert_hierarchy(result, expected)
|
||||||
|
|
||||||
def test_fed_caching(self):
|
def test_fed_caching(self) -> None:
|
||||||
"""
|
"""
|
||||||
Federation `/hierarchy` responses should be cached.
|
Federation `/hierarchy` responses should be cached.
|
||||||
"""
|
"""
|
||||||
|
@ -1023,7 +1033,7 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs: HomeServer):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.handler = self.hs.get_room_summary_handler()
|
self.handler = self.hs.get_room_summary_handler()
|
||||||
|
|
||||||
|
@ -1040,12 +1050,12 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
|
||||||
tok=self.token,
|
tok=self.token,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_own_room(self):
|
def test_own_room(self) -> None:
|
||||||
"""Test a simple room created by the requester."""
|
"""Test a simple room created by the requester."""
|
||||||
result = self.get_success(self.handler.get_room_summary(self.user, self.room))
|
result = self.get_success(self.handler.get_room_summary(self.user, self.room))
|
||||||
self.assertEqual(result.get("room_id"), self.room)
|
self.assertEqual(result.get("room_id"), self.room)
|
||||||
|
|
||||||
def test_visibility(self):
|
def test_visibility(self) -> None:
|
||||||
"""A user not in a private room cannot get its summary."""
|
"""A user not in a private room cannot get its summary."""
|
||||||
user2 = self.register_user("user2", "pass")
|
user2 = self.register_user("user2", "pass")
|
||||||
token2 = self.login("user2", "pass")
|
token2 = self.login("user2", "pass")
|
||||||
|
@ -1093,7 +1103,7 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
|
||||||
result = self.get_success(self.handler.get_room_summary(user2, self.room))
|
result = self.get_success(self.handler.get_room_summary(user2, self.room))
|
||||||
self.assertEqual(result.get("room_id"), self.room)
|
self.assertEqual(result.get("room_id"), self.room)
|
||||||
|
|
||||||
def test_fed(self):
|
def test_fed(self) -> None:
|
||||||
"""
|
"""
|
||||||
Return data over federation and ensure that it is handled properly.
|
Return data over federation and ensure that it is handled properly.
|
||||||
"""
|
"""
|
||||||
|
@ -1105,7 +1115,9 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
|
||||||
{"room_id": fed_room, "world_readable": True},
|
{"room_id": fed_room, "world_readable": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def summarize_remote_room_hierarchy(_self, room, suggested_only):
|
async def summarize_remote_room_hierarchy(
|
||||||
|
_self: Any, room: Any, suggested_only: bool
|
||||||
|
) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
|
||||||
return requested_room_entry, {}, set()
|
return requested_room_entry, {}, set()
|
||||||
|
|
||||||
with mock.patch(
|
with mock.patch(
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, Set, Tuple
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -20,7 +20,9 @@ import attr
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.errors import RedirectException
|
from synapse.api.errors import RedirectException
|
||||||
|
from synapse.module_api import ModuleApi
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
from synapse.types import JsonDict
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.test_utils import simple_async_mock
|
from tests.test_utils import simple_async_mock
|
||||||
|
@ -29,6 +31,7 @@ from tests.unittest import HomeserverTestCase, override_config
|
||||||
# Check if we have the dependencies to run the tests.
|
# Check if we have the dependencies to run the tests.
|
||||||
try:
|
try:
|
||||||
import saml2.config
|
import saml2.config
|
||||||
|
import saml2.response
|
||||||
from saml2.sigver import SigverError
|
from saml2.sigver import SigverError
|
||||||
|
|
||||||
has_saml2 = True
|
has_saml2 = True
|
||||||
|
@ -56,31 +59,39 @@ class FakeAuthnResponse:
|
||||||
|
|
||||||
|
|
||||||
class TestMappingProvider:
|
class TestMappingProvider:
|
||||||
def __init__(self, config, module):
|
def __init__(self, config: None, module: ModuleApi):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_config(config):
|
def parse_config(config: JsonDict) -> None:
|
||||||
return
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_saml_attributes(config):
|
def get_saml_attributes(config: None) -> Tuple[Set[str], Set[str]]:
|
||||||
return {"uid"}, {"displayName"}
|
return {"uid"}, {"displayName"}
|
||||||
|
|
||||||
def get_remote_user_id(self, saml_response, client_redirect_url):
|
def get_remote_user_id(
|
||||||
|
self, saml_response: "saml2.response.AuthnResponse", client_redirect_url: str
|
||||||
|
) -> str:
|
||||||
return saml_response.ava["uid"]
|
return saml_response.ava["uid"]
|
||||||
|
|
||||||
def saml_response_to_user_attributes(
|
def saml_response_to_user_attributes(
|
||||||
self, saml_response, failures, client_redirect_url
|
self,
|
||||||
):
|
saml_response: "saml2.response.AuthnResponse",
|
||||||
|
failures: int,
|
||||||
|
client_redirect_url: str,
|
||||||
|
) -> dict:
|
||||||
localpart = saml_response.ava["username"] + (str(failures) if failures else "")
|
localpart = saml_response.ava["username"] + (str(failures) if failures else "")
|
||||||
return {"mxid_localpart": localpart, "displayname": None}
|
return {"mxid_localpart": localpart, "displayname": None}
|
||||||
|
|
||||||
|
|
||||||
class TestRedirectMappingProvider(TestMappingProvider):
|
class TestRedirectMappingProvider(TestMappingProvider):
|
||||||
def saml_response_to_user_attributes(
|
def saml_response_to_user_attributes(
|
||||||
self, saml_response, failures, client_redirect_url
|
self,
|
||||||
):
|
saml_response: "saml2.response.AuthnResponse",
|
||||||
|
failures: int,
|
||||||
|
client_redirect_url: str,
|
||||||
|
) -> dict:
|
||||||
raise RedirectException(b"https://custom-saml-redirect/")
|
raise RedirectException(b"https://custom-saml-redirect/")
|
||||||
|
|
||||||
|
|
||||||
|
@ -347,7 +358,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _mock_request():
|
def _mock_request() -> Mock:
|
||||||
"""Returns a mock which will stand in as a SynapseRequest"""
|
"""Returns a mock which will stand in as a SynapseRequest"""
|
||||||
mock = Mock(
|
mock = Mock(
|
||||||
spec=[
|
spec=[
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from typing import List, Tuple
|
from typing import Callable, List, Tuple
|
||||||
|
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
|
|
||||||
|
@ -28,20 +28,27 @@ from tests.unittest import HomeserverTestCase, override_config
|
||||||
|
|
||||||
@implementer(interfaces.IMessageDelivery)
|
@implementer(interfaces.IMessageDelivery)
|
||||||
class _DummyMessageDelivery:
|
class _DummyMessageDelivery:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
# (recipient, message) tuples
|
# (recipient, message) tuples
|
||||||
self.messages: List[Tuple[smtp.Address, bytes]] = []
|
self.messages: List[Tuple[smtp.Address, bytes]] = []
|
||||||
|
|
||||||
def receivedHeader(self, helo, origin, recipients):
|
def receivedHeader(
|
||||||
|
self,
|
||||||
|
helo: Tuple[bytes, bytes],
|
||||||
|
origin: smtp.Address,
|
||||||
|
recipients: List[smtp.User],
|
||||||
|
) -> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def validateFrom(self, helo, origin):
|
def validateFrom(
|
||||||
|
self, helo: Tuple[bytes, bytes], origin: smtp.Address
|
||||||
|
) -> smtp.Address:
|
||||||
return origin
|
return origin
|
||||||
|
|
||||||
def record_message(self, recipient: smtp.Address, message: bytes):
|
def record_message(self, recipient: smtp.Address, message: bytes) -> None:
|
||||||
self.messages.append((recipient, message))
|
self.messages.append((recipient, message))
|
||||||
|
|
||||||
def validateTo(self, user: smtp.User):
|
def validateTo(self, user: smtp.User) -> Callable[[], interfaces.IMessageSMTP]:
|
||||||
return lambda: _DummyMessage(self, user)
|
return lambda: _DummyMessage(self, user)
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,20 +63,20 @@ class _DummyMessage:
|
||||||
self._user = user
|
self._user = user
|
||||||
self._buffer: List[bytes] = []
|
self._buffer: List[bytes] = []
|
||||||
|
|
||||||
def lineReceived(self, line):
|
def lineReceived(self, line: bytes) -> None:
|
||||||
self._buffer.append(line)
|
self._buffer.append(line)
|
||||||
|
|
||||||
def eomReceived(self):
|
def eomReceived(self) -> "defer.Deferred[bytes]":
|
||||||
message = b"\n".join(self._buffer) + b"\n"
|
message = b"\n".join(self._buffer) + b"\n"
|
||||||
self._delivery.record_message(self._user.dest, message)
|
self._delivery.record_message(self._user.dest, message)
|
||||||
return defer.succeed(b"saved")
|
return defer.succeed(b"saved")
|
||||||
|
|
||||||
def connectionLost(self):
|
def connectionLost(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SendEmailHandlerTestCase(HomeserverTestCase):
|
class SendEmailHandlerTestCase(HomeserverTestCase):
|
||||||
def test_send_email(self):
|
def test_send_email(self) -> None:
|
||||||
"""Happy-path test that we can send email to a non-TLS server."""
|
"""Happy-path test that we can send email to a non-TLS server."""
|
||||||
h = self.hs.get_send_email_handler()
|
h = self.hs.get_send_email_handler()
|
||||||
d = ensureDeferred(
|
d = ensureDeferred(
|
||||||
|
@ -119,7 +126,7 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_send_email_force_tls(self):
|
def test_send_email_force_tls(self) -> None:
|
||||||
"""Happy-path test that we can send email to an Implicit TLS server."""
|
"""Happy-path test that we can send email to an Implicit TLS server."""
|
||||||
h = self.hs.get_send_email_handler()
|
h = self.hs.get_send_email_handler()
|
||||||
d = ensureDeferred(
|
d = ensureDeferred(
|
||||||
|
|
|
@ -12,9 +12,15 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login, room
|
from synapse.rest.client import login, room
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.databases.main import stats
|
from synapse.storage.databases.main import stats
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
@ -32,11 +38,11 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.handler = self.hs.get_stats_handler()
|
self.handler = self.hs.get_stats_handler()
|
||||||
|
|
||||||
def _add_background_updates(self):
|
def _add_background_updates(self) -> None:
|
||||||
"""
|
"""
|
||||||
Add the background updates we need to run.
|
Add the background updates we need to run.
|
||||||
"""
|
"""
|
||||||
|
@ -63,12 +69,14 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_all_room_state(self):
|
async def get_all_room_state(self) -> List[Dict[str, Any]]:
|
||||||
return await self.store.db_pool.simple_select_list(
|
return await self.store.db_pool.simple_select_list(
|
||||||
"room_stats_state", None, retcols=("name", "topic", "canonical_alias")
|
"room_stats_state", None, retcols=("name", "topic", "canonical_alias")
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_current_stats(self, stats_type, stat_id):
|
def _get_current_stats(
|
||||||
|
self, stats_type: str, stat_id: str
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
table, id_col = stats.TYPE_TO_TABLE[stats_type]
|
table, id_col = stats.TYPE_TO_TABLE[stats_type]
|
||||||
|
|
||||||
cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type])
|
cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type])
|
||||||
|
@ -82,13 +90,13 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _perform_background_initial_update(self):
|
def _perform_background_initial_update(self) -> None:
|
||||||
# Do the initial population of the stats via the background update
|
# Do the initial population of the stats via the background update
|
||||||
self._add_background_updates()
|
self._add_background_updates()
|
||||||
|
|
||||||
self.wait_for_background_updates()
|
self.wait_for_background_updates()
|
||||||
|
|
||||||
def test_initial_room(self):
|
def test_initial_room(self) -> None:
|
||||||
"""
|
"""
|
||||||
The background updates will build the table from scratch.
|
The background updates will build the table from scratch.
|
||||||
"""
|
"""
|
||||||
|
@ -125,7 +133,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(len(r), 1)
|
self.assertEqual(len(r), 1)
|
||||||
self.assertEqual(r[0]["topic"], "foo")
|
self.assertEqual(r[0]["topic"], "foo")
|
||||||
|
|
||||||
def test_create_user(self):
|
def test_create_user(self) -> None:
|
||||||
"""
|
"""
|
||||||
When we create a user, it should have statistics already ready.
|
When we create a user, it should have statistics already ready.
|
||||||
"""
|
"""
|
||||||
|
@ -134,12 +142,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
u1stats = self._get_current_stats("user", u1)
|
u1stats = self._get_current_stats("user", u1)
|
||||||
|
|
||||||
self.assertIsNotNone(u1stats)
|
assert u1stats is not None
|
||||||
|
|
||||||
# not in any rooms by default
|
# not in any rooms by default
|
||||||
self.assertEqual(u1stats["joined_rooms"], 0)
|
self.assertEqual(u1stats["joined_rooms"], 0)
|
||||||
|
|
||||||
def test_create_room(self):
|
def test_create_room(self) -> None:
|
||||||
"""
|
"""
|
||||||
When we create a room, it should have statistics already ready.
|
When we create a room, it should have statistics already ready.
|
||||||
"""
|
"""
|
||||||
|
@ -153,8 +161,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
r2 = self.helper.create_room_as(u1, tok=u1token, is_public=False)
|
r2 = self.helper.create_room_as(u1, tok=u1token, is_public=False)
|
||||||
r2stats = self._get_current_stats("room", r2)
|
r2stats = self._get_current_stats("room", r2)
|
||||||
|
|
||||||
self.assertIsNotNone(r1stats)
|
assert r1stats is not None
|
||||||
self.assertIsNotNone(r2stats)
|
assert r2stats is not None
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM
|
r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM
|
||||||
|
@ -171,7 +179,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(r2stats["invited_members"], 0)
|
self.assertEqual(r2stats["invited_members"], 0)
|
||||||
self.assertEqual(r2stats["banned_members"], 0)
|
self.assertEqual(r2stats["banned_members"], 0)
|
||||||
|
|
||||||
def test_updating_profile_information_does_not_increase_joined_members_count(self):
|
def test_updating_profile_information_does_not_increase_joined_members_count(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Check that the joined_members count does not increase when a user changes their
|
Check that the joined_members count does not increase when a user changes their
|
||||||
profile information (which is done by sending another join membership event into
|
profile information (which is done by sending another join membership event into
|
||||||
|
@ -186,6 +196,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Get the current room stats
|
# Get the current room stats
|
||||||
r1stats_ante = self._get_current_stats("room", r1)
|
r1stats_ante = self._get_current_stats("room", r1)
|
||||||
|
assert r1stats_ante is not None
|
||||||
|
|
||||||
# Send a profile update into the room
|
# Send a profile update into the room
|
||||||
new_profile = {"displayname": "bob"}
|
new_profile = {"displayname": "bob"}
|
||||||
|
@ -195,6 +206,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Get the new room stats
|
# Get the new room stats
|
||||||
r1stats_post = self._get_current_stats("room", r1)
|
r1stats_post = self._get_current_stats("room", r1)
|
||||||
|
assert r1stats_post is not None
|
||||||
|
|
||||||
# Ensure that the user count did not changed
|
# Ensure that the user count did not changed
|
||||||
self.assertEqual(r1stats_post["joined_members"], r1stats_ante["joined_members"])
|
self.assertEqual(r1stats_post["joined_members"], r1stats_ante["joined_members"])
|
||||||
|
@ -202,7 +214,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
r1stats_post["local_users_in_room"], r1stats_ante["local_users_in_room"]
|
r1stats_post["local_users_in_room"], r1stats_ante["local_users_in_room"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_send_state_event_nonoverwriting(self):
|
def test_send_state_event_nonoverwriting(self) -> None:
|
||||||
"""
|
"""
|
||||||
When we send a non-overwriting state event, it increments current_state_events
|
When we send a non-overwriting state event, it increments current_state_events
|
||||||
"""
|
"""
|
||||||
|
@ -218,19 +230,21 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
r1stats_ante = self._get_current_stats("room", r1)
|
r1stats_ante = self._get_current_stats("room", r1)
|
||||||
|
assert r1stats_ante is not None
|
||||||
|
|
||||||
self.helper.send_state(
|
self.helper.send_state(
|
||||||
r1, "cat.hissing", {"value": False}, tok=u1token, state_key="moggy"
|
r1, "cat.hissing", {"value": False}, tok=u1token, state_key="moggy"
|
||||||
)
|
)
|
||||||
|
|
||||||
r1stats_post = self._get_current_stats("room", r1)
|
r1stats_post = self._get_current_stats("room", r1)
|
||||||
|
assert r1stats_post is not None
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
|
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
|
||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_join_first_time(self):
|
def test_join_first_time(self) -> None:
|
||||||
"""
|
"""
|
||||||
When a user joins a room for the first time, current_state_events and
|
When a user joins a room for the first time, current_state_events and
|
||||||
joined_members should increase by exactly 1.
|
joined_members should increase by exactly 1.
|
||||||
|
@ -246,10 +260,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
u2token = self.login("u2", "pass")
|
u2token = self.login("u2", "pass")
|
||||||
|
|
||||||
r1stats_ante = self._get_current_stats("room", r1)
|
r1stats_ante = self._get_current_stats("room", r1)
|
||||||
|
assert r1stats_ante is not None
|
||||||
|
|
||||||
self.helper.join(r1, u2, tok=u2token)
|
self.helper.join(r1, u2, tok=u2token)
|
||||||
|
|
||||||
r1stats_post = self._get_current_stats("room", r1)
|
r1stats_post = self._get_current_stats("room", r1)
|
||||||
|
assert r1stats_post is not None
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
|
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
|
||||||
|
@ -259,7 +275,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
r1stats_post["joined_members"] - r1stats_ante["joined_members"], 1
|
r1stats_post["joined_members"] - r1stats_ante["joined_members"], 1
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_join_after_leave(self):
|
def test_join_after_leave(self) -> None:
|
||||||
"""
|
"""
|
||||||
When a user joins a room after being previously left,
|
When a user joins a room after being previously left,
|
||||||
joined_members should increase by exactly 1.
|
joined_members should increase by exactly 1.
|
||||||
|
@ -280,10 +296,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
self.helper.leave(r1, u2, tok=u2token)
|
self.helper.leave(r1, u2, tok=u2token)
|
||||||
|
|
||||||
r1stats_ante = self._get_current_stats("room", r1)
|
r1stats_ante = self._get_current_stats("room", r1)
|
||||||
|
assert r1stats_ante is not None
|
||||||
|
|
||||||
self.helper.join(r1, u2, tok=u2token)
|
self.helper.join(r1, u2, tok=u2token)
|
||||||
|
|
||||||
r1stats_post = self._get_current_stats("room", r1)
|
r1stats_post = self._get_current_stats("room", r1)
|
||||||
|
assert r1stats_post is not None
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
|
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
|
||||||
|
@ -296,7 +314,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
r1stats_post["left_members"] - r1stats_ante["left_members"], -1
|
r1stats_post["left_members"] - r1stats_ante["left_members"], -1
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_invited(self):
|
def test_invited(self) -> None:
|
||||||
"""
|
"""
|
||||||
When a user invites another user, current_state_events and
|
When a user invites another user, current_state_events and
|
||||||
invited_members should increase by exactly 1.
|
invited_members should increase by exactly 1.
|
||||||
|
@ -311,10 +329,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
u2 = self.register_user("u2", "pass")
|
u2 = self.register_user("u2", "pass")
|
||||||
|
|
||||||
r1stats_ante = self._get_current_stats("room", r1)
|
r1stats_ante = self._get_current_stats("room", r1)
|
||||||
|
assert r1stats_ante is not None
|
||||||
|
|
||||||
self.helper.invite(r1, u1, u2, tok=u1token)
|
self.helper.invite(r1, u1, u2, tok=u1token)
|
||||||
|
|
||||||
r1stats_post = self._get_current_stats("room", r1)
|
r1stats_post = self._get_current_stats("room", r1)
|
||||||
|
assert r1stats_post is not None
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
|
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
|
||||||
|
@ -324,7 +344,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
r1stats_post["invited_members"] - r1stats_ante["invited_members"], +1
|
r1stats_post["invited_members"] - r1stats_ante["invited_members"], +1
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_join_after_invite(self):
|
def test_join_after_invite(self) -> None:
|
||||||
"""
|
"""
|
||||||
When a user joins a room after being invited and
|
When a user joins a room after being invited and
|
||||||
joined_members should increase by exactly 1.
|
joined_members should increase by exactly 1.
|
||||||
|
@ -344,10 +364,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
self.helper.invite(r1, u1, u2, tok=u1token)
|
self.helper.invite(r1, u1, u2, tok=u1token)
|
||||||
|
|
||||||
r1stats_ante = self._get_current_stats("room", r1)
|
r1stats_ante = self._get_current_stats("room", r1)
|
||||||
|
assert r1stats_ante is not None
|
||||||
|
|
||||||
self.helper.join(r1, u2, tok=u2token)
|
self.helper.join(r1, u2, tok=u2token)
|
||||||
|
|
||||||
r1stats_post = self._get_current_stats("room", r1)
|
r1stats_post = self._get_current_stats("room", r1)
|
||||||
|
assert r1stats_post is not None
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
|
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
|
||||||
|
@ -360,7 +382,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
r1stats_post["invited_members"] - r1stats_ante["invited_members"], -1
|
r1stats_post["invited_members"] - r1stats_ante["invited_members"], -1
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_left(self):
|
def test_left(self) -> None:
|
||||||
"""
|
"""
|
||||||
When a user leaves a room after joining and
|
When a user leaves a room after joining and
|
||||||
left_members should increase by exactly 1.
|
left_members should increase by exactly 1.
|
||||||
|
@ -380,10 +402,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
self.helper.join(r1, u2, tok=u2token)
|
self.helper.join(r1, u2, tok=u2token)
|
||||||
|
|
||||||
r1stats_ante = self._get_current_stats("room", r1)
|
r1stats_ante = self._get_current_stats("room", r1)
|
||||||
|
assert r1stats_ante is not None
|
||||||
|
|
||||||
self.helper.leave(r1, u2, tok=u2token)
|
self.helper.leave(r1, u2, tok=u2token)
|
||||||
|
|
||||||
r1stats_post = self._get_current_stats("room", r1)
|
r1stats_post = self._get_current_stats("room", r1)
|
||||||
|
assert r1stats_post is not None
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
|
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
|
||||||
|
@ -396,7 +420,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1
|
r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_banned(self):
|
def test_banned(self) -> None:
|
||||||
"""
|
"""
|
||||||
When a user is banned from a room after joining and
|
When a user is banned from a room after joining and
|
||||||
left_members should increase by exactly 1.
|
left_members should increase by exactly 1.
|
||||||
|
@ -416,10 +440,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
self.helper.join(r1, u2, tok=u2token)
|
self.helper.join(r1, u2, tok=u2token)
|
||||||
|
|
||||||
r1stats_ante = self._get_current_stats("room", r1)
|
r1stats_ante = self._get_current_stats("room", r1)
|
||||||
|
assert r1stats_ante is not None
|
||||||
|
|
||||||
self.helper.change_membership(r1, u1, u2, "ban", tok=u1token)
|
self.helper.change_membership(r1, u1, u2, "ban", tok=u1token)
|
||||||
|
|
||||||
r1stats_post = self._get_current_stats("room", r1)
|
r1stats_post = self._get_current_stats("room", r1)
|
||||||
|
assert r1stats_post is not None
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
|
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
|
||||||
|
@ -432,7 +458,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1
|
r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_initial_background_update(self):
|
def test_initial_background_update(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test that statistics can be generated by the initial background update
|
Test that statistics can be generated by the initial background update
|
||||||
handler.
|
handler.
|
||||||
|
@ -462,6 +488,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
r1stats = self._get_current_stats("room", r1)
|
r1stats = self._get_current_stats("room", r1)
|
||||||
u1stats = self._get_current_stats("user", u1)
|
u1stats = self._get_current_stats("user", u1)
|
||||||
|
|
||||||
|
assert r1stats is not None
|
||||||
|
assert u1stats is not None
|
||||||
|
|
||||||
self.assertEqual(r1stats["joined_members"], 1)
|
self.assertEqual(r1stats["joined_members"], 1)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM
|
r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM
|
||||||
|
@ -469,7 +498,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual(u1stats["joined_rooms"], 1)
|
self.assertEqual(u1stats["joined_rooms"], 1)
|
||||||
|
|
||||||
def test_incomplete_stats(self):
|
def test_incomplete_stats(self) -> None:
|
||||||
"""
|
"""
|
||||||
This tests that we track incomplete statistics.
|
This tests that we track incomplete statistics.
|
||||||
|
|
||||||
|
@ -533,8 +562,11 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
self.wait_for_background_updates()
|
self.wait_for_background_updates()
|
||||||
|
|
||||||
r1stats_complete = self._get_current_stats("room", r1)
|
r1stats_complete = self._get_current_stats("room", r1)
|
||||||
|
assert r1stats_complete is not None
|
||||||
u1stats_complete = self._get_current_stats("user", u1)
|
u1stats_complete = self._get_current_stats("user", u1)
|
||||||
|
assert u1stats_complete is not None
|
||||||
u2stats_complete = self._get_current_stats("user", u2)
|
u2stats_complete = self._get_current_stats("user", u2)
|
||||||
|
assert u2stats_complete is not None
|
||||||
|
|
||||||
# now we make our assertions
|
# now we make our assertions
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, JoinRules
|
from synapse.api.constants import EventTypes, JoinRules
|
||||||
from synapse.api.errors import Codes, ResourceLimitError
|
from synapse.api.errors import Codes, ResourceLimitError
|
||||||
from synapse.api.filtering import Filtering
|
from synapse.api.filtering import Filtering
|
||||||
|
@ -23,6 +25,7 @@ from synapse.rest import admin
|
||||||
from synapse.rest.client import knock, login, room
|
from synapse.rest.client import knock, login, room
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import UserID, create_requester
|
from synapse.types import UserID, create_requester
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
import tests.unittest
|
import tests.unittest
|
||||||
import tests.utils
|
import tests.utils
|
||||||
|
@ -39,7 +42,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
|
||||||
room.register_servlets,
|
room.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs: HomeServer):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.sync_handler = self.hs.get_sync_handler()
|
self.sync_handler = self.hs.get_sync_handler()
|
||||||
self.store = self.hs.get_datastores().main
|
self.store = self.hs.get_datastores().main
|
||||||
|
|
||||||
|
@ -47,7 +50,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
|
||||||
# modify its config instead of the hs'
|
# modify its config instead of the hs'
|
||||||
self.auth_blocking = self.hs.get_auth_blocking()
|
self.auth_blocking = self.hs.get_auth_blocking()
|
||||||
|
|
||||||
def test_wait_for_sync_for_user_auth_blocking(self):
|
def test_wait_for_sync_for_user_auth_blocking(self) -> None:
|
||||||
user_id1 = "@user1:test"
|
user_id1 = "@user1:test"
|
||||||
user_id2 = "@user2:test"
|
user_id2 = "@user2:test"
|
||||||
sync_config = generate_sync_config(user_id1)
|
sync_config = generate_sync_config(user_id1)
|
||||||
|
@ -82,7 +85,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||||
|
|
||||||
def test_unknown_room_version(self):
|
def test_unknown_room_version(self) -> None:
|
||||||
"""
|
"""
|
||||||
A room with an unknown room version should not break sync (and should be excluded).
|
A room with an unknown room version should not break sync (and should be excluded).
|
||||||
"""
|
"""
|
||||||
|
@ -186,7 +189,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
|
||||||
self.assertNotIn(invite_room, [r.room_id for r in result.invited])
|
self.assertNotIn(invite_room, [r.room_id for r in result.invited])
|
||||||
self.assertNotIn(knock_room, [r.room_id for r in result.knocked])
|
self.assertNotIn(knock_room, [r.room_id for r in result.knocked])
|
||||||
|
|
||||||
def test_ban_wins_race_with_join(self):
|
def test_ban_wins_race_with_join(self) -> None:
|
||||||
"""Rooms shouldn't appear under "joined" if a join loses a race to a ban.
|
"""Rooms shouldn't appear under "joined" if a join loses a race to a ban.
|
||||||
|
|
||||||
A complicated edge case. Imagine the following scenario:
|
A complicated edge case. Imagine the following scenario:
|
||||||
|
|
Loading…
Reference in New Issue