Add most of the missing type hints to `synapse.federation`. (#11483)

This skips a few methods which are difficult to type.
This commit is contained in:
Patrick Cloke 2021-12-02 11:18:10 -05:00 committed by GitHub
parent b50e39df57
commit d2279f471b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 84 additions and 49 deletions

1
changelog.d/11483.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to `synapse.federation`.

View File

@ -158,6 +158,12 @@ disallow_untyped_defs = True
[mypy-synapse.events.*] [mypy-synapse.events.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.federation.*]
disallow_untyped_defs = True
[mypy-synapse.federation.transport.client]
disallow_untyped_defs = False
[mypy-synapse.handlers.*] [mypy-synapse.handlers.*]
disallow_untyped_defs = True disallow_untyped_defs = True

View File

@ -128,7 +128,7 @@ class FederationClient(FederationBase):
reset_expiry_on_get=False, reset_expiry_on_get=False,
) )
def _clear_tried_cache(self): def _clear_tried_cache(self) -> None:
"""Clear pdu_destination_tried cache""" """Clear pdu_destination_tried cache"""
now = self._clock.time_msec() now = self._clock.time_msec()
@ -800,7 +800,7 @@ class FederationClient(FederationBase):
no servers successfully handle the request. no servers successfully handle the request.
""" """
async def send_request(destination) -> SendJoinResult: async def send_request(destination: str) -> SendJoinResult:
response = await self._do_send_join(room_version, destination, pdu) response = await self._do_send_join(room_version, destination, pdu)
# If an event was returned (and expected to be returned): # If an event was returned (and expected to be returned):

View File

@ -1,6 +1,6 @@
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd # Copyright 2018 New Vector Ltd
# Copyright 2019 Matrix.org Federation C.I.C # Copyright 2019-2021 Matrix.org Federation C.I.C
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -450,7 +450,7 @@ class FederationServer(FederationBase):
# require callouts to other servers to fetch missing events), but # require callouts to other servers to fetch missing events), but
# impose a limit to avoid going too crazy with ram/cpu. # impose a limit to avoid going too crazy with ram/cpu.
async def process_pdus_for_room(room_id: str): async def process_pdus_for_room(room_id: str) -> None:
with nested_logging_context(room_id): with nested_logging_context(room_id):
logger.debug("Processing PDUs for %s", room_id) logger.debug("Processing PDUs for %s", room_id)
@ -547,7 +547,7 @@ class FederationServer(FederationBase):
async def on_state_ids_request( async def on_state_ids_request(
self, origin: str, room_id: str, event_id: str self, origin: str, room_id: str, event_id: str
) -> Tuple[int, Dict[str, Any]]: ) -> Tuple[int, JsonDict]:
if not event_id: if not event_id:
raise NotImplementedError("Specify an event") raise NotImplementedError("Specify an event")
@ -567,7 +567,9 @@ class FederationServer(FederationBase):
return 200, resp return 200, resp
async def _on_state_ids_request_compute(self, room_id, event_id): async def _on_state_ids_request_compute(
self, room_id: str, event_id: str
) -> JsonDict:
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids) auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}

View File

@ -1,4 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -23,6 +24,7 @@ from typing import Optional, Tuple
from synapse.federation.units import Transaction from synapse.federation.units import Transaction
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
from synapse.storage.databases.main import DataStore
from synapse.types import JsonDict from synapse.types import JsonDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,7 +33,7 @@ logger = logging.getLogger(__name__)
class TransactionActions: class TransactionActions:
"""Defines persistence actions that relate to handling Transactions.""" """Defines persistence actions that relate to handling Transactions."""
def __init__(self, datastore): def __init__(self, datastore: DataStore):
self.store = datastore self.store = datastore
@log_function @log_function

View File

@ -1,4 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -350,7 +351,7 @@ class BaseFederationRow:
TypeId = "" # Unique string that ids the type. Must be overridden in sub classes. TypeId = "" # Unique string that ids the type. Must be overridden in sub classes.
@staticmethod @staticmethod
def from_data(data): def from_data(data: JsonDict) -> "BaseFederationRow":
"""Parse the data from the federation stream into a row. """Parse the data from the federation stream into a row.
Args: Args:
@ -359,7 +360,7 @@ class BaseFederationRow:
""" """
raise NotImplementedError() raise NotImplementedError()
def to_data(self): def to_data(self) -> JsonDict:
"""Serialize this row to be sent over the federation stream. """Serialize this row to be sent over the federation stream.
Returns: Returns:
@ -368,7 +369,7 @@ class BaseFederationRow:
""" """
raise NotImplementedError() raise NotImplementedError()
def add_to_buffer(self, buff): def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
"""Add this row to the appropriate field in the buffer ready for this """Add this row to the appropriate field in the buffer ready for this
to be sent over federation. to be sent over federation.
@ -391,15 +392,15 @@ class PresenceDestinationsRow(
TypeId = "pd" TypeId = "pd"
@staticmethod @staticmethod
def from_data(data): def from_data(data: JsonDict) -> "PresenceDestinationsRow":
return PresenceDestinationsRow( return PresenceDestinationsRow(
state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"] state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"]
) )
def to_data(self): def to_data(self) -> JsonDict:
return {"state": self.state.as_dict(), "dests": self.destinations} return {"state": self.state.as_dict(), "dests": self.destinations}
def add_to_buffer(self, buff): def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
buff.presence_destinations.append((self.state, self.destinations)) buff.presence_destinations.append((self.state, self.destinations))
@ -417,13 +418,13 @@ class KeyedEduRow(
TypeId = "k" TypeId = "k"
@staticmethod @staticmethod
def from_data(data): def from_data(data: JsonDict) -> "KeyedEduRow":
return KeyedEduRow(key=tuple(data["key"]), edu=Edu(**data["edu"])) return KeyedEduRow(key=tuple(data["key"]), edu=Edu(**data["edu"]))
def to_data(self): def to_data(self) -> JsonDict:
return {"key": self.key, "edu": self.edu.get_internal_dict()} return {"key": self.key, "edu": self.edu.get_internal_dict()}
def add_to_buffer(self, buff): def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu
@ -433,13 +434,13 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
TypeId = "e" TypeId = "e"
@staticmethod @staticmethod
def from_data(data): def from_data(data: JsonDict) -> "EduRow":
return EduRow(Edu(**data)) return EduRow(Edu(**data))
def to_data(self): def to_data(self) -> JsonDict:
return self.edu.get_internal_dict() return self.edu.get_internal_dict()
def add_to_buffer(self, buff): def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
buff.edus.setdefault(self.edu.destination, []).append(self.edu) buff.edus.setdefault(self.edu.destination, []).append(self.edu)

View File

@ -1,5 +1,6 @@
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd # Copyright 2019 New Vector Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,7 +15,8 @@
# limitations under the License. # limitations under the License.
import datetime import datetime
import logging import logging
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple from types import TracebackType
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, Type
import attr import attr
from prometheus_client import Counter from prometheus_client import Counter
@ -213,7 +215,7 @@ class PerDestinationQueue:
self._pending_edus_keyed[(edu.edu_type, key)] = edu self._pending_edus_keyed[(edu.edu_type, key)] = edu
self.attempt_new_transaction() self.attempt_new_transaction()
def send_edu(self, edu) -> None: def send_edu(self, edu: Edu) -> None:
self._pending_edus.append(edu) self._pending_edus.append(edu)
self.attempt_new_transaction() self.attempt_new_transaction()
@ -701,7 +703,12 @@ class _TransactionQueueManager:
return self._pdus, pending_edus return self._pdus, pending_edus
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
if exc_type is not None: if exc_type is not None:
# Failed to send transaction, so we bail out. # Failed to send transaction, so we bail out.
return return

View File

@ -21,6 +21,7 @@ from typing import (
Callable, Callable,
Collection, Collection,
Dict, Dict,
Generator,
Iterable, Iterable,
List, List,
Mapping, Mapping,
@ -235,11 +236,16 @@ class TransportLayerClient:
@log_function @log_function
async def make_query( async def make_query(
self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False self,
): destination: str,
query_type: str,
args: dict,
retry_on_dns_fail: bool,
ignore_backoff: bool = False,
) -> JsonDict:
path = _create_v1_path("/query/%s", query_type) path = _create_v1_path("/query/%s", query_type)
content = await self.client.get_json( return await self.client.get_json(
destination=destination, destination=destination,
path=path, path=path,
args=args, args=args,
@ -248,8 +254,6 @@ class TransportLayerClient:
ignore_backoff=ignore_backoff, ignore_backoff=ignore_backoff,
) )
return content
@log_function @log_function
async def make_membership_event( async def make_membership_event(
self, self,
@ -1317,7 +1321,7 @@ class SendJoinResponse:
@ijson.coroutine @ijson.coroutine
def _event_parser(event_dict: JsonDict): def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
"""Helper function for use with `ijson.kvitems_coro` to parse key-value pairs """Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
to add them to a given dictionary. to add them to a given dictionary.
""" """
@ -1328,7 +1332,9 @@ def _event_parser(event_dict: JsonDict):
@ijson.coroutine @ijson.coroutine
def _event_list_parser(room_version: RoomVersion, events: List[EventBase]): def _event_list_parser(
room_version: RoomVersion, events: List[EventBase]
) -> Generator[None, JsonDict, None]:
"""Helper function for use with `ijson.items_coro` to parse an array of """Helper function for use with `ijson.items_coro` to parse an array of
events and add them to the given list. events and add them to the given list.
""" """

View File

@ -302,7 +302,7 @@ def register_servlets(
authenticator: Authenticator, authenticator: Authenticator,
ratelimiter: FederationRateLimiter, ratelimiter: FederationRateLimiter,
servlet_groups: Optional[Iterable[str]] = None, servlet_groups: Optional[Iterable[str]] = None,
): ) -> None:
"""Initialize and register servlet classes. """Initialize and register servlet classes.
Will by default register all servlets. For custom behaviour, pass in Will by default register all servlets. For custom behaviour, pass in

View File

@ -15,10 +15,13 @@
import functools import functools
import logging import logging
import re import re
from typing import Any, Awaitable, Callable, Optional, Tuple, cast
from synapse.api.errors import Codes, FederationDeniedError, SynapseError from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.urls import FEDERATION_V1_PREFIX from synapse.api.urls import FEDERATION_V1_PREFIX
from synapse.http.server import HttpServer, ServletCallback
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing from synapse.logging import opentracing
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
from synapse.logging.opentracing import ( from synapse.logging.opentracing import (
@ -29,6 +32,7 @@ from synapse.logging.opentracing import (
whitelisted_homeserver, whitelisted_homeserver,
) )
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import parse_and_validate_server_name from synapse.util.stringutils import parse_and_validate_server_name
@ -59,9 +63,11 @@ class Authenticator:
self.replication_client = hs.get_tcp_replication() self.replication_client = hs.get_tcp_replication()
# A method just so we can pass 'self' as the authenticator to the Servlets # A method just so we can pass 'self' as the authenticator to the Servlets
async def authenticate_request(self, request, content): async def authenticate_request(
self, request: SynapseRequest, content: Optional[JsonDict]
) -> str:
now = self._clock.time_msec() now = self._clock.time_msec()
json_request = { json_request: JsonDict = {
"method": request.method.decode("ascii"), "method": request.method.decode("ascii"),
"uri": request.uri.decode("ascii"), "uri": request.uri.decode("ascii"),
"destination": self.server_name, "destination": self.server_name,
@ -114,7 +120,7 @@ class Authenticator:
return origin return origin
async def _reset_retry_timings(self, origin): async def _reset_retry_timings(self, origin: str) -> None:
try: try:
logger.info("Marking origin %r as up", origin) logger.info("Marking origin %r as up", origin)
await self.store.set_destination_retry_timings(origin, None, 0, 0) await self.store.set_destination_retry_timings(origin, None, 0, 0)
@ -133,14 +139,14 @@ class Authenticator:
logger.exception("Error resetting retry timings on %s", origin) logger.exception("Error resetting retry timings on %s", origin)
def _parse_auth_header(header_bytes): def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str]:
"""Parse an X-Matrix auth header """Parse an X-Matrix auth header
Args: Args:
header_bytes (bytes): header value header_bytes: header value
Returns: Returns:
Tuple[str, str, str]: origin, key id, signature. origin, key id, signature.
Raises: Raises:
AuthenticationError if the header could not be parsed AuthenticationError if the header could not be parsed
@ -148,9 +154,9 @@ def _parse_auth_header(header_bytes):
try: try:
header_str = header_bytes.decode("utf-8") header_str = header_bytes.decode("utf-8")
params = header_str.split(" ")[1].split(",") params = header_str.split(" ")[1].split(",")
param_dict = dict(kv.split("=") for kv in params) param_dict = {k: v for k, v in (kv.split("=", maxsplit=1) for kv in params)}
def strip_quotes(value): def strip_quotes(value: str) -> str:
if value.startswith('"'): if value.startswith('"'):
return value[1:-1] return value[1:-1]
else: else:
@ -233,23 +239,25 @@ class BaseFederationServlet:
self.ratelimiter = ratelimiter self.ratelimiter = ratelimiter
self.server_name = server_name self.server_name = server_name
def _wrap(self, func): def _wrap(self, func: Callable[..., Awaitable[Tuple[int, Any]]]) -> ServletCallback:
authenticator = self.authenticator authenticator = self.authenticator
ratelimiter = self.ratelimiter ratelimiter = self.ratelimiter
@functools.wraps(func) @functools.wraps(func)
async def new_func(request, *args, **kwargs): async def new_func(
request: SynapseRequest, *args: Any, **kwargs: str
) -> Optional[Tuple[int, Any]]:
"""A callback which can be passed to HttpServer.RegisterPaths """A callback which can be passed to HttpServer.RegisterPaths
Args: Args:
request (twisted.web.http.Request): request:
*args: unused? *args: unused?
**kwargs (dict[unicode, unicode]): the dict mapping keys to path **kwargs: the dict mapping keys to path components as specified
components as specified in the path match regexp. in the path match regexp.
Returns: Returns:
Tuple[int, object]|None: (response code, response object) as returned by (response code, response object) as returned by the callback method.
the callback method. None if the request has already been handled. None if the request has already been handled.
""" """
content = None content = None
if request.method in [b"PUT", b"POST"]: if request.method in [b"PUT", b"POST"]:
@ -257,7 +265,9 @@ class BaseFederationServlet:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
try: try:
origin = await authenticator.authenticate_request(request, content) origin: Optional[str] = await authenticator.authenticate_request(
request, content
)
except NoAuthenticationError: except NoAuthenticationError:
origin = None origin = None
if self.REQUIRE_AUTH: if self.REQUIRE_AUTH:
@ -301,7 +311,7 @@ class BaseFederationServlet:
"client disconnected before we started processing " "client disconnected before we started processing "
"request" "request"
) )
return -1, None return None
response = await func( response = await func(
origin, content, request.args, *args, **kwargs origin, content, request.args, *args, **kwargs
) )
@ -312,9 +322,9 @@ class BaseFederationServlet:
return response return response
return new_func return cast(ServletCallback, new_func)
def register(self, server): def register(self, server: HttpServer) -> None:
pattern = re.compile("^" + self.PREFIX + self.PATH + "$") pattern = re.compile("^" + self.PREFIX + self.PATH + "$")
for method in ("GET", "PUT", "POST"): for method in ("GET", "PUT", "POST"):