Type hints for tests.federation (#14991)

* Make tests.federation pass mypy

* Untyped defs in tests.federation.transport

* test methods return None

* Remaining type hints in tests.federation

* Changelog

* Avoid an uncessary type-ignore
This commit is contained in:
David Robertson 2023-02-06 16:05:06 +00:00 committed by GitHub
parent 156cd88eef
commit 0f34abed7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 127 additions and 94 deletions

1
changelog.d/14991.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -32,8 +32,6 @@ exclude = (?x)
|synapse/storage/databases/main/cache.py |synapse/storage/databases/main/cache.py
|synapse/storage/schema/ |synapse/storage/schema/
|tests/federation/test_federation_catch_up.py
|tests/federation/test_federation_sender.py
|tests/http/federation/test_matrix_federation_agent.py |tests/http/federation/test_matrix_federation_agent.py
|tests/http/federation/test_srv_resolver.py |tests/http/federation/test_srv_resolver.py
|tests/http/test_proxyagent.py |tests/http/test_proxyagent.py
@ -89,7 +87,7 @@ disallow_untyped_defs = True
[mypy-tests.events.*] [mypy-tests.events.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-tests.federation.transport.test_client] [mypy-tests.federation.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-tests.handlers.*] [mypy-tests.handlers.*]

View File

@ -17,7 +17,7 @@ from unittest.mock import Mock
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.types import UserID from synapse.types import JsonDict, UserID
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
@ -31,12 +31,12 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def default_config(self): def default_config(self) -> JsonDict:
config = super().default_config() config = super().default_config()
config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05} config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05}
return config return config
def test_complexity_simple(self): def test_complexity_simple(self) -> None:
u1 = self.register_user("u1", "pass") u1 = self.register_user("u1", "pass")
u1_token = self.login("u1", "pass") u1_token = self.login("u1", "pass")
@ -66,7 +66,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
complexity = channel.json_body["v1"] complexity = channel.json_body["v1"]
self.assertEqual(complexity, 1.23) self.assertEqual(complexity, 1.23)
def test_join_too_large(self): def test_join_too_large(self) -> None:
u1 = self.register_user("u1", "pass") u1 = self.register_user("u1", "pass")
@ -95,7 +95,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(f.value.code, 400, f.value) self.assertEqual(f.value.code, 400, f.value)
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_join_too_large_admin(self): def test_join_too_large_admin(self) -> None:
# Check whether an admin can join if option "admins_can_join" is undefined, # Check whether an admin can join if option "admins_can_join" is undefined,
# this option defaults to false, so the join should fail. # this option defaults to false, so the join should fail.
@ -126,7 +126,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(f.value.code, 400, f.value) self.assertEqual(f.value.code, 400, f.value)
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_join_too_large_once_joined(self): def test_join_too_large_once_joined(self) -> None:
u1 = self.register_user("u1", "pass") u1 = self.register_user("u1", "pass")
u1_token = self.login("u1", "pass") u1_token = self.login("u1", "pass")
@ -180,7 +180,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def default_config(self): def default_config(self) -> JsonDict:
config = super().default_config() config = super().default_config()
config["limit_remote_rooms"] = { config["limit_remote_rooms"] = {
"enabled": True, "enabled": True,
@ -189,7 +189,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
} }
return config return config
def test_join_too_large_no_admin(self): def test_join_too_large_no_admin(self) -> None:
# A user which is not an admin should not be able to join a remote room # A user which is not an admin should not be able to join a remote room
# which is too complex. # which is too complex.
@ -220,7 +220,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(f.value.code, 400, f.value) self.assertEqual(f.value.code, 400, f.value)
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_join_too_large_admin(self): def test_join_too_large_admin(self) -> None:
# An admin should be able to join rooms where a complexity check fails. # An admin should be able to join rooms where a complexity check fails.
u1 = self.register_user("u1", "pass", admin=True) u1 = self.register_user("u1", "pass", admin=True)

View File

@ -1,13 +1,17 @@
from typing import List, Tuple from typing import Callable, List, Optional, Tuple
from unittest.mock import Mock from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.federation.sender import PerDestinationQueue, TransactionManager from synapse.federation.sender import PerDestinationQueue, TransactionManager
from synapse.federation.units import Edu from synapse.federation.units import Edu, Transaction
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
from tests.test_utils import event_injection, make_awaitable from tests.test_utils import event_injection, make_awaitable
@ -28,23 +32,25 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver( return self.setup_test_homeserver(
federation_transport_client=Mock(spec=["send_transaction"]), federation_transport_client=Mock(spec=["send_transaction"]),
) )
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# stub out get_current_hosts_in_room # stub out get_current_hosts_in_room
state_handler = hs.get_state_handler() state_storage_controller = hs.get_storage_controllers().state
# This mock is crucial for destination_rooms to be populated. # This mock is crucial for destination_rooms to be populated.
state_handler.get_current_hosts_in_room = Mock( # TODO: this seems to no longer be the case---tests pass with this mock
return_value=make_awaitable(["test", "host2"]) # commented out.
state_storage_controller.get_current_hosts_in_room = Mock( # type: ignore[assignment]
return_value=make_awaitable({"test", "host2"})
) )
# whenever send_transaction is called, record the pdu data # whenever send_transaction is called, record the pdu data
self.pdus = [] self.pdus: List[JsonDict] = []
self.failed_pdus = [] self.failed_pdus: List[JsonDict] = []
self.is_online = True self.is_online = True
self.hs.get_federation_transport_client().send_transaction.side_effect = ( self.hs.get_federation_transport_client().send_transaction.side_effect = (
self.record_transaction self.record_transaction
@ -55,8 +61,13 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
config["federation_sender_instances"] = None config["federation_sender_instances"] = None
return config return config
async def record_transaction(self, txn, json_cb): async def record_transaction(
if self.is_online: self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]]
) -> JsonDict:
if json_cb is None:
# The tests seem to expect that this method raises in this situation.
raise Exception("Blank json_cb")
elif self.is_online:
data = json_cb() data = json_cb()
self.pdus.extend(data["pdus"]) self.pdus.extend(data["pdus"])
return {} return {}
@ -92,7 +103,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
)[0] )[0]
return {"event_id": event_id, "stream_ordering": stream_ordering} return {"event_id": event_id, "stream_ordering": stream_ordering}
def test_catch_up_destination_rooms_tracking(self): def test_catch_up_destination_rooms_tracking(self) -> None:
""" """
Tests that we populate the `destination_rooms` table as needed. Tests that we populate the `destination_rooms` table as needed.
""" """
@ -117,7 +128,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
self.assertEqual(row_2["event_id"], event_id_2) self.assertEqual(row_2["event_id"], event_id_2)
self.assertEqual(row_1["stream_ordering"], row_2["stream_ordering"] - 1) self.assertEqual(row_1["stream_ordering"], row_2["stream_ordering"] - 1)
def test_catch_up_last_successful_stream_ordering_tracking(self): def test_catch_up_last_successful_stream_ordering_tracking(self) -> None:
""" """
Tests that we populate the `destination_rooms` table as needed. Tests that we populate the `destination_rooms` table as needed.
""" """
@ -174,7 +185,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
"Send succeeded but not marked as last_successful_stream_ordering", "Send succeeded but not marked as last_successful_stream_ordering",
) )
def test_catch_up_from_blank_state(self): def test_catch_up_from_blank_state(self) -> None:
""" """
Runs an overall test of federation catch-up from scratch. Runs an overall test of federation catch-up from scratch.
Further tests will focus on more narrow aspects and edge-cases, but I Further tests will focus on more narrow aspects and edge-cases, but I
@ -261,16 +272,15 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
destination_tm: str, destination_tm: str,
pending_pdus: List[EventBase], pending_pdus: List[EventBase],
_pending_edus: List[Edu], _pending_edus: List[Edu],
) -> bool: ) -> None:
assert destination == destination_tm assert destination == destination_tm
results_list.extend(pending_pdus) results_list.extend(pending_pdus)
return True # success!
transaction_manager.send_new_transaction = fake_send transaction_manager.send_new_transaction = fake_send # type: ignore[assignment]
return per_dest_queue, results_list return per_dest_queue, results_list
def test_catch_up_loop(self): def test_catch_up_loop(self) -> None:
""" """
Tests the behaviour of _catch_up_transmission_loop. Tests the behaviour of _catch_up_transmission_loop.
""" """
@ -334,7 +344,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
event_5.internal_metadata.stream_ordering, event_5.internal_metadata.stream_ordering,
) )
def test_catch_up_on_synapse_startup(self): def test_catch_up_on_synapse_startup(self) -> None:
""" """
Tests the behaviour of get_catch_up_outstanding_destinations and Tests the behaviour of get_catch_up_outstanding_destinations and
_wake_destinations_needing_catchup. _wake_destinations_needing_catchup.
@ -412,7 +422,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
# patch wake_destination to just count the destinations instead # patch wake_destination to just count the destinations instead
woken = [] woken = []
def wake_destination_track(destination): def wake_destination_track(destination: str) -> None:
woken.append(destination) woken.append(destination)
self.hs.get_federation_sender().wake_destination = wake_destination_track self.hs.get_federation_sender().wake_destination = wake_destination_track
@ -432,7 +442,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
# - all destinations are woken exactly once; they appear once in woken. # - all destinations are woken exactly once; they appear once in woken.
self.assertCountEqual(woken, server_names[:-1]) self.assertCountEqual(woken, server_names[:-1])
def test_not_latest_event(self): def test_not_latest_event(self) -> None:
"""Test that we send the latest event in the room even if its not ours.""" """Test that we send the latest event in the room even if its not ours."""
per_dest_queue, sent_pdus = self.make_fake_destination_queue() per_dest_queue, sent_pdus = self.make_fake_destination_queue()

View File

@ -36,7 +36,9 @@ class FederationClientTest(FederatingHomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
super().prepare(reactor, clock, homeserver) super().prepare(reactor, clock, homeserver)
# mock out the Agent used by the federation client, which is easier than # mock out the Agent used by the federation client, which is easier than
@ -51,7 +53,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
self.creator = f"@creator:{self.OTHER_SERVER_NAME}" self.creator = f"@creator:{self.OTHER_SERVER_NAME}"
self.test_room_id = "!room_id" self.test_room_id = "!room_id"
def test_get_room_state(self): def test_get_room_state(self) -> None:
# mock up some events to use in the response. # mock up some events to use in the response.
# In real life, these would have things in `prev_events` and `auth_events`, but that's # In real life, these would have things in `prev_events` and `auth_events`, but that's
# a bit annoying to mock up, and the code under test doesn't care, so we don't bother. # a bit annoying to mock up, and the code under test doesn't care, so we don't bother.
@ -140,7 +142,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
["m.room.create", "m.room.member", "m.room.power_levels"], ["m.room.create", "m.room.member", "m.room.power_levels"],
) )
def test_get_pdu_returns_nothing_when_event_does_not_exist(self): def test_get_pdu_returns_nothing_when_event_does_not_exist(self) -> None:
"""No event should be returned when the event does not exist""" """No event should be returned when the event does not exist"""
pulled_pdu_info = self.get_success( pulled_pdu_info = self.get_success(
self.hs.get_federation_client().get_pdu( self.hs.get_federation_client().get_pdu(
@ -151,11 +153,11 @@ class FederationClientTest(FederatingHomeserverTestCase):
) )
self.assertEqual(pulled_pdu_info, None) self.assertEqual(pulled_pdu_info, None)
def test_get_pdu(self): def test_get_pdu(self) -> None:
"""Test to make sure an event is returned by `get_pdu()`""" """Test to make sure an event is returned by `get_pdu()`"""
self._get_pdu_once() self._get_pdu_once()
def test_get_pdu_event_from_cache_is_pristine(self): def test_get_pdu_event_from_cache_is_pristine(self) -> None:
"""Test that modifications made to events returned by `get_pdu()` """Test that modifications made to events returned by `get_pdu()`
do not propagate back to to the internal cache (events returned should do not propagate back to to the internal cache (events returned should
be a copy). be a copy).

View File

@ -11,18 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional from typing import Callable, FrozenSet, List, Optional, Set
from unittest.mock import Mock from unittest.mock import Mock
from signedjson import key, sign from signedjson import key, sign
from signedjson.types import BaseKey, SigningKey from signedjson.types import BaseKey, SigningKey
from twisted.internet import defer from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms
from synapse.federation.units import Transaction
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.types import JsonDict, ReadReceipt from synapse.types import JsonDict, ReadReceipt
from synapse.util import Clock
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -36,16 +40,16 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
re-enabled for the main process. re-enabled for the main process.
""" """
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
federation_transport_client=Mock(spec=["send_transaction"]), federation_transport_client=Mock(spec=["send_transaction"]),
) )
hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment]
return_value=make_awaitable({"test", "host2"}) return_value=make_awaitable({"test", "host2"})
) )
hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( # type: ignore[assignment]
hs.get_storage_controllers().state.get_current_hosts_in_room hs.get_storage_controllers().state.get_current_hosts_in_room
) )
@ -56,7 +60,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
config["federation_sender_instances"] = None config["federation_sender_instances"] = None
return config return config
def test_send_receipts(self): def test_send_receipts(self) -> None:
mock_send_transaction = ( mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction self.hs.get_federation_transport_client().send_transaction
) )
@ -98,7 +102,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
], ],
) )
def test_send_receipts_thread(self): def test_send_receipts_thread(self) -> None:
mock_send_transaction = ( mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction self.hs.get_federation_transport_client().send_transaction
) )
@ -174,7 +178,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
], ],
) )
def test_send_receipts_with_backoff(self): def test_send_receipts_with_backoff(self) -> None:
"""Send two receipts in quick succession; the second should be flushed, but """Send two receipts in quick succession; the second should be flushed, but
only after 20ms""" only after 20ms"""
mock_send_transaction = ( mock_send_transaction = (
@ -272,51 +276,55 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver( return self.setup_test_homeserver(
federation_transport_client=Mock( federation_transport_client=Mock(
spec=["send_transaction", "query_user_devices"] spec=["send_transaction", "query_user_devices"]
), ),
) )
def default_config(self): def default_config(self) -> JsonDict:
c = super().default_config() c = super().default_config()
# Enable federation sending on the main process. # Enable federation sending on the main process.
c["federation_sender_instances"] = None c["federation_sender_instances"] = None
return c return c
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
test_room_id = "!room:host1" test_room_id = "!room:host1"
# stub out `get_rooms_for_user` and `get_current_hosts_in_room` so that the # stub out `get_rooms_for_user` and `get_current_hosts_in_room` so that the
# server thinks the user shares a room with `@user2:host2` # server thinks the user shares a room with `@user2:host2`
def get_rooms_for_user(user_id): def get_rooms_for_user(user_id: str) -> "defer.Deferred[FrozenSet[str]]":
return defer.succeed({test_room_id}) return defer.succeed(frozenset({test_room_id}))
hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user # type: ignore[assignment]
async def get_current_hosts_in_room(room_id): async def get_current_hosts_in_room(room_id: str) -> Set[str]:
if room_id == test_room_id: if room_id == test_room_id:
return ["host2"] return {"host2"}
else:
# TODO: We should fail the test when we encounter an unxpected room ID. # TODO: We should fail the test when we encounter an unxpected room ID.
# We can't just use `self.fail(...)` here because the app code is greedy # We can't just use `self.fail(...)` here because the app code is greedy
# with `Exception` and will catch it before the test can see it. # with `Exception` and will catch it before the test can see it.
return set()
hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room # type: ignore[assignment]
# whenever send_transaction is called, record the edu data # whenever send_transaction is called, record the edu data
self.edus = [] self.edus: List[JsonDict] = []
self.hs.get_federation_transport_client().send_transaction.side_effect = ( self.hs.get_federation_transport_client().send_transaction.side_effect = (
self.record_transaction self.record_transaction
) )
def record_transaction(self, txn, json_cb): def record_transaction(
self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]] = None
) -> "defer.Deferred[JsonDict]":
assert json_cb is not None
data = json_cb() data = json_cb()
self.edus.extend(data["edus"]) self.edus.extend(data["edus"])
return defer.succeed({}) return defer.succeed({})
def test_send_device_updates(self): def test_send_device_updates(self) -> None:
"""Basic case: each device update should result in an EDU""" """Basic case: each device update should result in an EDU"""
# create a device # create a device
u1 = self.register_user("user", "pass") u1 = self.register_user("user", "pass")
@ -340,7 +348,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.assertEqual(len(self.edus), 1) self.assertEqual(len(self.edus), 1)
self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
def test_dont_send_device_updates_for_remote_users(self): def test_dont_send_device_updates_for_remote_users(self) -> None:
"""Check that we don't send device updates for remote users""" """Check that we don't send device updates for remote users"""
# Send the server a device list EDU for the other user, this will cause # Send the server a device list EDU for the other user, this will cause
@ -379,7 +387,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
) )
self.assertIn("D1", devices) self.assertIn("D1", devices)
def test_upload_signatures(self): def test_upload_signatures(self) -> None:
"""Uploading signatures on some devices should produce updates for that user""" """Uploading signatures on some devices should produce updates for that user"""
e2e_handler = self.hs.get_e2e_keys_handler() e2e_handler = self.hs.get_e2e_keys_handler()
@ -391,7 +399,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# expect two edus # expect two edus
self.assertEqual(len(self.edus), 2) self.assertEqual(len(self.edus), 2)
stream_id = None stream_id: Optional[int] = None
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id) stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id)
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
@ -473,13 +481,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE) self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE)
c = edu["content"] c = edu["content"]
if stream_id is not None: if stream_id is not None:
self.assertEqual(c["prev_id"], [stream_id]) self.assertEqual(c["prev_id"], [stream_id]) # type: ignore[unreachable]
self.assertGreaterEqual(c["stream_id"], stream_id) self.assertGreaterEqual(c["stream_id"], stream_id)
stream_id = c["stream_id"] stream_id = c["stream_id"]
devices = {edu["content"]["device_id"] for edu in self.edus} devices = {edu["content"]["device_id"] for edu in self.edus}
self.assertEqual({"D1", "D2"}, devices) self.assertEqual({"D1", "D2"}, devices)
def test_delete_devices(self): def test_delete_devices(self) -> None:
"""If devices are deleted, that should result in EDUs too""" """If devices are deleted, that should result in EDUs too"""
# create devices # create devices
@ -521,7 +529,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
devices = {edu["content"]["device_id"] for edu in self.edus} devices = {edu["content"]["device_id"] for edu in self.edus}
self.assertEqual({"D1", "D2", "D3"}, devices) self.assertEqual({"D1", "D2", "D3"}, devices)
def test_unreachable_server(self): def test_unreachable_server(self) -> None:
"""If the destination server is unreachable, all the updates should get sent on """If the destination server is unreachable, all the updates should get sent on
recovery recovery
""" """
@ -555,7 +563,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# for each device, there should be a single update # for each device, there should be a single update
self.assertEqual(len(self.edus), 3) self.assertEqual(len(self.edus), 3)
stream_id = None stream_id: Optional[int] = None
for edu in self.edus: for edu in self.edus:
self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE) self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE)
c = edu["content"] c = edu["content"]
@ -566,7 +574,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
devices = {edu["content"]["device_id"] for edu in self.edus} devices = {edu["content"]["device_id"] for edu in self.edus}
self.assertEqual({"D1", "D2", "D3"}, devices) self.assertEqual({"D1", "D2", "D3"}, devices)
def test_prune_outbound_device_pokes1(self): def test_prune_outbound_device_pokes1(self) -> None:
"""If a destination is unreachable, and the updates are pruned, we should get """If a destination is unreachable, and the updates are pruned, we should get
a single update. a single update.
@ -615,7 +623,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# synapse uses an empty prev_id list to indicate "needs a full resync". # synapse uses an empty prev_id list to indicate "needs a full resync".
self.assertEqual(c["prev_id"], []) self.assertEqual(c["prev_id"], [])
def test_prune_outbound_device_pokes2(self): def test_prune_outbound_device_pokes2(self) -> None:
"""If a destination is unreachable, and the updates are pruned, we should get """If a destination is unreachable, and the updates are pruned, we should get
a single update. a single update.
@ -741,7 +749,7 @@ def encode_pubkey(sk: SigningKey) -> str:
return key.encode_verify_key_base64(key.get_verify_key(sk)) return key.encode_verify_key_base64(key.get_verify_key(sk))
def build_device_dict(user_id: str, device_id: str, sk: SigningKey): def build_device_dict(user_id: str, device_id: str, sk: SigningKey) -> JsonDict:
"""Build a dict representing the given device""" """Build a dict representing the given device"""
return { return {
"user_id": user_id, "user_id": user_id,

View File

@ -21,7 +21,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.events import make_event_from_dict from synapse.events import EventBase, make_event_from_dict
from synapse.federation.federation_server import server_matches_acl_event from synapse.federation.federation_server import server_matches_acl_event
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
@ -42,7 +42,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
] ]
@parameterized.expand([(b"",), (b"foo",), (b'{"limit": Infinity}',)]) @parameterized.expand([(b"",), (b"foo",), (b'{"limit": Infinity}',)])
def test_bad_request(self, query_content): def test_bad_request(self, query_content: bytes) -> None:
""" """
Querying with bad data returns a reasonable error code. Querying with bad data returns a reasonable error code.
""" """
@ -64,7 +64,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
class ServerACLsTestCase(unittest.TestCase): class ServerACLsTestCase(unittest.TestCase):
def test_blacklisted_server(self): def test_blacklisted_server(self) -> None:
e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]}) e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
logging.info("ACL event: %s", e.content) logging.info("ACL event: %s", e.content)
@ -74,7 +74,7 @@ class ServerACLsTestCase(unittest.TestCase):
self.assertTrue(server_matches_acl_event("evil.com.au", e)) self.assertTrue(server_matches_acl_event("evil.com.au", e))
self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e)) self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e))
def test_block_ip_literals(self): def test_block_ip_literals(self) -> None:
e = _create_acl_event({"allow_ip_literals": False, "allow": ["*"]}) e = _create_acl_event({"allow_ip_literals": False, "allow": ["*"]})
logging.info("ACL event: %s", e.content) logging.info("ACL event: %s", e.content)
@ -83,7 +83,7 @@ class ServerACLsTestCase(unittest.TestCase):
self.assertFalse(server_matches_acl_event("[1:2::]", e)) self.assertFalse(server_matches_acl_event("[1:2::]", e))
self.assertTrue(server_matches_acl_event("1:2:3:4", e)) self.assertTrue(server_matches_acl_event("1:2:3:4", e))
def test_wildcard_matching(self): def test_wildcard_matching(self) -> None:
e = _create_acl_event({"allow": ["good*.com"]}) e = _create_acl_event({"allow": ["good*.com"]})
self.assertTrue( self.assertTrue(
server_matches_acl_event("good.com", e), server_matches_acl_event("good.com", e),
@ -110,7 +110,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def test_needs_to_be_in_room(self): def test_needs_to_be_in_room(self) -> None:
"""/v1/state/<room_id> requires the server to be in the room""" """/v1/state/<room_id> requires the server to be in the room"""
u1 = self.register_user("u1", "pass") u1 = self.register_user("u1", "pass")
u1_token = self.login("u1", "pass") u1_token = self.login("u1", "pass")
@ -131,7 +131,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs) super().prepare(reactor, clock, hs)
self._storage_controllers = hs.get_storage_controllers() self._storage_controllers = hs.get_storage_controllers()
@ -157,7 +157,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
return channel.json_body return channel.json_body
def test_send_join(self): def test_send_join(self) -> None:
"""happy-path test of send_join""" """happy-path test of send_join"""
joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
join_result = self._make_join(joining_user) join_result = self._make_join(joining_user)
@ -324,7 +324,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
# is probably sufficient to reassure that the bucket is updated. # is probably sufficient to reassure that the bucket is updated.
def _create_acl_event(content): def _create_acl_event(content: JsonDict) -> EventBase:
return make_event_from_dict( return make_event_from_dict(
{ {
"room_id": "!a:b", "room_id": "!a:b",

View File

@ -15,6 +15,8 @@
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from twisted.web.resource import Resource
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.federation.transport.server import BaseFederationServlet from synapse.federation.transport.server import BaseFederationServlet
from synapse.federation.transport.server._base import Authenticator, _parse_auth_header from synapse.federation.transport.server._base import Authenticator, _parse_auth_header
@ -62,7 +64,7 @@ class BaseFederationServletCancellationTests(unittest.FederatingHomeserverTestCa
path = f"{CancellableFederationServlet.PREFIX}{CancellableFederationServlet.PATH}" path = f"{CancellableFederationServlet.PREFIX}{CancellableFederationServlet.PATH}"
def create_test_resource(self): def create_test_resource(self) -> Resource:
"""Overrides `HomeserverTestCase.create_test_resource`.""" """Overrides `HomeserverTestCase.create_test_resource`."""
resource = JsonResource(self.hs) resource = JsonResource(self.hs)

View File

@ -12,15 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List from typing import Any, Dict, List, Optional
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.events import builder from synapse.events import EventBase, builder
from synapse.events.snapshot import EventContext
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import RoomAlias from synapse.types import RoomAlias
from synapse.util import Clock
from tests.test_utils import event_injection from tests.test_utils import event_injection
from tests.unittest import FederatingHomeserverTestCase, HomeserverTestCase from tests.unittest import FederatingHomeserverTestCase, HomeserverTestCase
@ -197,7 +201,9 @@ class FederationKnockingTestCase(
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, homeserver): def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.store = homeserver.get_datastores().main self.store = homeserver.get_datastores().main
# We're not going to be properly signing events as our remote homeserver is fake, # We're not going to be properly signing events as our remote homeserver is fake,
@ -205,23 +211,29 @@ class FederationKnockingTestCase(
# Note that these checks are not relevant to this test case. # Note that these checks are not relevant to this test case.
# Have this homeserver auto-approve all event signature checking. # Have this homeserver auto-approve all event signature checking.
async def approve_all_signature_checking(_, pdu): async def approve_all_signature_checking(
room_version: RoomVersion,
pdu: EventBase,
record_failure_callback: Any = None,
) -> EventBase:
return pdu return pdu
homeserver.get_federation_server()._check_sigs_and_hash = ( homeserver.get_federation_server()._check_sigs_and_hash = ( # type: ignore[assignment]
approve_all_signature_checking approve_all_signature_checking
) )
# Have this homeserver skip event auth checks. This is necessary due to # Have this homeserver skip event auth checks. This is necessary due to
# event auth checks ensuring that events were signed by the sender's homeserver. # event auth checks ensuring that events were signed by the sender's homeserver.
async def _check_event_auth(origin, event, context, *args, **kwargs): async def _check_event_auth(
return context origin: Optional[str], event: EventBase, context: EventContext
) -> None:
pass
homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment]
return super().prepare(reactor, clock, homeserver) return super().prepare(reactor, clock, homeserver)
def test_room_state_returned_when_knocking(self): def test_room_state_returned_when_knocking(self) -> None:
""" """
Tests that specific, stripped state events from a room are returned after Tests that specific, stripped state events from a room are returned after
a remote homeserver successfully knocks on a local room. a remote homeserver successfully knocks on a local room.

View File

@ -20,7 +20,7 @@ from tests.unittest import DEBUG, override_config
class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
@override_config({"allow_public_rooms_over_federation": False}) @override_config({"allow_public_rooms_over_federation": False})
def test_blocked_public_room_list_over_federation(self): def test_blocked_public_room_list_over_federation(self) -> None:
"""Test that unauthenticated requests to the public rooms directory 403 when """Test that unauthenticated requests to the public rooms directory 403 when
allow_public_rooms_over_federation is False. allow_public_rooms_over_federation is False.
""" """
@ -31,7 +31,7 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(403, channel.code) self.assertEqual(403, channel.code)
@override_config({"allow_public_rooms_over_federation": True}) @override_config({"allow_public_rooms_over_federation": True})
def test_open_public_room_list_over_federation(self): def test_open_public_room_list_over_federation(self) -> None:
"""Test that unauthenticated requests to the public rooms directory 200 when """Test that unauthenticated requests to the public rooms directory 200 when
allow_public_rooms_over_federation is True. allow_public_rooms_over_federation is True.
""" """
@ -42,7 +42,7 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(200, channel.code) self.assertEqual(200, channel.code)
@DEBUG @DEBUG
def test_edu_debugging_doesnt_explode(self): def test_edu_debugging_doesnt_explode(self) -> None:
"""Sanity check incoming federation succeeds with `synapse.debug_8631` enabled. """Sanity check incoming federation succeeds with `synapse.debug_8631` enabled.
Remove this when we strip out issue_8631_logger. Remove this when we strip out issue_8631_logger.