Proper types for tests.module_api (#15031)

* -> None for test methods

* A first batch of type fixes

* Introduce common parent test case

* Fixup that big test method

* tests.module_api passes mypy

* Changelog
This commit is contained in:
David Robertson 2023-02-09 00:23:35 +00:00 committed by GitHub
parent 30509a1010
commit 7081bb56e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 80 additions and 54 deletions

1
changelog.d/15031.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -32,7 +32,6 @@ exclude = (?x)
|synapse/storage/databases/main/cache.py |synapse/storage/databases/main/cache.py
|synapse/storage/schema/ |synapse/storage/schema/
|tests/module_api/test_api.py
|tests/server.py |tests/server.py
)$ )$

View File

@ -31,7 +31,11 @@ from synapse.util import Clock
from tests.handlers.test_sync import generate_sync_config from tests.handlers.test_sync import generate_sync_config
from tests.test_utils import simple_async_mock from tests.test_utils import simple_async_mock
from tests.unittest import FederatingHomeserverTestCase, override_config from tests.unittest import (
FederatingHomeserverTestCase,
HomeserverTestCase,
override_config,
)
@attr.s @attr.s
@ -470,7 +474,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
def send_presence_update( def send_presence_update(
testcase: FederatingHomeserverTestCase, testcase: HomeserverTestCase,
user_id: str, user_id: str,
access_token: str, access_token: str,
presence_state: str, presence_state: str,
@ -491,7 +495,7 @@ def send_presence_update(
def sync_presence( def sync_presence(
testcase: FederatingHomeserverTestCase, testcase: HomeserverTestCase,
user_id: str, user_id: str,
since_token: Optional[StreamToken] = None, since_token: Optional[StreamToken] = None,
) -> Tuple[List[UserPresenceState], StreamToken]: ) -> Tuple[List[UserPresenceState], StreamToken]:

View File

@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict
from unittest.mock import Mock from unittest.mock import Mock
from twisted.internet import defer from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes, EventTypes from synapse.api.constants import EduTypes, EventTypes
from synapse.api.errors import NotFoundError from synapse.api.errors import NotFoundError
@ -21,9 +23,12 @@ from synapse.events import EventBase
from synapse.federation.units import Transaction from synapse.federation.units import Transaction
from synapse.handlers.presence import UserPresenceState from synapse.handlers.presence import UserPresenceState
from synapse.handlers.push_rules import InvalidRuleException from synapse.handlers.push_rules import InvalidRuleException
from synapse.module_api import ModuleApi
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, notifications, presence, profile, room from synapse.rest.client import login, notifications, presence, profile, room
from synapse.types import create_requester from synapse.server import HomeServer
from synapse.types import JsonDict, create_requester
from synapse.util import Clock
from tests.events.test_presence_router import send_presence_update, sync_presence from tests.events.test_presence_router import send_presence_update, sync_presence
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
@ -32,7 +37,19 @@ from tests.test_utils.event_injection import inject_member_event
from tests.unittest import HomeserverTestCase, override_config from tests.unittest import HomeserverTestCase, override_config
class ModuleApiTestCase(HomeserverTestCase): class BaseModuleApiTestCase(HomeserverTestCase):
"""Common properties of the two test case classes."""
module_api: ModuleApi
# These are all written by _test_sending_local_online_presence_to_local_user.
presence_receiver_id: str
presence_receiver_tok: str
presence_sender_id: str
presence_sender_tok: str
class ModuleApiTestCase(BaseModuleApiTestCase):
servlets = [ servlets = [
admin.register_servlets, admin.register_servlets,
login.register_servlets, login.register_servlets,
@ -42,14 +59,14 @@ class ModuleApiTestCase(HomeserverTestCase):
notifications.register_servlets, notifications.register_servlets,
] ]
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = homeserver.get_datastores().main self.store = hs.get_datastores().main
self.module_api = homeserver.get_module_api() self.module_api = hs.get_module_api()
self.event_creation_handler = homeserver.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.sync_handler = homeserver.get_sync_handler() self.sync_handler = hs.get_sync_handler()
self.auth_handler = homeserver.get_auth_handler() self.auth_handler = hs.get_auth_handler()
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation. # Mock out the calls over federation.
fed_transport_client = Mock(spec=["send_transaction"]) fed_transport_client = Mock(spec=["send_transaction"])
fed_transport_client.send_transaction = simple_async_mock({}) fed_transport_client.send_transaction = simple_async_mock({})
@ -58,7 +75,7 @@ class ModuleApiTestCase(HomeserverTestCase):
federation_transport_client=fed_transport_client, federation_transport_client=fed_transport_client,
) )
def test_can_register_user(self): def test_can_register_user(self) -> None:
"""Tests that an external module can register a user""" """Tests that an external module can register a user"""
# Register a new user # Register a new user
user_id, access_token = self.get_success( user_id, access_token = self.get_success(
@ -88,16 +105,17 @@ class ModuleApiTestCase(HomeserverTestCase):
displayname = self.get_success(self.store.get_profile_displayname("bob")) displayname = self.get_success(self.store.get_profile_displayname("bob"))
self.assertEqual(displayname, "Bobberino") self.assertEqual(displayname, "Bobberino")
def test_can_register_admin_user(self): def test_can_register_admin_user(self) -> None:
user_id = self.register_user( user_id = self.register_user(
"bob_module_admin", "1234", displayname="Bobberino Admin", admin=True "bob_module_admin", "1234", displayname="Bobberino Admin", admin=True
) )
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
assert found_user is not None
self.assertEqual(found_user.user_id.to_string(), user_id) self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, True) self.assertIdentical(found_user.is_admin, True)
def test_can_set_admin(self): def test_can_set_admin(self) -> None:
user_id = self.register_user( user_id = self.register_user(
"alice_wants_admin", "alice_wants_admin",
"1234", "1234",
@ -107,16 +125,17 @@ class ModuleApiTestCase(HomeserverTestCase):
self.get_success(self.module_api.set_user_admin(user_id, True)) self.get_success(self.module_api.set_user_admin(user_id, True))
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
assert found_user is not None
self.assertEqual(found_user.user_id.to_string(), user_id) self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, True) self.assertIdentical(found_user.is_admin, True)
def test_can_set_displayname(self): def test_can_set_displayname(self) -> None:
localpart = "alice_wants_a_new_displayname" localpart = "alice_wants_a_new_displayname"
user_id = self.register_user( user_id = self.register_user(
localpart, "1234", displayname="Alice", admin=False localpart, "1234", displayname="Alice", admin=False
) )
found_userinfo = self.get_success(self.module_api.get_userinfo_by_id(user_id)) found_userinfo = self.get_success(self.module_api.get_userinfo_by_id(user_id))
assert found_userinfo is not None
self.get_success( self.get_success(
self.module_api.set_displayname( self.module_api.set_displayname(
found_userinfo.user_id, "Bob", deactivation=False found_userinfo.user_id, "Bob", deactivation=False
@ -128,17 +147,18 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(found_profile.display_name, "Bob") self.assertEqual(found_profile.display_name, "Bob")
def test_get_userinfo_by_id(self): def test_get_userinfo_by_id(self) -> None:
user_id = self.register_user("alice", "1234") user_id = self.register_user("alice", "1234")
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
assert found_user is not None
self.assertEqual(found_user.user_id.to_string(), user_id) self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, False) self.assertIdentical(found_user.is_admin, False)
def test_get_userinfo_by_id__no_user_found(self): def test_get_userinfo_by_id__no_user_found(self) -> None:
found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test")) found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test"))
self.assertIsNone(found_user) self.assertIsNone(found_user)
def test_get_user_ip_and_agents(self): def test_get_user_ip_and_agents(self) -> None:
user_id = self.register_user("test_get_user_ip_and_agents_user", "1234") user_id = self.register_user("test_get_user_ip_and_agents_user", "1234")
# Initially, we should have no ip/agent for our user. # Initially, we should have no ip/agent for our user.
@ -185,7 +205,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# we should only find the second ip, agent. # we should only find the second ip, agent.
info = self.get_success( info = self.get_success(
self.module_api.get_user_ip_and_agents( self.module_api.get_user_ip_and_agents(
user_id, (last_seen_1 + last_seen_2) / 2 user_id, (last_seen_1 + last_seen_2) // 2
) )
) )
self.assertEqual(len(info), 1) self.assertEqual(len(info), 1)
@ -200,7 +220,7 @@ class ModuleApiTestCase(HomeserverTestCase):
) )
self.assertEqual(info, []) self.assertEqual(info, [])
def test_get_user_ip_and_agents__no_user_found(self): def test_get_user_ip_and_agents__no_user_found(self) -> None:
info = self.get_success( info = self.get_success(
self.module_api.get_user_ip_and_agents( self.module_api.get_user_ip_and_agents(
"@test_get_user_ip_and_agents_user_nonexistent:example.com" "@test_get_user_ip_and_agents_user_nonexistent:example.com"
@ -208,10 +228,10 @@ class ModuleApiTestCase(HomeserverTestCase):
) )
self.assertEqual(info, []) self.assertEqual(info, [])
def test_sending_events_into_room(self): def test_sending_events_into_room(self) -> None:
"""Tests that a module can send events into a room""" """Tests that a module can send events into a room"""
# Mock out create_and_send_nonmember_event to check whether events are being sent # Mock out create_and_send_nonmember_event to check whether events are being sent
self.event_creation_handler.create_and_send_nonmember_event = Mock( self.event_creation_handler.create_and_send_nonmember_event = Mock( # type: ignore[assignment]
spec=[], spec=[],
side_effect=self.event_creation_handler.create_and_send_nonmember_event, side_effect=self.event_creation_handler.create_and_send_nonmember_event,
) )
@ -222,7 +242,7 @@ class ModuleApiTestCase(HomeserverTestCase):
room_id = self.helper.create_room_as(user_id, tok=tok) room_id = self.helper.create_room_as(user_id, tok=tok)
# Create and send a non-state event # Create and send a non-state event
content = {"body": "I am a puppet", "msgtype": "m.text"} content: JsonDict = {"body": "I am a puppet", "msgtype": "m.text"}
event_dict = { event_dict = {
"room_id": room_id, "room_id": room_id,
"type": "m.room.message", "type": "m.room.message",
@ -265,7 +285,7 @@ class ModuleApiTestCase(HomeserverTestCase):
"sender": user_id, "sender": user_id,
"state_key": "", "state_key": "",
} }
event: EventBase = self.get_success( event = self.get_success(
self.module_api.create_and_send_event_into_room(event_dict) self.module_api.create_and_send_event_into_room(event_dict)
) )
self.assertEqual(event.sender, user_id) self.assertEqual(event.sender, user_id)
@ -303,7 +323,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.module_api.create_and_send_event_into_room(event_dict), Exception self.module_api.create_and_send_event_into_room(event_dict), Exception
) )
def test_public_rooms(self): def test_public_rooms(self) -> None:
"""Tests that a room can be added and removed from the public rooms list, """Tests that a room can be added and removed from the public rooms list,
as well as have its public rooms directory state queried. as well as have its public rooms directory state queried.
""" """
@ -350,13 +370,13 @@ class ModuleApiTestCase(HomeserverTestCase):
) )
self.assertFalse(is_in_public_rooms) self.assertFalse(is_in_public_rooms)
def test_send_local_online_presence_to(self): def test_send_local_online_presence_to(self) -> None:
# Test sending local online presence to users from the main process # Test sending local online presence to users from the main process
_test_sending_local_online_presence_to_local_user(self, test_with_workers=False) _test_sending_local_online_presence_to_local_user(self, test_with_workers=False)
# Enable federation sending on the main process. # Enable federation sending on the main process.
@override_config({"federation_sender_instances": None}) @override_config({"federation_sender_instances": None})
def test_send_local_online_presence_to_federation(self): def test_send_local_online_presence_to_federation(self) -> None:
"""Tests that send_local_presence_to_users sends local online presence to remote users.""" """Tests that send_local_presence_to_users sends local online presence to remote users."""
# Create a user who will send presence updates # Create a user who will send presence updates
self.presence_sender_id = self.register_user("presence_sender1", "monkey") self.presence_sender_id = self.register_user("presence_sender1", "monkey")
@ -431,7 +451,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertTrue(found_update) self.assertTrue(found_update)
def test_update_membership(self): def test_update_membership(self) -> None:
"""Tests that the module API can update the membership of a user in a room.""" """Tests that the module API can update the membership of a user in a room."""
peter = self.register_user("peter", "hackme") peter = self.register_user("peter", "hackme")
lesley = self.register_user("lesley", "hackme") lesley = self.register_user("lesley", "hackme")
@ -554,7 +574,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(res["displayname"], "simone") self.assertEqual(res["displayname"], "simone")
self.assertIsNone(res["avatar_url"]) self.assertIsNone(res["avatar_url"])
def test_update_room_membership_remote_join(self): def test_update_room_membership_remote_join(self) -> None:
"""Test that the module API can join a remote room.""" """Test that the module API can join a remote room."""
# Necessary to fake a remote join. # Necessary to fake a remote join.
fake_stream_id = 1 fake_stream_id = 1
@ -582,7 +602,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that a remote join was attempted. # Check that a remote join was attempted.
self.assertEqual(mocked_remote_join.call_count, 1) self.assertEqual(mocked_remote_join.call_count, 1)
def test_get_room_state(self): def test_get_room_state(self) -> None:
"""Tests that a module can retrieve the state of a room through the module API.""" """Tests that a module can retrieve the state of a room through the module API."""
user_id = self.register_user("peter", "hackme") user_id = self.register_user("peter", "hackme")
tok = self.login("peter", "hackme") tok = self.login("peter", "hackme")
@ -677,7 +697,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.module_api.check_push_rule_actions(["foo"]) self.module_api.check_push_rule_actions(["foo"])
with self.assertRaises(InvalidRuleException): with self.assertRaises(InvalidRuleException):
self.module_api.check_push_rule_actions({"foo": "bar"}) self.module_api.check_push_rule_actions([{"foo": "bar"}])
self.module_api.check_push_rule_actions(["notify"]) self.module_api.check_push_rule_actions(["notify"])
@ -756,7 +776,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertIsNone(room_alias) self.assertIsNone(room_alias)
class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase): class ModuleApiWorkerTestCase(BaseModuleApiTestCase, BaseMultiWorkerStreamTestCase):
"""For testing ModuleApi functionality in a multi-worker setup""" """For testing ModuleApi functionality in a multi-worker setup"""
servlets = [ servlets = [
@ -766,7 +786,7 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
presence.register_servlets, presence.register_servlets,
] ]
def default_config(self): def default_config(self) -> Dict[str, Any]:
conf = super().default_config() conf = super().default_config()
conf["stream_writers"] = {"presence": ["presence_writer"]} conf["stream_writers"] = {"presence": ["presence_writer"]}
conf["instance_map"] = { conf["instance_map"] = {
@ -774,18 +794,18 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
} }
return conf return conf
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.module_api = homeserver.get_module_api() self.module_api = hs.get_module_api()
self.sync_handler = homeserver.get_sync_handler() self.sync_handler = hs.get_sync_handler()
def test_send_local_online_presence_to_workers(self): def test_send_local_online_presence_to_workers(self) -> None:
# Test sending local online presence to users from a worker process # Test sending local online presence to users from a worker process
_test_sending_local_online_presence_to_local_user(self, test_with_workers=True) _test_sending_local_online_presence_to_local_user(self, test_with_workers=True)
def _test_sending_local_online_presence_to_local_user( def _test_sending_local_online_presence_to_local_user(
test_case: HomeserverTestCase, test_with_workers: bool = False test_case: BaseModuleApiTestCase, test_with_workers: bool = False
): ) -> None:
"""Tests that send_local_presence_to_users sends local online presence to local users. """Tests that send_local_presence_to_users sends local online presence to local users.
This simultaneously tests two different usecases: This simultaneously tests two different usecases:
@ -852,6 +872,7 @@ def _test_sending_local_online_presence_to_local_user(
# Replicate the current sync presence token from the main process to the worker process. # Replicate the current sync presence token from the main process to the worker process.
# We need to do this so that the worker process knows the current presence stream ID to # We need to do this so that the worker process knows the current presence stream ID to
# insert into the database when we call ModuleApi.send_local_online_presence_to. # insert into the database when we call ModuleApi.send_local_online_presence_to.
assert isinstance(test_case, BaseMultiWorkerStreamTestCase)
test_case.replicate() test_case.replicate()
# Syncing again should result in no presence updates # Syncing again should result in no presence updates
@ -868,6 +889,7 @@ def _test_sending_local_online_presence_to_local_user(
# Determine on which process (main or worker) to call ModuleApi.send_local_online_presence_to on # Determine on which process (main or worker) to call ModuleApi.send_local_online_presence_to on
if test_with_workers: if test_with_workers:
assert isinstance(test_case, BaseMultiWorkerStreamTestCase)
module_api_to_use = worker_hs.get_module_api() module_api_to_use = worker_hs.get_module_api()
else: else:
module_api_to_use = test_case.module_api module_api_to_use = test_case.module_api
@ -875,12 +897,11 @@ def _test_sending_local_online_presence_to_local_user(
# Trigger sending local online presence. We expect this information # Trigger sending local online presence. We expect this information
# to be saved to the database where all processes can access it. # to be saved to the database where all processes can access it.
# Note that we're syncing via the master. # Note that we're syncing via the master.
d = module_api_to_use.send_local_online_presence_to( d = defer.ensureDeferred(
[ module_api_to_use.send_local_online_presence_to(
test_case.presence_receiver_id, [test_case.presence_receiver_id],
] )
) )
d = defer.ensureDeferred(d)
if test_with_workers: if test_with_workers:
# In order for the required presence_set_state replication request to occur between the # In order for the required presence_set_state replication request to occur between the
@ -897,7 +918,7 @@ def _test_sending_local_online_presence_to_local_user(
) )
test_case.assertEqual(len(presence_updates), 1) test_case.assertEqual(len(presence_updates), 1)
presence_update: UserPresenceState = presence_updates[0] presence_update = presence_updates[0]
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id) test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
test_case.assertEqual(presence_update.state, "online") test_case.assertEqual(presence_update.state, "online")
@ -908,7 +929,7 @@ def _test_sending_local_online_presence_to_local_user(
) )
test_case.assertEqual(len(presence_updates), 1) test_case.assertEqual(len(presence_updates), 1)
presence_update: UserPresenceState = presence_updates[0] presence_update = presence_updates[0]
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id) test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
test_case.assertEqual(presence_update.state, "online") test_case.assertEqual(presence_update.state, "online")
@ -936,12 +957,13 @@ def _test_sending_local_online_presence_to_local_user(
test_case.assertEqual(len(presence_updates), 1) test_case.assertEqual(len(presence_updates), 1)
# Now trigger sending local online presence. # Now trigger sending local online presence.
d = module_api_to_use.send_local_online_presence_to( d = defer.ensureDeferred(
[ module_api_to_use.send_local_online_presence_to(
test_case.presence_receiver_id, [
] test_case.presence_receiver_id,
]
)
) )
d = defer.ensureDeferred(d)
if test_with_workers: if test_with_workers:
# In order for the required presence_set_state replication request to occur between the # In order for the required presence_set_state replication request to occur between the