Add type hints to `tests/rest/client` (#12108)

* Add type hints to `tests/rest/client`

* newsfile

* fix imports

* add `test_account.py`

* Remove one type hint in `test_report_event.py`

* change `on_create_room` to `async`

* update new functions in `test_third_party_rules.py`

* Add `test_filter.py`

* add `test_rooms.py`

* change to `assertEquals` to `assertEqual`

* lint
This commit is contained in:
Dirk Klimpel 2022-03-02 17:34:14 +01:00 committed by GitHub
parent b4461e7d8a
commit 2ffaf30803
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 421 additions and 350 deletions

1
changelog.d/12108.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `tests/rest/client`.

View File

@ -78,13 +78,7 @@ exclude = (?x)
|tests/push/test_http.py |tests/push/test_http.py
|tests/push/test_presentable_names.py |tests/push/test_presentable_names.py
|tests/push/test_push_rule_evaluator.py |tests/push/test_push_rule_evaluator.py
|tests/rest/client/test_account.py
|tests/rest/client/test_filter.py
|tests/rest/client/test_report_event.py
|tests/rest/client/test_rooms.py
|tests/rest/client/test_third_party_rules.py
|tests/rest/client/test_transactions.py |tests/rest/client/test_transactions.py
|tests/rest/client/test_typing.py
|tests/rest/key/v2/test_remote_key_resource.py |tests/rest/key/v2/test_remote_key_resource.py
|tests/rest/media/v1/test_base.py |tests/rest/media/v1/test_base.py
|tests/rest/media/v1/test_media_storage.py |tests/rest/media/v1/test_media_storage.py

View File

@ -15,11 +15,12 @@ import json
import os import os
import re import re
from email.parser import Parser from email.parser import Parser
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional, Union
from unittest.mock import Mock from unittest.mock import Mock
import pkg_resources import pkg_resources
from twisted.internet.interfaces import IReactorTCP
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
@ -30,6 +31,7 @@ from synapse.rest import admin
from synapse.rest.client import account, login, register, room from synapse.rest.client import account, login, register, room
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -46,7 +48,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config() config = self.default_config()
# Email config. # Email config.
@ -67,20 +69,27 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(config=config) hs = self.setup_test_homeserver(config=config)
async def sendmail( async def sendmail(
reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs reactor: IReactorTCP,
): smtphost: str,
self.email_attempts.append(msg) smtpport: int,
from_addr: str,
to_addr: str,
msg_bytes: bytes,
*args: Any,
**kwargs: Any,
) -> None:
self.email_attempts.append(msg_bytes)
self.email_attempts = [] self.email_attempts: List[bytes] = []
hs.get_send_email_handler()._sendmail = sendmail hs.get_send_email_handler()._sendmail = sendmail
return hs return hs
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.submit_token_resource = PasswordResetSubmitTokenResource(hs) self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
def test_basic_password_reset(self): def test_basic_password_reset(self) -> None:
"""Test basic password reset flow""" """Test basic password reset flow"""
old_password = "monkey" old_password = "monkey"
new_password = "kangeroo" new_password = "kangeroo"
@ -118,7 +127,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.attempt_wrong_password_login("kermit", old_password) self.attempt_wrong_password_login("kermit", old_password)
@override_config({"rc_3pid_validation": {"burst_count": 3}}) @override_config({"rc_3pid_validation": {"burst_count": 3}})
def test_ratelimit_by_email(self): def test_ratelimit_by_email(self) -> None:
"""Test that we ratelimit /requestToken for the same email.""" """Test that we ratelimit /requestToken for the same email."""
old_password = "monkey" old_password = "monkey"
new_password = "kangeroo" new_password = "kangeroo"
@ -139,7 +148,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
) )
) )
def reset(ip): def reset(ip: str) -> None:
client_secret = "foobar" client_secret = "foobar"
session_id = self._request_token(email, client_secret, ip) session_id = self._request_token(email, client_secret, ip)
@ -166,7 +175,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.assertEqual(cm.exception.code, 429) self.assertEqual(cm.exception.code, 429)
def test_basic_password_reset_canonicalise_email(self): def test_basic_password_reset_canonicalise_email(self) -> None:
"""Test basic password reset flow """Test basic password reset flow
Request password reset with different spelling Request password reset with different spelling
""" """
@ -206,7 +215,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Assert we can't log in with the old password # Assert we can't log in with the old password
self.attempt_wrong_password_login("kermit", old_password) self.attempt_wrong_password_login("kermit", old_password)
def test_cant_reset_password_without_clicking_link(self): def test_cant_reset_password_without_clicking_link(self) -> None:
"""Test that we do actually need to click the link in the email""" """Test that we do actually need to click the link in the email"""
old_password = "monkey" old_password = "monkey"
new_password = "kangeroo" new_password = "kangeroo"
@ -241,7 +250,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Assert we can't log in with the new password # Assert we can't log in with the new password
self.attempt_wrong_password_login("kermit", new_password) self.attempt_wrong_password_login("kermit", new_password)
def test_no_valid_token(self): def test_no_valid_token(self) -> None:
"""Test that we do actually need to request a token and can't just """Test that we do actually need to request a token and can't just
make a session up. make a session up.
""" """
@ -277,7 +286,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.attempt_wrong_password_login("kermit", new_password) self.attempt_wrong_password_login("kermit", new_password)
@unittest.override_config({"request_token_inhibit_3pid_errors": True}) @unittest.override_config({"request_token_inhibit_3pid_errors": True})
def test_password_reset_bad_email_inhibit_error(self): def test_password_reset_bad_email_inhibit_error(self) -> None:
"""Test that triggering a password reset with an email address that isn't bound """Test that triggering a password reset with an email address that isn't bound
to an account doesn't leak the lack of binding for that address if configured to an account doesn't leak the lack of binding for that address if configured
that way. that way.
@ -292,7 +301,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.assertIsNotNone(session_id) self.assertIsNotNone(session_id)
def _request_token(self, email, client_secret, ip="127.0.0.1"): def _request_token(
self,
email: str,
client_secret: str,
ip: str = "127.0.0.1",
) -> str:
channel = self.make_request( channel = self.make_request(
"POST", "POST",
b"account/password/email/requestToken", b"account/password/email/requestToken",
@ -309,7 +323,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
return channel.json_body["sid"] return channel.json_body["sid"]
def _validate_token(self, link): def _validate_token(self, link: str) -> None:
# Remove the host # Remove the host
path = link.replace("https://example.com", "") path = link.replace("https://example.com", "")
@ -339,7 +353,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(200, channel.code, channel.result) self.assertEqual(200, channel.code, channel.result)
def _get_link_from_email(self): def _get_link_from_email(self) -> str:
assert self.email_attempts, "No emails have been sent" assert self.email_attempts, "No emails have been sent"
raw_msg = self.email_attempts[-1].decode("UTF-8") raw_msg = self.email_attempts[-1].decode("UTF-8")
@ -354,14 +368,19 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
if not text: if not text:
self.fail("Could not find text portion of email to parse") self.fail("Could not find text portion of email to parse")
assert text is not None
match = re.search(r"https://example.com\S+", text) match = re.search(r"https://example.com\S+", text)
assert match, "Could not find link in email" assert match, "Could not find link in email"
return match.group(0) return match.group(0)
def _reset_password( def _reset_password(
self, new_password, session_id, client_secret, expected_code=200 self,
): new_password: str,
session_id: str,
client_secret: str,
expected_code: int = 200,
) -> None:
channel = self.make_request( channel = self.make_request(
"POST", "POST",
b"account/password", b"account/password",
@ -388,11 +407,11 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
room.register_servlets, room.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.hs = self.setup_test_homeserver() self.hs = self.setup_test_homeserver()
return self.hs return self.hs
def test_deactivate_account(self): def test_deactivate_account(self) -> None:
user_id = self.register_user("kermit", "test") user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test") tok = self.login("kermit", "test")
@ -407,7 +426,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", "account/whoami", access_token=tok) channel = self.make_request("GET", "account/whoami", access_token=tok)
self.assertEqual(channel.code, 401) self.assertEqual(channel.code, 401)
def test_pending_invites(self): def test_pending_invites(self) -> None:
"""Tests that deactivating a user rejects every pending invite for them.""" """Tests that deactivating a user rejects every pending invite for them."""
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
@ -448,7 +467,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(memberships), 1, memberships) self.assertEqual(len(memberships), 1, memberships)
self.assertEqual(memberships[0].room_id, room_id, memberships) self.assertEqual(memberships[0].room_id, room_id, memberships)
def deactivate(self, user_id, tok): def deactivate(self, user_id: str, tok: str) -> None:
request_data = json.dumps( request_data = json.dumps(
{ {
"auth": { "auth": {
@ -474,12 +493,12 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
register.register_servlets, register.register_servlets,
] ]
def default_config(self): def default_config(self) -> Dict[str, Any]:
config = super().default_config() config = super().default_config()
config["allow_guest_access"] = True config["allow_guest_access"] = True
return config return config
def test_GET_whoami(self): def test_GET_whoami(self) -> None:
device_id = "wouldgohere" device_id = "wouldgohere"
user_id = self.register_user("kermit", "test") user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test", device_id=device_id) tok = self.login("kermit", "test", device_id=device_id)
@ -496,7 +515,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
}, },
) )
def test_GET_whoami_guests(self): def test_GET_whoami_guests(self) -> None:
channel = self.make_request( channel = self.make_request(
b"POST", b"/_matrix/client/r0/register?kind=guest", b"{}" b"POST", b"/_matrix/client/r0/register?kind=guest", b"{}"
) )
@ -516,7 +535,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
}, },
) )
def test_GET_whoami_appservices(self): def test_GET_whoami_appservices(self) -> None:
user_id = "@as:test" user_id = "@as:test"
as_token = "i_am_an_app_service" as_token = "i_am_an_app_service"
@ -541,7 +560,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
) )
self.assertFalse(hasattr(whoami, "device_id")) self.assertFalse(hasattr(whoami, "device_id"))
def _whoami(self, tok): def _whoami(self, tok: str) -> JsonDict:
channel = self.make_request("GET", "account/whoami", {}, access_token=tok) channel = self.make_request("GET", "account/whoami", {}, access_token=tok)
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
return channel.json_body return channel.json_body
@ -555,7 +574,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource, synapse.rest.admin.register_servlets_for_client_rest_resource,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config() config = self.default_config()
# Email config. # Email config.
@ -576,16 +595,23 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.hs = self.setup_test_homeserver(config=config) self.hs = self.setup_test_homeserver(config=config)
async def sendmail( async def sendmail(
reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs reactor: IReactorTCP,
): smtphost: str,
self.email_attempts.append(msg) smtpport: int,
from_addr: str,
to_addr: str,
msg_bytes: bytes,
*args: Any,
**kwargs: Any,
) -> None:
self.email_attempts.append(msg_bytes)
self.email_attempts = [] self.email_attempts: List[bytes] = []
self.hs.get_send_email_handler()._sendmail = sendmail self.hs.get_send_email_handler()._sendmail = sendmail
return self.hs return self.hs
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.user_id = self.register_user("kermit", "test") self.user_id = self.register_user("kermit", "test")
@ -593,83 +619,73 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.email = "test@example.com" self.email = "test@example.com"
self.url_3pid = b"account/3pid" self.url_3pid = b"account/3pid"
def test_add_valid_email(self): def test_add_valid_email(self) -> None:
self.get_success(self._add_email(self.email, self.email)) self._add_email(self.email, self.email)
def test_add_valid_email_second_time(self): def test_add_valid_email_second_time(self) -> None:
self.get_success(self._add_email(self.email, self.email)) self._add_email(self.email, self.email)
self.get_success(
self._request_token_invalid_email( self._request_token_invalid_email(
self.email, self.email,
expected_errcode=Codes.THREEPID_IN_USE, expected_errcode=Codes.THREEPID_IN_USE,
expected_error="Email is already in use", expected_error="Email is already in use",
) )
)
def test_add_valid_email_second_time_canonicalise(self): def test_add_valid_email_second_time_canonicalise(self) -> None:
self.get_success(self._add_email(self.email, self.email)) self._add_email(self.email, self.email)
self.get_success(
self._request_token_invalid_email( self._request_token_invalid_email(
"TEST@EXAMPLE.COM", "TEST@EXAMPLE.COM",
expected_errcode=Codes.THREEPID_IN_USE, expected_errcode=Codes.THREEPID_IN_USE,
expected_error="Email is already in use", expected_error="Email is already in use",
) )
)
def test_add_email_no_at(self): def test_add_email_no_at(self) -> None:
self.get_success(
self._request_token_invalid_email( self._request_token_invalid_email(
"address-without-at.bar", "address-without-at.bar",
expected_errcode=Codes.UNKNOWN, expected_errcode=Codes.UNKNOWN,
expected_error="Unable to parse email address", expected_error="Unable to parse email address",
) )
)
def test_add_email_two_at(self): def test_add_email_two_at(self) -> None:
self.get_success(
self._request_token_invalid_email( self._request_token_invalid_email(
"foo@foo@test.bar", "foo@foo@test.bar",
expected_errcode=Codes.UNKNOWN, expected_errcode=Codes.UNKNOWN,
expected_error="Unable to parse email address", expected_error="Unable to parse email address",
) )
)
def test_add_email_bad_format(self): def test_add_email_bad_format(self) -> None:
self.get_success(
self._request_token_invalid_email( self._request_token_invalid_email(
"user@bad.example.net@good.example.com", "user@bad.example.net@good.example.com",
expected_errcode=Codes.UNKNOWN, expected_errcode=Codes.UNKNOWN,
expected_error="Unable to parse email address", expected_error="Unable to parse email address",
) )
)
def test_add_email_domain_to_lower(self): def test_add_email_domain_to_lower(self) -> None:
self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar")) self._add_email("foo@TEST.BAR", "foo@test.bar")
def test_add_email_domain_with_umlaut(self): def test_add_email_domain_with_umlaut(self) -> None:
self.get_success(self._add_email("foo@Öumlaut.com", "foo@öumlaut.com")) self._add_email("foo@Öumlaut.com", "foo@öumlaut.com")
def test_add_email_address_casefold(self): def test_add_email_address_casefold(self) -> None:
self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com")) self._add_email("Strauß@Example.com", "strauss@example.com")
def test_address_trim(self): def test_address_trim(self) -> None:
self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar")) self._add_email(" foo@test.bar ", "foo@test.bar")
@override_config({"rc_3pid_validation": {"burst_count": 3}}) @override_config({"rc_3pid_validation": {"burst_count": 3}})
def test_ratelimit_by_ip(self): def test_ratelimit_by_ip(self) -> None:
"""Tests that adding emails is ratelimited by IP""" """Tests that adding emails is ratelimited by IP"""
# We expect to be able to set three emails before getting ratelimited. # We expect to be able to set three emails before getting ratelimited.
self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar")) self._add_email("foo1@test.bar", "foo1@test.bar")
self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar")) self._add_email("foo2@test.bar", "foo2@test.bar")
self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar")) self._add_email("foo3@test.bar", "foo3@test.bar")
with self.assertRaises(HttpResponseException) as cm: with self.assertRaises(HttpResponseException) as cm:
self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar")) self._add_email("foo4@test.bar", "foo4@test.bar")
self.assertEqual(cm.exception.code, 429) self.assertEqual(cm.exception.code, 429)
def test_add_email_if_disabled(self): def test_add_email_if_disabled(self) -> None:
"""Test adding email to profile when doing so is disallowed""" """Test adding email to profile when doing so is disallowed"""
self.hs.config.registration.enable_3pid_changes = False self.hs.config.registration.enable_3pid_changes = False
@ -695,7 +711,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
}, },
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user # Get user
@ -705,10 +721,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"]) self.assertFalse(channel.json_body["threepids"])
def test_delete_email(self): def test_delete_email(self) -> None:
"""Test deleting an email from profile""" """Test deleting an email from profile"""
# Add a threepid # Add a threepid
self.get_success( self.get_success(
@ -727,7 +743,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
{"medium": "email", "address": self.email}, {"medium": "email", "address": self.email},
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.result["body"])
# Get user # Get user
channel = self.make_request( channel = self.make_request(
@ -736,10 +752,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"]) self.assertFalse(channel.json_body["threepids"])
def test_delete_email_if_disabled(self): def test_delete_email_if_disabled(self) -> None:
"""Test deleting an email from profile when disallowed""" """Test deleting an email from profile when disallowed"""
self.hs.config.registration.enable_3pid_changes = False self.hs.config.registration.enable_3pid_changes = False
@ -761,7 +777,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user # Get user
@ -771,11 +787,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
def test_cant_add_email_without_clicking_link(self): def test_cant_add_email_without_clicking_link(self) -> None:
"""Test that we do actually need to click the link in the email""" """Test that we do actually need to click the link in the email"""
client_secret = "foobar" client_secret = "foobar"
session_id = self._request_token(self.email, client_secret) session_id = self._request_token(self.email, client_secret)
@ -797,7 +813,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
}, },
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user # Get user
@ -807,10 +823,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"]) self.assertFalse(channel.json_body["threepids"])
def test_no_valid_token(self): def test_no_valid_token(self) -> None:
"""Test that we do actually need to request a token and can't just """Test that we do actually need to request a token and can't just
make a session up. make a session up.
""" """
@ -832,7 +848,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
}, },
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user # Get user
@ -842,11 +858,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"]) self.assertFalse(channel.json_body["threepids"])
@override_config({"next_link_domain_whitelist": None}) @override_config({"next_link_domain_whitelist": None})
def test_next_link(self): def test_next_link(self) -> None:
"""Tests a valid next_link parameter value with no whitelist (good case)""" """Tests a valid next_link parameter value with no whitelist (good case)"""
self._request_token( self._request_token(
"something@example.com", "something@example.com",
@ -856,7 +872,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
) )
@override_config({"next_link_domain_whitelist": None}) @override_config({"next_link_domain_whitelist": None})
def test_next_link_exotic_protocol(self): def test_next_link_exotic_protocol(self) -> None:
"""Tests using a esoteric protocol as a next_link parameter value. """Tests using a esoteric protocol as a next_link parameter value.
Someone may be hosting a client on IPFS etc. Someone may be hosting a client on IPFS etc.
""" """
@ -868,7 +884,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
) )
@override_config({"next_link_domain_whitelist": None}) @override_config({"next_link_domain_whitelist": None})
def test_next_link_file_uri(self): def test_next_link_file_uri(self) -> None:
"""Tests next_link parameters cannot be file URI""" """Tests next_link parameters cannot be file URI"""
# Attempt to use a next_link value that points to the local disk # Attempt to use a next_link value that points to the local disk
self._request_token( self._request_token(
@ -879,7 +895,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
) )
@override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
def test_next_link_domain_whitelist(self): def test_next_link_domain_whitelist(self) -> None:
"""Tests next_link parameters must fit the whitelist if provided""" """Tests next_link parameters must fit the whitelist if provided"""
# Ensure not providing a next_link parameter still works # Ensure not providing a next_link parameter still works
@ -912,7 +928,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
) )
@override_config({"next_link_domain_whitelist": []}) @override_config({"next_link_domain_whitelist": []})
def test_empty_next_link_domain_whitelist(self): def test_empty_next_link_domain_whitelist(self) -> None:
"""Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially
disallowed disallowed
""" """
@ -962,28 +978,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def _request_token_invalid_email( def _request_token_invalid_email(
self, self,
email, email: str,
expected_errcode, expected_errcode: str,
expected_error, expected_error: str,
client_secret="foobar", client_secret: str = "foobar",
): ) -> None:
channel = self.make_request( channel = self.make_request(
"POST", "POST",
b"account/3pid/email/requestToken", b"account/3pid/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1}, {"client_secret": client_secret, "email": email, "send_attempt": 1},
) )
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(expected_errcode, channel.json_body["errcode"]) self.assertEqual(expected_errcode, channel.json_body["errcode"])
self.assertEqual(expected_error, channel.json_body["error"]) self.assertEqual(expected_error, channel.json_body["error"])
def _validate_token(self, link): def _validate_token(self, link: str) -> None:
# Remove the host # Remove the host
path = link.replace("https://example.com", "") path = link.replace("https://example.com", "")
channel = self.make_request("GET", path, shorthand=False) channel = self.make_request("GET", path, shorthand=False)
self.assertEqual(200, channel.code, channel.result) self.assertEqual(200, channel.code, channel.result)
def _get_link_from_email(self): def _get_link_from_email(self) -> str:
assert self.email_attempts, "No emails have been sent" assert self.email_attempts, "No emails have been sent"
raw_msg = self.email_attempts[-1].decode("UTF-8") raw_msg = self.email_attempts[-1].decode("UTF-8")
@ -998,12 +1014,13 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
if not text: if not text:
self.fail("Could not find text portion of email to parse") self.fail("Could not find text portion of email to parse")
assert text is not None
match = re.search(r"https://example.com\S+", text) match = re.search(r"https://example.com\S+", text)
assert match, "Could not find link in email" assert match, "Could not find link in email"
return match.group(0) return match.group(0)
def _add_email(self, request_email, expected_email): def _add_email(self, request_email: str, expected_email: str) -> None:
"""Test adding an email to profile""" """Test adding an email to profile"""
previous_email_attempts = len(self.email_attempts) previous_email_attempts = len(self.email_attempts)
@ -1030,7 +1047,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.result["body"])
# Get user # Get user
channel = self.make_request( channel = self.make_request(
@ -1039,7 +1056,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
@ -1055,18 +1072,18 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
url = "/_matrix/client/unstable/org.matrix.msc3720/account_status" url = "/_matrix/client/unstable/org.matrix.msc3720/account_status"
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config() config = self.default_config()
config["experimental_features"] = {"msc3720_enabled": True} config["experimental_features"] = {"msc3720_enabled": True}
return self.setup_test_homeserver(config=config) return self.setup_test_homeserver(config=config)
def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.requester = self.register_user("requester", "password") self.requester = self.register_user("requester", "password")
self.requester_tok = self.login("requester", "password") self.requester_tok = self.login("requester", "password")
self.server_name = homeserver.config.server.server_name self.server_name = hs.config.server.server_name
def test_missing_mxid(self): def test_missing_mxid(self) -> None:
"""Tests that not providing any MXID raises an error.""" """Tests that not providing any MXID raises an error."""
self._test_status( self._test_status(
users=None, users=None,
@ -1074,7 +1091,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_errcode=Codes.MISSING_PARAM, expected_errcode=Codes.MISSING_PARAM,
) )
def test_invalid_mxid(self): def test_invalid_mxid(self) -> None:
"""Tests that providing an invalid MXID raises an error.""" """Tests that providing an invalid MXID raises an error."""
self._test_status( self._test_status(
users=["bad:test"], users=["bad:test"],
@ -1082,7 +1099,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_errcode=Codes.INVALID_PARAM, expected_errcode=Codes.INVALID_PARAM,
) )
def test_local_user_not_exists(self): def test_local_user_not_exists(self) -> None:
"""Tests that the account status endpoints correctly reports that a user doesn't """Tests that the account status endpoints correctly reports that a user doesn't
exist. exist.
""" """
@ -1098,7 +1115,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_failures=[], expected_failures=[],
) )
def test_local_user_exists(self): def test_local_user_exists(self) -> None:
"""Tests that the account status endpoint correctly reports that a user doesn't """Tests that the account status endpoint correctly reports that a user doesn't
exist. exist.
""" """
@ -1115,7 +1132,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_failures=[], expected_failures=[],
) )
def test_local_user_deactivated(self): def test_local_user_deactivated(self) -> None:
"""Tests that the account status endpoint correctly reports a deactivated user.""" """Tests that the account status endpoint correctly reports a deactivated user."""
user = self.register_user("someuser", "password") user = self.register_user("someuser", "password")
self.get_success( self.get_success(
@ -1135,7 +1152,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_failures=[], expected_failures=[],
) )
def test_mixed_local_and_remote_users(self): def test_mixed_local_and_remote_users(self) -> None:
"""Tests that if some users are remote the account status endpoint correctly """Tests that if some users are remote the account status endpoint correctly
merges the remote responses with the local result. merges the remote responses with the local result.
""" """
@ -1150,7 +1167,13 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
"@bad:badremote", "@bad:badremote",
] ]
async def post_json(destination, path, data, *a, **kwa): async def post_json(
destination: str,
path: str,
data: Optional[JsonDict] = None,
*a: Any,
**kwa: Any,
) -> Union[JsonDict, list]:
if destination == "remote": if destination == "remote":
return { return {
"account_statuses": { "account_statuses": {
@ -1160,9 +1183,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
}, },
} }
} }
if destination == "otherremote": elif destination == "badremote":
return {}
if destination == "badremote":
# badremote tries to overwrite the status of a user that doesn't belong # badremote tries to overwrite the status of a user that doesn't belong
# to it (i.e. users[1]) with false data, which Synapse is expected to # to it (i.e. users[1]) with false data, which Synapse is expected to
# ignore. # ignore.
@ -1176,6 +1197,9 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
}, },
} }
} }
# if destination == "otherremote"
else:
return {}
# Register a mock that will return the expected result depending on the remote. # Register a mock that will return the expected result depending on the remote.
self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json)
@ -1205,7 +1229,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None, expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None,
expected_failures: Optional[List[str]] = None, expected_failures: Optional[List[str]] = None,
expected_errcode: Optional[str] = None, expected_errcode: Optional[str] = None,
): ) -> None:
"""Send a request to the account status endpoint and check that the response """Send a request to the account status endpoint and check that the response
matches with what's expected. matches with what's expected.

View File

@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.rest.client import filter from synapse.rest.client import filter
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest from tests import unittest
@ -30,11 +32,11 @@ class FilterTestCase(unittest.HomeserverTestCase):
EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}' EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
servlets = [filter.register_servlets] servlets = [filter.register_servlets]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.filtering = hs.get_filtering() self.filtering = hs.get_filtering()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
def test_add_filter(self): def test_add_filter(self) -> None:
channel = self.make_request( channel = self.make_request(
"POST", "POST",
"/_matrix/client/r0/user/%s/filter" % (self.user_id), "/_matrix/client/r0/user/%s/filter" % (self.user_id),
@ -43,11 +45,13 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")
self.assertEqual(channel.json_body, {"filter_id": "0"}) self.assertEqual(channel.json_body, {"filter_id": "0"})
filter = self.store.get_user_filter(user_localpart="apple", filter_id=0) filter = self.get_success(
self.store.get_user_filter(user_localpart="apple", filter_id=0)
)
self.pump() self.pump()
self.assertEqual(filter.result, self.EXAMPLE_FILTER) self.assertEqual(filter, self.EXAMPLE_FILTER)
def test_add_filter_for_other_user(self): def test_add_filter_for_other_user(self) -> None:
channel = self.make_request( channel = self.make_request(
"POST", "POST",
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"), "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
@ -57,7 +61,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.result["code"], b"403")
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
def test_add_filter_non_local_user(self): def test_add_filter_non_local_user(self) -> None:
_is_mine = self.hs.is_mine _is_mine = self.hs.is_mine
self.hs.is_mine = lambda target_user: False self.hs.is_mine = lambda target_user: False
channel = self.make_request( channel = self.make_request(
@ -70,14 +74,13 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.result["code"], b"403")
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
def test_get_filter(self): def test_get_filter(self) -> None:
filter_id = defer.ensureDeferred( filter_id = self.get_success(
self.filtering.add_user_filter( self.filtering.add_user_filter(
user_localpart="apple", user_filter=self.EXAMPLE_FILTER user_localpart="apple", user_filter=self.EXAMPLE_FILTER
) )
) )
self.reactor.advance(1) self.reactor.advance(1)
filter_id = filter_id.result
channel = self.make_request( channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id) "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
) )
@ -85,7 +88,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")
self.assertEqual(channel.json_body, self.EXAMPLE_FILTER) self.assertEqual(channel.json_body, self.EXAMPLE_FILTER)
def test_get_filter_non_existant(self): def test_get_filter_non_existant(self) -> None:
channel = self.make_request( channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id) "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
) )
@ -95,7 +98,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
# Currently invalid params do not have an appropriate errcode # Currently invalid params do not have an appropriate errcode
# in errors.py # in errors.py
def test_get_filter_invalid_id(self): def test_get_filter_invalid_id(self) -> None:
channel = self.make_request( channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id) "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
) )
@ -103,7 +106,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.result["code"], b"400")
# No ID also returns an invalid_id error # No ID also returns an invalid_id error
def test_get_filter_no_id(self): def test_get_filter_no_id(self) -> None:
channel = self.make_request( channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id) "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
) )

View File

@ -15,7 +15,7 @@
import itertools import itertools
import urllib.parse import urllib.parse
from typing import Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import patch from unittest.mock import patch
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -45,7 +45,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
] ]
hijack_auth = False hijack_auth = False
def default_config(self) -> dict: def default_config(self) -> Dict[str, Any]:
# We need to enable msc1849 support for aggregations # We need to enable msc1849 support for aggregations
config = super().default_config() config = super().default_config()

View File

@ -14,8 +14,13 @@
import json import json
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.rest.client import login, report_event, room from synapse.rest.client import login, report_event, room
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest from tests import unittest
@ -28,7 +33,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
report_event.register_servlets, report_event.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.login("admin", "pass")
self.other_user = self.register_user("user", "pass") self.other_user = self.register_user("user", "pass")
@ -42,35 +47,35 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
self.event_id = resp["event_id"] self.event_id = resp["event_id"]
self.report_path = f"rooms/{self.room_id}/report/{self.event_id}" self.report_path = f"rooms/{self.room_id}/report/{self.event_id}"
def test_reason_str_and_score_int(self): def test_reason_str_and_score_int(self) -> None:
data = {"reason": "this makes me sad", "score": -100} data = {"reason": "this makes me sad", "score": -100}
self._assert_status(200, data) self._assert_status(200, data)
def test_no_reason(self): def test_no_reason(self) -> None:
data = {"score": 0} data = {"score": 0}
self._assert_status(200, data) self._assert_status(200, data)
def test_no_score(self): def test_no_score(self) -> None:
data = {"reason": "this makes me sad"} data = {"reason": "this makes me sad"}
self._assert_status(200, data) self._assert_status(200, data)
def test_no_reason_and_no_score(self): def test_no_reason_and_no_score(self) -> None:
data = {} data: JsonDict = {}
self._assert_status(200, data) self._assert_status(200, data)
def test_reason_int_and_score_str(self): def test_reason_int_and_score_str(self) -> None:
data = {"reason": 10, "score": "string"} data = {"reason": 10, "score": "string"}
self._assert_status(400, data) self._assert_status(400, data)
def test_reason_zero_and_score_blank(self): def test_reason_zero_and_score_blank(self) -> None:
data = {"reason": 0, "score": ""} data = {"reason": 0, "score": ""}
self._assert_status(400, data) self._assert_status(400, data)
def test_reason_and_score_null(self): def test_reason_and_score_null(self) -> None:
data = {"reason": None, "score": None} data = {"reason": None, "score": None}
self._assert_status(400, data) self._assert_status(400, data)
def _assert_status(self, response_status, data): def _assert_status(self, response_status: int, data: JsonDict) -> None:
channel = self.make_request( channel = self.make_request(
"POST", "POST",
self.report_path, self.report_path,

View File

@ -18,11 +18,12 @@
"""Tests REST events for /rooms paths.""" """Tests REST events for /rooms paths."""
import json import json
from typing import Iterable, List from typing import Any, Dict, Iterable, List, Optional
from unittest.mock import Mock, call from unittest.mock import Mock, call
from urllib import parse as urlparse from urllib import parse as urlparse
from twisted.internet import defer from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import ( from synapse.api.constants import (
@ -35,7 +36,9 @@ from synapse.api.errors import Codes, HttpResponseException
from synapse.handlers.pagination import PurgeStatus from synapse.handlers.pagination import PurgeStatus
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import account, directory, login, profile, room, sync from synapse.rest.client import account, directory, login, profile, room, sync
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias, UserID, create_requester from synapse.types import JsonDict, RoomAlias, UserID, create_requester
from synapse.util import Clock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from tests import unittest from tests import unittest
@ -45,11 +48,11 @@ PATH_PREFIX = b"/_matrix/client/api/v1"
class RoomBase(unittest.HomeserverTestCase): class RoomBase(unittest.HomeserverTestCase):
rmcreator_id = None rmcreator_id: Optional[str] = None
servlets = [room.register_servlets, room.register_deprecated_servlets] servlets = [room.register_servlets, room.register_deprecated_servlets]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.hs = self.setup_test_homeserver( self.hs = self.setup_test_homeserver(
"red", "red",
@ -57,15 +60,15 @@ class RoomBase(unittest.HomeserverTestCase):
federation_client=Mock(), federation_client=Mock(),
) )
self.hs.get_federation_handler = Mock() self.hs.get_federation_handler = Mock() # type: ignore[assignment]
self.hs.get_federation_handler.return_value.maybe_backfill = Mock( self.hs.get_federation_handler.return_value.maybe_backfill = Mock(
return_value=make_awaitable(None) return_value=make_awaitable(None)
) )
async def _insert_client_ip(*args, **kwargs): async def _insert_client_ip(*args: Any, **kwargs: Any) -> None:
return None return None
self.hs.get_datastores().main.insert_client_ip = _insert_client_ip self.hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[assignment]
return self.hs return self.hs
@ -76,7 +79,7 @@ class RoomPermissionsTestCase(RoomBase):
user_id = "@sid1:red" user_id = "@sid1:red"
rmcreator_id = "@notme:red" rmcreator_id = "@notme:red"
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.helper.auth_user_id = self.rmcreator_id self.helper.auth_user_id = self.rmcreator_id
# create some rooms under the name rmcreator_id # create some rooms under the name rmcreator_id
@ -108,12 +111,12 @@ class RoomPermissionsTestCase(RoomBase):
# auth as user_id now # auth as user_id now
self.helper.auth_user_id = self.user_id self.helper.auth_user_id = self.user_id
def test_can_do_action(self): def test_can_do_action(self) -> None:
msg_content = b'{"msgtype":"m.text","body":"hello"}' msg_content = b'{"msgtype":"m.text","body":"hello"}'
seq = iter(range(100)) seq = iter(range(100))
def send_msg_path(): def send_msg_path() -> str:
return "/rooms/%s/send/m.room.message/mid%s" % ( return "/rooms/%s/send/m.room.message/mid%s" % (
self.created_rmid, self.created_rmid,
str(next(seq)), str(next(seq)),
@ -148,7 +151,7 @@ class RoomPermissionsTestCase(RoomBase):
channel = self.make_request("PUT", send_msg_path(), msg_content) channel = self.make_request("PUT", send_msg_path(), msg_content)
self.assertEqual(403, channel.code, msg=channel.result["body"]) self.assertEqual(403, channel.code, msg=channel.result["body"])
def test_topic_perms(self): def test_topic_perms(self) -> None:
topic_content = b'{"topic":"My Topic Name"}' topic_content = b'{"topic":"My Topic Name"}'
topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid
@ -214,14 +217,14 @@ class RoomPermissionsTestCase(RoomBase):
self.assertEqual(403, channel.code, msg=channel.result["body"]) self.assertEqual(403, channel.code, msg=channel.result["body"])
def _test_get_membership( def _test_get_membership(
self, room=None, members: Iterable = frozenset(), expect_code=None self, room: str, members: Iterable = frozenset(), expect_code: int = 200
): ) -> None:
for member in members: for member in members:
path = "/rooms/%s/state/m.room.member/%s" % (room, member) path = "/rooms/%s/state/m.room.member/%s" % (room, member)
channel = self.make_request("GET", path) channel = self.make_request("GET", path)
self.assertEqual(expect_code, channel.code) self.assertEqual(expect_code, channel.code)
def test_membership_basic_room_perms(self): def test_membership_basic_room_perms(self) -> None:
# === room does not exist === # === room does not exist ===
room = self.uncreated_rmid room = self.uncreated_rmid
# get membership of self, get membership of other, uncreated room # get membership of self, get membership of other, uncreated room
@ -241,7 +244,7 @@ class RoomPermissionsTestCase(RoomBase):
self.helper.join(room=room, user=usr, expect_code=404) self.helper.join(room=room, user=usr, expect_code=404)
self.helper.leave(room=room, user=usr, expect_code=404) self.helper.leave(room=room, user=usr, expect_code=404)
def test_membership_private_room_perms(self): def test_membership_private_room_perms(self) -> None:
room = self.created_rmid room = self.created_rmid
# get membership of self, get membership of other, private room + invite # get membership of self, get membership of other, private room + invite
# expect all 403s # expect all 403s
@ -264,7 +267,7 @@ class RoomPermissionsTestCase(RoomBase):
members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 members=[self.user_id, self.rmcreator_id], room=room, expect_code=200
) )
def test_membership_public_room_perms(self): def test_membership_public_room_perms(self) -> None:
room = self.created_public_rmid room = self.created_public_rmid
# get membership of self, get membership of other, public room + invite # get membership of self, get membership of other, public room + invite
# expect 403 # expect 403
@ -287,7 +290,7 @@ class RoomPermissionsTestCase(RoomBase):
members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 members=[self.user_id, self.rmcreator_id], room=room, expect_code=200
) )
def test_invited_permissions(self): def test_invited_permissions(self) -> None:
room = self.created_rmid room = self.created_rmid
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
@ -310,7 +313,7 @@ class RoomPermissionsTestCase(RoomBase):
expect_code=403, expect_code=403,
) )
def test_joined_permissions(self): def test_joined_permissions(self) -> None:
room = self.created_rmid room = self.created_rmid
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
self.helper.join(room=room, user=self.user_id) self.helper.join(room=room, user=self.user_id)
@ -348,7 +351,7 @@ class RoomPermissionsTestCase(RoomBase):
# set left of self, expect 200 # set left of self, expect 200
self.helper.leave(room=room, user=self.user_id) self.helper.leave(room=room, user=self.user_id)
def test_leave_permissions(self): def test_leave_permissions(self) -> None:
room = self.created_rmid room = self.created_rmid
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
self.helper.join(room=room, user=self.user_id) self.helper.join(room=room, user=self.user_id)
@ -383,7 +386,7 @@ class RoomPermissionsTestCase(RoomBase):
) )
# tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember # tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember
def test_member_event_from_ban(self): def test_member_event_from_ban(self) -> None:
room = self.created_rmid room = self.created_rmid
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
self.helper.join(room=room, user=self.user_id) self.helper.join(room=room, user=self.user_id)
@ -475,21 +478,21 @@ class RoomsMemberListTestCase(RoomBase):
user_id = "@sid1:red" user_id = "@sid1:red"
def test_get_member_list(self): def test_get_member_list(self) -> None:
room_id = self.helper.create_room_as(self.user_id) room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request("GET", "/rooms/%s/members" % room_id) channel = self.make_request("GET", "/rooms/%s/members" % room_id)
self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.result["body"])
def test_get_member_list_no_room(self): def test_get_member_list_no_room(self) -> None:
channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
self.assertEqual(403, channel.code, msg=channel.result["body"]) self.assertEqual(403, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission(self): def test_get_member_list_no_permission(self) -> None:
room_id = self.helper.create_room_as("@some_other_guy:red") room_id = self.helper.create_room_as("@some_other_guy:red")
channel = self.make_request("GET", "/rooms/%s/members" % room_id) channel = self.make_request("GET", "/rooms/%s/members" % room_id)
self.assertEqual(403, channel.code, msg=channel.result["body"]) self.assertEqual(403, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission_with_at_token(self): def test_get_member_list_no_permission_with_at_token(self) -> None:
""" """
Tests that a stranger to the room cannot get the member list Tests that a stranger to the room cannot get the member list
(in the case that they use an at token). (in the case that they use an at token).
@ -509,7 +512,7 @@ class RoomsMemberListTestCase(RoomBase):
) )
self.assertEqual(403, channel.code, msg=channel.result["body"]) self.assertEqual(403, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission_former_member(self): def test_get_member_list_no_permission_former_member(self) -> None:
""" """
Tests that a former member of the room can not get the member list. Tests that a former member of the room can not get the member list.
""" """
@ -529,7 +532,7 @@ class RoomsMemberListTestCase(RoomBase):
channel = self.make_request("GET", "/rooms/%s/members" % room_id) channel = self.make_request("GET", "/rooms/%s/members" % room_id)
self.assertEqual(403, channel.code, msg=channel.result["body"]) self.assertEqual(403, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission_former_member_with_at_token(self): def test_get_member_list_no_permission_former_member_with_at_token(self) -> None:
""" """
Tests that a former member of the room can not get the member list Tests that a former member of the room can not get the member list
(in the case that they use an at token). (in the case that they use an at token).
@ -569,7 +572,7 @@ class RoomsMemberListTestCase(RoomBase):
) )
self.assertEqual(403, channel.code, msg=channel.result["body"]) self.assertEqual(403, channel.code, msg=channel.result["body"])
def test_get_member_list_mixed_memberships(self): def test_get_member_list_mixed_memberships(self) -> None:
room_creator = "@some_other_guy:red" room_creator = "@some_other_guy:red"
room_id = self.helper.create_room_as(room_creator) room_id = self.helper.create_room_as(room_creator)
room_path = "/rooms/%s/members" % room_id room_path = "/rooms/%s/members" % room_id
@ -594,26 +597,26 @@ class RoomsCreateTestCase(RoomBase):
user_id = "@sid1:red" user_id = "@sid1:red"
def test_post_room_no_keys(self): def test_post_room_no_keys(self) -> None:
# POST with no config keys, expect new room id # POST with no config keys, expect new room id
channel = self.make_request("POST", "/createRoom", "{}") channel = self.make_request("POST", "/createRoom", "{}")
self.assertEqual(200, channel.code, channel.result) self.assertEqual(200, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body) self.assertTrue("room_id" in channel.json_body)
def test_post_room_visibility_key(self): def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id # POST with visibility config key, expect new room id
channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}')
self.assertEqual(200, channel.code) self.assertEqual(200, channel.code)
self.assertTrue("room_id" in channel.json_body) self.assertTrue("room_id" in channel.json_body)
def test_post_room_custom_key(self): def test_post_room_custom_key(self) -> None:
# POST with custom config keys, expect new room id # POST with custom config keys, expect new room id
channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}')
self.assertEqual(200, channel.code) self.assertEqual(200, channel.code)
self.assertTrue("room_id" in channel.json_body) self.assertTrue("room_id" in channel.json_body)
def test_post_room_known_and_unknown_keys(self): def test_post_room_known_and_unknown_keys(self) -> None:
# POST with custom + known config keys, expect new room id # POST with custom + known config keys, expect new room id
channel = self.make_request( channel = self.make_request(
"POST", "/createRoom", b'{"visibility":"private","custom":"things"}' "POST", "/createRoom", b'{"visibility":"private","custom":"things"}'
@ -621,7 +624,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(200, channel.code) self.assertEqual(200, channel.code)
self.assertTrue("room_id" in channel.json_body) self.assertTrue("room_id" in channel.json_body)
def test_post_room_invalid_content(self): def test_post_room_invalid_content(self) -> None:
# POST with invalid content / paths, expect 400 # POST with invalid content / paths, expect 400
channel = self.make_request("POST", "/createRoom", b'{"visibili') channel = self.make_request("POST", "/createRoom", b'{"visibili')
self.assertEqual(400, channel.code) self.assertEqual(400, channel.code)
@ -629,7 +632,7 @@ class RoomsCreateTestCase(RoomBase):
channel = self.make_request("POST", "/createRoom", b'["hello"]') channel = self.make_request("POST", "/createRoom", b'["hello"]')
self.assertEqual(400, channel.code) self.assertEqual(400, channel.code)
def test_post_room_invitees_invalid_mxid(self): def test_post_room_invitees_invalid_mxid(self) -> None:
# POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088
# Note the trailing space in the MXID here! # Note the trailing space in the MXID here!
channel = self.make_request( channel = self.make_request(
@ -638,7 +641,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(400, channel.code) self.assertEqual(400, channel.code)
@unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}}) @unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}})
def test_post_room_invitees_ratelimit(self): def test_post_room_invitees_ratelimit(self) -> None:
"""Test that invites sent when creating a room are ratelimited by a RateLimiter, """Test that invites sent when creating a room are ratelimited by a RateLimiter,
which ratelimits them correctly, including by not limiting when the requester is which ratelimits them correctly, including by not limiting when the requester is
exempt from ratelimiting. exempt from ratelimiting.
@ -674,7 +677,7 @@ class RoomsCreateTestCase(RoomBase):
channel = self.make_request("POST", "/createRoom", content) channel = self.make_request("POST", "/createRoom", content)
self.assertEqual(200, channel.code) self.assertEqual(200, channel.code)
def test_spam_checker_may_join_room(self): def test_spam_checker_may_join_room(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly bypassed """Tests that the user_may_join_room spam checker callback is correctly bypassed
when creating a new room. when creating a new room.
""" """
@ -704,12 +707,12 @@ class RoomTopicTestCase(RoomBase):
user_id = "@sid1:red" user_id = "@sid1:red"
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# create the room # create the room
self.room_id = self.helper.create_room_as(self.user_id) self.room_id = self.helper.create_room_as(self.user_id)
self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,) self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,)
def test_invalid_puts(self): def test_invalid_puts(self) -> None:
# missing keys or invalid json # missing keys or invalid json
channel = self.make_request("PUT", self.path, "{}") channel = self.make_request("PUT", self.path, "{}")
self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.result["body"])
@ -736,7 +739,7 @@ class RoomTopicTestCase(RoomBase):
channel = self.make_request("PUT", self.path, content) channel = self.make_request("PUT", self.path, content)
self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.result["body"])
def test_rooms_topic(self): def test_rooms_topic(self) -> None:
# nothing should be there # nothing should be there
channel = self.make_request("GET", self.path) channel = self.make_request("GET", self.path)
self.assertEqual(404, channel.code, msg=channel.result["body"]) self.assertEqual(404, channel.code, msg=channel.result["body"])
@ -751,7 +754,7 @@ class RoomTopicTestCase(RoomBase):
self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body) self.assert_dict(json.loads(content), channel.json_body)
def test_rooms_topic_with_extra_keys(self): def test_rooms_topic_with_extra_keys(self) -> None:
# valid put with extra keys # valid put with extra keys
content = '{"topic":"Seasons","subtopic":"Summer"}' content = '{"topic":"Seasons","subtopic":"Summer"}'
channel = self.make_request("PUT", self.path, content) channel = self.make_request("PUT", self.path, content)
@ -768,10 +771,10 @@ class RoomMemberStateTestCase(RoomBase):
user_id = "@sid1:red" user_id = "@sid1:red"
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(self.user_id) self.room_id = self.helper.create_room_as(self.user_id)
def test_invalid_puts(self): def test_invalid_puts(self) -> None:
path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
# missing keys or invalid json # missing keys or invalid json
channel = self.make_request("PUT", path, "{}") channel = self.make_request("PUT", path, "{}")
@ -801,7 +804,7 @@ class RoomMemberStateTestCase(RoomBase):
channel = self.make_request("PUT", path, content.encode("ascii")) channel = self.make_request("PUT", path, content.encode("ascii"))
self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.result["body"])
def test_rooms_members_self(self): def test_rooms_members_self(self) -> None:
path = "/rooms/%s/state/m.room.member/%s" % ( path = "/rooms/%s/state/m.room.member/%s" % (
urlparse.quote(self.room_id), urlparse.quote(self.room_id),
self.user_id, self.user_id,
@ -812,13 +815,13 @@ class RoomMemberStateTestCase(RoomBase):
channel = self.make_request("PUT", path, content.encode("ascii")) channel = self.make_request("PUT", path, content.encode("ascii"))
self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", path, None) channel = self.make_request("GET", path, content=b"")
self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.result["body"])
expected_response = {"membership": Membership.JOIN} expected_response = {"membership": Membership.JOIN}
self.assertEqual(expected_response, channel.json_body) self.assertEqual(expected_response, channel.json_body)
def test_rooms_members_other(self): def test_rooms_members_other(self) -> None:
self.other_id = "@zzsid1:red" self.other_id = "@zzsid1:red"
path = "/rooms/%s/state/m.room.member/%s" % ( path = "/rooms/%s/state/m.room.member/%s" % (
urlparse.quote(self.room_id), urlparse.quote(self.room_id),
@ -830,11 +833,11 @@ class RoomMemberStateTestCase(RoomBase):
channel = self.make_request("PUT", path, content) channel = self.make_request("PUT", path, content)
self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", path, None) channel = self.make_request("GET", path, content=b"")
self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertEqual(json.loads(content), channel.json_body) self.assertEqual(json.loads(content), channel.json_body)
def test_rooms_members_other_custom_keys(self): def test_rooms_members_other_custom_keys(self) -> None:
self.other_id = "@zzsid1:red" self.other_id = "@zzsid1:red"
path = "/rooms/%s/state/m.room.member/%s" % ( path = "/rooms/%s/state/m.room.member/%s" % (
urlparse.quote(self.room_id), urlparse.quote(self.room_id),
@ -849,7 +852,7 @@ class RoomMemberStateTestCase(RoomBase):
channel = self.make_request("PUT", path, content) channel = self.make_request("PUT", path, content)
self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", path, None) channel = self.make_request("GET", path, content=b"")
self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertEqual(json.loads(content), channel.json_body) self.assertEqual(json.loads(content), channel.json_body)
@ -866,7 +869,7 @@ class RoomInviteRatelimitTestCase(RoomBase):
@unittest.override_config( @unittest.override_config(
{"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}} {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}}
) )
def test_invites_by_rooms_ratelimit(self): def test_invites_by_rooms_ratelimit(self) -> None:
"""Tests that invites in a room are actually rate-limited.""" """Tests that invites in a room are actually rate-limited."""
room_id = self.helper.create_room_as(self.user_id) room_id = self.helper.create_room_as(self.user_id)
@ -878,7 +881,7 @@ class RoomInviteRatelimitTestCase(RoomBase):
@unittest.override_config( @unittest.override_config(
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
) )
def test_invites_by_users_ratelimit(self): def test_invites_by_users_ratelimit(self) -> None:
"""Tests that invites to a specific user are actually rate-limited.""" """Tests that invites to a specific user are actually rate-limited."""
for _ in range(3): for _ in range(3):
@ -897,7 +900,7 @@ class RoomJoinTestCase(RoomBase):
room.register_servlets, room.register_servlets,
] ]
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user1 = self.register_user("thomas", "hackme") self.user1 = self.register_user("thomas", "hackme")
self.tok1 = self.login("thomas", "hackme") self.tok1 = self.login("thomas", "hackme")
@ -908,7 +911,7 @@ class RoomJoinTestCase(RoomBase):
self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
def test_spam_checker_may_join_room(self): def test_spam_checker_may_join_room(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly called """Tests that the user_may_join_room spam checker callback is correctly called
and blocks room joins when needed. and blocks room joins when needed.
""" """
@ -975,8 +978,8 @@ class RoomJoinRatelimitTestCase(RoomBase):
room.register_servlets, room.register_servlets,
] ]
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, homeserver) super().prepare(reactor, clock, hs)
# profile changes expect that the user is actually registered # profile changes expect that the user is actually registered
user = UserID.from_string(self.user_id) user = UserID.from_string(self.user_id)
self.get_success(self.register_user(user.localpart, "supersecretpassword")) self.get_success(self.register_user(user.localpart, "supersecretpassword"))
@ -984,7 +987,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
@unittest.override_config( @unittest.override_config(
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
) )
def test_join_local_ratelimit(self): def test_join_local_ratelimit(self) -> None:
"""Tests that local joins are actually rate-limited.""" """Tests that local joins are actually rate-limited."""
for _ in range(3): for _ in range(3):
self.helper.create_room_as(self.user_id) self.helper.create_room_as(self.user_id)
@ -994,7 +997,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
@unittest.override_config( @unittest.override_config(
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
) )
def test_join_local_ratelimit_profile_change(self): def test_join_local_ratelimit_profile_change(self) -> None:
"""Tests that sending a profile update into all of the user's joined rooms isn't """Tests that sending a profile update into all of the user's joined rooms isn't
rate-limited by the rate-limiter on joins.""" rate-limited by the rate-limiter on joins."""
@ -1031,7 +1034,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
@unittest.override_config( @unittest.override_config(
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
) )
def test_join_local_ratelimit_idempotent(self): def test_join_local_ratelimit_idempotent(self) -> None:
"""Tests that the room join endpoints remain idempotent despite rate-limiting """Tests that the room join endpoints remain idempotent despite rate-limiting
on room joins.""" on room joins."""
room_id = self.helper.create_room_as(self.user_id) room_id = self.helper.create_room_as(self.user_id)
@ -1056,7 +1059,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
"autocreate_auto_join_rooms": True, "autocreate_auto_join_rooms": True,
}, },
) )
def test_autojoin_rooms(self): def test_autojoin_rooms(self) -> None:
user_id = self.register_user("testuser", "password") user_id = self.register_user("testuser", "password")
# Check that the new user successfully joined the four rooms # Check that the new user successfully joined the four rooms
@ -1071,10 +1074,10 @@ class RoomMessagesTestCase(RoomBase):
user_id = "@sid1:red" user_id = "@sid1:red"
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(self.user_id) self.room_id = self.helper.create_room_as(self.user_id)
def test_invalid_puts(self): def test_invalid_puts(self) -> None:
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
# missing keys or invalid json # missing keys or invalid json
channel = self.make_request("PUT", path, b"{}") channel = self.make_request("PUT", path, b"{}")
@ -1095,7 +1098,7 @@ class RoomMessagesTestCase(RoomBase):
channel = self.make_request("PUT", path, b"") channel = self.make_request("PUT", path, b"")
self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.result["body"])
def test_rooms_messages_sent(self): def test_rooms_messages_sent(self) -> None:
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
content = b'{"body":"test","msgtype":{"type":"a"}}' content = b'{"body":"test","msgtype":{"type":"a"}}'
@ -1119,11 +1122,11 @@ class RoomInitialSyncTestCase(RoomBase):
user_id = "@sid1:red" user_id = "@sid1:red"
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# create the room # create the room
self.room_id = self.helper.create_room_as(self.user_id) self.room_id = self.helper.create_room_as(self.user_id)
def test_initial_sync(self): def test_initial_sync(self) -> None:
channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id)
self.assertEqual(200, channel.code) self.assertEqual(200, channel.code)
@ -1131,7 +1134,7 @@ class RoomInitialSyncTestCase(RoomBase):
self.assertEqual("join", channel.json_body["membership"]) self.assertEqual("join", channel.json_body["membership"])
# Room state is easier to assert on if we unpack it into a dict # Room state is easier to assert on if we unpack it into a dict
state = {} state: JsonDict = {}
for event in channel.json_body["state"]: for event in channel.json_body["state"]:
if "state_key" not in event: if "state_key" not in event:
continue continue
@ -1160,10 +1163,10 @@ class RoomMessageListTestCase(RoomBase):
user_id = "@sid1:red" user_id = "@sid1:red"
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(self.user_id) self.room_id = self.helper.create_room_as(self.user_id)
def test_topo_token_is_accepted(self): def test_topo_token_is_accepted(self) -> None:
token = "t1-0_0_0_0_0_0_0_0_0" token = "t1-0_0_0_0_0_0_0_0_0"
channel = self.make_request( channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
@ -1174,7 +1177,7 @@ class RoomMessageListTestCase(RoomBase):
self.assertTrue("chunk" in channel.json_body) self.assertTrue("chunk" in channel.json_body)
self.assertTrue("end" in channel.json_body) self.assertTrue("end" in channel.json_body)
def test_stream_token_is_accepted_for_fwd_pagianation(self): def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
token = "s0_0_0_0_0_0_0_0_0" token = "s0_0_0_0_0_0_0_0_0"
channel = self.make_request( channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
@ -1185,7 +1188,7 @@ class RoomMessageListTestCase(RoomBase):
self.assertTrue("chunk" in channel.json_body) self.assertTrue("chunk" in channel.json_body)
self.assertTrue("end" in channel.json_body) self.assertTrue("end" in channel.json_body)
def test_room_messages_purge(self): def test_room_messages_purge(self) -> None:
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
pagination_handler = self.hs.get_pagination_handler() pagination_handler = self.hs.get_pagination_handler()
@ -1278,10 +1281,10 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
user_id = True user_id = True
hijack_auth = False hijack_auth = False
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Register the user who does the searching # Register the user who does the searching
self.user_id = self.register_user("user", "pass") self.user_id2 = self.register_user("user", "pass")
self.access_token = self.login("user", "pass") self.access_token = self.login("user", "pass")
# Register the user who sends the message # Register the user who sends the message
@ -1289,12 +1292,12 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
self.other_access_token = self.login("otheruser", "pass") self.other_access_token = self.login("otheruser", "pass")
# Create a room # Create a room
self.room = self.helper.create_room_as(self.user_id, tok=self.access_token) self.room = self.helper.create_room_as(self.user_id2, tok=self.access_token)
# Invite the other person # Invite the other person
self.helper.invite( self.helper.invite(
room=self.room, room=self.room,
src=self.user_id, src=self.user_id2,
tok=self.access_token, tok=self.access_token,
targ=self.other_user_id, targ=self.other_user_id,
) )
@ -1304,7 +1307,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
room=self.room, user=self.other_user_id, tok=self.other_access_token room=self.room, user=self.other_user_id, tok=self.other_access_token
) )
def test_finds_message(self): def test_finds_message(self) -> None:
""" """
The search functionality will search for content in messages if asked to The search functionality will search for content in messages if asked to
do so. do so.
@ -1333,7 +1336,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
# No context was requested, so we should get none. # No context was requested, so we should get none.
self.assertEqual(results["results"][0]["context"], {}) self.assertEqual(results["results"][0]["context"], {})
def test_include_context(self): def test_include_context(self) -> None:
""" """
When event_context includes include_profile, profile information will be When event_context includes include_profile, profile information will be
included in the search response. included in the search response.
@ -1379,7 +1382,7 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.url = b"/_matrix/client/r0/publicRooms" self.url = b"/_matrix/client/r0/publicRooms"
@ -1389,11 +1392,11 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
return self.hs return self.hs
def test_restricted_no_auth(self): def test_restricted_no_auth(self) -> None:
channel = self.make_request("GET", self.url) channel = self.make_request("GET", self.url)
self.assertEqual(channel.code, 401, channel.result) self.assertEqual(channel.code, 401, channel.result)
def test_restricted_auth(self): def test_restricted_auth(self) -> None:
self.register_user("user", "pass") self.register_user("user", "pass")
tok = self.login("user", "pass") tok = self.login("user", "pass")
@ -1412,19 +1415,19 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(federation_client=Mock()) return self.setup_test_homeserver(federation_client=Mock())
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.register_user("user", "pass") self.register_user("user", "pass")
self.token = self.login("user", "pass") self.token = self.login("user", "pass")
self.federation_client = hs.get_federation_client() self.federation_client = hs.get_federation_client()
def test_simple(self): def test_simple(self) -> None:
"Simple test for searching rooms over federation" "Simple test for searching rooms over federation"
self.federation_client.get_public_rooms.side_effect = ( self.federation_client.get_public_rooms.side_effect = lambda *a, **k: defer.succeed( # type: ignore[attr-defined]
lambda *a, **k: defer.succeed({}) {}
) )
search_filter = {"generic_search_term": "foobar"} search_filter = {"generic_search_term": "foobar"}
@ -1437,7 +1440,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
self.federation_client.get_public_rooms.assert_called_once_with( self.federation_client.get_public_rooms.assert_called_once_with( # type: ignore[attr-defined]
"testserv", "testserv",
limit=100, limit=100,
since_token=None, since_token=None,
@ -1446,12 +1449,12 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
third_party_instance_id=None, third_party_instance_id=None,
) )
def test_fallback(self): def test_fallback(self) -> None:
"Test that searching public rooms over federation falls back if it gets a 404" "Test that searching public rooms over federation falls back if it gets a 404"
# The `get_public_rooms` should be called again if the first call fails # The `get_public_rooms` should be called again if the first call fails
# with a 404, when using search filters. # with a 404, when using search filters.
self.federation_client.get_public_rooms.side_effect = ( self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
HttpResponseException(404, "Not Found", b""), HttpResponseException(404, "Not Found", b""),
defer.succeed({}), defer.succeed({}),
) )
@ -1466,7 +1469,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
self.federation_client.get_public_rooms.assert_has_calls( self.federation_client.get_public_rooms.assert_has_calls( # type: ignore[attr-defined]
[ [
call( call(
"testserv", "testserv",
@ -1497,14 +1500,14 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
profile.register_servlets, profile.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config() config = self.default_config()
config["allow_per_room_profiles"] = False config["allow_per_room_profiles"] = False
self.hs = self.setup_test_homeserver(config=config) self.hs = self.setup_test_homeserver(config=config)
return self.hs return self.hs
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("test", "test") self.user_id = self.register_user("test", "test")
self.tok = self.login("test", "test") self.tok = self.login("test", "test")
@ -1522,7 +1525,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
def test_per_room_profile_forbidden(self): def test_per_room_profile_forbidden(self) -> None:
data = {"membership": "join", "displayname": "other test user"} data = {"membership": "join", "displayname": "other test user"}
request_data = json.dumps(data) request_data = json.dumps(data)
channel = self.make_request( channel = self.make_request(
@ -1557,7 +1560,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.creator = self.register_user("creator", "test") self.creator = self.register_user("creator", "test")
self.creator_tok = self.login("creator", "test") self.creator_tok = self.login("creator", "test")
@ -1566,7 +1569,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok)
def test_join_reason(self): def test_join_reason(self) -> None:
reason = "hello" reason = "hello"
channel = self.make_request( channel = self.make_request(
"POST", "POST",
@ -1578,7 +1581,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason) self._check_for_reason(reason)
def test_leave_reason(self): def test_leave_reason(self) -> None:
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
reason = "hello" reason = "hello"
@ -1592,7 +1595,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason) self._check_for_reason(reason)
def test_kick_reason(self): def test_kick_reason(self) -> None:
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
reason = "hello" reason = "hello"
@ -1606,7 +1609,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason) self._check_for_reason(reason)
def test_ban_reason(self): def test_ban_reason(self) -> None:
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
reason = "hello" reason = "hello"
@ -1620,7 +1623,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason) self._check_for_reason(reason)
def test_unban_reason(self): def test_unban_reason(self) -> None:
reason = "hello" reason = "hello"
channel = self.make_request( channel = self.make_request(
"POST", "POST",
@ -1632,7 +1635,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason) self._check_for_reason(reason)
def test_invite_reason(self): def test_invite_reason(self) -> None:
reason = "hello" reason = "hello"
channel = self.make_request( channel = self.make_request(
"POST", "POST",
@ -1644,7 +1647,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason) self._check_for_reason(reason)
def test_reject_invite_reason(self): def test_reject_invite_reason(self) -> None:
self.helper.invite( self.helper.invite(
self.room_id, self.room_id,
src=self.creator, src=self.creator,
@ -1663,7 +1666,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason) self._check_for_reason(reason)
def _check_for_reason(self, reason): def _check_for_reason(self, reason: str) -> None:
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format( "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format(
@ -1704,12 +1707,12 @@ class LabelsTestCase(unittest.HomeserverTestCase):
"org.matrix.not_labels": ["#notfun"], "org.matrix.not_labels": ["#notfun"],
} }
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("test", "test") self.user_id = self.register_user("test", "test")
self.tok = self.login("test", "test") self.tok = self.login("test", "test")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
def test_context_filter_labels(self): def test_context_filter_labels(self) -> None:
"""Test that we can filter by a label on a /context request.""" """Test that we can filter by a label on a /context request."""
event_id = self._send_labelled_messages_in_room() event_id = self._send_labelled_messages_in_room()
@ -1739,7 +1742,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
events_after[0]["content"]["body"], "with right label", events_after[0] events_after[0]["content"]["body"], "with right label", events_after[0]
) )
def test_context_filter_not_labels(self): def test_context_filter_not_labels(self) -> None:
"""Test that we can filter by the absence of a label on a /context request.""" """Test that we can filter by the absence of a label on a /context request."""
event_id = self._send_labelled_messages_in_room() event_id = self._send_labelled_messages_in_room()
@ -1772,7 +1775,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
events_after[1]["content"]["body"], "with two wrong labels", events_after[1] events_after[1]["content"]["body"], "with two wrong labels", events_after[1]
) )
def test_context_filter_labels_not_labels(self): def test_context_filter_labels_not_labels(self) -> None:
"""Test that we can filter by both a label and the absence of another label on a """Test that we can filter by both a label and the absence of another label on a
/context request. /context request.
""" """
@ -1801,7 +1804,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
events_after[0]["content"]["body"], "with wrong label", events_after[0] events_after[0]["content"]["body"], "with wrong label", events_after[0]
) )
def test_messages_filter_labels(self): def test_messages_filter_labels(self) -> None:
"""Test that we can filter by a label on a /messages request.""" """Test that we can filter by a label on a /messages request."""
self._send_labelled_messages_in_room() self._send_labelled_messages_in_room()
@ -1818,7 +1821,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
def test_messages_filter_not_labels(self): def test_messages_filter_not_labels(self) -> None:
"""Test that we can filter by the absence of a label on a /messages request.""" """Test that we can filter by the absence of a label on a /messages request."""
self._send_labelled_messages_in_room() self._send_labelled_messages_in_room()
@ -1839,7 +1842,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
events[3]["content"]["body"], "with two wrong labels", events[3] events[3]["content"]["body"], "with two wrong labels", events[3]
) )
def test_messages_filter_labels_not_labels(self): def test_messages_filter_labels_not_labels(self) -> None:
"""Test that we can filter by both a label and the absence of another label on a """Test that we can filter by both a label and the absence of another label on a
/messages request. /messages request.
""" """
@ -1862,7 +1865,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(events), 1, [event["content"] for event in events]) self.assertEqual(len(events), 1, [event["content"] for event in events])
self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
def test_search_filter_labels(self): def test_search_filter_labels(self) -> None:
"""Test that we can filter by a label on a /search request.""" """Test that we can filter by a label on a /search request."""
request_data = json.dumps( request_data = json.dumps(
{ {
@ -1899,7 +1902,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
results[1]["result"]["content"]["body"], results[1]["result"]["content"]["body"],
) )
def test_search_filter_not_labels(self): def test_search_filter_not_labels(self) -> None:
"""Test that we can filter by the absence of a label on a /search request.""" """Test that we can filter by the absence of a label on a /search request."""
request_data = json.dumps( request_data = json.dumps(
{ {
@ -1946,7 +1949,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
results[3]["result"]["content"]["body"], results[3]["result"]["content"]["body"],
) )
def test_search_filter_labels_not_labels(self): def test_search_filter_labels_not_labels(self) -> None:
"""Test that we can filter by both a label and the absence of another label on a """Test that we can filter by both a label and the absence of another label on a
/search request. /search request.
""" """
@ -1980,7 +1983,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
results[0]["result"]["content"]["body"], results[0]["result"]["content"]["body"],
) )
def _send_labelled_messages_in_room(self): def _send_labelled_messages_in_room(self) -> str:
"""Sends several messages to a room with different labels (or without any) to test """Sends several messages to a room with different labels (or without any) to test
filtering by label. filtering by label.
Returns: Returns:
@ -2056,12 +2059,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def default_config(self): def default_config(self) -> Dict[str, Any]:
config = super().default_config() config = super().default_config()
config["experimental_features"] = {"msc3440_enabled": True} config["experimental_features"] = {"msc3440_enabled": True}
return config return config
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("test", "test") self.user_id = self.register_user("test", "test")
self.tok = self.login("test", "test") self.tok = self.login("test", "test")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
@ -2136,7 +2139,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
return channel.json_body["chunk"] return channel.json_body["chunk"]
def test_filter_relation_senders(self): def test_filter_relation_senders(self) -> None:
# Messages which second user reacted to. # Messages which second user reacted to.
filter = {"io.element.relation_senders": [self.second_user_id]} filter = {"io.element.relation_senders": [self.second_user_id]}
chunk = self._filter_messages(filter) chunk = self._filter_messages(filter)
@ -2159,7 +2162,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
[c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
) )
def test_filter_relation_type(self): def test_filter_relation_type(self) -> None:
# Messages which have annotations. # Messages which have annotations.
filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
chunk = self._filter_messages(filter) chunk = self._filter_messages(filter)
@ -2185,7 +2188,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
[c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
) )
def test_filter_relation_senders_and_type(self): def test_filter_relation_senders_and_type(self) -> None:
# Messages which second user reacted to. # Messages which second user reacted to.
filter = { filter = {
"io.element.relation_senders": [self.second_user_id], "io.element.relation_senders": [self.second_user_id],
@ -2205,7 +2208,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
account.register_servlets, account.register_servlets,
] ]
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("user", "password") self.user_id = self.register_user("user", "password")
self.tok = self.login("user", "password") self.tok = self.login("user", "password")
self.room_id = self.helper.create_room_as( self.room_id = self.helper.create_room_as(
@ -2218,7 +2221,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok) self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok)
self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok) self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok)
def test_erased_sender(self): def test_erased_sender(self) -> None:
"""Test that an erasure request results in the requester's events being hidden """Test that an erasure request results in the requester's events being hidden
from any new member of the room. from any new member of the room.
""" """
@ -2332,7 +2335,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
room.register_servlets, room.register_servlets,
] ]
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_owner = self.register_user("room_owner", "test") self.room_owner = self.register_user("room_owner", "test")
self.room_owner_tok = self.login("room_owner", "test") self.room_owner_tok = self.login("room_owner", "test")
@ -2340,17 +2343,17 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
self.room_owner, tok=self.room_owner_tok self.room_owner, tok=self.room_owner_tok
) )
def test_no_aliases(self): def test_no_aliases(self) -> None:
res = self._get_aliases(self.room_owner_tok) res = self._get_aliases(self.room_owner_tok)
self.assertEqual(res["aliases"], []) self.assertEqual(res["aliases"], [])
def test_not_in_room(self): def test_not_in_room(self) -> None:
self.register_user("user", "test") self.register_user("user", "test")
user_tok = self.login("user", "test") user_tok = self.login("user", "test")
res = self._get_aliases(user_tok, expected_code=403) res = self._get_aliases(user_tok, expected_code=403)
self.assertEqual(res["errcode"], "M_FORBIDDEN") self.assertEqual(res["errcode"], "M_FORBIDDEN")
def test_admin_user(self): def test_admin_user(self) -> None:
alias1 = self._random_alias() alias1 = self._random_alias()
self._set_alias_via_directory(alias1) self._set_alias_via_directory(alias1)
@ -2360,7 +2363,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
res = self._get_aliases(user_tok) res = self._get_aliases(user_tok)
self.assertEqual(res["aliases"], [alias1]) self.assertEqual(res["aliases"], [alias1])
def test_with_aliases(self): def test_with_aliases(self) -> None:
alias1 = self._random_alias() alias1 = self._random_alias()
alias2 = self._random_alias() alias2 = self._random_alias()
@ -2370,7 +2373,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
res = self._get_aliases(self.room_owner_tok) res = self._get_aliases(self.room_owner_tok)
self.assertEqual(set(res["aliases"]), {alias1, alias2}) self.assertEqual(set(res["aliases"]), {alias1, alias2})
def test_peekable_room(self): def test_peekable_room(self) -> None:
alias1 = self._random_alias() alias1 = self._random_alias()
self._set_alias_via_directory(alias1) self._set_alias_via_directory(alias1)
@ -2404,7 +2407,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
def _random_alias(self) -> str: def _random_alias(self) -> str:
return RoomAlias(random_string(5), self.hs.hostname).to_string() return RoomAlias(random_string(5), self.hs.hostname).to_string()
def _set_alias_via_directory(self, alias: str, expected_code: int = 200): def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None:
url = "/_matrix/client/r0/directory/room/" + alias url = "/_matrix/client/r0/directory/room/" + alias
data = {"room_id": self.room_id} data = {"room_id": self.room_id}
request_data = json.dumps(data) request_data = json.dumps(data)
@ -2423,7 +2426,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
room.register_servlets, room.register_servlets,
] ]
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_owner = self.register_user("room_owner", "test") self.room_owner = self.register_user("room_owner", "test")
self.room_owner_tok = self.login("room_owner", "test") self.room_owner_tok = self.login("room_owner", "test")
@ -2434,7 +2437,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
self.alias = "#alias:test" self.alias = "#alias:test"
self._set_alias_via_directory(self.alias) self._set_alias_via_directory(self.alias)
def _set_alias_via_directory(self, alias: str, expected_code: int = 200): def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None:
url = "/_matrix/client/r0/directory/room/" + alias url = "/_matrix/client/r0/directory/room/" + alias
data = {"room_id": self.room_id} data = {"room_id": self.room_id}
request_data = json.dumps(data) request_data = json.dumps(data)
@ -2456,7 +2459,9 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
self.assertIsInstance(res, dict) self.assertIsInstance(res, dict)
return res return res
def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict: def _set_canonical_alias(
self, content: JsonDict, expected_code: int = 200
) -> JsonDict:
"""Calls the endpoint under test. returns the json response object.""" """Calls the endpoint under test. returns the json response object."""
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
@ -2469,7 +2474,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
self.assertIsInstance(res, dict) self.assertIsInstance(res, dict)
return res return res
def test_canonical_alias(self): def test_canonical_alias(self) -> None:
"""Test a basic alias message.""" """Test a basic alias message."""
# There is no canonical alias to start with. # There is no canonical alias to start with.
self._get_canonical_alias(expected_code=404) self._get_canonical_alias(expected_code=404)
@ -2488,7 +2493,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
res = self._get_canonical_alias() res = self._get_canonical_alias()
self.assertEqual(res, {}) self.assertEqual(res, {})
def test_alt_aliases(self): def test_alt_aliases(self) -> None:
"""Test a canonical alias message with alt_aliases.""" """Test a canonical alias message with alt_aliases."""
# Create an alias. # Create an alias.
self._set_canonical_alias({"alt_aliases": [self.alias]}) self._set_canonical_alias({"alt_aliases": [self.alias]})
@ -2504,7 +2509,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
res = self._get_canonical_alias() res = self._get_canonical_alias()
self.assertEqual(res, {}) self.assertEqual(res, {})
def test_alias_alt_aliases(self): def test_alias_alt_aliases(self) -> None:
"""Test a canonical alias message with an alias and alt_aliases.""" """Test a canonical alias message with an alias and alt_aliases."""
# Create an alias. # Create an alias.
self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
@ -2520,7 +2525,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
res = self._get_canonical_alias() res = self._get_canonical_alias()
self.assertEqual(res, {}) self.assertEqual(res, {})
def test_partial_modify(self): def test_partial_modify(self) -> None:
"""Test removing only the alt_aliases.""" """Test removing only the alt_aliases."""
# Create an alias. # Create an alias.
self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
@ -2536,7 +2541,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
res = self._get_canonical_alias() res = self._get_canonical_alias()
self.assertEqual(res, {"alias": self.alias}) self.assertEqual(res, {"alias": self.alias})
def test_add_alias(self): def test_add_alias(self) -> None:
"""Test removing only the alt_aliases.""" """Test removing only the alt_aliases."""
# Create an additional alias. # Create an additional alias.
second_alias = "#second:test" second_alias = "#second:test"
@ -2556,7 +2561,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]}
) )
def test_bad_data(self): def test_bad_data(self) -> None:
"""Invalid data for alt_aliases should cause errors.""" """Invalid data for alt_aliases should cause errors."""
self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400) self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400)
self._set_canonical_alias({"alt_aliases": None}, expected_code=400) self._set_canonical_alias({"alt_aliases": None}, expected_code=400)
@ -2566,7 +2571,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
self._set_canonical_alias({"alt_aliases": True}, expected_code=400) self._set_canonical_alias({"alt_aliases": True}, expected_code=400)
self._set_canonical_alias({"alt_aliases": {}}, expected_code=400) self._set_canonical_alias({"alt_aliases": {}}, expected_code=400)
def test_bad_alias(self): def test_bad_alias(self) -> None:
"""An alias which does not point to the room raises a SynapseError.""" """An alias which does not point to the room raises a SynapseError."""
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
@ -2580,13 +2585,13 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
room.register_servlets, room.register_servlets,
] ]
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("thomas", "hackme") self.user_id = self.register_user("thomas", "hackme")
self.tok = self.login("thomas", "hackme") self.tok = self.login("thomas", "hackme")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
def test_threepid_invite_spamcheck(self): def test_threepid_invite_spamcheck(self) -> None:
# Mock a few functions to prevent the test from failing due to failing to talk to # Mock a few functions to prevent the test from failing due to failing to talk to
# a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we # a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we
# can check its call_count later on during the test. # can check its call_count later on during the test.

View File

@ -12,16 +12,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import threading import threading
from typing import TYPE_CHECKING, Dict, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from unittest.mock import Mock from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, LoginType, Membership from synapse.api.constants import EventTypes, LoginType, Membership
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import account, login, profile, room from synapse.rest.client import account, login, profile, room
from synapse.server import HomeServer
from synapse.types import JsonDict, Requester, StateMap from synapse.types import JsonDict, Requester, StateMap
from synapse.util import Clock
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
from tests import unittest from tests import unittest
@ -34,7 +40,7 @@ thread_local = threading.local()
class LegacyThirdPartyRulesTestModule: class LegacyThirdPartyRulesTestModule:
def __init__(self, config: Dict, module_api: "ModuleApi"): def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
# keep a record of the "current" rules module, so that the test can patch # keep a record of the "current" rules module, so that the test can patch
# it if desired. # it if desired.
thread_local.rules_module = self thread_local.rules_module = self
@ -42,32 +48,36 @@ class LegacyThirdPartyRulesTestModule:
async def on_create_room( async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool self, requester: Requester, config: dict, is_requester_admin: bool
): ) -> bool:
return True return True
async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]): async def check_event_allowed(
self, event: EventBase, state: StateMap[EventBase]
) -> Union[bool, dict]:
return True return True
@staticmethod @staticmethod
def parse_config(config): def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
return config return config
class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule): class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
def __init__(self, config: Dict, module_api: "ModuleApi"): def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
super().__init__(config, module_api) super().__init__(config, module_api)
def on_create_room( async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool self, requester: Requester, config: dict, is_requester_admin: bool
): ) -> bool:
return False return False
class LegacyChangeEvents(LegacyThirdPartyRulesTestModule): class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
def __init__(self, config: Dict, module_api: "ModuleApi"): def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
super().__init__(config, module_api) super().__init__(config, module_api)
async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]): async def check_event_allowed(
self, event: EventBase, state: StateMap[EventBase]
) -> JsonDict:
d = event.get_dict() d = event.get_dict()
content = unfreeze(event.content) content = unfreeze(event.content)
content["foo"] = "bar" content["foo"] = "bar"
@ -84,7 +94,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
account.register_servlets, account.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver() hs = self.setup_test_homeserver()
load_legacy_third_party_event_rules(hs) load_legacy_third_party_event_rules(hs)
@ -94,22 +104,30 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Note that these checks are not relevant to this test case. # Note that these checks are not relevant to this test case.
# Have this homeserver auto-approve all event signature checking. # Have this homeserver auto-approve all event signature checking.
async def approve_all_signature_checking(_, pdu): async def approve_all_signature_checking(
_: RoomVersion, pdu: EventBase
) -> EventBase:
return pdu return pdu
hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking # type: ignore[assignment]
# Have this homeserver skip event auth checks. This is necessary due to # Have this homeserver skip event auth checks. This is necessary due to
# event auth checks ensuring that events were signed by the sender's homeserver. # event auth checks ensuring that events were signed by the sender's homeserver.
async def _check_event_auth(origin, event, context, *args, **kwargs): async def _check_event_auth(
origin: str,
event: EventBase,
context: EventContext,
*args: Any,
**kwargs: Any,
) -> EventContext:
return context return context
hs.get_federation_event_handler()._check_event_auth = _check_event_auth hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment]
return hs return hs
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, homeserver) super().prepare(reactor, clock, hs)
# Create some users and a room to play with during the tests # Create some users and a room to play with during the tests
self.user_id = self.register_user("kermit", "monkey") self.user_id = self.register_user("kermit", "monkey")
self.invitee = self.register_user("invitee", "hackme") self.invitee = self.register_user("invitee", "hackme")
@ -121,13 +139,15 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
except Exception: except Exception:
pass pass
def test_third_party_rules(self): def test_third_party_rules(self) -> None:
"""Tests that a forbidden event is forbidden from being sent, but an allowed one """Tests that a forbidden event is forbidden from being sent, but an allowed one
can be sent. can be sent.
""" """
# patch the rules module with a Mock which will return False for some event # patch the rules module with a Mock which will return False for some event
# types # types
async def check(ev, state): async def check(
ev: EventBase, state: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
return ev.type != "foo.bar.forbidden", None return ev.type != "foo.bar.forbidden", None
callback = Mock(spec=[], side_effect=check) callback = Mock(spec=[], side_effect=check)
@ -161,7 +181,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
) )
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
def test_third_party_rules_workaround_synapse_errors_pass_through(self): def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None:
""" """
Tests that the workaround introduced by https://github.com/matrix-org/synapse/pull/11042 Tests that the workaround introduced by https://github.com/matrix-org/synapse/pull/11042
is functional: that SynapseErrors are passed through from check_event_allowed is functional: that SynapseErrors are passed through from check_event_allowed
@ -172,7 +192,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
""" """
class NastyHackException(SynapseError): class NastyHackException(SynapseError):
def error_dict(self): def error_dict(self) -> JsonDict:
""" """
This overrides SynapseError's `error_dict` to nastily inject This overrides SynapseError's `error_dict` to nastily inject
JSON into the error response. JSON into the error response.
@ -182,7 +202,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
return result return result
# add a callback that will raise our hacky exception # add a callback that will raise our hacky exception
async def check(ev, state) -> Tuple[bool, Optional[JsonDict]]: async def check(
ev: EventBase, state: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
raise NastyHackException(429, "message") raise NastyHackException(429, "message")
self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check] self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
@ -202,11 +224,13 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
{"errcode": "M_UNKNOWN", "error": "message", "nasty": "very"}, {"errcode": "M_UNKNOWN", "error": "message", "nasty": "very"},
) )
def test_cannot_modify_event(self): def test_cannot_modify_event(self) -> None:
"""cannot accidentally modify an event before it is persisted""" """cannot accidentally modify an event before it is persisted"""
# first patch the event checker so that it will try to modify the event # first patch the event checker so that it will try to modify the event
async def check(ev: EventBase, state): async def check(
ev: EventBase, state: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
ev.content = {"x": "y"} ev.content = {"x": "y"}
return True, None return True, None
@ -223,10 +247,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# 500 Internal Server Error # 500 Internal Server Error
self.assertEqual(channel.code, 500, channel.result) self.assertEqual(channel.code, 500, channel.result)
def test_modify_event(self): def test_modify_event(self) -> None:
"""The module can return a modified version of the event""" """The module can return a modified version of the event"""
# first patch the event checker so that it will modify the event # first patch the event checker so that it will modify the event
async def check(ev: EventBase, state): async def check(
ev: EventBase, state: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
d = ev.get_dict() d = ev.get_dict()
d["content"] = {"x": "y"} d["content"] = {"x": "y"}
return True, d return True, d
@ -253,10 +279,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
ev = channel.json_body ev = channel.json_body
self.assertEqual(ev["content"]["x"], "y") self.assertEqual(ev["content"]["x"], "y")
def test_message_edit(self): def test_message_edit(self) -> None:
"""Ensure that the module doesn't cause issues with edited messages.""" """Ensure that the module doesn't cause issues with edited messages."""
# first patch the event checker so that it will modify the event # first patch the event checker so that it will modify the event
async def check(ev: EventBase, state): async def check(
ev: EventBase, state: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
d = ev.get_dict() d = ev.get_dict()
d["content"] = { d["content"] = {
"msgtype": "m.text", "msgtype": "m.text",
@ -315,7 +343,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
ev = channel.json_body ev = channel.json_body
self.assertEqual(ev["content"]["body"], "EDITED BODY") self.assertEqual(ev["content"]["body"], "EDITED BODY")
def test_send_event(self): def test_send_event(self) -> None:
"""Tests that a module can send an event into a room via the module api""" """Tests that a module can send an event into a room via the module api"""
content = { content = {
"msgtype": "m.text", "msgtype": "m.text",
@ -344,7 +372,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
} }
} }
) )
def test_legacy_check_event_allowed(self): def test_legacy_check_event_allowed(self) -> None:
"""Tests that the wrapper for legacy check_event_allowed callbacks works """Tests that the wrapper for legacy check_event_allowed callbacks works
correctly. correctly.
""" """
@ -379,13 +407,13 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
} }
} }
) )
def test_legacy_on_create_room(self): def test_legacy_on_create_room(self) -> None:
"""Tests that the wrapper for legacy on_create_room callbacks works """Tests that the wrapper for legacy on_create_room callbacks works
correctly. correctly.
""" """
self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403) self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403)
def test_sent_event_end_up_in_room_state(self): def test_sent_event_end_up_in_room_state(self) -> None:
"""Tests that a state event sent by a module while processing another state event """Tests that a state event sent by a module while processing another state event
doesn't get dropped from the state of the room. This is to guard against a bug doesn't get dropped from the state of the room. This is to guard against a bug
where Synapse has been observed doing so, see https://github.com/matrix-org/synapse/issues/10830 where Synapse has been observed doing so, see https://github.com/matrix-org/synapse/issues/10830
@ -400,7 +428,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
api = self.hs.get_module_api() api = self.hs.get_module_api()
# Define a callback that sends a custom event on power levels update. # Define a callback that sends a custom event on power levels update.
async def test_fn(event: EventBase, state_events): async def test_fn(
event: EventBase, state_events: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
if event.is_state and event.type == EventTypes.PowerLevels: if event.is_state and event.type == EventTypes.PowerLevels:
await api.create_and_send_event_into_room( await api.create_and_send_event_into_room(
{ {
@ -436,7 +466,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["i"], i) self.assertEqual(channel.json_body["i"], i)
def test_on_new_event(self): def test_on_new_event(self) -> None:
"""Test that the on_new_event callback is called on new events""" """Test that the on_new_event callback is called on new events"""
on_new_event = Mock(make_awaitable(None)) on_new_event = Mock(make_awaitable(None))
self.hs.get_third_party_event_rules()._on_new_event_callbacks.append( self.hs.get_third_party_event_rules()._on_new_event_callbacks.append(
@ -501,7 +531,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
def _update_power_levels(self, event_default: int = 0): def _update_power_levels(self, event_default: int = 0) -> None:
"""Updates the room's power levels. """Updates the room's power levels.
Args: Args:
@ -533,7 +563,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
tok=self.tok, tok=self.tok,
) )
def test_on_profile_update(self): def test_on_profile_update(self) -> None:
"""Tests that the on_profile_update module callback is correctly called on """Tests that the on_profile_update module callback is correctly called on
profile updates. profile updates.
""" """
@ -592,7 +622,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(profile_info.display_name, displayname) self.assertEqual(profile_info.display_name, displayname)
self.assertEqual(profile_info.avatar_url, avatar_url) self.assertEqual(profile_info.avatar_url, avatar_url)
def test_on_profile_update_admin(self): def test_on_profile_update_admin(self) -> None:
"""Tests that the on_profile_update module callback is correctly called on """Tests that the on_profile_update module callback is correctly called on
profile updates triggered by a server admin. profile updates triggered by a server admin.
""" """
@ -634,7 +664,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(profile_info.display_name, displayname) self.assertEqual(profile_info.display_name, displayname)
self.assertEqual(profile_info.avatar_url, avatar_url) self.assertEqual(profile_info.avatar_url, avatar_url)
def test_on_user_deactivation_status_changed(self): def test_on_user_deactivation_status_changed(self) -> None:
"""Tests that the on_user_deactivation_status_changed module callback is called """Tests that the on_user_deactivation_status_changed module callback is called
correctly when processing a user's deactivation. correctly when processing a user's deactivation.
""" """
@ -691,7 +721,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
args = profile_mock.call_args[0] args = profile_mock.call_args[0]
self.assertTrue(args[3]) self.assertTrue(args[3])
def test_on_user_deactivation_status_changed_admin(self): def test_on_user_deactivation_status_changed_admin(self) -> None:
"""Tests that the on_user_deactivation_status_changed module callback is called """Tests that the on_user_deactivation_status_changed module callback is called
correctly when processing a user's deactivation triggered by a server admin as correctly when processing a user's deactivation triggered by a server admin as
well as a reactivation. well as a reactivation.

View File

@ -14,11 +14,16 @@
# limitations under the License. # limitations under the License.
"""Tests REST events for /rooms paths.""" """Tests REST events for /rooms paths."""
from typing import Any
from unittest.mock import Mock from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest.client import room from synapse.rest.client import room
from synapse.server import HomeServer
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import UserID from synapse.types import UserID
from synapse.util import Clock
from tests import unittest from tests import unittest
@ -33,7 +38,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
servlets = [room.register_servlets] servlets = [room.register_servlets]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
"red", "red",
@ -43,30 +48,34 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
self.event_source = hs.get_event_sources().sources.typing self.event_source = hs.get_event_sources().sources.typing
hs.get_federation_handler = Mock() hs.get_federation_handler = Mock() # type: ignore[assignment]
async def get_user_by_access_token(token=None, allow_guest=False): async def get_user_by_access_token(
return { token: str,
"user": UserID.from_string(self.auth_user_id), rights: str = "access",
"token_id": 1, allow_expired: bool = False,
"is_guest": False, ) -> TokenLookupResult:
} return TokenLookupResult(
user_id=self.user_id,
is_guest=False,
token_id=1,
)
hs.get_auth().get_user_by_access_token = get_user_by_access_token hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment]
async def _insert_client_ip(*args, **kwargs): async def _insert_client_ip(*args: Any, **kwargs: Any) -> None:
return None return None
hs.get_datastores().main.insert_client_ip = _insert_client_ip hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[assignment]
return hs return hs
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(self.user_id) self.room_id = self.helper.create_room_as(self.user_id)
# Need another user to make notifications actually work # Need another user to make notifications actually work
self.helper.join(self.room_id, user="@jim:red") self.helper.join(self.room_id, user="@jim:red")
def test_set_typing(self): def test_set_typing(self) -> None:
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id), "/rooms/%s/typing/%s" % (self.room_id, self.user_id),
@ -95,7 +104,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
], ],
) )
def test_set_not_typing(self): def test_set_not_typing(self) -> None:
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id), "/rooms/%s/typing/%s" % (self.room_id, self.user_id),
@ -103,7 +112,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(200, channel.code) self.assertEqual(200, channel.code)
def test_typing_timeout(self): def test_typing_timeout(self) -> None:
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id), "/rooms/%s/typing/%s" % (self.room_id, self.user_id),