From a8a46b13360b8bd07cbca48798791098ef6d6d3c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 25 Aug 2023 09:27:21 -0400 Subject: [PATCH] Replace simple_async_mock with AsyncMock (#16180) Python 3.8 has a native AsyncMock, use it instead of a custom implementation. --- changelog.d/16180.misc | 1 + tests/api/test_auth.py | 97 +++++++++++---------- tests/appservice/test_appservice.py | 31 +++---- tests/appservice/test_scheduler.py | 43 ++++----- tests/events/test_presence_router.py | 5 +- tests/handlers/test_appservice.py | 8 +- tests/handlers/test_cas.py | 11 ++- tests/handlers/test_oauth_delegation.py | 42 ++++----- tests/handlers/test_oidc.py | 6 +- tests/handlers/test_saml.py | 13 ++- tests/module_api/test_api.py | 9 +- tests/push/test_bulk_push_rule_evaluator.py | 5 +- tests/rest/client/test_notifications.py | 5 +- tests/storage/test_background_update.py | 5 +- tests/test_utils/__init__.py | 19 +--- 15 files changed, 140 insertions(+), 160 deletions(-) create mode 100644 changelog.d/16180.misc diff --git a/changelog.d/16180.misc b/changelog.d/16180.misc new file mode 100644 index 0000000000..8d04954ab9 --- /dev/null +++ b/changelog.d/16180.misc @@ -0,0 +1 @@ +Use `AsyncMock` instead of custom code. diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index ce96574915..dcd01d5688 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import pymacaroons @@ -35,7 +35,6 @@ from synapse.types import Requester, UserID from synapse.util import Clock from tests import unittest -from tests.test_utils import simple_async_mock from tests.unittest import override_config from tests.utils import mock_getRawHeaders @@ -60,16 +59,16 @@ class AuthTestCase(unittest.HomeserverTestCase): # this is overridden for the appservice tests self.store.get_app_service_by_token = Mock(return_value=None) - self.store.insert_client_ip = simple_async_mock(None) - self.store.is_support_user = simple_async_mock(False) + self.store.insert_client_ip = AsyncMock(return_value=None) + self.store.is_support_user = AsyncMock(return_value=False) def test_get_user_by_req_user_valid_token(self) -> None: user_info = TokenLookupResult( user_id=self.test_user, token_id=5, device_id="device" ) - self.store.get_user_by_access_token = simple_async_mock(user_info) - self.store.mark_access_token_as_used = simple_async_mock(None) - self.store.get_user_locked_status = simple_async_mock(False) + self.store.get_user_by_access_token = AsyncMock(return_value=user_info) + self.store.mark_access_token_as_used = AsyncMock(return_value=None) + self.store.get_user_locked_status = AsyncMock(return_value=False) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] @@ -78,7 +77,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.assertEqual(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self) -> None: - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] @@ -91,7 +90,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_get_user_by_req_user_missing_token(self) -> None: user_info = TokenLookupResult(user_id=self.test_user, token_id=5) - self.store.get_user_by_access_token = simple_async_mock(user_info) + self.store.get_user_by_access_token = AsyncMock(return_value=user_info) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() @@ -106,7 +105,7 @@ class AuthTestCase(unittest.HomeserverTestCase): token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" @@ -125,7 +124,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.getClientAddress.return_value.host = "192.168.10.10" @@ -144,7 +143,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.getClientAddress.return_value.host = "131.111.8.42" @@ -158,7 +157,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_get_user_by_req_appservice_bad_token(self) -> None: self.store.get_app_service_by_token = Mock(return_value=None) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] @@ -172,7 +171,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_get_user_by_req_appservice_missing_token(self) -> None: app_service = Mock(token="foobar", url="a_url", sender=self.test_user) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() @@ -190,8 +189,8 @@ class AuthTestCase(unittest.HomeserverTestCase): app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) # This just needs to return a truth-y value. - self.store.get_user_by_id = simple_async_mock({"is_guest": False}) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False}) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" @@ -210,7 +209,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) app_service.is_interested_in_user = Mock(return_value=False) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" @@ -234,10 +233,10 @@ class AuthTestCase(unittest.HomeserverTestCase): app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) # This just needs to return a truth-y value. - self.store.get_user_by_id = simple_async_mock({"is_guest": False}) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False}) + self.store.get_user_by_access_token = AsyncMock(return_value=None) # This also needs to just return a truth-y value - self.store.get_device = simple_async_mock({"hidden": False}) + self.store.get_device = AsyncMock(return_value={"hidden": False}) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" @@ -266,10 +265,10 @@ class AuthTestCase(unittest.HomeserverTestCase): app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) # This just needs to return a truth-y value. - self.store.get_user_by_id = simple_async_mock({"is_guest": False}) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False}) + self.store.get_user_by_access_token = AsyncMock(return_value=None) # This also needs to just return a falsey value - self.store.get_device = simple_async_mock(None) + self.store.get_device = AsyncMock(return_value=None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" @@ -283,8 +282,8 @@ class AuthTestCase(unittest.HomeserverTestCase): self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE) def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> None: - self.store.get_user_by_access_token = simple_async_mock( - TokenLookupResult( + self.store.get_user_by_access_token = AsyncMock( + return_value=TokenLookupResult( user_id="@baldrick:matrix.org", device_id="device", token_id=5, @@ -292,9 +291,9 @@ class AuthTestCase(unittest.HomeserverTestCase): token_used=True, ) ) - self.store.insert_client_ip = simple_async_mock(None) - self.store.mark_access_token_as_used = simple_async_mock(None) - self.store.get_user_locked_status = simple_async_mock(False) + self.store.insert_client_ip = AsyncMock(return_value=None) + self.store.mark_access_token_as_used = AsyncMock(return_value=None) + self.store.get_user_locked_status = AsyncMock(return_value=False) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] @@ -304,8 +303,8 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None: self.auth._track_puppeted_user_ips = True - self.store.get_user_by_access_token = simple_async_mock( - TokenLookupResult( + self.store.get_user_by_access_token = AsyncMock( + return_value=TokenLookupResult( user_id="@baldrick:matrix.org", device_id="device", token_id=5, @@ -313,9 +312,9 @@ class AuthTestCase(unittest.HomeserverTestCase): token_used=True, ) ) - self.store.get_user_locked_status = simple_async_mock(False) - self.store.insert_client_ip = simple_async_mock(None) - self.store.mark_access_token_as_used = simple_async_mock(None) + self.store.get_user_locked_status = AsyncMock(return_value=False) + self.store.insert_client_ip = AsyncMock(return_value=None) + self.store.mark_access_token_as_used = AsyncMock(return_value=None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] @@ -324,7 +323,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.assertEqual(self.store.insert_client_ip.call_count, 2) def test_get_user_from_macaroon(self) -> None: - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( @@ -342,8 +341,8 @@ class AuthTestCase(unittest.HomeserverTestCase): ) def test_get_guest_user_from_macaroon(self) -> None: - self.store.get_user_by_id = simple_async_mock({"is_guest": True}) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True}) + self.store.get_user_by_access_token = AsyncMock(return_value=None) user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( @@ -373,7 +372,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._limit_usage_by_mau = True - self.store.get_monthly_active_count = simple_async_mock(lots_of_users) + self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users) e = self.get_failure( self.auth_blocking.check_auth_blocking(), ResourceLimitError @@ -383,25 +382,27 @@ class AuthTestCase(unittest.HomeserverTestCase): self.assertEqual(e.value.code, 403) # Ensure does not throw an error - self.store.get_monthly_active_count = simple_async_mock(small_number_of_users) + self.store.get_monthly_active_count = AsyncMock( + return_value=small_number_of_users + ) self.get_success(self.auth_blocking.check_auth_blocking()) def test_blocking_mau__depending_on_user_type(self) -> None: self.auth_blocking._max_mau_value = 50 self.auth_blocking._limit_usage_by_mau = True - self.store.get_monthly_active_count = simple_async_mock(100) + self.store.get_monthly_active_count = AsyncMock(return_value=100) # Support users allowed self.get_success( self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT) ) - self.store.get_monthly_active_count = simple_async_mock(100) + self.store.get_monthly_active_count = AsyncMock(return_value=100) # Bots not allowed self.get_failure( self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT), ResourceLimitError, ) - self.store.get_monthly_active_count = simple_async_mock(100) + self.store.get_monthly_active_count = AsyncMock(return_value=100) # Real users not allowed self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError) @@ -412,9 +413,9 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._track_appservice_user_ips = False - self.store.get_monthly_active_count = simple_async_mock(100) - self.store.user_last_seen_monthly_active = simple_async_mock() - self.store.is_trial_user = simple_async_mock() + self.store.get_monthly_active_count = AsyncMock(return_value=100) + self.store.user_last_seen_monthly_active = AsyncMock(return_value=None) + self.store.is_trial_user = AsyncMock(return_value=False) appservice = ApplicationService( "abcd", @@ -443,9 +444,9 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._track_appservice_user_ips = True - self.store.get_monthly_active_count = simple_async_mock(100) - self.store.user_last_seen_monthly_active = simple_async_mock() - self.store.is_trial_user = simple_async_mock() + self.store.get_monthly_active_count = AsyncMock(return_value=100) + self.store.user_last_seen_monthly_active = AsyncMock(return_value=None) + self.store.is_trial_user = AsyncMock(return_value=False) appservice = ApplicationService( "abcd", @@ -473,7 +474,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_reserved_threepid(self) -> None: self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._max_mau_value = 1 - self.store.get_monthly_active_count = simple_async_mock(2) + self.store.get_monthly_active_count = AsyncMock(return_value=2) threepid = {"medium": "email", "address": "reserved@server.com"} unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} self.auth_blocking._mau_limits_reserved_threepids = [threepid] diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 66753c60c4..6ac5fc1ae7 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -13,14 +13,13 @@ # limitations under the License. import re from typing import Any, Generator -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.internet import defer from synapse.appservice import ApplicationService, Namespace from tests import unittest -from tests.test_utils import simple_async_mock def _regex(regex: str, exclusive: bool = True) -> Namespace: @@ -43,8 +42,8 @@ class ApplicationServiceTestCase(unittest.TestCase): ) self.store = Mock() - self.store.get_aliases_for_room = simple_async_mock([]) - self.store.get_local_users_in_room = simple_async_mock([]) + self.store.get_aliases_for_room = AsyncMock(return_value=[]) + self.store.get_local_users_in_room = AsyncMock(return_value=[]) @defer.inlineCallbacks def test_regex_user_id_prefix_match( @@ -127,10 +126,10 @@ class ApplicationServiceTestCase(unittest.TestCase): self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) - self.store.get_aliases_for_room = simple_async_mock( - ["#irc_foobar:matrix.org", "#athing:matrix.org"] + self.store.get_aliases_for_room = AsyncMock( + return_value=["#irc_foobar:matrix.org", "#athing:matrix.org"] ) - self.store.get_local_users_in_room = simple_async_mock([]) + self.store.get_local_users_in_room = AsyncMock(return_value=[]) self.assertTrue( ( yield self.service.is_interested_in_event( @@ -182,10 +181,10 @@ class ApplicationServiceTestCase(unittest.TestCase): self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) - self.store.get_aliases_for_room = simple_async_mock( - ["#xmpp_foobar:matrix.org", "#athing:matrix.org"] + self.store.get_aliases_for_room = AsyncMock( + return_value=["#xmpp_foobar:matrix.org", "#athing:matrix.org"] ) - self.store.get_local_users_in_room = simple_async_mock([]) + self.store.get_local_users_in_room = AsyncMock(return_value=[]) self.assertFalse( ( yield defer.ensureDeferred( @@ -205,8 +204,10 @@ class ApplicationServiceTestCase(unittest.TestCase): ) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" - self.store.get_aliases_for_room = simple_async_mock(["#irc_barfoo:matrix.org"]) - self.store.get_local_users_in_room = simple_async_mock([]) + self.store.get_aliases_for_room = AsyncMock( + return_value=["#irc_barfoo:matrix.org"] + ) + self.store.get_local_users_in_room = AsyncMock(return_value=[]) self.assertTrue( ( yield self.service.is_interested_in_event( @@ -235,10 +236,10 @@ class ApplicationServiceTestCase(unittest.TestCase): def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]: self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) # Note that @irc_fo:here is the AS user. - self.store.get_local_users_in_room = simple_async_mock( - ["@alice:here", "@irc_fo:here", "@bob:here"] + self.store.get_local_users_in_room = AsyncMock( + return_value=["@alice:here", "@irc_fo:here", "@bob:here"] ) - self.store.get_aliases_for_room = simple_async_mock([]) + self.store.get_aliases_for_room = AsyncMock(return_value=[]) self.event.sender = "@xmpp_foobar:matrix.org" self.assertTrue( diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index e2a3bad065..445919417e 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Sequence, Tuple, cast -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from typing_extensions import TypeAlias @@ -37,7 +37,6 @@ from synapse.types import DeviceListUpdates, JsonDict from synapse.util import Clock from tests import unittest -from tests.test_utils import simple_async_mock from ..utils import MockClock @@ -62,10 +61,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): txn = Mock(id=txn_id, service=service, events=events) # mock methods - self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP) - txn.send = simple_async_mock(True) - txn.complete = simple_async_mock(True) - self.store.create_appservice_txn = simple_async_mock(txn) + self.store.get_appservice_state = AsyncMock( + return_value=ApplicationServiceState.UP + ) + txn.send = AsyncMock(return_value=True) + txn.complete = AsyncMock(return_value=True) + self.store.create_appservice_txn = AsyncMock(return_value=txn) # actual call self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) @@ -89,10 +90,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events = [Mock(), Mock()] txn = Mock(id="idhere", service=service, events=events) - self.store.get_appservice_state = simple_async_mock( - ApplicationServiceState.DOWN + self.store.get_appservice_state = AsyncMock( + return_value=ApplicationServiceState.DOWN ) - self.store.create_appservice_txn = simple_async_mock(txn) + self.store.create_appservice_txn = AsyncMock(return_value=txn) # actual call self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) @@ -118,10 +119,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): txn = Mock(id=txn_id, service=service, events=events) # mock methods - self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP) - self.store.set_appservice_state = simple_async_mock(True) - txn.send = simple_async_mock(False) # fails to send - self.store.create_appservice_txn = simple_async_mock(txn) + self.store.get_appservice_state = AsyncMock( + return_value=ApplicationServiceState.UP + ) + self.store.set_appservice_state = AsyncMock(return_value=True) + txn.send = AsyncMock(return_value=False) # fails to send + self.store.create_appservice_txn = AsyncMock(return_value=txn) # actual call self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) @@ -150,7 +153,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.as_api = Mock() self.store = Mock() self.service = Mock() - self.callback = simple_async_mock() + self.callback = AsyncMock() self.recoverer = _Recoverer( clock=cast(Clock, self.clock), as_api=self.as_api, @@ -174,8 +177,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.recoverer.recover() # shouldn't have called anything prior to waiting for exp backoff self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) - txn.send = simple_async_mock(True) - txn.complete = simple_async_mock(None) + txn.send = AsyncMock(return_value=True) + txn.complete = AsyncMock(return_value=None) # wait for exp backoff self.clock.advance_time(2) self.assertEqual(1, txn.send.call_count) @@ -202,8 +205,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.recoverer.recover() self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) - txn.send = simple_async_mock(False) - txn.complete = simple_async_mock(None) + txn.send = AsyncMock(return_value=False) + txn.complete = AsyncMock(return_value=None) self.clock.advance_time(2) self.assertEqual(1, txn.send.call_count) self.assertEqual(0, txn.complete.call_count) @@ -216,7 +219,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.assertEqual(3, txn.send.call_count) self.assertEqual(0, txn.complete.call_count) self.assertEqual(0, self.callback.call_count) - txn.send = simple_async_mock(True) # successfully send the txn + txn.send = AsyncMock(return_value=True) # successfully send the txn pop_txn = True # returns the txn the first time, then no more. self.clock.advance_time(16) self.assertEqual(1, txn.send.call_count) # new mock reset call count @@ -244,7 +247,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer) -> None: self.scheduler = ApplicationServiceScheduler(hs) self.txn_ctrl = Mock() - self.txn_ctrl.send = simple_async_mock() + self.txn_ctrl.send = AsyncMock() # Replace instantiated _TransactionController instances with our Mock self.scheduler.txn_ctrl = self.txn_ctrl diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index 6fb1f1bd6e..0fcfe38efa 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, Iterable, List, Optional, Set, Tuple, Union -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import attr @@ -30,7 +30,6 @@ from synapse.types import JsonDict, StreamToken, create_requester from synapse.util import Clock from tests.handlers.test_sync import generate_sync_config -from tests.test_utils import simple_async_mock from tests.unittest import ( FederatingHomeserverTestCase, HomeserverTestCase, @@ -157,7 +156,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # Mock out the calls over federation. self.fed_transport_client = Mock(spec=["send_transaction"]) - self.fed_transport_client.send_transaction = simple_async_mock({}) + self.fed_transport_client.send_transaction = AsyncMock(return_value={}) hs = self.setup_test_homeserver( federation_transport_client=self.fed_transport_client, diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 5e2ae82cd4..4bd0facd65 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -36,7 +36,7 @@ from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest -from tests.test_utils import event_injection, simple_async_mock +from tests.test_utils import event_injection from tests.unittest import override_config from tests.utils import MockClock @@ -399,7 +399,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): self.hs = hs # Mock the ApplicationServiceScheduler's _TransactionController's send method so that # we can track any outgoing ephemeral events - self.send_mock = simple_async_mock() + self.send_mock = AsyncMock() hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # Mock out application services, and allow defining our own in tests @@ -897,7 +897,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase) # Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that # will be sent over the wire - self.put_json = simple_async_mock() + self.put_json = AsyncMock() hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment] # Mock out application services, and allow defining our own in tests @@ -1003,7 +1003,7 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Mock the ApplicationServiceScheduler's _TransactionController's send method so that # we can track what's going out - self.send_mock = simple_async_mock() + self.send_mock = AsyncMock() hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method. # Define an application service for the tests diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index 63aad0d10c..2cb24add20 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -20,7 +20,6 @@ from synapse.handlers.cas import CasResponse from synapse.server import HomeServer from synapse.util import Clock -from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config # These are a few constants that are used as config parameters in the tests. @@ -61,7 +60,7 @@ class CasHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] cas_response = CasResponse("test_user", {}) request = _mock_request() @@ -89,7 +88,7 @@ class CasHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] # Map a user via SSO. cas_response = CasResponse("test_user", {}) @@ -129,7 +128,7 @@ class CasHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] cas_response = CasResponse("föö", {}) request = _mock_request() @@ -160,7 +159,7 @@ class CasHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] # The response doesn't have the proper userGroup or department. cas_response = CasResponse("test_user", {}) diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py index b891e84690..9152694653 100644 --- a/tests/handlers/test_oauth_delegation.py +++ b/tests/handlers/test_oauth_delegation.py @@ -39,7 +39,7 @@ from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock -from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock +from tests.test_utils import FakeResponse, get_awaitable_result from tests.unittest import HomeserverTestCase, skip_unless from tests.utils import mock_getRawHeaders @@ -147,7 +147,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): def test_inactive_token(self) -> None: """The handler should return a 403 where the token is inactive.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={"active": False}, @@ -166,7 +166,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): def test_active_no_scope(self) -> None: """The handler should return a 403 where no scope is given.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={"active": True}, @@ -185,7 +185,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): def test_active_user_no_subject(self) -> None: """The handler should return a 500 when no subject is present.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={"active": True, "scope": " ".join([MATRIX_USER_SCOPE])}, @@ -204,7 +204,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): def test_active_no_user_scope(self) -> None: """The handler should return a 500 when no subject is present.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ @@ -227,7 +227,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): def test_active_admin_not_user(self) -> None: """The handler should raise when the scope has admin right but not user.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ @@ -251,7 +251,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): def test_active_admin(self) -> None: """The handler should return a requester with admin rights.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ @@ -281,7 +281,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): def test_active_admin_highest_privilege(self) -> None: """The handler should resolve to the most permissive scope.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ @@ -313,7 +313,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): def test_active_user(self) -> None: """The handler should return a requester with normal user rights.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ @@ -344,7 +344,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): """The handler should return a requester with normal user rights and an user ID matching the one specified in query param `user_id`""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ @@ -378,7 +378,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): def test_active_user_with_device(self) -> None: """The handler should return a requester with normal user rights and a device ID.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ @@ -408,7 +408,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): def test_multiple_devices(self) -> None: """The handler should raise an error if multiple devices are found in the scope.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ @@ -433,7 +433,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): def test_active_guest_not_allowed(self) -> None: """The handler should return an insufficient scope error.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ @@ -463,7 +463,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): def test_active_guest_allowed(self) -> None: """The handler should return a requester with guest user rights and a device ID.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ @@ -499,19 +499,19 @@ class MSC3861OAuthDelegation(HomeserverTestCase): request.requestHeaders.getRawHeaders = mock_getRawHeaders() # The introspection endpoint is returning an error. - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse(code=500, body=b"Internal Server Error") ) error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) self.assertEqual(error.value.code, 503) # The introspection endpoint request fails. - self.http_client.request = simple_async_mock(raises=Exception()) + self.http_client.request = AsyncMock(side_effect=Exception()) error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) self.assertEqual(error.value.code, 503) # The introspection endpoint does not return a JSON object. - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload=["this is an array", "not an object"] ) @@ -520,7 +520,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): self.assertEqual(error.value.code, 503) # The introspection endpoint does not return valid JSON. - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse(code=200, body=b"this is not valid JSON") ) error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) @@ -528,7 +528,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): def test_introspection_token_cache(self) -> None: access_token = "open_sesame" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={"active": "true", "scope": "guest", "jti": access_token}, @@ -559,7 +559,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): # test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a # token with a soon-to-expire `exp` field to the cache - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ @@ -640,7 +640,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): def test_cross_signing(self) -> None: """Try uploading device keys with OAuth delegation enabled.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 0a8bae54fb..9b2c7812cc 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -13,7 +13,7 @@ # limitations under the License. import os from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple -from unittest.mock import ANY, Mock, patch +from unittest.mock import ANY, AsyncMock, Mock, patch from urllib.parse import parse_qs, urlparse import pymacaroons @@ -28,7 +28,7 @@ from synapse.util import Clock from synapse.util.macaroons import get_value_from_macaroon from synapse.util.stringutils import random_string -from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock +from tests.test_utils import FakeResponse, get_awaitable_result from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer from tests.unittest import HomeserverTestCase, override_config @@ -164,7 +164,7 @@ class OidcHandlerTestCase(HomeserverTestCase): auth_handler = hs.get_auth_handler() # Mock the complete SSO login method. - self.complete_sso_login = simple_async_mock() + self.complete_sso_login = AsyncMock() auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment] return hs diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index b5c772a7ae..6e666d7bed 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any, Dict, Optional, Set, Tuple -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import attr @@ -25,7 +25,6 @@ from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock -from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config # Check if we have the dependencies to run the tests. @@ -134,7 +133,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] # send a mocked-up SAML response to the callback saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) @@ -164,7 +163,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] # Map a user via SSO. saml_response = FakeAuthnResponse( @@ -206,7 +205,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] # mock out the error renderer too sso_handler = self.hs.get_sso_handler() @@ -227,7 +226,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler and error renderer auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] sso_handler = self.hs.get_sso_handler() sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment] @@ -312,7 +311,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] # The response doesn't have the proper userGroup or department. saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index fe631d7ecb..9ce9326190 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor @@ -33,7 +33,6 @@ from synapse.util import Clock from tests.events.test_presence_router import send_presence_update, sync_presence from tests.replication._base import BaseMultiWorkerStreamTestCase -from tests.test_utils import simple_async_mock from tests.test_utils.event_injection import inject_member_event from tests.unittest import HomeserverTestCase, override_config @@ -70,7 +69,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # Mock out the calls over federation. self.fed_transport_client = Mock(spec=["send_transaction"]) - self.fed_transport_client.send_transaction = simple_async_mock({}) + self.fed_transport_client.send_transaction = AsyncMock(return_value={}) return self.setup_test_homeserver( federation_transport_client=self.fed_transport_client, @@ -579,9 +578,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase): """Test that the module API can join a remote room.""" # Necessary to fake a remote join. fake_stream_id = 1 - mocked_remote_join = simple_async_mock( - return_value=("fake-event-id", fake_stream_id) - ) + mocked_remote_join = AsyncMock(return_value=("fake-event-id", fake_stream_id)) self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[assignment] fake_remote_host = f"{self.module_api.server_name}-remote" diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 937e6ebb7d..a3880ac171 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any, Optional -from unittest.mock import patch +from unittest.mock import AsyncMock, patch from parameterized import parameterized @@ -28,7 +28,6 @@ from synapse.server import HomeServer from synapse.types import JsonDict, create_requester from synapse.util import Clock -from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config @@ -191,7 +190,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): # Mock the method which calculates push rules -- we do this instead of # e.g. checking the results in the database because we want to ensure # that code isn't even running. - bulk_evaluator._action_for_event_by_user = simple_async_mock() # type: ignore[assignment] + bulk_evaluator._action_for_event_by_user = AsyncMock() # type: ignore[assignment] # Ensure no actions are generated! self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) diff --git a/tests/rest/client/test_notifications.py b/tests/rest/client/test_notifications.py index 700f6587a0..41ceb3db51 100644 --- a/tests/rest/client/test_notifications.py +++ b/tests/rest/client/test_notifications.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -20,7 +20,6 @@ from synapse.rest.client import login, notifications, receipts, room from synapse.server import HomeServer from synapse.util import Clock -from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase @@ -45,7 +44,7 @@ class HTTPPusherTests(HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # Mock out the calls over federation. fed_transport_client = Mock(spec=["send_transaction"]) - fed_transport_client.send_transaction = simple_async_mock({}) + fed_transport_client.send_transaction = AsyncMock(return_value={}) return self.setup_test_homeserver( federation_transport_client=fed_transport_client, diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 2af7280ba3..52beb4e89d 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -32,7 +32,6 @@ from synapse.types import JsonDict from synapse.util import Clock from tests import unittest -from tests.test_utils import simple_async_mock from tests.unittest import override_config @@ -348,8 +347,8 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): # Mock out the AsyncContextManager class MockCM: - __aenter__ = simple_async_mock(return_value=None) - __aexit__ = simple_async_mock(return_value=None) + __aenter__ = AsyncMock(return_value=None) + __aexit__ = AsyncMock(return_value=None) self._update_ctx_manager = MockCM diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 21e112a8b5..fa731426cd 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -19,8 +19,7 @@ import json import sys import warnings from binascii import unhexlify -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar -from unittest.mock import Mock +from typing import TYPE_CHECKING, Awaitable, Callable, Tuple, TypeVar import attr import zope.interface @@ -62,10 +61,6 @@ def setup_awaitable_errors() -> Callable[[], None]: """ warnings.simplefilter("error", RuntimeWarning) - # unraisablehook was added in Python 3.8. - if not hasattr(sys, "unraisablehook"): - return lambda: None - # State shared between unraisablehook and check_for_unraisable_exceptions. unraisable_exceptions = [] orig_unraisablehook = sys.unraisablehook @@ -88,18 +83,6 @@ def setup_awaitable_errors() -> Callable[[], None]: return cleanup -def simple_async_mock( - return_value: Optional[TV] = None, raises: Optional[Exception] = None -) -> Mock: - # AsyncMock is not available in python3.5, this mimics part of its behaviour - async def cb(*args: Any, **kwargs: Any) -> Optional[TV]: - if raises: - raise raises - return return_value - - return Mock(side_effect=cb) - - # Type ignore: it does not fully implement IResponse, but is good enough for tests @zope.interface.implementer(IResponse) @attr.s(slots=True, frozen=True, auto_attribs=True)