From 2ffaf30803f93273a4d8a65c9e6c3110c8433488 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 2 Mar 2022 17:34:14 +0100 Subject: [PATCH] Add type hints to `tests/rest/client` (#12108) * Add type hints to `tests/rest/client` * newsfile * fix imports * add `test_account.py` * Remove one type hint in `test_report_event.py` * change `on_create_room` to `async` * update new functions in `test_third_party_rules.py` * Add `test_filter.py` * add `test_rooms.py` * change to `assertEquals` to `assertEqual` * lint --- changelog.d/12108.misc | 1 + mypy.ini | 6 - tests/rest/client/test_account.py | 286 +++++++++++--------- tests/rest/client/test_filter.py | 29 +- tests/rest/client/test_relations.py | 4 +- tests/rest/client/test_report_event.py | 25 +- tests/rest/client/test_rooms.py | 271 ++++++++++--------- tests/rest/client/test_third_party_rules.py | 108 +++++--- tests/rest/client/test_typing.py | 41 +-- 9 files changed, 421 insertions(+), 350 deletions(-) create mode 100644 changelog.d/12108.misc diff --git a/changelog.d/12108.misc b/changelog.d/12108.misc new file mode 100644 index 0000000000..0360dbd61e --- /dev/null +++ b/changelog.d/12108.misc @@ -0,0 +1 @@ +Add type hints to `tests/rest/client`. diff --git a/mypy.ini b/mypy.ini index 6b1e995e64..23ca4eaa5a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -78,13 +78,7 @@ exclude = (?x) |tests/push/test_http.py |tests/push/test_presentable_names.py |tests/push/test_push_rule_evaluator.py - |tests/rest/client/test_account.py - |tests/rest/client/test_filter.py - |tests/rest/client/test_report_event.py - |tests/rest/client/test_rooms.py - |tests/rest/client/test_third_party_rules.py |tests/rest/client/test_transactions.py - |tests/rest/client/test_typing.py |tests/rest/key/v2/test_remote_key_resource.py |tests/rest/media/v1/test_base.py |tests/rest/media/v1/test_media_storage.py diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 6c4462e74a..def836054d 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -15,11 +15,12 @@ import json import os import re from email.parser import Parser -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from unittest.mock import Mock import pkg_resources +from twisted.internet.interfaces import IReactorTCP from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -30,6 +31,7 @@ from synapse.rest import admin from synapse.rest.client import account, login, register, room from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource from synapse.server import HomeServer +from synapse.types import JsonDict from synapse.util import Clock from tests import unittest @@ -46,7 +48,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() # Email config. @@ -67,20 +69,27 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) async def sendmail( - reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs - ): - self.email_attempts.append(msg) + reactor: IReactorTCP, + smtphost: str, + smtpport: int, + from_addr: str, + to_addr: str, + msg_bytes: bytes, + *args: Any, + **kwargs: Any, + ) -> None: + self.email_attempts.append(msg_bytes) - self.email_attempts = [] + self.email_attempts: List[bytes] = [] hs.get_send_email_handler()._sendmail = sendmail return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.submit_token_resource = PasswordResetSubmitTokenResource(hs) - def test_basic_password_reset(self): + def test_basic_password_reset(self) -> None: """Test basic password reset flow""" old_password = "monkey" new_password = "kangeroo" @@ -118,7 +127,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.attempt_wrong_password_login("kermit", old_password) @override_config({"rc_3pid_validation": {"burst_count": 3}}) - def test_ratelimit_by_email(self): + def test_ratelimit_by_email(self) -> None: """Test that we ratelimit /requestToken for the same email.""" old_password = "monkey" new_password = "kangeroo" @@ -139,7 +148,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): ) ) - def reset(ip): + def reset(ip: str) -> None: client_secret = "foobar" session_id = self._request_token(email, client_secret, ip) @@ -166,7 +175,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.assertEqual(cm.exception.code, 429) - def test_basic_password_reset_canonicalise_email(self): + def test_basic_password_reset_canonicalise_email(self) -> None: """Test basic password reset flow Request password reset with different spelling """ @@ -206,7 +215,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Assert we can't log in with the old password self.attempt_wrong_password_login("kermit", old_password) - def test_cant_reset_password_without_clicking_link(self): + def test_cant_reset_password_without_clicking_link(self) -> None: """Test that we do actually need to click the link in the email""" old_password = "monkey" new_password = "kangeroo" @@ -241,7 +250,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Assert we can't log in with the new password self.attempt_wrong_password_login("kermit", new_password) - def test_no_valid_token(self): + def test_no_valid_token(self) -> None: """Test that we do actually need to request a token and can't just make a session up. """ @@ -277,7 +286,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.attempt_wrong_password_login("kermit", new_password) @unittest.override_config({"request_token_inhibit_3pid_errors": True}) - def test_password_reset_bad_email_inhibit_error(self): + def test_password_reset_bad_email_inhibit_error(self) -> None: """Test that triggering a password reset with an email address that isn't bound to an account doesn't leak the lack of binding for that address if configured that way. @@ -292,7 +301,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.assertIsNotNone(session_id) - def _request_token(self, email, client_secret, ip="127.0.0.1"): + def _request_token( + self, + email: str, + client_secret: str, + ip: str = "127.0.0.1", + ) -> str: channel = self.make_request( "POST", b"account/password/email/requestToken", @@ -309,7 +323,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): return channel.json_body["sid"] - def _validate_token(self, link): + def _validate_token(self, link: str) -> None: # Remove the host path = link.replace("https://example.com", "") @@ -339,7 +353,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): ) self.assertEqual(200, channel.code, channel.result) - def _get_link_from_email(self): + def _get_link_from_email(self) -> str: assert self.email_attempts, "No emails have been sent" raw_msg = self.email_attempts[-1].decode("UTF-8") @@ -354,14 +368,19 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): if not text: self.fail("Could not find text portion of email to parse") + assert text is not None match = re.search(r"https://example.com\S+", text) assert match, "Could not find link in email" return match.group(0) def _reset_password( - self, new_password, session_id, client_secret, expected_code=200 - ): + self, + new_password: str, + session_id: str, + client_secret: str, + expected_code: int = 200, + ) -> None: channel = self.make_request( "POST", b"account/password", @@ -388,11 +407,11 @@ class DeactivateTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver() return self.hs - def test_deactivate_account(self): + def test_deactivate_account(self) -> None: user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") @@ -407,7 +426,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", "account/whoami", access_token=tok) self.assertEqual(channel.code, 401) - def test_pending_invites(self): + def test_pending_invites(self) -> None: """Tests that deactivating a user rejects every pending invite for them.""" store = self.hs.get_datastores().main @@ -448,7 +467,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): self.assertEqual(len(memberships), 1, memberships) self.assertEqual(memberships[0].room_id, room_id, memberships) - def deactivate(self, user_id, tok): + def deactivate(self, user_id: str, tok: str) -> None: request_data = json.dumps( { "auth": { @@ -474,12 +493,12 @@ class WhoamiTestCase(unittest.HomeserverTestCase): register.register_servlets, ] - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["allow_guest_access"] = True return config - def test_GET_whoami(self): + def test_GET_whoami(self) -> None: device_id = "wouldgohere" user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test", device_id=device_id) @@ -496,7 +515,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase): }, ) - def test_GET_whoami_guests(self): + def test_GET_whoami_guests(self) -> None: channel = self.make_request( b"POST", b"/_matrix/client/r0/register?kind=guest", b"{}" ) @@ -516,7 +535,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase): }, ) - def test_GET_whoami_appservices(self): + def test_GET_whoami_appservices(self) -> None: user_id = "@as:test" as_token = "i_am_an_app_service" @@ -541,7 +560,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase): ) self.assertFalse(hasattr(whoami, "device_id")) - def _whoami(self, tok): + def _whoami(self, tok: str) -> JsonDict: channel = self.make_request("GET", "account/whoami", {}, access_token=tok) self.assertEqual(channel.code, 200) return channel.json_body @@ -555,7 +574,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): synapse.rest.admin.register_servlets_for_client_rest_resource, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() # Email config. @@ -576,16 +595,23 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.hs = self.setup_test_homeserver(config=config) async def sendmail( - reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs - ): - self.email_attempts.append(msg) + reactor: IReactorTCP, + smtphost: str, + smtpport: int, + from_addr: str, + to_addr: str, + msg_bytes: bytes, + *args: Any, + **kwargs: Any, + ) -> None: + self.email_attempts.append(msg_bytes) - self.email_attempts = [] + self.email_attempts: List[bytes] = [] self.hs.get_send_email_handler()._sendmail = sendmail return self.hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.user_id = self.register_user("kermit", "test") @@ -593,83 +619,73 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.email = "test@example.com" self.url_3pid = b"account/3pid" - def test_add_valid_email(self): - self.get_success(self._add_email(self.email, self.email)) + def test_add_valid_email(self) -> None: + self._add_email(self.email, self.email) - def test_add_valid_email_second_time(self): - self.get_success(self._add_email(self.email, self.email)) - self.get_success( - self._request_token_invalid_email( - self.email, - expected_errcode=Codes.THREEPID_IN_USE, - expected_error="Email is already in use", - ) + def test_add_valid_email_second_time(self) -> None: + self._add_email(self.email, self.email) + self._request_token_invalid_email( + self.email, + expected_errcode=Codes.THREEPID_IN_USE, + expected_error="Email is already in use", ) - def test_add_valid_email_second_time_canonicalise(self): - self.get_success(self._add_email(self.email, self.email)) - self.get_success( - self._request_token_invalid_email( - "TEST@EXAMPLE.COM", - expected_errcode=Codes.THREEPID_IN_USE, - expected_error="Email is already in use", - ) + def test_add_valid_email_second_time_canonicalise(self) -> None: + self._add_email(self.email, self.email) + self._request_token_invalid_email( + "TEST@EXAMPLE.COM", + expected_errcode=Codes.THREEPID_IN_USE, + expected_error="Email is already in use", ) - def test_add_email_no_at(self): - self.get_success( - self._request_token_invalid_email( - "address-without-at.bar", - expected_errcode=Codes.UNKNOWN, - expected_error="Unable to parse email address", - ) + def test_add_email_no_at(self) -> None: + self._request_token_invalid_email( + "address-without-at.bar", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", ) - def test_add_email_two_at(self): - self.get_success( - self._request_token_invalid_email( - "foo@foo@test.bar", - expected_errcode=Codes.UNKNOWN, - expected_error="Unable to parse email address", - ) + def test_add_email_two_at(self) -> None: + self._request_token_invalid_email( + "foo@foo@test.bar", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", ) - def test_add_email_bad_format(self): - self.get_success( - self._request_token_invalid_email( - "user@bad.example.net@good.example.com", - expected_errcode=Codes.UNKNOWN, - expected_error="Unable to parse email address", - ) + def test_add_email_bad_format(self) -> None: + self._request_token_invalid_email( + "user@bad.example.net@good.example.com", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", ) - def test_add_email_domain_to_lower(self): - self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar")) + def test_add_email_domain_to_lower(self) -> None: + self._add_email("foo@TEST.BAR", "foo@test.bar") - def test_add_email_domain_with_umlaut(self): - self.get_success(self._add_email("foo@Öumlaut.com", "foo@öumlaut.com")) + def test_add_email_domain_with_umlaut(self) -> None: + self._add_email("foo@Öumlaut.com", "foo@öumlaut.com") - def test_add_email_address_casefold(self): - self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com")) + def test_add_email_address_casefold(self) -> None: + self._add_email("Strauß@Example.com", "strauss@example.com") - def test_address_trim(self): - self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar")) + def test_address_trim(self) -> None: + self._add_email(" foo@test.bar ", "foo@test.bar") @override_config({"rc_3pid_validation": {"burst_count": 3}}) - def test_ratelimit_by_ip(self): + def test_ratelimit_by_ip(self) -> None: """Tests that adding emails is ratelimited by IP""" # We expect to be able to set three emails before getting ratelimited. - self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar")) - self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar")) - self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar")) + self._add_email("foo1@test.bar", "foo1@test.bar") + self._add_email("foo2@test.bar", "foo2@test.bar") + self._add_email("foo3@test.bar", "foo3@test.bar") with self.assertRaises(HttpResponseException) as cm: - self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar")) + self._add_email("foo4@test.bar", "foo4@test.bar") self.assertEqual(cm.exception.code, 429) - def test_add_email_if_disabled(self): + def test_add_email_if_disabled(self) -> None: """Test adding email to profile when doing so is disallowed""" self.hs.config.registration.enable_3pid_changes = False @@ -695,7 +711,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user @@ -705,10 +721,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) - def test_delete_email(self): + def test_delete_email(self) -> None: """Test deleting an email from profile""" # Add a threepid self.get_success( @@ -727,7 +743,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): {"medium": "email", "address": self.email}, access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # Get user channel = self.make_request( @@ -736,10 +752,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) - def test_delete_email_if_disabled(self): + def test_delete_email_if_disabled(self) -> None: """Test deleting an email from profile when disallowed""" self.hs.config.registration.enable_3pid_changes = False @@ -761,7 +777,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user @@ -771,11 +787,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) - def test_cant_add_email_without_clicking_link(self): + def test_cant_add_email_without_clicking_link(self) -> None: """Test that we do actually need to click the link in the email""" client_secret = "foobar" session_id = self._request_token(self.email, client_secret) @@ -797,7 +813,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user @@ -807,10 +823,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) - def test_no_valid_token(self): + def test_no_valid_token(self) -> None: """Test that we do actually need to request a token and can't just make a session up. """ @@ -832,7 +848,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user @@ -842,11 +858,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) @override_config({"next_link_domain_whitelist": None}) - def test_next_link(self): + def test_next_link(self) -> None: """Tests a valid next_link parameter value with no whitelist (good case)""" self._request_token( "something@example.com", @@ -856,7 +872,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) @override_config({"next_link_domain_whitelist": None}) - def test_next_link_exotic_protocol(self): + def test_next_link_exotic_protocol(self) -> None: """Tests using a esoteric protocol as a next_link parameter value. Someone may be hosting a client on IPFS etc. """ @@ -868,7 +884,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) @override_config({"next_link_domain_whitelist": None}) - def test_next_link_file_uri(self): + def test_next_link_file_uri(self) -> None: """Tests next_link parameters cannot be file URI""" # Attempt to use a next_link value that points to the local disk self._request_token( @@ -879,7 +895,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) - def test_next_link_domain_whitelist(self): + def test_next_link_domain_whitelist(self) -> None: """Tests next_link parameters must fit the whitelist if provided""" # Ensure not providing a next_link parameter still works @@ -912,7 +928,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) @override_config({"next_link_domain_whitelist": []}) - def test_empty_next_link_domain_whitelist(self): + def test_empty_next_link_domain_whitelist(self) -> None: """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially disallowed """ @@ -962,28 +978,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def _request_token_invalid_email( self, - email, - expected_errcode, - expected_error, - client_secret="foobar", - ): + email: str, + expected_errcode: str, + expected_error: str, + client_secret: str = "foobar", + ) -> None: channel = self.make_request( "POST", b"account/3pid/email/requestToken", {"client_secret": client_secret, "email": email, "send_attempt": 1}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(expected_errcode, channel.json_body["errcode"]) self.assertEqual(expected_error, channel.json_body["error"]) - def _validate_token(self, link): + def _validate_token(self, link: str) -> None: # Remove the host path = link.replace("https://example.com", "") channel = self.make_request("GET", path, shorthand=False) self.assertEqual(200, channel.code, channel.result) - def _get_link_from_email(self): + def _get_link_from_email(self) -> str: assert self.email_attempts, "No emails have been sent" raw_msg = self.email_attempts[-1].decode("UTF-8") @@ -998,12 +1014,13 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): if not text: self.fail("Could not find text portion of email to parse") + assert text is not None match = re.search(r"https://example.com\S+", text) assert match, "Could not find link in email" return match.group(0) - def _add_email(self, request_email, expected_email): + def _add_email(self, request_email: str, expected_email: str) -> None: """Test adding an email to profile""" previous_email_attempts = len(self.email_attempts) @@ -1030,7 +1047,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # Get user channel = self.make_request( @@ -1039,7 +1056,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} @@ -1055,18 +1072,18 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): url = "/_matrix/client/unstable/org.matrix.msc3720/account_status" - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["experimental_features"] = {"msc3720_enabled": True} return self.setup_test_homeserver(config=config) - def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.requester = self.register_user("requester", "password") self.requester_tok = self.login("requester", "password") - self.server_name = homeserver.config.server.server_name + self.server_name = hs.config.server.server_name - def test_missing_mxid(self): + def test_missing_mxid(self) -> None: """Tests that not providing any MXID raises an error.""" self._test_status( users=None, @@ -1074,7 +1091,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): expected_errcode=Codes.MISSING_PARAM, ) - def test_invalid_mxid(self): + def test_invalid_mxid(self) -> None: """Tests that providing an invalid MXID raises an error.""" self._test_status( users=["bad:test"], @@ -1082,7 +1099,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): expected_errcode=Codes.INVALID_PARAM, ) - def test_local_user_not_exists(self): + def test_local_user_not_exists(self) -> None: """Tests that the account status endpoints correctly reports that a user doesn't exist. """ @@ -1098,7 +1115,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): expected_failures=[], ) - def test_local_user_exists(self): + def test_local_user_exists(self) -> None: """Tests that the account status endpoint correctly reports that a user doesn't exist. """ @@ -1115,7 +1132,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): expected_failures=[], ) - def test_local_user_deactivated(self): + def test_local_user_deactivated(self) -> None: """Tests that the account status endpoint correctly reports a deactivated user.""" user = self.register_user("someuser", "password") self.get_success( @@ -1135,7 +1152,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): expected_failures=[], ) - def test_mixed_local_and_remote_users(self): + def test_mixed_local_and_remote_users(self) -> None: """Tests that if some users are remote the account status endpoint correctly merges the remote responses with the local result. """ @@ -1150,7 +1167,13 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): "@bad:badremote", ] - async def post_json(destination, path, data, *a, **kwa): + async def post_json( + destination: str, + path: str, + data: Optional[JsonDict] = None, + *a: Any, + **kwa: Any, + ) -> Union[JsonDict, list]: if destination == "remote": return { "account_statuses": { @@ -1160,9 +1183,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): }, } } - if destination == "otherremote": - return {} - if destination == "badremote": + elif destination == "badremote": # badremote tries to overwrite the status of a user that doesn't belong # to it (i.e. users[1]) with false data, which Synapse is expected to # ignore. @@ -1176,6 +1197,9 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): }, } } + # if destination == "otherremote" + else: + return {} # Register a mock that will return the expected result depending on the remote. self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) @@ -1205,7 +1229,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None, expected_failures: Optional[List[str]] = None, expected_errcode: Optional[str] = None, - ): + ) -> None: """Send a request to the account status endpoint and check that the response matches with what's expected. diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index 5c31a54421..823e8ab8c4 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import Codes from synapse.rest.client import filter +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -30,11 +32,11 @@ class FilterTestCase(unittest.HomeserverTestCase): EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}' servlets = [filter.register_servlets] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.filtering = hs.get_filtering() self.store = hs.get_datastores().main - def test_add_filter(self): + def test_add_filter(self) -> None: channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % (self.user_id), @@ -43,11 +45,13 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.json_body, {"filter_id": "0"}) - filter = self.store.get_user_filter(user_localpart="apple", filter_id=0) + filter = self.get_success( + self.store.get_user_filter(user_localpart="apple", filter_id=0) + ) self.pump() - self.assertEqual(filter.result, self.EXAMPLE_FILTER) + self.assertEqual(filter, self.EXAMPLE_FILTER) - def test_add_filter_for_other_user(self): + def test_add_filter_for_other_user(self) -> None: channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"), @@ -57,7 +61,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) - def test_add_filter_non_local_user(self): + def test_add_filter_non_local_user(self) -> None: _is_mine = self.hs.is_mine self.hs.is_mine = lambda target_user: False channel = self.make_request( @@ -70,14 +74,13 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) - def test_get_filter(self): - filter_id = defer.ensureDeferred( + def test_get_filter(self) -> None: + filter_id = self.get_success( self.filtering.add_user_filter( user_localpart="apple", user_filter=self.EXAMPLE_FILTER ) ) self.reactor.advance(1) - filter_id = filter_id.result channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id) ) @@ -85,7 +88,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.json_body, self.EXAMPLE_FILTER) - def test_get_filter_non_existant(self): + def test_get_filter_non_existant(self) -> None: channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id) ) @@ -95,7 +98,7 @@ class FilterTestCase(unittest.HomeserverTestCase): # Currently invalid params do not have an appropriate errcode # in errors.py - def test_get_filter_invalid_id(self): + def test_get_filter_invalid_id(self) -> None: channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id) ) @@ -103,7 +106,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"400") # No ID also returns an invalid_id error - def test_get_filter_no_id(self): + def test_get_filter_no_id(self) -> None: channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id) ) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index a087cd7b21..709f851a38 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -15,7 +15,7 @@ import itertools import urllib.parse -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from unittest.mock import patch from twisted.test.proto_helpers import MemoryReactor @@ -45,7 +45,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): ] hijack_auth = False - def default_config(self) -> dict: + def default_config(self) -> Dict[str, Any]: # We need to enable msc1849 support for aggregations config = super().default_config() diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py index ee6b0b9ebf..20a259fc43 100644 --- a/tests/rest/client/test_report_event.py +++ b/tests/rest/client/test_report_event.py @@ -14,8 +14,13 @@ import json +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.rest.client import login, report_event, room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest @@ -28,7 +33,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase): report_event.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") self.other_user = self.register_user("user", "pass") @@ -42,35 +47,35 @@ class ReportEventTestCase(unittest.HomeserverTestCase): self.event_id = resp["event_id"] self.report_path = f"rooms/{self.room_id}/report/{self.event_id}" - def test_reason_str_and_score_int(self): + def test_reason_str_and_score_int(self) -> None: data = {"reason": "this makes me sad", "score": -100} self._assert_status(200, data) - def test_no_reason(self): + def test_no_reason(self) -> None: data = {"score": 0} self._assert_status(200, data) - def test_no_score(self): + def test_no_score(self) -> None: data = {"reason": "this makes me sad"} self._assert_status(200, data) - def test_no_reason_and_no_score(self): - data = {} + def test_no_reason_and_no_score(self) -> None: + data: JsonDict = {} self._assert_status(200, data) - def test_reason_int_and_score_str(self): + def test_reason_int_and_score_str(self) -> None: data = {"reason": 10, "score": "string"} self._assert_status(400, data) - def test_reason_zero_and_score_blank(self): + def test_reason_zero_and_score_blank(self) -> None: data = {"reason": 0, "score": ""} self._assert_status(400, data) - def test_reason_and_score_null(self): + def test_reason_and_score_null(self) -> None: data = {"reason": None, "score": None} self._assert_status(400, data) - def _assert_status(self, response_status, data): + def _assert_status(self, response_status: int, data: JsonDict) -> None: channel = self.make_request( "POST", self.report_path, diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index e0b11e7264..37866ee330 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -18,11 +18,12 @@ """Tests REST events for /rooms paths.""" import json -from typing import Iterable, List +from typing import Any, Dict, Iterable, List, Optional from unittest.mock import Mock, call from urllib import parse as urlparse from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.constants import ( @@ -35,7 +36,9 @@ from synapse.api.errors import Codes, HttpResponseException from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client import account, directory, login, profile, room, sync +from synapse.server import HomeServer from synapse.types import JsonDict, RoomAlias, UserID, create_requester +from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest @@ -45,11 +48,11 @@ PATH_PREFIX = b"/_matrix/client/api/v1" class RoomBase(unittest.HomeserverTestCase): - rmcreator_id = None + rmcreator_id: Optional[str] = None servlets = [room.register_servlets, room.register_deprecated_servlets] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver( "red", @@ -57,15 +60,15 @@ class RoomBase(unittest.HomeserverTestCase): federation_client=Mock(), ) - self.hs.get_federation_handler = Mock() + self.hs.get_federation_handler = Mock() # type: ignore[assignment] self.hs.get_federation_handler.return_value.maybe_backfill = Mock( return_value=make_awaitable(None) ) - async def _insert_client_ip(*args, **kwargs): + async def _insert_client_ip(*args: Any, **kwargs: Any) -> None: return None - self.hs.get_datastores().main.insert_client_ip = _insert_client_ip + self.hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[assignment] return self.hs @@ -76,7 +79,7 @@ class RoomPermissionsTestCase(RoomBase): user_id = "@sid1:red" rmcreator_id = "@notme:red" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.helper.auth_user_id = self.rmcreator_id # create some rooms under the name rmcreator_id @@ -108,12 +111,12 @@ class RoomPermissionsTestCase(RoomBase): # auth as user_id now self.helper.auth_user_id = self.user_id - def test_can_do_action(self): + def test_can_do_action(self) -> None: msg_content = b'{"msgtype":"m.text","body":"hello"}' seq = iter(range(100)) - def send_msg_path(): + def send_msg_path() -> str: return "/rooms/%s/send/m.room.message/mid%s" % ( self.created_rmid, str(next(seq)), @@ -148,7 +151,7 @@ class RoomPermissionsTestCase(RoomBase): channel = self.make_request("PUT", send_msg_path(), msg_content) self.assertEqual(403, channel.code, msg=channel.result["body"]) - def test_topic_perms(self): + def test_topic_perms(self) -> None: topic_content = b'{"topic":"My Topic Name"}' topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid @@ -214,14 +217,14 @@ class RoomPermissionsTestCase(RoomBase): self.assertEqual(403, channel.code, msg=channel.result["body"]) def _test_get_membership( - self, room=None, members: Iterable = frozenset(), expect_code=None - ): + self, room: str, members: Iterable = frozenset(), expect_code: int = 200 + ) -> None: for member in members: path = "/rooms/%s/state/m.room.member/%s" % (room, member) channel = self.make_request("GET", path) self.assertEqual(expect_code, channel.code) - def test_membership_basic_room_perms(self): + def test_membership_basic_room_perms(self) -> None: # === room does not exist === room = self.uncreated_rmid # get membership of self, get membership of other, uncreated room @@ -241,7 +244,7 @@ class RoomPermissionsTestCase(RoomBase): self.helper.join(room=room, user=usr, expect_code=404) self.helper.leave(room=room, user=usr, expect_code=404) - def test_membership_private_room_perms(self): + def test_membership_private_room_perms(self) -> None: room = self.created_rmid # get membership of self, get membership of other, private room + invite # expect all 403s @@ -264,7 +267,7 @@ class RoomPermissionsTestCase(RoomBase): members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 ) - def test_membership_public_room_perms(self): + def test_membership_public_room_perms(self) -> None: room = self.created_public_rmid # get membership of self, get membership of other, public room + invite # expect 403 @@ -287,7 +290,7 @@ class RoomPermissionsTestCase(RoomBase): members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 ) - def test_invited_permissions(self): + def test_invited_permissions(self) -> None: room = self.created_rmid self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) @@ -310,7 +313,7 @@ class RoomPermissionsTestCase(RoomBase): expect_code=403, ) - def test_joined_permissions(self): + def test_joined_permissions(self) -> None: room = self.created_rmid self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) self.helper.join(room=room, user=self.user_id) @@ -348,7 +351,7 @@ class RoomPermissionsTestCase(RoomBase): # set left of self, expect 200 self.helper.leave(room=room, user=self.user_id) - def test_leave_permissions(self): + def test_leave_permissions(self) -> None: room = self.created_rmid self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) self.helper.join(room=room, user=self.user_id) @@ -383,7 +386,7 @@ class RoomPermissionsTestCase(RoomBase): ) # tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember - def test_member_event_from_ban(self): + def test_member_event_from_ban(self) -> None: room = self.created_rmid self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) self.helper.join(room=room, user=self.user_id) @@ -475,21 +478,21 @@ class RoomsMemberListTestCase(RoomBase): user_id = "@sid1:red" - def test_get_member_list(self): + def test_get_member_list(self) -> None: room_id = self.helper.create_room_as(self.user_id) channel = self.make_request("GET", "/rooms/%s/members" % room_id) self.assertEqual(200, channel.code, msg=channel.result["body"]) - def test_get_member_list_no_room(self): + def test_get_member_list_no_room(self) -> None: channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") self.assertEqual(403, channel.code, msg=channel.result["body"]) - def test_get_member_list_no_permission(self): + def test_get_member_list_no_permission(self) -> None: room_id = self.helper.create_room_as("@some_other_guy:red") channel = self.make_request("GET", "/rooms/%s/members" % room_id) self.assertEqual(403, channel.code, msg=channel.result["body"]) - def test_get_member_list_no_permission_with_at_token(self): + def test_get_member_list_no_permission_with_at_token(self) -> None: """ Tests that a stranger to the room cannot get the member list (in the case that they use an at token). @@ -509,7 +512,7 @@ class RoomsMemberListTestCase(RoomBase): ) self.assertEqual(403, channel.code, msg=channel.result["body"]) - def test_get_member_list_no_permission_former_member(self): + def test_get_member_list_no_permission_former_member(self) -> None: """ Tests that a former member of the room can not get the member list. """ @@ -529,7 +532,7 @@ class RoomsMemberListTestCase(RoomBase): channel = self.make_request("GET", "/rooms/%s/members" % room_id) self.assertEqual(403, channel.code, msg=channel.result["body"]) - def test_get_member_list_no_permission_former_member_with_at_token(self): + def test_get_member_list_no_permission_former_member_with_at_token(self) -> None: """ Tests that a former member of the room can not get the member list (in the case that they use an at token). @@ -569,7 +572,7 @@ class RoomsMemberListTestCase(RoomBase): ) self.assertEqual(403, channel.code, msg=channel.result["body"]) - def test_get_member_list_mixed_memberships(self): + def test_get_member_list_mixed_memberships(self) -> None: room_creator = "@some_other_guy:red" room_id = self.helper.create_room_as(room_creator) room_path = "/rooms/%s/members" % room_id @@ -594,26 +597,26 @@ class RoomsCreateTestCase(RoomBase): user_id = "@sid1:red" - def test_post_room_no_keys(self): + def test_post_room_no_keys(self) -> None: # POST with no config keys, expect new room id channel = self.make_request("POST", "/createRoom", "{}") self.assertEqual(200, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) - def test_post_room_visibility_key(self): + def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') self.assertEqual(200, channel.code) self.assertTrue("room_id" in channel.json_body) - def test_post_room_custom_key(self): + def test_post_room_custom_key(self) -> None: # POST with custom config keys, expect new room id channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') self.assertEqual(200, channel.code) self.assertTrue("room_id" in channel.json_body) - def test_post_room_known_and_unknown_keys(self): + def test_post_room_known_and_unknown_keys(self) -> None: # POST with custom + known config keys, expect new room id channel = self.make_request( "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' @@ -621,7 +624,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(200, channel.code) self.assertTrue("room_id" in channel.json_body) - def test_post_room_invalid_content(self): + def test_post_room_invalid_content(self) -> None: # POST with invalid content / paths, expect 400 channel = self.make_request("POST", "/createRoom", b'{"visibili') self.assertEqual(400, channel.code) @@ -629,7 +632,7 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request("POST", "/createRoom", b'["hello"]') self.assertEqual(400, channel.code) - def test_post_room_invitees_invalid_mxid(self): + def test_post_room_invitees_invalid_mxid(self) -> None: # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 # Note the trailing space in the MXID here! channel = self.make_request( @@ -638,7 +641,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(400, channel.code) @unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}}) - def test_post_room_invitees_ratelimit(self): + def test_post_room_invitees_ratelimit(self) -> None: """Test that invites sent when creating a room are ratelimited by a RateLimiter, which ratelimits them correctly, including by not limiting when the requester is exempt from ratelimiting. @@ -674,7 +677,7 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request("POST", "/createRoom", content) self.assertEqual(200, channel.code) - def test_spam_checker_may_join_room(self): + def test_spam_checker_may_join_room(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly bypassed when creating a new room. """ @@ -704,12 +707,12 @@ class RoomTopicTestCase(RoomBase): user_id = "@sid1:red" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # create the room self.room_id = self.helper.create_room_as(self.user_id) self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,) - def test_invalid_puts(self): + def test_invalid_puts(self) -> None: # missing keys or invalid json channel = self.make_request("PUT", self.path, "{}") self.assertEqual(400, channel.code, msg=channel.result["body"]) @@ -736,7 +739,7 @@ class RoomTopicTestCase(RoomBase): channel = self.make_request("PUT", self.path, content) self.assertEqual(400, channel.code, msg=channel.result["body"]) - def test_rooms_topic(self): + def test_rooms_topic(self) -> None: # nothing should be there channel = self.make_request("GET", self.path) self.assertEqual(404, channel.code, msg=channel.result["body"]) @@ -751,7 +754,7 @@ class RoomTopicTestCase(RoomBase): self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) - def test_rooms_topic_with_extra_keys(self): + def test_rooms_topic_with_extra_keys(self) -> None: # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' channel = self.make_request("PUT", self.path, content) @@ -768,10 +771,10 @@ class RoomMemberStateTestCase(RoomBase): user_id = "@sid1:red" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) - def test_invalid_puts(self): + def test_invalid_puts(self) -> None: path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json channel = self.make_request("PUT", path, "{}") @@ -801,7 +804,7 @@ class RoomMemberStateTestCase(RoomBase): channel = self.make_request("PUT", path, content.encode("ascii")) self.assertEqual(400, channel.code, msg=channel.result["body"]) - def test_rooms_members_self(self): + def test_rooms_members_self(self) -> None: path = "/rooms/%s/state/m.room.member/%s" % ( urlparse.quote(self.room_id), self.user_id, @@ -812,13 +815,13 @@ class RoomMemberStateTestCase(RoomBase): channel = self.make_request("PUT", path, content.encode("ascii")) self.assertEqual(200, channel.code, msg=channel.result["body"]) - channel = self.make_request("GET", path, None) + channel = self.make_request("GET", path, content=b"") self.assertEqual(200, channel.code, msg=channel.result["body"]) expected_response = {"membership": Membership.JOIN} self.assertEqual(expected_response, channel.json_body) - def test_rooms_members_other(self): + def test_rooms_members_other(self) -> None: self.other_id = "@zzsid1:red" path = "/rooms/%s/state/m.room.member/%s" % ( urlparse.quote(self.room_id), @@ -830,11 +833,11 @@ class RoomMemberStateTestCase(RoomBase): channel = self.make_request("PUT", path, content) self.assertEqual(200, channel.code, msg=channel.result["body"]) - channel = self.make_request("GET", path, None) + channel = self.make_request("GET", path, content=b"") self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(json.loads(content), channel.json_body) - def test_rooms_members_other_custom_keys(self): + def test_rooms_members_other_custom_keys(self) -> None: self.other_id = "@zzsid1:red" path = "/rooms/%s/state/m.room.member/%s" % ( urlparse.quote(self.room_id), @@ -849,7 +852,7 @@ class RoomMemberStateTestCase(RoomBase): channel = self.make_request("PUT", path, content) self.assertEqual(200, channel.code, msg=channel.result["body"]) - channel = self.make_request("GET", path, None) + channel = self.make_request("GET", path, content=b"") self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(json.loads(content), channel.json_body) @@ -866,7 +869,7 @@ class RoomInviteRatelimitTestCase(RoomBase): @unittest.override_config( {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}} ) - def test_invites_by_rooms_ratelimit(self): + def test_invites_by_rooms_ratelimit(self) -> None: """Tests that invites in a room are actually rate-limited.""" room_id = self.helper.create_room_as(self.user_id) @@ -878,7 +881,7 @@ class RoomInviteRatelimitTestCase(RoomBase): @unittest.override_config( {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} ) - def test_invites_by_users_ratelimit(self): + def test_invites_by_users_ratelimit(self) -> None: """Tests that invites to a specific user are actually rate-limited.""" for _ in range(3): @@ -897,7 +900,7 @@ class RoomJoinTestCase(RoomBase): room.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user1 = self.register_user("thomas", "hackme") self.tok1 = self.login("thomas", "hackme") @@ -908,7 +911,7 @@ class RoomJoinTestCase(RoomBase): self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) - def test_spam_checker_may_join_room(self): + def test_spam_checker_may_join_room(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly called and blocks room joins when needed. """ @@ -975,8 +978,8 @@ class RoomJoinRatelimitTestCase(RoomBase): room.register_servlets, ] - def prepare(self, reactor, clock, homeserver): - super().prepare(reactor, clock, homeserver) + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) # profile changes expect that the user is actually registered user = UserID.from_string(self.user_id) self.get_success(self.register_user(user.localpart, "supersecretpassword")) @@ -984,7 +987,7 @@ class RoomJoinRatelimitTestCase(RoomBase): @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} ) - def test_join_local_ratelimit(self): + def test_join_local_ratelimit(self) -> None: """Tests that local joins are actually rate-limited.""" for _ in range(3): self.helper.create_room_as(self.user_id) @@ -994,7 +997,7 @@ class RoomJoinRatelimitTestCase(RoomBase): @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} ) - def test_join_local_ratelimit_profile_change(self): + def test_join_local_ratelimit_profile_change(self) -> None: """Tests that sending a profile update into all of the user's joined rooms isn't rate-limited by the rate-limiter on joins.""" @@ -1031,7 +1034,7 @@ class RoomJoinRatelimitTestCase(RoomBase): @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} ) - def test_join_local_ratelimit_idempotent(self): + def test_join_local_ratelimit_idempotent(self) -> None: """Tests that the room join endpoints remain idempotent despite rate-limiting on room joins.""" room_id = self.helper.create_room_as(self.user_id) @@ -1056,7 +1059,7 @@ class RoomJoinRatelimitTestCase(RoomBase): "autocreate_auto_join_rooms": True, }, ) - def test_autojoin_rooms(self): + def test_autojoin_rooms(self) -> None: user_id = self.register_user("testuser", "password") # Check that the new user successfully joined the four rooms @@ -1071,10 +1074,10 @@ class RoomMessagesTestCase(RoomBase): user_id = "@sid1:red" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) - def test_invalid_puts(self): + def test_invalid_puts(self) -> None: path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json channel = self.make_request("PUT", path, b"{}") @@ -1095,7 +1098,7 @@ class RoomMessagesTestCase(RoomBase): channel = self.make_request("PUT", path, b"") self.assertEqual(400, channel.code, msg=channel.result["body"]) - def test_rooms_messages_sent(self): + def test_rooms_messages_sent(self) -> None: path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) content = b'{"body":"test","msgtype":{"type":"a"}}' @@ -1119,11 +1122,11 @@ class RoomInitialSyncTestCase(RoomBase): user_id = "@sid1:red" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # create the room self.room_id = self.helper.create_room_as(self.user_id) - def test_initial_sync(self): + def test_initial_sync(self) -> None: channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) self.assertEqual(200, channel.code) @@ -1131,7 +1134,7 @@ class RoomInitialSyncTestCase(RoomBase): self.assertEqual("join", channel.json_body["membership"]) # Room state is easier to assert on if we unpack it into a dict - state = {} + state: JsonDict = {} for event in channel.json_body["state"]: if "state_key" not in event: continue @@ -1160,10 +1163,10 @@ class RoomMessageListTestCase(RoomBase): user_id = "@sid1:red" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) - def test_topo_token_is_accepted(self): + def test_topo_token_is_accepted(self) -> None: token = "t1-0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) @@ -1174,7 +1177,7 @@ class RoomMessageListTestCase(RoomBase): self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) - def test_stream_token_is_accepted_for_fwd_pagianation(self): + def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None: token = "s0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) @@ -1185,7 +1188,7 @@ class RoomMessageListTestCase(RoomBase): self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) - def test_room_messages_purge(self): + def test_room_messages_purge(self) -> None: store = self.hs.get_datastores().main pagination_handler = self.hs.get_pagination_handler() @@ -1278,10 +1281,10 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): user_id = True hijack_auth = False - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Register the user who does the searching - self.user_id = self.register_user("user", "pass") + self.user_id2 = self.register_user("user", "pass") self.access_token = self.login("user", "pass") # Register the user who sends the message @@ -1289,12 +1292,12 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): self.other_access_token = self.login("otheruser", "pass") # Create a room - self.room = self.helper.create_room_as(self.user_id, tok=self.access_token) + self.room = self.helper.create_room_as(self.user_id2, tok=self.access_token) # Invite the other person self.helper.invite( room=self.room, - src=self.user_id, + src=self.user_id2, tok=self.access_token, targ=self.other_user_id, ) @@ -1304,7 +1307,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): room=self.room, user=self.other_user_id, tok=self.other_access_token ) - def test_finds_message(self): + def test_finds_message(self) -> None: """ The search functionality will search for content in messages if asked to do so. @@ -1333,7 +1336,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): # No context was requested, so we should get none. self.assertEqual(results["results"][0]["context"], {}) - def test_include_context(self): + def test_include_context(self) -> None: """ When event_context includes include_profile, profile information will be included in the search response. @@ -1379,7 +1382,7 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.url = b"/_matrix/client/r0/publicRooms" @@ -1389,11 +1392,11 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): return self.hs - def test_restricted_no_auth(self): + def test_restricted_no_auth(self) -> None: channel = self.make_request("GET", self.url) self.assertEqual(channel.code, 401, channel.result) - def test_restricted_auth(self): + def test_restricted_auth(self) -> None: self.register_user("user", "pass") tok = self.login("user", "pass") @@ -1412,19 +1415,19 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: return self.setup_test_homeserver(federation_client=Mock()) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.register_user("user", "pass") self.token = self.login("user", "pass") self.federation_client = hs.get_federation_client() - def test_simple(self): + def test_simple(self) -> None: "Simple test for searching rooms over federation" - self.federation_client.get_public_rooms.side_effect = ( - lambda *a, **k: defer.succeed({}) + self.federation_client.get_public_rooms.side_effect = lambda *a, **k: defer.succeed( # type: ignore[attr-defined] + {} ) search_filter = {"generic_search_term": "foobar"} @@ -1437,7 +1440,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 200, channel.result) - self.federation_client.get_public_rooms.assert_called_once_with( + self.federation_client.get_public_rooms.assert_called_once_with( # type: ignore[attr-defined] "testserv", limit=100, since_token=None, @@ -1446,12 +1449,12 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): third_party_instance_id=None, ) - def test_fallback(self): + def test_fallback(self) -> None: "Test that searching public rooms over federation falls back if it gets a 404" # The `get_public_rooms` should be called again if the first call fails # with a 404, when using search filters. - self.federation_client.get_public_rooms.side_effect = ( + self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined] HttpResponseException(404, "Not Found", b""), defer.succeed({}), ) @@ -1466,7 +1469,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 200, channel.result) - self.federation_client.get_public_rooms.assert_has_calls( + self.federation_client.get_public_rooms.assert_has_calls( # type: ignore[attr-defined] [ call( "testserv", @@ -1497,14 +1500,14 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): profile.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["allow_per_room_profiles"] = False self.hs = self.setup_test_homeserver(config=config) return self.hs - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("test", "test") self.tok = self.login("test", "test") @@ -1522,7 +1525,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - def test_per_room_profile_forbidden(self): + def test_per_room_profile_forbidden(self) -> None: data = {"membership": "join", "displayname": "other test user"} request_data = json.dumps(data) channel = self.make_request( @@ -1557,7 +1560,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.creator = self.register_user("creator", "test") self.creator_tok = self.login("creator", "test") @@ -1566,7 +1569,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) - def test_join_reason(self): + def test_join_reason(self) -> None: reason = "hello" channel = self.make_request( "POST", @@ -1578,7 +1581,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self._check_for_reason(reason) - def test_leave_reason(self): + def test_leave_reason(self) -> None: self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) reason = "hello" @@ -1592,7 +1595,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self._check_for_reason(reason) - def test_kick_reason(self): + def test_kick_reason(self) -> None: self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) reason = "hello" @@ -1606,7 +1609,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self._check_for_reason(reason) - def test_ban_reason(self): + def test_ban_reason(self) -> None: self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) reason = "hello" @@ -1620,7 +1623,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self._check_for_reason(reason) - def test_unban_reason(self): + def test_unban_reason(self) -> None: reason = "hello" channel = self.make_request( "POST", @@ -1632,7 +1635,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self._check_for_reason(reason) - def test_invite_reason(self): + def test_invite_reason(self) -> None: reason = "hello" channel = self.make_request( "POST", @@ -1644,7 +1647,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self._check_for_reason(reason) - def test_reject_invite_reason(self): + def test_reject_invite_reason(self) -> None: self.helper.invite( self.room_id, src=self.creator, @@ -1663,7 +1666,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self._check_for_reason(reason) - def _check_for_reason(self, reason): + def _check_for_reason(self, reason: str) -> None: channel = self.make_request( "GET", "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format( @@ -1704,12 +1707,12 @@ class LabelsTestCase(unittest.HomeserverTestCase): "org.matrix.not_labels": ["#notfun"], } - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("test", "test") self.tok = self.login("test", "test") self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - def test_context_filter_labels(self): + def test_context_filter_labels(self) -> None: """Test that we can filter by a label on a /context request.""" event_id = self._send_labelled_messages_in_room() @@ -1739,7 +1742,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): events_after[0]["content"]["body"], "with right label", events_after[0] ) - def test_context_filter_not_labels(self): + def test_context_filter_not_labels(self) -> None: """Test that we can filter by the absence of a label on a /context request.""" event_id = self._send_labelled_messages_in_room() @@ -1772,7 +1775,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): events_after[1]["content"]["body"], "with two wrong labels", events_after[1] ) - def test_context_filter_labels_not_labels(self): + def test_context_filter_labels_not_labels(self) -> None: """Test that we can filter by both a label and the absence of another label on a /context request. """ @@ -1801,7 +1804,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): events_after[0]["content"]["body"], "with wrong label", events_after[0] ) - def test_messages_filter_labels(self): + def test_messages_filter_labels(self) -> None: """Test that we can filter by a label on a /messages request.""" self._send_labelled_messages_in_room() @@ -1818,7 +1821,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) - def test_messages_filter_not_labels(self): + def test_messages_filter_not_labels(self) -> None: """Test that we can filter by the absence of a label on a /messages request.""" self._send_labelled_messages_in_room() @@ -1839,7 +1842,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): events[3]["content"]["body"], "with two wrong labels", events[3] ) - def test_messages_filter_labels_not_labels(self): + def test_messages_filter_labels_not_labels(self) -> None: """Test that we can filter by both a label and the absence of another label on a /messages request. """ @@ -1862,7 +1865,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): self.assertEqual(len(events), 1, [event["content"] for event in events]) self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) - def test_search_filter_labels(self): + def test_search_filter_labels(self) -> None: """Test that we can filter by a label on a /search request.""" request_data = json.dumps( { @@ -1899,7 +1902,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): results[1]["result"]["content"]["body"], ) - def test_search_filter_not_labels(self): + def test_search_filter_not_labels(self) -> None: """Test that we can filter by the absence of a label on a /search request.""" request_data = json.dumps( { @@ -1946,7 +1949,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): results[3]["result"]["content"]["body"], ) - def test_search_filter_labels_not_labels(self): + def test_search_filter_labels_not_labels(self) -> None: """Test that we can filter by both a label and the absence of another label on a /search request. """ @@ -1980,7 +1983,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): results[0]["result"]["content"]["body"], ) - def _send_labelled_messages_in_room(self): + def _send_labelled_messages_in_room(self) -> str: """Sends several messages to a room with different labels (or without any) to test filtering by label. Returns: @@ -2056,12 +2059,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["experimental_features"] = {"msc3440_enabled": True} return config - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("test", "test") self.tok = self.login("test", "test") self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) @@ -2136,7 +2139,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): return channel.json_body["chunk"] - def test_filter_relation_senders(self): + def test_filter_relation_senders(self) -> None: # Messages which second user reacted to. filter = {"io.element.relation_senders": [self.second_user_id]} chunk = self._filter_messages(filter) @@ -2159,7 +2162,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] ) - def test_filter_relation_type(self): + def test_filter_relation_type(self) -> None: # Messages which have annotations. filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) @@ -2185,7 +2188,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] ) - def test_filter_relation_senders_and_type(self): + def test_filter_relation_senders_and_type(self) -> None: # Messages which second user reacted to. filter = { "io.element.relation_senders": [self.second_user_id], @@ -2205,7 +2208,7 @@ class ContextTestCase(unittest.HomeserverTestCase): account.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("user", "password") self.tok = self.login("user", "password") self.room_id = self.helper.create_room_as( @@ -2218,7 +2221,7 @@ class ContextTestCase(unittest.HomeserverTestCase): self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok) self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok) - def test_erased_sender(self): + def test_erased_sender(self) -> None: """Test that an erasure request results in the requester's events being hidden from any new member of the room. """ @@ -2332,7 +2335,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_owner = self.register_user("room_owner", "test") self.room_owner_tok = self.login("room_owner", "test") @@ -2340,17 +2343,17 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): self.room_owner, tok=self.room_owner_tok ) - def test_no_aliases(self): + def test_no_aliases(self) -> None: res = self._get_aliases(self.room_owner_tok) self.assertEqual(res["aliases"], []) - def test_not_in_room(self): + def test_not_in_room(self) -> None: self.register_user("user", "test") user_tok = self.login("user", "test") res = self._get_aliases(user_tok, expected_code=403) self.assertEqual(res["errcode"], "M_FORBIDDEN") - def test_admin_user(self): + def test_admin_user(self) -> None: alias1 = self._random_alias() self._set_alias_via_directory(alias1) @@ -2360,7 +2363,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): res = self._get_aliases(user_tok) self.assertEqual(res["aliases"], [alias1]) - def test_with_aliases(self): + def test_with_aliases(self) -> None: alias1 = self._random_alias() alias2 = self._random_alias() @@ -2370,7 +2373,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): res = self._get_aliases(self.room_owner_tok) self.assertEqual(set(res["aliases"]), {alias1, alias2}) - def test_peekable_room(self): + def test_peekable_room(self) -> None: alias1 = self._random_alias() self._set_alias_via_directory(alias1) @@ -2404,7 +2407,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): def _random_alias(self) -> str: return RoomAlias(random_string(5), self.hs.hostname).to_string() - def _set_alias_via_directory(self, alias: str, expected_code: int = 200): + def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None: url = "/_matrix/client/r0/directory/room/" + alias data = {"room_id": self.room_id} request_data = json.dumps(data) @@ -2423,7 +2426,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_owner = self.register_user("room_owner", "test") self.room_owner_tok = self.login("room_owner", "test") @@ -2434,7 +2437,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): self.alias = "#alias:test" self._set_alias_via_directory(self.alias) - def _set_alias_via_directory(self, alias: str, expected_code: int = 200): + def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None: url = "/_matrix/client/r0/directory/room/" + alias data = {"room_id": self.room_id} request_data = json.dumps(data) @@ -2456,7 +2459,9 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): self.assertIsInstance(res, dict) return res - def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict: + def _set_canonical_alias( + self, content: JsonDict, expected_code: int = 200 + ) -> JsonDict: """Calls the endpoint under test. returns the json response object.""" channel = self.make_request( "PUT", @@ -2469,7 +2474,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): self.assertIsInstance(res, dict) return res - def test_canonical_alias(self): + def test_canonical_alias(self) -> None: """Test a basic alias message.""" # There is no canonical alias to start with. self._get_canonical_alias(expected_code=404) @@ -2488,7 +2493,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): res = self._get_canonical_alias() self.assertEqual(res, {}) - def test_alt_aliases(self): + def test_alt_aliases(self) -> None: """Test a canonical alias message with alt_aliases.""" # Create an alias. self._set_canonical_alias({"alt_aliases": [self.alias]}) @@ -2504,7 +2509,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): res = self._get_canonical_alias() self.assertEqual(res, {}) - def test_alias_alt_aliases(self): + def test_alias_alt_aliases(self) -> None: """Test a canonical alias message with an alias and alt_aliases.""" # Create an alias. self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) @@ -2520,7 +2525,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): res = self._get_canonical_alias() self.assertEqual(res, {}) - def test_partial_modify(self): + def test_partial_modify(self) -> None: """Test removing only the alt_aliases.""" # Create an alias. self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) @@ -2536,7 +2541,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): res = self._get_canonical_alias() self.assertEqual(res, {"alias": self.alias}) - def test_add_alias(self): + def test_add_alias(self) -> None: """Test removing only the alt_aliases.""" # Create an additional alias. second_alias = "#second:test" @@ -2556,7 +2561,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} ) - def test_bad_data(self): + def test_bad_data(self) -> None: """Invalid data for alt_aliases should cause errors.""" self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400) self._set_canonical_alias({"alt_aliases": None}, expected_code=400) @@ -2566,7 +2571,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): self._set_canonical_alias({"alt_aliases": True}, expected_code=400) self._set_canonical_alias({"alt_aliases": {}}, expected_code=400) - def test_bad_alias(self): + def test_bad_alias(self) -> None: """An alias which does not point to the room raises a SynapseError.""" self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) @@ -2580,13 +2585,13 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("thomas", "hackme") self.tok = self.login("thomas", "hackme") self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - def test_threepid_invite_spamcheck(self): + def test_threepid_invite_spamcheck(self) -> None: # Mock a few functions to prevent the test from failing due to failing to talk to # a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we # can check its call_count later on during the test. diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index bfc04785b7..58f1ea11b7 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -12,16 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. import threading -from typing import TYPE_CHECKING, Dict, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes, LoginType, Membership from synapse.api.errors import SynapseError +from synapse.api.room_versions import RoomVersion from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.rest import admin from synapse.rest.client import account, login, profile, room +from synapse.server import HomeServer from synapse.types import JsonDict, Requester, StateMap +from synapse.util import Clock from synapse.util.frozenutils import unfreeze from tests import unittest @@ -34,7 +40,7 @@ thread_local = threading.local() class LegacyThirdPartyRulesTestModule: - def __init__(self, config: Dict, module_api: "ModuleApi"): + def __init__(self, config: Dict, module_api: "ModuleApi") -> None: # keep a record of the "current" rules module, so that the test can patch # it if desired. thread_local.rules_module = self @@ -42,32 +48,36 @@ class LegacyThirdPartyRulesTestModule: async def on_create_room( self, requester: Requester, config: dict, is_requester_admin: bool - ): + ) -> bool: return True - async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]): + async def check_event_allowed( + self, event: EventBase, state: StateMap[EventBase] + ) -> Union[bool, dict]: return True @staticmethod - def parse_config(config): + def parse_config(config: Dict[str, Any]) -> Dict[str, Any]: return config class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule): - def __init__(self, config: Dict, module_api: "ModuleApi"): + def __init__(self, config: Dict, module_api: "ModuleApi") -> None: super().__init__(config, module_api) - def on_create_room( + async def on_create_room( self, requester: Requester, config: dict, is_requester_admin: bool - ): + ) -> bool: return False class LegacyChangeEvents(LegacyThirdPartyRulesTestModule): - def __init__(self, config: Dict, module_api: "ModuleApi"): + def __init__(self, config: Dict, module_api: "ModuleApi") -> None: super().__init__(config, module_api) - async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]): + async def check_event_allowed( + self, event: EventBase, state: StateMap[EventBase] + ) -> JsonDict: d = event.get_dict() content = unfreeze(event.content) content["foo"] = "bar" @@ -84,7 +94,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): account.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver() load_legacy_third_party_event_rules(hs) @@ -94,22 +104,30 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): # Note that these checks are not relevant to this test case. # Have this homeserver auto-approve all event signature checking. - async def approve_all_signature_checking(_, pdu): + async def approve_all_signature_checking( + _: RoomVersion, pdu: EventBase + ) -> EventBase: return pdu - hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking + hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking # type: ignore[assignment] # Have this homeserver skip event auth checks. This is necessary due to # event auth checks ensuring that events were signed by the sender's homeserver. - async def _check_event_auth(origin, event, context, *args, **kwargs): + async def _check_event_auth( + origin: str, + event: EventBase, + context: EventContext, + *args: Any, + **kwargs: Any, + ) -> EventContext: return context - hs.get_federation_event_handler()._check_event_auth = _check_event_auth + hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment] return hs - def prepare(self, reactor, clock, homeserver): - super().prepare(reactor, clock, homeserver) + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) # Create some users and a room to play with during the tests self.user_id = self.register_user("kermit", "monkey") self.invitee = self.register_user("invitee", "hackme") @@ -121,13 +139,15 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): except Exception: pass - def test_third_party_rules(self): + def test_third_party_rules(self) -> None: """Tests that a forbidden event is forbidden from being sent, but an allowed one can be sent. """ # patch the rules module with a Mock which will return False for some event # types - async def check(ev, state): + async def check( + ev: EventBase, state: StateMap[EventBase] + ) -> Tuple[bool, Optional[JsonDict]]: return ev.type != "foo.bar.forbidden", None callback = Mock(spec=[], side_effect=check) @@ -161,7 +181,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): ) self.assertEqual(channel.result["code"], b"403", channel.result) - def test_third_party_rules_workaround_synapse_errors_pass_through(self): + def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None: """ Tests that the workaround introduced by https://github.com/matrix-org/synapse/pull/11042 is functional: that SynapseErrors are passed through from check_event_allowed @@ -172,7 +192,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): """ class NastyHackException(SynapseError): - def error_dict(self): + def error_dict(self) -> JsonDict: """ This overrides SynapseError's `error_dict` to nastily inject JSON into the error response. @@ -182,7 +202,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): return result # add a callback that will raise our hacky exception - async def check(ev, state) -> Tuple[bool, Optional[JsonDict]]: + async def check( + ev: EventBase, state: StateMap[EventBase] + ) -> Tuple[bool, Optional[JsonDict]]: raise NastyHackException(429, "message") self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check] @@ -202,11 +224,13 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): {"errcode": "M_UNKNOWN", "error": "message", "nasty": "very"}, ) - def test_cannot_modify_event(self): + def test_cannot_modify_event(self) -> None: """cannot accidentally modify an event before it is persisted""" # first patch the event checker so that it will try to modify the event - async def check(ev: EventBase, state): + async def check( + ev: EventBase, state: StateMap[EventBase] + ) -> Tuple[bool, Optional[JsonDict]]: ev.content = {"x": "y"} return True, None @@ -223,10 +247,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): # 500 Internal Server Error self.assertEqual(channel.code, 500, channel.result) - def test_modify_event(self): + def test_modify_event(self) -> None: """The module can return a modified version of the event""" # first patch the event checker so that it will modify the event - async def check(ev: EventBase, state): + async def check( + ev: EventBase, state: StateMap[EventBase] + ) -> Tuple[bool, Optional[JsonDict]]: d = ev.get_dict() d["content"] = {"x": "y"} return True, d @@ -253,10 +279,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): ev = channel.json_body self.assertEqual(ev["content"]["x"], "y") - def test_message_edit(self): + def test_message_edit(self) -> None: """Ensure that the module doesn't cause issues with edited messages.""" # first patch the event checker so that it will modify the event - async def check(ev: EventBase, state): + async def check( + ev: EventBase, state: StateMap[EventBase] + ) -> Tuple[bool, Optional[JsonDict]]: d = ev.get_dict() d["content"] = { "msgtype": "m.text", @@ -315,7 +343,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): ev = channel.json_body self.assertEqual(ev["content"]["body"], "EDITED BODY") - def test_send_event(self): + def test_send_event(self) -> None: """Tests that a module can send an event into a room via the module api""" content = { "msgtype": "m.text", @@ -344,7 +372,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): } } ) - def test_legacy_check_event_allowed(self): + def test_legacy_check_event_allowed(self) -> None: """Tests that the wrapper for legacy check_event_allowed callbacks works correctly. """ @@ -379,13 +407,13 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): } } ) - def test_legacy_on_create_room(self): + def test_legacy_on_create_room(self) -> None: """Tests that the wrapper for legacy on_create_room callbacks works correctly. """ self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403) - def test_sent_event_end_up_in_room_state(self): + def test_sent_event_end_up_in_room_state(self) -> None: """Tests that a state event sent by a module while processing another state event doesn't get dropped from the state of the room. This is to guard against a bug where Synapse has been observed doing so, see https://github.com/matrix-org/synapse/issues/10830 @@ -400,7 +428,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): api = self.hs.get_module_api() # Define a callback that sends a custom event on power levels update. - async def test_fn(event: EventBase, state_events): + async def test_fn( + event: EventBase, state_events: StateMap[EventBase] + ) -> Tuple[bool, Optional[JsonDict]]: if event.is_state and event.type == EventTypes.PowerLevels: await api.create_and_send_event_into_room( { @@ -436,7 +466,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["i"], i) - def test_on_new_event(self): + def test_on_new_event(self) -> None: """Test that the on_new_event callback is called on new events""" on_new_event = Mock(make_awaitable(None)) self.hs.get_third_party_event_rules()._on_new_event_callbacks.append( @@ -501,7 +531,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) - def _update_power_levels(self, event_default: int = 0): + def _update_power_levels(self, event_default: int = 0) -> None: """Updates the room's power levels. Args: @@ -533,7 +563,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): tok=self.tok, ) - def test_on_profile_update(self): + def test_on_profile_update(self) -> None: """Tests that the on_profile_update module callback is correctly called on profile updates. """ @@ -592,7 +622,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual(profile_info.display_name, displayname) self.assertEqual(profile_info.avatar_url, avatar_url) - def test_on_profile_update_admin(self): + def test_on_profile_update_admin(self) -> None: """Tests that the on_profile_update module callback is correctly called on profile updates triggered by a server admin. """ @@ -634,7 +664,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual(profile_info.display_name, displayname) self.assertEqual(profile_info.avatar_url, avatar_url) - def test_on_user_deactivation_status_changed(self): + def test_on_user_deactivation_status_changed(self) -> None: """Tests that the on_user_deactivation_status_changed module callback is called correctly when processing a user's deactivation. """ @@ -691,7 +721,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): args = profile_mock.call_args[0] self.assertTrue(args[3]) - def test_on_user_deactivation_status_changed_admin(self): + def test_on_user_deactivation_status_changed_admin(self) -> None: """Tests that the on_user_deactivation_status_changed module callback is called correctly when processing a user's deactivation triggered by a server admin as well as a reactivation. diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py index 8b2da88e8a..43be711a64 100644 --- a/tests/rest/client/test_typing.py +++ b/tests/rest/client/test_typing.py @@ -14,11 +14,16 @@ # limitations under the License. """Tests REST events for /rooms paths.""" - +from typing import Any from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.rest.client import room +from synapse.server import HomeServer +from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import UserID +from synapse.util import Clock from tests import unittest @@ -33,7 +38,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): user = UserID.from_string(user_id) servlets = [room.register_servlets] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver( "red", @@ -43,30 +48,34 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): self.event_source = hs.get_event_sources().sources.typing - hs.get_federation_handler = Mock() + hs.get_federation_handler = Mock() # type: ignore[assignment] - async def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } + async def get_user_by_access_token( + token: str, + rights: str = "access", + allow_expired: bool = False, + ) -> TokenLookupResult: + return TokenLookupResult( + user_id=self.user_id, + is_guest=False, + token_id=1, + ) - hs.get_auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment] - async def _insert_client_ip(*args, **kwargs): + async def _insert_client_ip(*args: Any, **kwargs: Any) -> None: return None - hs.get_datastores().main.insert_client_ip = _insert_client_ip + hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[assignment] return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) # Need another user to make notifications actually work self.helper.join(self.room_id, user="@jim:red") - def test_set_typing(self): + def test_set_typing(self) -> None: channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), @@ -95,7 +104,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): ], ) - def test_set_not_typing(self): + def test_set_not_typing(self) -> None: channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), @@ -103,7 +112,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): ) self.assertEqual(200, channel.code) - def test_typing_timeout(self): + def test_typing_timeout(self) -> None: channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id),