Replace simple_async_mock with AsyncMock (#16180)

Python 3.8 has a native AsyncMock, use it instead of a custom
implementation.
This commit is contained in:
Patrick Cloke 2023-08-25 09:27:21 -04:00 committed by GitHub
parent 5c9402b9fd
commit a8a46b1336
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 140 additions and 160 deletions

1
changelog.d/16180.misc Normal file
View File

@ -0,0 +1 @@
Use `AsyncMock` instead of custom code.

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
import pymacaroons
@ -35,7 +35,6 @@ from synapse.types import Requester, UserID
from synapse.util import Clock
from tests import unittest
from tests.test_utils import simple_async_mock
from tests.unittest import override_config
from tests.utils import mock_getRawHeaders
@ -60,16 +59,16 @@ class AuthTestCase(unittest.HomeserverTestCase):
# this is overridden for the appservice tests
self.store.get_app_service_by_token = Mock(return_value=None)
self.store.insert_client_ip = simple_async_mock(None)
self.store.is_support_user = simple_async_mock(False)
self.store.insert_client_ip = AsyncMock(return_value=None)
self.store.is_support_user = AsyncMock(return_value=False)
def test_get_user_by_req_user_valid_token(self) -> None:
user_info = TokenLookupResult(
user_id=self.test_user, token_id=5, device_id="device"
)
self.store.get_user_by_access_token = simple_async_mock(user_info)
self.store.mark_access_token_as_used = simple_async_mock(None)
self.store.get_user_locked_status = simple_async_mock(False)
self.store.get_user_by_access_token = AsyncMock(return_value=user_info)
self.store.mark_access_token_as_used = AsyncMock(return_value=None)
self.store.get_user_locked_status = AsyncMock(return_value=False)
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
@ -78,7 +77,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self) -> None:
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
@ -91,7 +90,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_get_user_by_req_user_missing_token(self) -> None:
user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
self.store.get_user_by_access_token = simple_async_mock(user_info)
self.store.get_user_by_access_token = AsyncMock(return_value=user_info)
request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@ -106,7 +105,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
@ -125,7 +124,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={})
request.getClientAddress.return_value.host = "192.168.10.10"
@ -144,7 +143,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={})
request.getClientAddress.return_value.host = "131.111.8.42"
@ -158,7 +157,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_get_user_by_req_appservice_bad_token(self) -> None:
self.store.get_app_service_by_token = Mock(return_value=None)
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
@ -172,7 +171,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_get_user_by_req_appservice_missing_token(self) -> None:
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@ -190,8 +189,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
# This just needs to return a truth-y value.
self.store.get_user_by_id = simple_async_mock({"is_guest": False})
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
@ -210,7 +209,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
@ -234,10 +233,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
# This just needs to return a truth-y value.
self.store.get_user_by_id = simple_async_mock({"is_guest": False})
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
self.store.get_user_by_access_token = AsyncMock(return_value=None)
# This also needs to just return a truth-y value
self.store.get_device = simple_async_mock({"hidden": False})
self.store.get_device = AsyncMock(return_value={"hidden": False})
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
@ -266,10 +265,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
# This just needs to return a truth-y value.
self.store.get_user_by_id = simple_async_mock({"is_guest": False})
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
self.store.get_user_by_access_token = AsyncMock(return_value=None)
# This also needs to just return a falsey value
self.store.get_device = simple_async_mock(None)
self.store.get_device = AsyncMock(return_value=None)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
@ -283,8 +282,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE)
def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> None:
self.store.get_user_by_access_token = simple_async_mock(
TokenLookupResult(
self.store.get_user_by_access_token = AsyncMock(
return_value=TokenLookupResult(
user_id="@baldrick:matrix.org",
device_id="device",
token_id=5,
@ -292,9 +291,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
token_used=True,
)
)
self.store.insert_client_ip = simple_async_mock(None)
self.store.mark_access_token_as_used = simple_async_mock(None)
self.store.get_user_locked_status = simple_async_mock(False)
self.store.insert_client_ip = AsyncMock(return_value=None)
self.store.mark_access_token_as_used = AsyncMock(return_value=None)
self.store.get_user_locked_status = AsyncMock(return_value=False)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
@ -304,8 +303,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None:
self.auth._track_puppeted_user_ips = True
self.store.get_user_by_access_token = simple_async_mock(
TokenLookupResult(
self.store.get_user_by_access_token = AsyncMock(
return_value=TokenLookupResult(
user_id="@baldrick:matrix.org",
device_id="device",
token_id=5,
@ -313,9 +312,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
token_used=True,
)
)
self.store.get_user_locked_status = simple_async_mock(False)
self.store.insert_client_ip = simple_async_mock(None)
self.store.mark_access_token_as_used = simple_async_mock(None)
self.store.get_user_locked_status = AsyncMock(return_value=False)
self.store.insert_client_ip = AsyncMock(return_value=None)
self.store.mark_access_token_as_used = AsyncMock(return_value=None)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
@ -324,7 +323,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(self.store.insert_client_ip.call_count, 2)
def test_get_user_from_macaroon(self) -> None:
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_access_token = AsyncMock(return_value=None)
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
@ -342,8 +341,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
def test_get_guest_user_from_macaroon(self) -> None:
self.store.get_user_by_id = simple_async_mock({"is_guest": True})
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True})
self.store.get_user_by_access_token = AsyncMock(return_value=None)
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
@ -373,7 +372,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.auth_blocking._limit_usage_by_mau = True
self.store.get_monthly_active_count = simple_async_mock(lots_of_users)
self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users)
e = self.get_failure(
self.auth_blocking.check_auth_blocking(), ResourceLimitError
@ -383,25 +382,27 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(e.value.code, 403)
# Ensure does not throw an error
self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
self.store.get_monthly_active_count = AsyncMock(
return_value=small_number_of_users
)
self.get_success(self.auth_blocking.check_auth_blocking())
def test_blocking_mau__depending_on_user_type(self) -> None:
self.auth_blocking._max_mau_value = 50
self.auth_blocking._limit_usage_by_mau = True
self.store.get_monthly_active_count = simple_async_mock(100)
self.store.get_monthly_active_count = AsyncMock(return_value=100)
# Support users allowed
self.get_success(
self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT)
)
self.store.get_monthly_active_count = simple_async_mock(100)
self.store.get_monthly_active_count = AsyncMock(return_value=100)
# Bots not allowed
self.get_failure(
self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT),
ResourceLimitError,
)
self.store.get_monthly_active_count = simple_async_mock(100)
self.store.get_monthly_active_count = AsyncMock(return_value=100)
# Real users not allowed
self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
@ -412,9 +413,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._track_appservice_user_ips = False
self.store.get_monthly_active_count = simple_async_mock(100)
self.store.user_last_seen_monthly_active = simple_async_mock()
self.store.is_trial_user = simple_async_mock()
self.store.get_monthly_active_count = AsyncMock(return_value=100)
self.store.user_last_seen_monthly_active = AsyncMock(return_value=None)
self.store.is_trial_user = AsyncMock(return_value=False)
appservice = ApplicationService(
"abcd",
@ -443,9 +444,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._track_appservice_user_ips = True
self.store.get_monthly_active_count = simple_async_mock(100)
self.store.user_last_seen_monthly_active = simple_async_mock()
self.store.is_trial_user = simple_async_mock()
self.store.get_monthly_active_count = AsyncMock(return_value=100)
self.store.user_last_seen_monthly_active = AsyncMock(return_value=None)
self.store.is_trial_user = AsyncMock(return_value=False)
appservice = ApplicationService(
"abcd",
@ -473,7 +474,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_reserved_threepid(self) -> None:
self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._max_mau_value = 1
self.store.get_monthly_active_count = simple_async_mock(2)
self.store.get_monthly_active_count = AsyncMock(return_value=2)
threepid = {"medium": "email", "address": "reserved@server.com"}
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
self.auth_blocking._mau_limits_reserved_threepids = [threepid]

View File

@ -13,14 +13,13 @@
# limitations under the License.
import re
from typing import Any, Generator
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
from twisted.internet import defer
from synapse.appservice import ApplicationService, Namespace
from tests import unittest
from tests.test_utils import simple_async_mock
def _regex(regex: str, exclusive: bool = True) -> Namespace:
@ -43,8 +42,8 @@ class ApplicationServiceTestCase(unittest.TestCase):
)
self.store = Mock()
self.store.get_aliases_for_room = simple_async_mock([])
self.store.get_local_users_in_room = simple_async_mock([])
self.store.get_aliases_for_room = AsyncMock(return_value=[])
self.store.get_local_users_in_room = AsyncMock(return_value=[])
@defer.inlineCallbacks
def test_regex_user_id_prefix_match(
@ -127,10 +126,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
self.store.get_aliases_for_room = simple_async_mock(
["#irc_foobar:matrix.org", "#athing:matrix.org"]
self.store.get_aliases_for_room = AsyncMock(
return_value=["#irc_foobar:matrix.org", "#athing:matrix.org"]
)
self.store.get_local_users_in_room = simple_async_mock([])
self.store.get_local_users_in_room = AsyncMock(return_value=[])
self.assertTrue(
(
yield self.service.is_interested_in_event(
@ -182,10 +181,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
self.store.get_aliases_for_room = simple_async_mock(
["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
self.store.get_aliases_for_room = AsyncMock(
return_value=["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
)
self.store.get_local_users_in_room = simple_async_mock([])
self.store.get_local_users_in_room = AsyncMock(return_value=[])
self.assertFalse(
(
yield defer.ensureDeferred(
@ -205,8 +204,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
)
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@irc_foobar:matrix.org"
self.store.get_aliases_for_room = simple_async_mock(["#irc_barfoo:matrix.org"])
self.store.get_local_users_in_room = simple_async_mock([])
self.store.get_aliases_for_room = AsyncMock(
return_value=["#irc_barfoo:matrix.org"]
)
self.store.get_local_users_in_room = AsyncMock(return_value=[])
self.assertTrue(
(
yield self.service.is_interested_in_event(
@ -235,10 +236,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
# Note that @irc_fo:here is the AS user.
self.store.get_local_users_in_room = simple_async_mock(
["@alice:here", "@irc_fo:here", "@bob:here"]
self.store.get_local_users_in_room = AsyncMock(
return_value=["@alice:here", "@irc_fo:here", "@bob:here"]
)
self.store.get_aliases_for_room = simple_async_mock([])
self.store.get_aliases_for_room = AsyncMock(return_value=[])
self.event.sender = "@xmpp_foobar:matrix.org"
self.assertTrue(

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Sequence, Tuple, cast
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
from typing_extensions import TypeAlias
@ -37,7 +37,6 @@ from synapse.types import DeviceListUpdates, JsonDict
from synapse.util import Clock
from tests import unittest
from tests.test_utils import simple_async_mock
from ..utils import MockClock
@ -62,10 +61,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
txn = Mock(id=txn_id, service=service, events=events)
# mock methods
self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP)
txn.send = simple_async_mock(True)
txn.complete = simple_async_mock(True)
self.store.create_appservice_txn = simple_async_mock(txn)
self.store.get_appservice_state = AsyncMock(
return_value=ApplicationServiceState.UP
)
txn.send = AsyncMock(return_value=True)
txn.complete = AsyncMock(return_value=True)
self.store.create_appservice_txn = AsyncMock(return_value=txn)
# actual call
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
@ -89,10 +90,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
events = [Mock(), Mock()]
txn = Mock(id="idhere", service=service, events=events)
self.store.get_appservice_state = simple_async_mock(
ApplicationServiceState.DOWN
self.store.get_appservice_state = AsyncMock(
return_value=ApplicationServiceState.DOWN
)
self.store.create_appservice_txn = simple_async_mock(txn)
self.store.create_appservice_txn = AsyncMock(return_value=txn)
# actual call
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
@ -118,10 +119,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
txn = Mock(id=txn_id, service=service, events=events)
# mock methods
self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP)
self.store.set_appservice_state = simple_async_mock(True)
txn.send = simple_async_mock(False) # fails to send
self.store.create_appservice_txn = simple_async_mock(txn)
self.store.get_appservice_state = AsyncMock(
return_value=ApplicationServiceState.UP
)
self.store.set_appservice_state = AsyncMock(return_value=True)
txn.send = AsyncMock(return_value=False) # fails to send
self.store.create_appservice_txn = AsyncMock(return_value=txn)
# actual call
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
@ -150,7 +153,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.as_api = Mock()
self.store = Mock()
self.service = Mock()
self.callback = simple_async_mock()
self.callback = AsyncMock()
self.recoverer = _Recoverer(
clock=cast(Clock, self.clock),
as_api=self.as_api,
@ -174,8 +177,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.recoverer.recover()
# shouldn't have called anything prior to waiting for exp backoff
self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count)
txn.send = simple_async_mock(True)
txn.complete = simple_async_mock(None)
txn.send = AsyncMock(return_value=True)
txn.complete = AsyncMock(return_value=None)
# wait for exp backoff
self.clock.advance_time(2)
self.assertEqual(1, txn.send.call_count)
@ -202,8 +205,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.recoverer.recover()
self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count)
txn.send = simple_async_mock(False)
txn.complete = simple_async_mock(None)
txn.send = AsyncMock(return_value=False)
txn.complete = AsyncMock(return_value=None)
self.clock.advance_time(2)
self.assertEqual(1, txn.send.call_count)
self.assertEqual(0, txn.complete.call_count)
@ -216,7 +219,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.assertEqual(3, txn.send.call_count)
self.assertEqual(0, txn.complete.call_count)
self.assertEqual(0, self.callback.call_count)
txn.send = simple_async_mock(True) # successfully send the txn
txn.send = AsyncMock(return_value=True) # successfully send the txn
pop_txn = True # returns the txn the first time, then no more.
self.clock.advance_time(16)
self.assertEqual(1, txn.send.call_count) # new mock reset call count
@ -244,7 +247,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer) -> None:
self.scheduler = ApplicationServiceScheduler(hs)
self.txn_ctrl = Mock()
self.txn_ctrl.send = simple_async_mock()
self.txn_ctrl.send = AsyncMock()
# Replace instantiated _TransactionController instances with our Mock
self.scheduler.txn_ctrl = self.txn_ctrl

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
import attr
@ -30,7 +30,6 @@ from synapse.types import JsonDict, StreamToken, create_requester
from synapse.util import Clock
from tests.handlers.test_sync import generate_sync_config
from tests.test_utils import simple_async_mock
from tests.unittest import (
FederatingHomeserverTestCase,
HomeserverTestCase,
@ -157,7 +156,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation.
self.fed_transport_client = Mock(spec=["send_transaction"])
self.fed_transport_client.send_transaction = simple_async_mock({})
self.fed_transport_client.send_transaction = AsyncMock(return_value={})
hs = self.setup_test_homeserver(
federation_transport_client=self.fed_transport_client,

View File

@ -36,7 +36,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
from tests.test_utils import event_injection, simple_async_mock
from tests.test_utils import event_injection
from tests.unittest import override_config
from tests.utils import MockClock
@ -399,7 +399,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
self.hs = hs
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
# we can track any outgoing ephemeral events
self.send_mock = simple_async_mock()
self.send_mock = AsyncMock()
hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment]
# Mock out application services, and allow defining our own in tests
@ -897,7 +897,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
# Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that
# will be sent over the wire
self.put_json = simple_async_mock()
self.put_json = AsyncMock()
hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment]
# Mock out application services, and allow defining our own in tests
@ -1003,7 +1003,7 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
# we can track what's going out
self.send_mock = simple_async_mock()
self.send_mock = AsyncMock()
hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method.
# Define an application service for the tests

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor
@ -20,7 +20,6 @@ from synapse.handlers.cas import CasResponse
from synapse.server import HomeServer
from synapse.util import Clock
from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
# These are a few constants that are used as config parameters in the tests.
@ -61,7 +60,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
cas_response = CasResponse("test_user", {})
request = _mock_request()
@ -89,7 +88,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
# Map a user via SSO.
cas_response = CasResponse("test_user", {})
@ -129,7 +128,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
cas_response = CasResponse("föö", {})
request = _mock_request()
@ -160,7 +159,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
# The response doesn't have the proper userGroup or department.
cas_response = CasResponse("test_user", {})

View File

@ -39,7 +39,7 @@ from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
from tests.test_utils import FakeResponse, get_awaitable_result
from tests.unittest import HomeserverTestCase, skip_unless
from tests.utils import mock_getRawHeaders
@ -147,7 +147,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_inactive_token(self) -> None:
"""The handler should return a 403 where the token is inactive."""
self.http_client.request = simple_async_mock(
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={"active": False},
@ -166,7 +166,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_no_scope(self) -> None:
"""The handler should return a 403 where no scope is given."""
self.http_client.request = simple_async_mock(
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={"active": True},
@ -185,7 +185,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_user_no_subject(self) -> None:
"""The handler should return a 500 when no subject is present."""
self.http_client.request = simple_async_mock(
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={"active": True, "scope": " ".join([MATRIX_USER_SCOPE])},
@ -204,7 +204,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_no_user_scope(self) -> None:
"""The handler should return a 500 when no subject is present."""
self.http_client.request = simple_async_mock(
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
@ -227,7 +227,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_admin_not_user(self) -> None:
"""The handler should raise when the scope has admin right but not user."""
self.http_client.request = simple_async_mock(
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
@ -251,7 +251,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_admin(self) -> None:
"""The handler should return a requester with admin rights."""
self.http_client.request = simple_async_mock(
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
@ -281,7 +281,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_admin_highest_privilege(self) -> None:
"""The handler should resolve to the most permissive scope."""
self.http_client.request = simple_async_mock(
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
@ -313,7 +313,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_user(self) -> None:
"""The handler should return a requester with normal user rights."""
self.http_client.request = simple_async_mock(
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
@ -344,7 +344,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
"""The handler should return a requester with normal user rights
and an user ID matching the one specified in query param `user_id`"""
self.http_client.request = simple_async_mock(
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
@ -378,7 +378,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_user_with_device(self) -> None:
"""The handler should return a requester with normal user rights and a device ID."""
self.http_client.request = simple_async_mock(
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
@ -408,7 +408,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_multiple_devices(self) -> None:
"""The handler should raise an error if multiple devices are found in the scope."""
self.http_client.request = simple_async_mock(
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
@ -433,7 +433,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_guest_not_allowed(self) -> None:
"""The handler should return an insufficient scope error."""
self.http_client.request = simple_async_mock(
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
@ -463,7 +463,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_guest_allowed(self) -> None:
"""The handler should return a requester with guest user rights and a device ID."""
self.http_client.request = simple_async_mock(
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
@ -499,19 +499,19 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
# The introspection endpoint is returning an error.
self.http_client.request = simple_async_mock(
self.http_client.request = AsyncMock(
return_value=FakeResponse(code=500, body=b"Internal Server Error")
)
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
self.assertEqual(error.value.code, 503)
# The introspection endpoint request fails.
self.http_client.request = simple_async_mock(raises=Exception())
self.http_client.request = AsyncMock(side_effect=Exception())
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
self.assertEqual(error.value.code, 503)
# The introspection endpoint does not return a JSON object.
self.http_client.request = simple_async_mock(
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200, payload=["this is an array", "not an object"]
)
@ -520,7 +520,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
self.assertEqual(error.value.code, 503)
# The introspection endpoint does not return valid JSON.
self.http_client.request = simple_async_mock(
self.http_client.request = AsyncMock(
return_value=FakeResponse(code=200, body=b"this is not valid JSON")
)
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
@ -528,7 +528,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_introspection_token_cache(self) -> None:
access_token = "open_sesame"
self.http_client.request = simple_async_mock(
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={"active": "true", "scope": "guest", "jti": access_token},
@ -559,7 +559,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
# test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a
# token with a soon-to-expire `exp` field to the cache
self.http_client.request = simple_async_mock(
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
@ -640,7 +640,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_cross_signing(self) -> None:
"""Try uploading device keys with OAuth delegation enabled."""
self.http_client.request = simple_async_mock(
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={

View File

@ -13,7 +13,7 @@
# limitations under the License.
import os
from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple
from unittest.mock import ANY, Mock, patch
from unittest.mock import ANY, AsyncMock, Mock, patch
from urllib.parse import parse_qs, urlparse
import pymacaroons
@ -28,7 +28,7 @@ from synapse.util import Clock
from synapse.util.macaroons import get_value_from_macaroon
from synapse.util.stringutils import random_string
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
from tests.test_utils import FakeResponse, get_awaitable_result
from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
from tests.unittest import HomeserverTestCase, override_config
@ -164,7 +164,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
auth_handler = hs.get_auth_handler()
# Mock the complete SSO login method.
self.complete_sso_login = simple_async_mock()
self.complete_sso_login = AsyncMock()
auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment]
return hs

View File

@ -13,7 +13,7 @@
# limitations under the License.
from typing import Any, Dict, Optional, Set, Tuple
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
import attr
@ -25,7 +25,6 @@ from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
# Check if we have the dependencies to run the tests.
@ -134,7 +133,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
# send a mocked-up SAML response to the callback
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
@ -164,7 +163,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
# Map a user via SSO.
saml_response = FakeAuthnResponse(
@ -206,7 +205,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
# mock out the error renderer too
sso_handler = self.hs.get_sso_handler()
@ -227,7 +226,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler and error renderer
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
sso_handler = self.hs.get_sso_handler()
sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment]
@ -312,7 +311,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
# The response doesn't have the proper userGroup or department.
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
@ -33,7 +33,6 @@ from synapse.util import Clock
from tests.events.test_presence_router import send_presence_update, sync_presence
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.test_utils import simple_async_mock
from tests.test_utils.event_injection import inject_member_event
from tests.unittest import HomeserverTestCase, override_config
@ -70,7 +69,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation.
self.fed_transport_client = Mock(spec=["send_transaction"])
self.fed_transport_client.send_transaction = simple_async_mock({})
self.fed_transport_client.send_transaction = AsyncMock(return_value={})
return self.setup_test_homeserver(
federation_transport_client=self.fed_transport_client,
@ -579,9 +578,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
"""Test that the module API can join a remote room."""
# Necessary to fake a remote join.
fake_stream_id = 1
mocked_remote_join = simple_async_mock(
return_value=("fake-event-id", fake_stream_id)
)
mocked_remote_join = AsyncMock(return_value=("fake-event-id", fake_stream_id))
self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[assignment]
fake_remote_host = f"{self.module_api.server_name}-remote"

View File

@ -13,7 +13,7 @@
# limitations under the License.
from typing import Any, Optional
from unittest.mock import patch
from unittest.mock import AsyncMock, patch
from parameterized import parameterized
@ -28,7 +28,6 @@ from synapse.server import HomeServer
from synapse.types import JsonDict, create_requester
from synapse.util import Clock
from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
@ -191,7 +190,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
# Mock the method which calculates push rules -- we do this instead of
# e.g. checking the results in the database because we want to ensure
# that code isn't even running.
bulk_evaluator._action_for_event_by_user = simple_async_mock() # type: ignore[assignment]
bulk_evaluator._action_for_event_by_user = AsyncMock() # type: ignore[assignment]
# Ensure no actions are generated!
self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor
@ -20,7 +20,6 @@ from synapse.rest.client import login, notifications, receipts, room
from synapse.server import HomeServer
from synapse.util import Clock
from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase
@ -45,7 +44,7 @@ class HTTPPusherTests(HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation.
fed_transport_client = Mock(spec=["send_transaction"])
fed_transport_client.send_transaction = simple_async_mock({})
fed_transport_client.send_transaction = AsyncMock(return_value={})
return self.setup_test_homeserver(
federation_transport_client=fed_transport_client,

View File

@ -32,7 +32,6 @@ from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
from tests.test_utils import simple_async_mock
from tests.unittest import override_config
@ -348,8 +347,8 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
# Mock out the AsyncContextManager
class MockCM:
__aenter__ = simple_async_mock(return_value=None)
__aexit__ = simple_async_mock(return_value=None)
__aenter__ = AsyncMock(return_value=None)
__aexit__ = AsyncMock(return_value=None)
self._update_ctx_manager = MockCM

View File

@ -19,8 +19,7 @@ import json
import sys
import warnings
from binascii import unhexlify
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar
from unittest.mock import Mock
from typing import TYPE_CHECKING, Awaitable, Callable, Tuple, TypeVar
import attr
import zope.interface
@ -62,10 +61,6 @@ def setup_awaitable_errors() -> Callable[[], None]:
"""
warnings.simplefilter("error", RuntimeWarning)
# unraisablehook was added in Python 3.8.
if not hasattr(sys, "unraisablehook"):
return lambda: None
# State shared between unraisablehook and check_for_unraisable_exceptions.
unraisable_exceptions = []
orig_unraisablehook = sys.unraisablehook
@ -88,18 +83,6 @@ def setup_awaitable_errors() -> Callable[[], None]:
return cleanup
def simple_async_mock(
return_value: Optional[TV] = None, raises: Optional[Exception] = None
) -> Mock:
# AsyncMock is not available in python3.5, this mimics part of its behaviour
async def cb(*args: Any, **kwargs: Any) -> Optional[TV]:
if raises:
raise raises
return return_value
return Mock(side_effect=cb)
# Type ignore: it does not fully implement IResponse, but is good enough for tests
@zope.interface.implementer(IResponse)
@attr.s(slots=True, frozen=True, auto_attribs=True)