Run pyupgrade for python 3.7 & 3.8. (#16110)

This commit is contained in:
Patrick Cloke 2023-08-15 08:11:20 -04:00 committed by GitHub
parent 4347473946
commit ad3f43be9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
43 changed files with 113 additions and 121 deletions

1
changelog.d/16110.misc Normal file
View File

@ -0,0 +1 @@
Run `pyupgrade` for Python 3.8+.

View File

@ -769,7 +769,7 @@ def main(server_url, identity_server_url, username, token, config_path):
global CONFIG_JSON global CONFIG_JSON
CONFIG_JSON = config_path # bit cheeky, but just overwrite the global CONFIG_JSON = config_path # bit cheeky, but just overwrite the global
try: try:
with open(config_path, "r") as config: with open(config_path) as config:
syn_cmd.config = json.load(config) syn_cmd.config = json.load(config)
try: try:
http_client.verbose = "on" == syn_cmd.config["verbose"] http_client.verbose = "on" == syn_cmd.config["verbose"]

View File

@ -861,7 +861,7 @@ def generate_worker_files(
# Then a worker config file # Then a worker config file
convert( convert(
"/conf/worker.yaml.j2", "/conf/worker.yaml.j2",
"/conf/workers/{name}.yaml".format(name=worker_name), f"/conf/workers/{worker_name}.yaml",
**worker_config, **worker_config,
worker_log_config_filepath=log_config_filepath, worker_log_config_filepath=log_config_filepath,
using_unix_sockets=using_unix_sockets, using_unix_sockets=using_unix_sockets,

View File

@ -82,7 +82,7 @@ def generate_config_from_template(
with open(filename) as handle: with open(filename) as handle:
value = handle.read() value = handle.read()
else: else:
log("Generating a random secret for {}".format(secret)) log(f"Generating a random secret for {secret}")
value = codecs.encode(os.urandom(32), "hex").decode() value = codecs.encode(os.urandom(32), "hex").decode()
with open(filename, "w") as handle: with open(filename, "w") as handle:
handle.write(value) handle.write(value)

View File

@ -47,7 +47,7 @@ can be passed on the commandline for debugging.
projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
class Builder(object): class Builder:
def __init__( def __init__(
self, self,
redirect_stdout: bool = False, redirect_stdout: bool = False,

View File

@ -43,7 +43,7 @@ def main(force_colors: bool) -> None:
diffs: List[git.Diff] = repo.remote().refs.develop.commit.diff(None) diffs: List[git.Diff] = repo.remote().refs.develop.commit.diff(None)
# Get the schema version of the local file to check against current schema on develop # Get the schema version of the local file to check against current schema on develop
with open("synapse/storage/schema/__init__.py", "r") as file: with open("synapse/storage/schema/__init__.py") as file:
local_schema = file.read() local_schema = file.read()
new_locals: Dict[str, Any] = {} new_locals: Dict[str, Any] = {}
exec(local_schema, new_locals) exec(local_schema, new_locals)

View File

@ -247,7 +247,7 @@ def main() -> None:
def read_args_from_config(args: argparse.Namespace) -> None: def read_args_from_config(args: argparse.Namespace) -> None:
with open(args.config, "r") as fh: with open(args.config) as fh:
config = yaml.safe_load(fh) config = yaml.safe_load(fh)
if not args.server_name: if not args.server_name:

View File

@ -1,5 +1,4 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright 2020 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -145,7 +145,7 @@ Example usage:
def read_args_from_config(args: argparse.Namespace) -> None: def read_args_from_config(args: argparse.Namespace) -> None:
with open(args.config, "r") as fh: with open(args.config) as fh:
config = yaml.safe_load(fh) config = yaml.safe_load(fh)
if not args.server_name: if not args.server_name:
args.server_name = config["server_name"] args.server_name = config["server_name"]

View File

@ -25,7 +25,11 @@ from synapse.util.rust import check_rust_lib_up_to_date
from synapse.util.stringutils import strtobool from synapse.util.stringutils import strtobool
# Check that we're not running on an unsupported Python version. # Check that we're not running on an unsupported Python version.
if sys.version_info < (3, 8): #
# Note that we use an (unneeded) variable here so that pyupgrade doesn't nuke the
# if-statement completely.
py_version = sys.version_info
if py_version < (3, 8):
print("Synapse requires Python 3.8 or above.") print("Synapse requires Python 3.8 or above.")
sys.exit(1) sys.exit(1)
@ -78,7 +82,7 @@ try:
except ImportError: except ImportError:
pass pass
import synapse.util import synapse.util # noqa: E402
__version__ = synapse.util.SYNAPSE_VERSION __version__ = synapse.util.SYNAPSE_VERSION

View File

@ -1205,10 +1205,10 @@ class CursesProgress(Progress):
self.total_processed = 0 self.total_processed = 0
self.total_remaining = 0 self.total_remaining = 0
super(CursesProgress, self).__init__() super().__init__()
def update(self, table: str, num_done: int) -> None: def update(self, table: str, num_done: int) -> None:
super(CursesProgress, self).update(table, num_done) super().update(table, num_done)
self.total_processed = 0 self.total_processed = 0
self.total_remaining = 0 self.total_remaining = 0
@ -1304,7 +1304,7 @@ class TerminalProgress(Progress):
"""Just prints progress to the terminal""" """Just prints progress to the terminal"""
def update(self, table: str, num_done: int) -> None: def update(self, table: str, num_done: int) -> None:
super(TerminalProgress, self).update(table, num_done) super().update(table, num_done)
data = self.tables[table] data = self.tables[table]

View File

@ -38,7 +38,7 @@ class MockHomeserver(HomeServer):
DATASTORE_CLASS = DataStore # type: ignore [assignment] DATASTORE_CLASS = DataStore # type: ignore [assignment]
def __init__(self, config: HomeServerConfig): def __init__(self, config: HomeServerConfig):
super(MockHomeserver, self).__init__( super().__init__(
hostname=config.server.server_name, hostname=config.server.server_name,
config=config, config=config,
reactor=reactor, reactor=reactor,

View File

@ -18,8 +18,7 @@
"""Contains constants from the specification.""" """Contains constants from the specification."""
import enum import enum
from typing import Final
from typing_extensions import Final
# the max size of a (canonical-json-encoded) event # the max size of a (canonical-json-encoded) event
MAX_PDU_SIZE = 65536 MAX_PDU_SIZE = 65536

View File

@ -32,6 +32,7 @@ from typing import (
Any, Any,
Callable, Callable,
Collection, Collection,
ContextManager,
Dict, Dict,
Generator, Generator,
Iterable, Iterable,
@ -43,7 +44,6 @@ from typing import (
) )
from prometheus_client import Counter from prometheus_client import Counter
from typing_extensions import ContextManager
import synapse.metrics import synapse.metrics
from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState

View File

@ -24,13 +24,14 @@ from typing import (
Iterable, Iterable,
List, List,
Mapping, Mapping,
NoReturn,
Optional, Optional,
Set, Set,
) )
from urllib.parse import urlencode from urllib.parse import urlencode
import attr import attr
from typing_extensions import NoReturn, Protocol from typing_extensions import Protocol
from twisted.web.iweb import IRequest from twisted.web.iweb import IRequest
from twisted.web.server import Request from twisted.web.server import Request
@ -791,7 +792,7 @@ class SsoHandler:
if code != 200: if code != 200:
raise Exception( raise Exception(
"GET request to download sso avatar image returned {}".format(code) f"GET request to download sso avatar image returned {code}"
) )
# upload name includes hash of the image file's content so that we can # upload name includes hash of the image file's content so that we can

View File

@ -14,9 +14,15 @@
# limitations under the License. # limitations under the License.
import logging import logging
from collections import Counter from collections import Counter
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple from typing import (
TYPE_CHECKING,
from typing_extensions import Counter as CounterType Any,
Counter as CounterType,
Dict,
Iterable,
Optional,
Tuple,
)
from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.metrics import event_processing_positions from synapse.metrics import event_processing_positions

View File

@ -1442,11 +1442,9 @@ class SyncHandler:
# Now we have our list of joined room IDs, exclude as configured and freeze # Now we have our list of joined room IDs, exclude as configured and freeze
joined_room_ids = frozenset( joined_room_ids = frozenset(
( room_id
room_id for room_id in mutable_joined_room_ids
for room_id in mutable_joined_room_ids if room_id not in mutable_rooms_to_exclude
if room_id not in mutable_rooms_to_exclude
)
) )
logger.debug( logger.debug(

View File

@ -18,10 +18,9 @@ import traceback
from collections import deque from collections import deque
from ipaddress import IPv4Address, IPv6Address, ip_address from ipaddress import IPv4Address, IPv6Address, ip_address
from math import floor from math import floor
from typing import Callable, Optional from typing import Callable, Deque, Optional
import attr import attr
from typing_extensions import Deque
from zope.interface import implementer from zope.interface import implementer
from twisted.application.internet import ClientService from twisted.application.internet import ClientService

View File

@ -426,9 +426,7 @@ class SpamCheckerModuleApiCallbacks:
generally discouraged as it doesn't support internationalization. generally discouraged as it doesn't support internationalization.
""" """
for callback in self._check_event_for_spam_callbacks: for callback in self._check_event_for_spam_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation(callback(event)) res = await delay_cancellation(callback(event))
if res is False or res == self.NOT_SPAM: if res is False or res == self.NOT_SPAM:
# This spam-checker accepts the event. # This spam-checker accepts the event.
@ -481,9 +479,7 @@ class SpamCheckerModuleApiCallbacks:
True if the event should be silently dropped True if the event should be silently dropped
""" """
for callback in self._should_drop_federated_event_callbacks: for callback in self._should_drop_federated_event_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res: Union[bool, str] = await delay_cancellation(callback(event)) res: Union[bool, str] = await delay_cancellation(callback(event))
if res: if res:
return res return res
@ -505,9 +501,7 @@ class SpamCheckerModuleApiCallbacks:
NOT_SPAM if the operation is permitted, [Codes, Dict] otherwise. NOT_SPAM if the operation is permitted, [Codes, Dict] otherwise.
""" """
for callback in self._user_may_join_room_callbacks: for callback in self._user_may_join_room_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation(callback(user_id, room_id, is_invited)) res = await delay_cancellation(callback(user_id, room_id, is_invited))
# Normalize return values to `Codes` or `"NOT_SPAM"`. # Normalize return values to `Codes` or `"NOT_SPAM"`.
if res is True or res is self.NOT_SPAM: if res is True or res is self.NOT_SPAM:
@ -546,9 +540,7 @@ class SpamCheckerModuleApiCallbacks:
NOT_SPAM if the operation is permitted, Codes otherwise. NOT_SPAM if the operation is permitted, Codes otherwise.
""" """
for callback in self._user_may_invite_callbacks: for callback in self._user_may_invite_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation( res = await delay_cancellation(
callback(inviter_userid, invitee_userid, room_id) callback(inviter_userid, invitee_userid, room_id)
) )
@ -593,9 +585,7 @@ class SpamCheckerModuleApiCallbacks:
NOT_SPAM if the operation is permitted, Codes otherwise. NOT_SPAM if the operation is permitted, Codes otherwise.
""" """
for callback in self._user_may_send_3pid_invite_callbacks: for callback in self._user_may_send_3pid_invite_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation( res = await delay_cancellation(
callback(inviter_userid, medium, address, room_id) callback(inviter_userid, medium, address, room_id)
) )
@ -630,9 +620,7 @@ class SpamCheckerModuleApiCallbacks:
userid: The ID of the user attempting to create a room userid: The ID of the user attempting to create a room
""" """
for callback in self._user_may_create_room_callbacks: for callback in self._user_may_create_room_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation(callback(userid)) res = await delay_cancellation(callback(userid))
if res is True or res is self.NOT_SPAM: if res is True or res is self.NOT_SPAM:
continue continue
@ -666,9 +654,7 @@ class SpamCheckerModuleApiCallbacks:
""" """
for callback in self._user_may_create_room_alias_callbacks: for callback in self._user_may_create_room_alias_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation(callback(userid, room_alias)) res = await delay_cancellation(callback(userid, room_alias))
if res is True or res is self.NOT_SPAM: if res is True or res is self.NOT_SPAM:
continue continue
@ -701,9 +687,7 @@ class SpamCheckerModuleApiCallbacks:
room_id: The ID of the room that would be published room_id: The ID of the room that would be published
""" """
for callback in self._user_may_publish_room_callbacks: for callback in self._user_may_publish_room_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation(callback(userid, room_id)) res = await delay_cancellation(callback(userid, room_id))
if res is True or res is self.NOT_SPAM: if res is True or res is self.NOT_SPAM:
continue continue
@ -742,9 +726,7 @@ class SpamCheckerModuleApiCallbacks:
True if the user is spammy. True if the user is spammy.
""" """
for callback in self._check_username_for_spam_callbacks: for callback in self._check_username_for_spam_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
# Make a copy of the user profile object to ensure the spam checker cannot # Make a copy of the user profile object to ensure the spam checker cannot
# modify it. # modify it.
res = await delay_cancellation(callback(user_profile.copy())) res = await delay_cancellation(callback(user_profile.copy()))
@ -776,9 +758,7 @@ class SpamCheckerModuleApiCallbacks:
""" """
for callback in self._check_registration_for_spam_callbacks: for callback in self._check_registration_for_spam_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
behaviour = await delay_cancellation( behaviour = await delay_cancellation(
callback(email_threepid, username, request_info, auth_provider_id) callback(email_threepid, username, request_info, auth_provider_id)
) )
@ -820,9 +800,7 @@ class SpamCheckerModuleApiCallbacks:
""" """
for callback in self._check_media_file_for_spam_callbacks: for callback in self._check_media_file_for_spam_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation(callback(file_wrapper, file_info)) res = await delay_cancellation(callback(file_wrapper, file_info))
# Normalize return values to `Codes` or `"NOT_SPAM"`. # Normalize return values to `Codes` or `"NOT_SPAM"`.
if res is False or res is self.NOT_SPAM: if res is False or res is self.NOT_SPAM:
@ -869,9 +847,7 @@ class SpamCheckerModuleApiCallbacks:
""" """
for callback in self._check_login_for_spam_callbacks: for callback in self._check_login_for_spam_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation( res = await delay_cancellation(
callback( callback(
user_id, user_id,

View File

@ -17,6 +17,7 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Awaitable, Awaitable,
Deque,
Dict, Dict,
Iterable, Iterable,
Iterator, Iterator,
@ -29,7 +30,6 @@ from typing import (
) )
from prometheus_client import Counter from prometheus_client import Counter
from typing_extensions import Deque
from twisted.internet.protocol import ReconnectingClientFactory from twisted.internet.protocol import ReconnectingClientFactory

View File

@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple, Union, cast from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from typing_extensions import TYPE_CHECKING
from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json

View File

@ -188,7 +188,7 @@ class KeyStore(SQLBaseStore):
# invalidate takes a tuple corresponding to the params of # invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one # _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id). # param, which is itself the 2-tuple (server_name, key_id).
self._get_server_keys_json.invalidate((((server_name, key_id),))) self._get_server_keys_json.invalidate(((server_name, key_id),))
@cached() @cached()
def _get_server_keys_json( def _get_server_keys_json(

View File

@ -19,6 +19,7 @@ from itertools import chain
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Counter,
Dict, Dict,
Iterable, Iterable,
List, List,
@ -28,8 +29,6 @@ from typing import (
cast, cast,
) )
from typing_extensions import Counter
from twisted.internet.defer import DeferredLock from twisted.internet.defer import DeferredLock
from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership

View File

@ -145,5 +145,5 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM
This is not provided by DBAPI2, and so needs engine-specific support. This is not provided by DBAPI2, and so needs engine-specific support.
""" """
with open(filepath, "rt") as f: with open(filepath) as f:
cls.executescript(cursor, f.read()) cls.executescript(cursor, f.read())

View File

@ -16,10 +16,18 @@ import logging
import os import os
import re import re
from collections import Counter from collections import Counter
from typing import Collection, Generator, Iterable, List, Optional, TextIO, Tuple from typing import (
Collection,
Counter as CounterType,
Generator,
Iterable,
List,
Optional,
TextIO,
Tuple,
)
import attr import attr
from typing_extensions import Counter as CounterType
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import LoggingDatabaseConnection, LoggingTransaction from synapse.storage.database import LoggingDatabaseConnection, LoggingTransaction

View File

@ -21,6 +21,7 @@ from typing import (
Any, Any,
ClassVar, ClassVar,
Dict, Dict,
Final,
List, List,
Mapping, Mapping,
Match, Match,
@ -38,7 +39,7 @@ import attr
from immutabledict import immutabledict from immutabledict import immutabledict
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from signedjson.types import VerifyKey from signedjson.types import VerifyKey
from typing_extensions import Final, TypedDict from typing_extensions import TypedDict
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
from zope.interface import Interface from zope.interface import Interface

View File

@ -22,6 +22,7 @@ import logging
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import ( from typing import (
Any, Any,
AsyncContextManager,
AsyncIterator, AsyncIterator,
Awaitable, Awaitable,
Callable, Callable,
@ -42,7 +43,7 @@ from typing import (
) )
import attr import attr
from typing_extensions import AsyncContextManager, Concatenate, Literal, ParamSpec from typing_extensions import Concatenate, Literal, ParamSpec
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import CancelledError from twisted.internet.defer import CancelledError

View File

@ -218,7 +218,7 @@ class MacaroonGenerator:
# to avoid validating those as guest tokens, we explicitely verify if # to avoid validating those as guest tokens, we explicitely verify if
# the macaroon includes the "guest = true" caveat. # the macaroon includes the "guest = true" caveat.
is_guest = any( is_guest = any(
(caveat.caveat_id == "guest = true" for caveat in macaroon.caveats) caveat.caveat_id == "guest = true" for caveat in macaroon.caveats
) )
if not is_guest: if not is_guest:

View File

@ -20,6 +20,7 @@ import typing
from typing import ( from typing import (
Any, Any,
Callable, Callable,
ContextManager,
DefaultDict, DefaultDict,
Dict, Dict,
Iterator, Iterator,
@ -33,7 +34,6 @@ from typing import (
from weakref import WeakSet from weakref import WeakSet
from prometheus_client.core import Counter from prometheus_client.core import Counter
from typing_extensions import ContextManager
from twisted.internet import defer from twisted.internet import defer

View File

@ -17,6 +17,7 @@ from enum import Enum, auto
from typing import ( from typing import (
Collection, Collection,
Dict, Dict,
Final,
FrozenSet, FrozenSet,
List, List,
Mapping, Mapping,
@ -27,7 +28,6 @@ from typing import (
) )
import attr import attr
from typing_extensions import Final
from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.events import EventBase from synapse.events import EventBase

View File

@ -26,7 +26,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase):
def make_homeserver( def make_homeserver(
self, reactor: ThreadedMemoryReactorClock, clock: Clock self, reactor: ThreadedMemoryReactorClock, clock: Clock
) -> HomeServer: ) -> HomeServer:
hs = super(PhoneHomeR30V2TestCase, self).make_homeserver(reactor, clock) hs = super().make_homeserver(reactor, clock)
# We don't want our tests to actually report statistics, so check # We don't want our tests to actually report statistics, so check
# that it's not enabled # that it's not enabled

View File

@ -312,7 +312,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
[("server9", get_key_id(key1))] [("server9", get_key_id(key1))]
) )
result = self.get_success(d) result = self.get_success(d)
self.assertEquals(result[("server9", get_key_id(key1))].valid_until_ts, 0) self.assertEqual(result[("server9", get_key_id(key1))].valid_until_ts, 0)
def test_verify_json_dedupes_key_requests(self) -> None: def test_verify_json_dedupes_key_requests(self) -> None:
"""Two requests for the same key should be deduped.""" """Two requests for the same key should be deduped."""

View File

@ -514,7 +514,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(response.code, 200) self.assertEqual(response.code, 200)
# Send the body # Send the body
request.write('{ "a": 1 }'.encode("ascii")) request.write(b'{ "a": 1 }')
request.finish() request.finish()
self.reactor.pump((0.1,)) self.reactor.pump((0.1,))

View File

@ -757,7 +757,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
self.assertEqual(channel.json_body["creator"], user_id) self.assertEqual(channel.json_body["creator"], user_id)
# Check room alias. # Check room alias.
self.assertEquals(room_alias, f"#foo-bar:{self.module_api.server_name}") self.assertEqual(room_alias, f"#foo-bar:{self.module_api.server_name}")
# Let's try a room with no alias. # Let's try a room with no alias.
room_id, room_alias = self.get_success( room_id, room_alias = self.get_success(

View File

@ -116,7 +116,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(request.method, b"GET") self.assertEqual(request.method, b"GET")
self.assertEqual( self.assertEqual(
request.path, request.path,
f"/_matrix/media/r0/download/{target}/{media_id}".encode("utf-8"), f"/_matrix/media/r0/download/{target}/{media_id}".encode(),
) )
self.assertEqual( self.assertEqual(
request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]

View File

@ -627,8 +627,8 @@ class RedactionsTestCase(HomeserverTestCase):
redact_event = timeline[-1] redact_event = timeline[-1]
self.assertEqual(redact_event["type"], EventTypes.Redaction) self.assertEqual(redact_event["type"], EventTypes.Redaction)
# The redacts key should be in the content and the redacts keys. # The redacts key should be in the content and the redacts keys.
self.assertEquals(redact_event["content"]["redacts"], event_id) self.assertEqual(redact_event["content"]["redacts"], event_id)
self.assertEquals(redact_event["redacts"], event_id) self.assertEqual(redact_event["redacts"], event_id)
# But it isn't actually part of the event. # But it isn't actually part of the event.
def get_event(txn: LoggingTransaction) -> JsonDict: def get_event(txn: LoggingTransaction) -> JsonDict:
@ -642,10 +642,10 @@ class RedactionsTestCase(HomeserverTestCase):
event_json = self.get_success( event_json = self.get_success(
main_datastore.db_pool.runInteraction("get_event", get_event) main_datastore.db_pool.runInteraction("get_event", get_event)
) )
self.assertEquals(event_json["type"], EventTypes.Redaction) self.assertEqual(event_json["type"], EventTypes.Redaction)
if expect_content: if expect_content:
self.assertNotIn("redacts", event_json) self.assertNotIn("redacts", event_json)
self.assertEquals(event_json["content"]["redacts"], event_id) self.assertEqual(event_json["content"]["redacts"], event_id)
else: else:
self.assertEquals(event_json["redacts"], event_id) self.assertEqual(event_json["redacts"], event_id)
self.assertNotIn("redacts", event_json["content"]) self.assertNotIn("redacts", event_json["content"])

View File

@ -129,7 +129,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}", f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}",
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(200, channel.code, channel.json_body)
return [ev["event_id"] for ev in channel.json_body["chunk"]] return [ev["event_id"] for ev in channel.json_body["chunk"]]
def _get_bundled_aggregations(self) -> JsonDict: def _get_bundled_aggregations(self) -> JsonDict:
@ -142,7 +142,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
f"/_matrix/client/v3/rooms/{self.room}/event/{self.parent_id}", f"/_matrix/client/v3/rooms/{self.room}/event/{self.parent_id}",
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(200, channel.code, channel.json_body)
return channel.json_body["unsigned"].get("m.relations", {}) return channel.json_body["unsigned"].get("m.relations", {})
def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
@ -1602,7 +1602,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads", f"/_matrix/client/v1/rooms/{self.room}/threads",
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(200, channel.code, channel.json_body)
threads = channel.json_body["chunk"] threads = channel.json_body["chunk"]
return [ return [
( (
@ -1634,7 +1634,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
################################################## ##################################################
# Check the test data is configured as expected. # # Check the test data is configured as expected. #
################################################## ##################################################
self.assertEquals(self._get_related_events(), list(reversed(thread_replies))) self.assertEqual(self._get_related_events(), list(reversed(thread_replies)))
relations = self._get_bundled_aggregations() relations = self._get_bundled_aggregations()
self.assertDictContainsSubset( self.assertDictContainsSubset(
{"count": 3, "current_user_participated": True}, {"count": 3, "current_user_participated": True},
@ -1655,7 +1655,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
self._redact(thread_replies.pop()) self._redact(thread_replies.pop())
# The thread should still exist, but the latest event should be updated. # The thread should still exist, but the latest event should be updated.
self.assertEquals(self._get_related_events(), list(reversed(thread_replies))) self.assertEqual(self._get_related_events(), list(reversed(thread_replies)))
relations = self._get_bundled_aggregations() relations = self._get_bundled_aggregations()
self.assertDictContainsSubset( self.assertDictContainsSubset(
{"count": 2, "current_user_participated": True}, {"count": 2, "current_user_participated": True},
@ -1674,7 +1674,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
self._redact(thread_replies.pop(0)) self._redact(thread_replies.pop(0))
# Nothing should have changed (except the thread count). # Nothing should have changed (except the thread count).
self.assertEquals(self._get_related_events(), thread_replies) self.assertEqual(self._get_related_events(), thread_replies)
relations = self._get_bundled_aggregations() relations = self._get_bundled_aggregations()
self.assertDictContainsSubset( self.assertDictContainsSubset(
{"count": 1, "current_user_participated": True}, {"count": 1, "current_user_participated": True},
@ -1691,11 +1691,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
# Redact the last remaining event. # # Redact the last remaining event. #
#################################### ####################################
self._redact(thread_replies.pop(0)) self._redact(thread_replies.pop(0))
self.assertEquals(thread_replies, []) self.assertEqual(thread_replies, [])
# The event should no longer be considered a thread. # The event should no longer be considered a thread.
self.assertEquals(self._get_related_events(), []) self.assertEqual(self._get_related_events(), [])
self.assertEquals(self._get_bundled_aggregations(), {}) self.assertEqual(self._get_bundled_aggregations(), {})
self.assertEqual(self._get_threads(), []) self.assertEqual(self._get_threads(), [])
def test_redact_parent_edit(self) -> None: def test_redact_parent_edit(self) -> None:
@ -1749,8 +1749,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
# The relations are returned. # The relations are returned.
event_ids = self._get_related_events() event_ids = self._get_related_events()
relations = self._get_bundled_aggregations() relations = self._get_bundled_aggregations()
self.assertEquals(event_ids, [related_event_id]) self.assertEqual(event_ids, [related_event_id])
self.assertEquals( self.assertEqual(
relations[RelationTypes.REFERENCE], relations[RelationTypes.REFERENCE],
{"chunk": [{"event_id": related_event_id}]}, {"chunk": [{"event_id": related_event_id}]},
) )
@ -1772,7 +1772,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
# The unredacted relation should still exist. # The unredacted relation should still exist.
event_ids = self._get_related_events() event_ids = self._get_related_events()
relations = self._get_bundled_aggregations() relations = self._get_bundled_aggregations()
self.assertEquals(len(event_ids), 1) self.assertEqual(len(event_ids), 1)
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {
"count": 1, "count": 1,
@ -1816,7 +1816,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads", f"/_matrix/client/v1/rooms/{self.room}/threads",
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(200, channel.code, channel.json_body)
threads = self._get_threads(channel.json_body) threads = self._get_threads(channel.json_body)
self.assertEqual(threads, [(thread_2, reply_2), (thread_1, reply_1)]) self.assertEqual(threads, [(thread_2, reply_2), (thread_1, reply_1)])
@ -1829,7 +1829,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads", f"/_matrix/client/v1/rooms/{self.room}/threads",
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(200, channel.code, channel.json_body)
# Tuple of (thread ID, latest event ID) for each thread. # Tuple of (thread ID, latest event ID) for each thread.
threads = self._get_threads(channel.json_body) threads = self._get_threads(channel.json_body)
self.assertEqual(threads, [(thread_1, reply_3), (thread_2, reply_2)]) self.assertEqual(threads, [(thread_1, reply_3), (thread_2, reply_2)])
@ -1850,7 +1850,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1", f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1",
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_2]) self.assertEqual(thread_roots, [thread_2])
@ -1864,7 +1864,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1&from={next_batch}", f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1&from={next_batch}",
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_1], channel.json_body) self.assertEqual(thread_roots, [thread_1], channel.json_body)
@ -1899,7 +1899,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads", f"/_matrix/client/v1/rooms/{self.room}/threads",
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual( self.assertEqual(
thread_roots, [thread_3, thread_2, thread_1], channel.json_body thread_roots, [thread_3, thread_2, thread_1], channel.json_body
@ -1911,7 +1911,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads?include=participated", f"/_matrix/client/v1/rooms/{self.room}/threads?include=participated",
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body) self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body)
@ -1943,6 +1943,6 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads", f"/_matrix/client/v1/rooms/{self.room}/threads",
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_1], channel.json_body) self.assertEqual(thread_roots, [thread_1], channel.json_body)

View File

@ -1362,7 +1362,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
# Ensure the event was persisted with the correct timestamp. # Ensure the event was persisted with the correct timestamp.
res = self.get_success(self.main_store.get_event(event_id)) res = self.get_success(self.main_store.get_event(event_id))
self.assertEquals(ts, res.origin_server_ts) self.assertEqual(ts, res.origin_server_ts)
def test_send_state_event_ts(self) -> None: def test_send_state_event_ts(self) -> None:
"""Test sending a state event with a custom timestamp.""" """Test sending a state event with a custom timestamp."""
@ -1384,7 +1384,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
# Ensure the event was persisted with the correct timestamp. # Ensure the event was persisted with the correct timestamp.
res = self.get_success(self.main_store.get_event(event_id)) res = self.get_success(self.main_store.get_event(event_id))
self.assertEquals(ts, res.origin_server_ts) self.assertEqual(ts, res.origin_server_ts)
def test_send_membership_event_ts(self) -> None: def test_send_membership_event_ts(self) -> None:
"""Test sending a membership event with a custom timestamp.""" """Test sending a membership event with a custom timestamp."""
@ -1406,7 +1406,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
# Ensure the event was persisted with the correct timestamp. # Ensure the event was persisted with the correct timestamp.
res = self.get_success(self.main_store.get_event(event_id)) res = self.get_success(self.main_store.get_event(event_id))
self.assertEquals(ts, res.origin_server_ts) self.assertEqual(ts, res.origin_server_ts)
class RoomJoinRatelimitTestCase(RoomBase): class RoomJoinRatelimitTestCase(RoomBase):

View File

@ -26,6 +26,7 @@ from typing import (
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
Deque,
Dict, Dict,
Iterable, Iterable,
List, List,
@ -41,7 +42,7 @@ from typing import (
from unittest.mock import Mock from unittest.mock import Mock
import attr import attr
from typing_extensions import Deque, ParamSpec from typing_extensions import ParamSpec
from zope.interface import implementer from zope.interface import implementer
from twisted.internet import address, threads, udp from twisted.internet import address, threads, udp

View File

@ -40,7 +40,7 @@ from tests.test_utils import make_awaitable
class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
def setUp(self) -> None: def setUp(self) -> None:
super(ApplicationServiceStoreTestCase, self).setUp() super().setUp()
self.as_yaml_files: List[str] = [] self.as_yaml_files: List[str] = []
@ -71,7 +71,7 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
except Exception: except Exception:
pass pass
super(ApplicationServiceStoreTestCase, self).tearDown() super().tearDown()
def _add_appservice( def _add_appservice(
self, as_token: str, id: str, url: str, hs_token: str, sender: str self, as_token: str, id: str, url: str, hs_token: str, sender: str
@ -110,7 +110,7 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
def setUp(self) -> None: def setUp(self) -> None:
super(ApplicationServiceTransactionStoreTestCase, self).setUp() super().setUp()
self.as_yaml_files: List[str] = [] self.as_yaml_files: List[str] = []
self.hs.config.appservice.app_service_config_files = self.as_yaml_files self.hs.config.appservice.app_service_config_files = self.as_yaml_files

View File

@ -20,7 +20,7 @@ from tests import unittest
class DataStoreTestCase(unittest.HomeserverTestCase): class DataStoreTestCase(unittest.HomeserverTestCase):
def setUp(self) -> None: def setUp(self) -> None:
super(DataStoreTestCase, self).setUp() super().setUp()
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main

View File

@ -318,14 +318,14 @@ class MessageSearchTest(HomeserverTestCase):
result = self.get_success( result = self.get_success(
store.search_msgs([self.room_id], query, ["content.body"]) store.search_msgs([self.room_id], query, ["content.body"])
) )
self.assertEquals( self.assertEqual(
result["count"], result["count"],
1 if expect_to_contain else 0, 1 if expect_to_contain else 0,
f"expected '{query}' to match '{self.PHRASE}'" f"expected '{query}' to match '{self.PHRASE}'"
if expect_to_contain if expect_to_contain
else f"'{query}' unexpectedly matched '{self.PHRASE}'", else f"'{query}' unexpectedly matched '{self.PHRASE}'",
) )
self.assertEquals( self.assertEqual(
len(result["results"]), len(result["results"]),
1 if expect_to_contain else 0, 1 if expect_to_contain else 0,
"results array length should match count", "results array length should match count",
@ -336,14 +336,14 @@ class MessageSearchTest(HomeserverTestCase):
result = self.get_success( result = self.get_success(
store.search_rooms([self.room_id], query, ["content.body"], 10) store.search_rooms([self.room_id], query, ["content.body"], 10)
) )
self.assertEquals( self.assertEqual(
result["count"], result["count"],
1 if expect_to_contain else 0, 1 if expect_to_contain else 0,
f"expected '{query}' to match '{self.PHRASE}'" f"expected '{query}' to match '{self.PHRASE}'"
if expect_to_contain if expect_to_contain
else f"'{query}' unexpectedly matched '{self.PHRASE}'", else f"'{query}' unexpectedly matched '{self.PHRASE}'",
) )
self.assertEquals( self.assertEqual(
len(result["results"]), len(result["results"]),
1 if expect_to_contain else 0, 1 if expect_to_contain else 0,
"results array length should match count", "results array length should match count",

View File

@ -31,7 +31,7 @@ TEST_ROOM_ID = "!TEST:ROOM"
class FilterEventsForServerTestCase(unittest.HomeserverTestCase): class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
def setUp(self) -> None: def setUp(self) -> None:
super(FilterEventsForServerTestCase, self).setUp() super().setUp()
self.event_creation_handler = self.hs.get_event_creation_handler() self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory() self.event_builder_factory = self.hs.get_event_builder_factory()
self._storage_controllers = self.hs.get_storage_controllers() self._storage_controllers = self.hs.get_storage_controllers()