Add most of the missing type hints to `synapse.federation`. (#11483)

This skips a few methods which are difficult to type.
This commit is contained in:
Patrick Cloke 2021-12-02 11:18:10 -05:00 committed by GitHub
parent b50e39df57
commit d2279f471b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 84 additions and 49 deletions

1
changelog.d/11483.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to `synapse.federation`.

View File

@ -158,6 +158,12 @@ disallow_untyped_defs = True
[mypy-synapse.events.*]
disallow_untyped_defs = True
[mypy-synapse.federation.*]
disallow_untyped_defs = True
[mypy-synapse.federation.transport.client]
disallow_untyped_defs = False
[mypy-synapse.handlers.*]
disallow_untyped_defs = True

View File

@ -128,7 +128,7 @@ class FederationClient(FederationBase):
reset_expiry_on_get=False,
)
def _clear_tried_cache(self):
def _clear_tried_cache(self) -> None:
"""Clear pdu_destination_tried cache"""
now = self._clock.time_msec()
@ -800,7 +800,7 @@ class FederationClient(FederationBase):
no servers successfully handle the request.
"""
async def send_request(destination) -> SendJoinResult:
async def send_request(destination: str) -> SendJoinResult:
response = await self._do_send_join(room_version, destination, pdu)
# If an event was returned (and expected to be returned):

View File

@ -1,6 +1,6 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# Copyright 2019 Matrix.org Federation C.I.C
# Copyright 2019-2021 Matrix.org Federation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -450,7 +450,7 @@ class FederationServer(FederationBase):
# require callouts to other servers to fetch missing events), but
# impose a limit to avoid going too crazy with ram/cpu.
async def process_pdus_for_room(room_id: str):
async def process_pdus_for_room(room_id: str) -> None:
with nested_logging_context(room_id):
logger.debug("Processing PDUs for %s", room_id)
@ -547,7 +547,7 @@ class FederationServer(FederationBase):
async def on_state_ids_request(
self, origin: str, room_id: str, event_id: str
) -> Tuple[int, Dict[str, Any]]:
) -> Tuple[int, JsonDict]:
if not event_id:
raise NotImplementedError("Specify an event")
@ -567,7 +567,9 @@ class FederationServer(FederationBase):
return 200, resp
async def _on_state_ids_request_compute(self, room_id, event_id):
async def _on_state_ids_request_compute(
self, room_id: str, event_id: str
) -> JsonDict:
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}

View File

@ -1,4 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -23,6 +24,7 @@ from typing import Optional, Tuple
from synapse.federation.units import Transaction
from synapse.logging.utils import log_function
from synapse.storage.databases.main import DataStore
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@ -31,7 +33,7 @@ logger = logging.getLogger(__name__)
class TransactionActions:
"""Defines persistence actions that relate to handling Transactions."""
def __init__(self, datastore):
def __init__(self, datastore: DataStore):
self.store = datastore
@log_function

View File

@ -1,4 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -350,7 +351,7 @@ class BaseFederationRow:
TypeId = "" # Unique string that ids the type. Must be overridden in sub classes.
@staticmethod
def from_data(data):
def from_data(data: JsonDict) -> "BaseFederationRow":
"""Parse the data from the federation stream into a row.
Args:
@ -359,7 +360,7 @@ class BaseFederationRow:
"""
raise NotImplementedError()
def to_data(self):
def to_data(self) -> JsonDict:
"""Serialize this row to be sent over the federation stream.
Returns:
@ -368,7 +369,7 @@ class BaseFederationRow:
"""
raise NotImplementedError()
def add_to_buffer(self, buff):
def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
"""Add this row to the appropriate field in the buffer ready for this
to be sent over federation.
@ -391,15 +392,15 @@ class PresenceDestinationsRow(
TypeId = "pd"
@staticmethod
def from_data(data):
def from_data(data: JsonDict) -> "PresenceDestinationsRow":
return PresenceDestinationsRow(
state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"]
)
def to_data(self):
def to_data(self) -> JsonDict:
return {"state": self.state.as_dict(), "dests": self.destinations}
def add_to_buffer(self, buff):
def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
buff.presence_destinations.append((self.state, self.destinations))
@ -417,13 +418,13 @@ class KeyedEduRow(
TypeId = "k"
@staticmethod
def from_data(data):
def from_data(data: JsonDict) -> "KeyedEduRow":
return KeyedEduRow(key=tuple(data["key"]), edu=Edu(**data["edu"]))
def to_data(self):
def to_data(self) -> JsonDict:
return {"key": self.key, "edu": self.edu.get_internal_dict()}
def add_to_buffer(self, buff):
def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu
@ -433,13 +434,13 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
TypeId = "e"
@staticmethod
def from_data(data):
def from_data(data: JsonDict) -> "EduRow":
return EduRow(Edu(**data))
def to_data(self):
def to_data(self) -> JsonDict:
return self.edu.get_internal_dict()
def add_to_buffer(self, buff):
def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
buff.edus.setdefault(self.edu.destination, []).append(self.edu)

View File

@ -1,5 +1,6 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,7 +15,8 @@
# limitations under the License.
import datetime
import logging
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple
from types import TracebackType
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, Type
import attr
from prometheus_client import Counter
@ -213,7 +215,7 @@ class PerDestinationQueue:
self._pending_edus_keyed[(edu.edu_type, key)] = edu
self.attempt_new_transaction()
def send_edu(self, edu) -> None:
def send_edu(self, edu: Edu) -> None:
self._pending_edus.append(edu)
self.attempt_new_transaction()
@ -701,7 +703,12 @@ class _TransactionQueueManager:
return self._pdus, pending_edus
async def __aexit__(self, exc_type, exc, tb):
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
if exc_type is not None:
# Failed to send transaction, so we bail out.
return

View File

@ -21,6 +21,7 @@ from typing import (
Callable,
Collection,
Dict,
Generator,
Iterable,
List,
Mapping,
@ -235,11 +236,16 @@ class TransportLayerClient:
@log_function
async def make_query(
self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False
):
self,
destination: str,
query_type: str,
args: dict,
retry_on_dns_fail: bool,
ignore_backoff: bool = False,
) -> JsonDict:
path = _create_v1_path("/query/%s", query_type)
content = await self.client.get_json(
return await self.client.get_json(
destination=destination,
path=path,
args=args,
@ -248,8 +254,6 @@ class TransportLayerClient:
ignore_backoff=ignore_backoff,
)
return content
@log_function
async def make_membership_event(
self,
@ -1317,7 +1321,7 @@ class SendJoinResponse:
@ijson.coroutine
def _event_parser(event_dict: JsonDict):
def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
"""Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
to add them to a given dictionary.
"""
@ -1328,7 +1332,9 @@ def _event_parser(event_dict: JsonDict):
@ijson.coroutine
def _event_list_parser(room_version: RoomVersion, events: List[EventBase]):
def _event_list_parser(
room_version: RoomVersion, events: List[EventBase]
) -> Generator[None, JsonDict, None]:
"""Helper function for use with `ijson.items_coro` to parse an array of
events and add them to the given list.
"""

View File

@ -302,7 +302,7 @@ def register_servlets(
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
servlet_groups: Optional[Iterable[str]] = None,
):
) -> None:
"""Initialize and register servlet classes.
Will by default register all servlets. For custom behaviour, pass in

View File

@ -15,10 +15,13 @@
import functools
import logging
import re
from typing import Any, Awaitable, Callable, Optional, Tuple, cast
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.urls import FEDERATION_V1_PREFIX
from synapse.http.server import HttpServer, ServletCallback
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
@ -29,6 +32,7 @@ from synapse.logging.opentracing import (
whitelisted_homeserver,
)
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import parse_and_validate_server_name
@ -59,9 +63,11 @@ class Authenticator:
self.replication_client = hs.get_tcp_replication()
# A method just so we can pass 'self' as the authenticator to the Servlets
async def authenticate_request(self, request, content):
async def authenticate_request(
self, request: SynapseRequest, content: Optional[JsonDict]
) -> str:
now = self._clock.time_msec()
json_request = {
json_request: JsonDict = {
"method": request.method.decode("ascii"),
"uri": request.uri.decode("ascii"),
"destination": self.server_name,
@ -114,7 +120,7 @@ class Authenticator:
return origin
async def _reset_retry_timings(self, origin):
async def _reset_retry_timings(self, origin: str) -> None:
try:
logger.info("Marking origin %r as up", origin)
await self.store.set_destination_retry_timings(origin, None, 0, 0)
@ -133,14 +139,14 @@ class Authenticator:
logger.exception("Error resetting retry timings on %s", origin)
def _parse_auth_header(header_bytes):
def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str]:
"""Parse an X-Matrix auth header
Args:
header_bytes (bytes): header value
header_bytes: header value
Returns:
Tuple[str, str, str]: origin, key id, signature.
origin, key id, signature.
Raises:
AuthenticationError if the header could not be parsed
@ -148,9 +154,9 @@ def _parse_auth_header(header_bytes):
try:
header_str = header_bytes.decode("utf-8")
params = header_str.split(" ")[1].split(",")
param_dict = dict(kv.split("=") for kv in params)
param_dict = {k: v for k, v in (kv.split("=", maxsplit=1) for kv in params)}
def strip_quotes(value):
def strip_quotes(value: str) -> str:
if value.startswith('"'):
return value[1:-1]
else:
@ -233,23 +239,25 @@ class BaseFederationServlet:
self.ratelimiter = ratelimiter
self.server_name = server_name
def _wrap(self, func):
def _wrap(self, func: Callable[..., Awaitable[Tuple[int, Any]]]) -> ServletCallback:
authenticator = self.authenticator
ratelimiter = self.ratelimiter
@functools.wraps(func)
async def new_func(request, *args, **kwargs):
async def new_func(
request: SynapseRequest, *args: Any, **kwargs: str
) -> Optional[Tuple[int, Any]]:
"""A callback which can be passed to HttpServer.RegisterPaths
Args:
request (twisted.web.http.Request):
request:
*args: unused?
**kwargs (dict[unicode, unicode]): the dict mapping keys to path
components as specified in the path match regexp.
**kwargs: the dict mapping keys to path components as specified
in the path match regexp.
Returns:
Tuple[int, object]|None: (response code, response object) as returned by
the callback method. None if the request has already been handled.
(response code, response object) as returned by the callback method.
None if the request has already been handled.
"""
content = None
if request.method in [b"PUT", b"POST"]:
@ -257,7 +265,9 @@ class BaseFederationServlet:
content = parse_json_object_from_request(request)
try:
origin = await authenticator.authenticate_request(request, content)
origin: Optional[str] = await authenticator.authenticate_request(
request, content
)
except NoAuthenticationError:
origin = None
if self.REQUIRE_AUTH:
@ -301,7 +311,7 @@ class BaseFederationServlet:
"client disconnected before we started processing "
"request"
)
return -1, None
return None
response = await func(
origin, content, request.args, *args, **kwargs
)
@ -312,9 +322,9 @@ class BaseFederationServlet:
return response
return new_func
return cast(ServletCallback, new_func)
def register(self, server):
def register(self, server: HttpServer) -> None:
pattern = re.compile("^" + self.PREFIX + self.PATH + "$")
for method in ("GET", "PUT", "POST"):