Add most of the missing type hints to `synapse.federation`. (#11483)
This skips a few methods which are difficult to type.
This commit is contained in:
parent
b50e39df57
commit
d2279f471b
|
@ -0,0 +1 @@
|
|||
Add missing type hints to `synapse.federation`.
|
6
mypy.ini
6
mypy.ini
|
@ -158,6 +158,12 @@ disallow_untyped_defs = True
|
|||
[mypy-synapse.events.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.federation.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.federation.transport.client]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-synapse.handlers.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
|
|
|
@ -128,7 +128,7 @@ class FederationClient(FederationBase):
|
|||
reset_expiry_on_get=False,
|
||||
)
|
||||
|
||||
def _clear_tried_cache(self):
|
||||
def _clear_tried_cache(self) -> None:
|
||||
"""Clear pdu_destination_tried cache"""
|
||||
now = self._clock.time_msec()
|
||||
|
||||
|
@ -800,7 +800,7 @@ class FederationClient(FederationBase):
|
|||
no servers successfully handle the request.
|
||||
"""
|
||||
|
||||
async def send_request(destination) -> SendJoinResult:
|
||||
async def send_request(destination: str) -> SendJoinResult:
|
||||
response = await self._do_send_join(room_version, destination, pdu)
|
||||
|
||||
# If an event was returned (and expected to be returned):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
# Copyright 2019 Matrix.org Federation C.I.C
|
||||
# Copyright 2019-2021 Matrix.org Federation C.I.C
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -450,7 +450,7 @@ class FederationServer(FederationBase):
|
|||
# require callouts to other servers to fetch missing events), but
|
||||
# impose a limit to avoid going too crazy with ram/cpu.
|
||||
|
||||
async def process_pdus_for_room(room_id: str):
|
||||
async def process_pdus_for_room(room_id: str) -> None:
|
||||
with nested_logging_context(room_id):
|
||||
logger.debug("Processing PDUs for %s", room_id)
|
||||
|
||||
|
@ -547,7 +547,7 @@ class FederationServer(FederationBase):
|
|||
|
||||
async def on_state_ids_request(
|
||||
self, origin: str, room_id: str, event_id: str
|
||||
) -> Tuple[int, Dict[str, Any]]:
|
||||
) -> Tuple[int, JsonDict]:
|
||||
if not event_id:
|
||||
raise NotImplementedError("Specify an event")
|
||||
|
||||
|
@ -567,7 +567,9 @@ class FederationServer(FederationBase):
|
|||
|
||||
return 200, resp
|
||||
|
||||
async def _on_state_ids_request_compute(self, room_id, event_id):
|
||||
async def _on_state_ids_request_compute(
|
||||
self, room_id: str, event_id: str
|
||||
) -> JsonDict:
|
||||
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
|
||||
auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
|
||||
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -23,6 +24,7 @@ from typing import Optional, Tuple
|
|||
|
||||
from synapse.federation.units import Transaction
|
||||
from synapse.logging.utils import log_function
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.types import JsonDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -31,7 +33,7 @@ logger = logging.getLogger(__name__)
|
|||
class TransactionActions:
|
||||
"""Defines persistence actions that relate to handling Transactions."""
|
||||
|
||||
def __init__(self, datastore):
|
||||
def __init__(self, datastore: DataStore):
|
||||
self.store = datastore
|
||||
|
||||
@log_function
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -350,7 +351,7 @@ class BaseFederationRow:
|
|||
TypeId = "" # Unique string that ids the type. Must be overridden in sub classes.
|
||||
|
||||
@staticmethod
|
||||
def from_data(data):
|
||||
def from_data(data: JsonDict) -> "BaseFederationRow":
|
||||
"""Parse the data from the federation stream into a row.
|
||||
|
||||
Args:
|
||||
|
@ -359,7 +360,7 @@ class BaseFederationRow:
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def to_data(self):
|
||||
def to_data(self) -> JsonDict:
|
||||
"""Serialize this row to be sent over the federation stream.
|
||||
|
||||
Returns:
|
||||
|
@ -368,7 +369,7 @@ class BaseFederationRow:
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def add_to_buffer(self, buff):
|
||||
def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
|
||||
"""Add this row to the appropriate field in the buffer ready for this
|
||||
to be sent over federation.
|
||||
|
||||
|
@ -391,15 +392,15 @@ class PresenceDestinationsRow(
|
|||
TypeId = "pd"
|
||||
|
||||
@staticmethod
|
||||
def from_data(data):
|
||||
def from_data(data: JsonDict) -> "PresenceDestinationsRow":
|
||||
return PresenceDestinationsRow(
|
||||
state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"]
|
||||
)
|
||||
|
||||
def to_data(self):
|
||||
def to_data(self) -> JsonDict:
|
||||
return {"state": self.state.as_dict(), "dests": self.destinations}
|
||||
|
||||
def add_to_buffer(self, buff):
|
||||
def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
|
||||
buff.presence_destinations.append((self.state, self.destinations))
|
||||
|
||||
|
||||
|
@ -417,13 +418,13 @@ class KeyedEduRow(
|
|||
TypeId = "k"
|
||||
|
||||
@staticmethod
|
||||
def from_data(data):
|
||||
def from_data(data: JsonDict) -> "KeyedEduRow":
|
||||
return KeyedEduRow(key=tuple(data["key"]), edu=Edu(**data["edu"]))
|
||||
|
||||
def to_data(self):
|
||||
def to_data(self) -> JsonDict:
|
||||
return {"key": self.key, "edu": self.edu.get_internal_dict()}
|
||||
|
||||
def add_to_buffer(self, buff):
|
||||
def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
|
||||
buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu
|
||||
|
||||
|
||||
|
@ -433,13 +434,13 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
|
|||
TypeId = "e"
|
||||
|
||||
@staticmethod
|
||||
def from_data(data):
|
||||
def from_data(data: JsonDict) -> "EduRow":
|
||||
return EduRow(Edu(**data))
|
||||
|
||||
def to_data(self):
|
||||
def to_data(self) -> JsonDict:
|
||||
return self.edu.get_internal_dict()
|
||||
|
||||
def add_to_buffer(self, buff):
|
||||
def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
|
||||
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2019 New Vector Ltd
|
||||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -14,7 +15,8 @@
|
|||
# limitations under the License.
|
||||
import datetime
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple
|
||||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
import attr
|
||||
from prometheus_client import Counter
|
||||
|
@ -213,7 +215,7 @@ class PerDestinationQueue:
|
|||
self._pending_edus_keyed[(edu.edu_type, key)] = edu
|
||||
self.attempt_new_transaction()
|
||||
|
||||
def send_edu(self, edu) -> None:
|
||||
def send_edu(self, edu: Edu) -> None:
|
||||
self._pending_edus.append(edu)
|
||||
self.attempt_new_transaction()
|
||||
|
||||
|
@ -701,7 +703,12 @@ class _TransactionQueueManager:
|
|||
|
||||
return self._pdus, pending_edus
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc: Optional[BaseException],
|
||||
tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
if exc_type is not None:
|
||||
# Failed to send transaction, so we bail out.
|
||||
return
|
||||
|
|
|
@ -21,6 +21,7 @@ from typing import (
|
|||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
|
@ -235,11 +236,16 @@ class TransportLayerClient:
|
|||
|
||||
@log_function
|
||||
async def make_query(
|
||||
self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False
|
||||
):
|
||||
self,
|
||||
destination: str,
|
||||
query_type: str,
|
||||
args: dict,
|
||||
retry_on_dns_fail: bool,
|
||||
ignore_backoff: bool = False,
|
||||
) -> JsonDict:
|
||||
path = _create_v1_path("/query/%s", query_type)
|
||||
|
||||
content = await self.client.get_json(
|
||||
return await self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args=args,
|
||||
|
@ -248,8 +254,6 @@ class TransportLayerClient:
|
|||
ignore_backoff=ignore_backoff,
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
@log_function
|
||||
async def make_membership_event(
|
||||
self,
|
||||
|
@ -1317,7 +1321,7 @@ class SendJoinResponse:
|
|||
|
||||
|
||||
@ijson.coroutine
|
||||
def _event_parser(event_dict: JsonDict):
|
||||
def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
|
||||
"""Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
|
||||
to add them to a given dictionary.
|
||||
"""
|
||||
|
@ -1328,7 +1332,9 @@ def _event_parser(event_dict: JsonDict):
|
|||
|
||||
|
||||
@ijson.coroutine
|
||||
def _event_list_parser(room_version: RoomVersion, events: List[EventBase]):
|
||||
def _event_list_parser(
|
||||
room_version: RoomVersion, events: List[EventBase]
|
||||
) -> Generator[None, JsonDict, None]:
|
||||
"""Helper function for use with `ijson.items_coro` to parse an array of
|
||||
events and add them to the given list.
|
||||
"""
|
||||
|
|
|
@ -302,7 +302,7 @@ def register_servlets(
|
|||
authenticator: Authenticator,
|
||||
ratelimiter: FederationRateLimiter,
|
||||
servlet_groups: Optional[Iterable[str]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Initialize and register servlet classes.
|
||||
|
||||
Will by default register all servlets. For custom behaviour, pass in
|
||||
|
|
|
@ -15,10 +15,13 @@
|
|||
import functools
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Awaitable, Callable, Optional, Tuple, cast
|
||||
|
||||
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
|
||||
from synapse.api.urls import FEDERATION_V1_PREFIX
|
||||
from synapse.http.server import HttpServer, ServletCallback
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging import opentracing
|
||||
from synapse.logging.context import run_in_background
|
||||
from synapse.logging.opentracing import (
|
||||
|
@ -29,6 +32,7 @@ from synapse.logging.opentracing import (
|
|||
whitelisted_homeserver,
|
||||
)
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
from synapse.util.stringutils import parse_and_validate_server_name
|
||||
|
||||
|
@ -59,9 +63,11 @@ class Authenticator:
|
|||
self.replication_client = hs.get_tcp_replication()
|
||||
|
||||
# A method just so we can pass 'self' as the authenticator to the Servlets
|
||||
async def authenticate_request(self, request, content):
|
||||
async def authenticate_request(
|
||||
self, request: SynapseRequest, content: Optional[JsonDict]
|
||||
) -> str:
|
||||
now = self._clock.time_msec()
|
||||
json_request = {
|
||||
json_request: JsonDict = {
|
||||
"method": request.method.decode("ascii"),
|
||||
"uri": request.uri.decode("ascii"),
|
||||
"destination": self.server_name,
|
||||
|
@ -114,7 +120,7 @@ class Authenticator:
|
|||
|
||||
return origin
|
||||
|
||||
async def _reset_retry_timings(self, origin):
|
||||
async def _reset_retry_timings(self, origin: str) -> None:
|
||||
try:
|
||||
logger.info("Marking origin %r as up", origin)
|
||||
await self.store.set_destination_retry_timings(origin, None, 0, 0)
|
||||
|
@ -133,14 +139,14 @@ class Authenticator:
|
|||
logger.exception("Error resetting retry timings on %s", origin)
|
||||
|
||||
|
||||
def _parse_auth_header(header_bytes):
|
||||
def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str]:
|
||||
"""Parse an X-Matrix auth header
|
||||
|
||||
Args:
|
||||
header_bytes (bytes): header value
|
||||
header_bytes: header value
|
||||
|
||||
Returns:
|
||||
Tuple[str, str, str]: origin, key id, signature.
|
||||
origin, key id, signature.
|
||||
|
||||
Raises:
|
||||
AuthenticationError if the header could not be parsed
|
||||
|
@ -148,9 +154,9 @@ def _parse_auth_header(header_bytes):
|
|||
try:
|
||||
header_str = header_bytes.decode("utf-8")
|
||||
params = header_str.split(" ")[1].split(",")
|
||||
param_dict = dict(kv.split("=") for kv in params)
|
||||
param_dict = {k: v for k, v in (kv.split("=", maxsplit=1) for kv in params)}
|
||||
|
||||
def strip_quotes(value):
|
||||
def strip_quotes(value: str) -> str:
|
||||
if value.startswith('"'):
|
||||
return value[1:-1]
|
||||
else:
|
||||
|
@ -233,23 +239,25 @@ class BaseFederationServlet:
|
|||
self.ratelimiter = ratelimiter
|
||||
self.server_name = server_name
|
||||
|
||||
def _wrap(self, func):
|
||||
def _wrap(self, func: Callable[..., Awaitable[Tuple[int, Any]]]) -> ServletCallback:
|
||||
authenticator = self.authenticator
|
||||
ratelimiter = self.ratelimiter
|
||||
|
||||
@functools.wraps(func)
|
||||
async def new_func(request, *args, **kwargs):
|
||||
async def new_func(
|
||||
request: SynapseRequest, *args: Any, **kwargs: str
|
||||
) -> Optional[Tuple[int, Any]]:
|
||||
"""A callback which can be passed to HttpServer.RegisterPaths
|
||||
|
||||
Args:
|
||||
request (twisted.web.http.Request):
|
||||
request:
|
||||
*args: unused?
|
||||
**kwargs (dict[unicode, unicode]): the dict mapping keys to path
|
||||
components as specified in the path match regexp.
|
||||
**kwargs: the dict mapping keys to path components as specified
|
||||
in the path match regexp.
|
||||
|
||||
Returns:
|
||||
Tuple[int, object]|None: (response code, response object) as returned by
|
||||
the callback method. None if the request has already been handled.
|
||||
(response code, response object) as returned by the callback method.
|
||||
None if the request has already been handled.
|
||||
"""
|
||||
content = None
|
||||
if request.method in [b"PUT", b"POST"]:
|
||||
|
@ -257,7 +265,9 @@ class BaseFederationServlet:
|
|||
content = parse_json_object_from_request(request)
|
||||
|
||||
try:
|
||||
origin = await authenticator.authenticate_request(request, content)
|
||||
origin: Optional[str] = await authenticator.authenticate_request(
|
||||
request, content
|
||||
)
|
||||
except NoAuthenticationError:
|
||||
origin = None
|
||||
if self.REQUIRE_AUTH:
|
||||
|
@ -301,7 +311,7 @@ class BaseFederationServlet:
|
|||
"client disconnected before we started processing "
|
||||
"request"
|
||||
)
|
||||
return -1, None
|
||||
return None
|
||||
response = await func(
|
||||
origin, content, request.args, *args, **kwargs
|
||||
)
|
||||
|
@ -312,9 +322,9 @@ class BaseFederationServlet:
|
|||
|
||||
return response
|
||||
|
||||
return new_func
|
||||
return cast(ServletCallback, new_func)
|
||||
|
||||
def register(self, server):
|
||||
def register(self, server: HttpServer) -> None:
|
||||
pattern = re.compile("^" + self.PREFIX + self.PATH + "$")
|
||||
|
||||
for method in ("GET", "PUT", "POST"):
|
||||
|
|
Loading…
Reference in New Issue