Add most of the missing type hints to `synapse.federation`. (#11483)
This skips a few methods which are difficult to type.
This commit is contained in:
parent
b50e39df57
commit
d2279f471b
|
@ -0,0 +1 @@
|
||||||
|
Add missing type hints to `synapse.federation`.
|
6
mypy.ini
6
mypy.ini
|
@ -158,6 +158,12 @@ disallow_untyped_defs = True
|
||||||
[mypy-synapse.events.*]
|
[mypy-synapse.events.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.federation.*]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.federation.transport.client]
|
||||||
|
disallow_untyped_defs = False
|
||||||
|
|
||||||
[mypy-synapse.handlers.*]
|
[mypy-synapse.handlers.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
|
|
@ -128,7 +128,7 @@ class FederationClient(FederationBase):
|
||||||
reset_expiry_on_get=False,
|
reset_expiry_on_get=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _clear_tried_cache(self):
|
def _clear_tried_cache(self) -> None:
|
||||||
"""Clear pdu_destination_tried cache"""
|
"""Clear pdu_destination_tried cache"""
|
||||||
now = self._clock.time_msec()
|
now = self._clock.time_msec()
|
||||||
|
|
||||||
|
@ -800,7 +800,7 @@ class FederationClient(FederationBase):
|
||||||
no servers successfully handle the request.
|
no servers successfully handle the request.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def send_request(destination) -> SendJoinResult:
|
async def send_request(destination: str) -> SendJoinResult:
|
||||||
response = await self._do_send_join(room_version, destination, pdu)
|
response = await self._do_send_join(room_version, destination, pdu)
|
||||||
|
|
||||||
# If an event was returned (and expected to be returned):
|
# If an event was returned (and expected to be returned):
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright 2015, 2016 OpenMarket Ltd
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
# Copyright 2018 New Vector Ltd
|
# Copyright 2018 New Vector Ltd
|
||||||
# Copyright 2019 Matrix.org Federation C.I.C
|
# Copyright 2019-2021 Matrix.org Federation C.I.C
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -450,7 +450,7 @@ class FederationServer(FederationBase):
|
||||||
# require callouts to other servers to fetch missing events), but
|
# require callouts to other servers to fetch missing events), but
|
||||||
# impose a limit to avoid going too crazy with ram/cpu.
|
# impose a limit to avoid going too crazy with ram/cpu.
|
||||||
|
|
||||||
async def process_pdus_for_room(room_id: str):
|
async def process_pdus_for_room(room_id: str) -> None:
|
||||||
with nested_logging_context(room_id):
|
with nested_logging_context(room_id):
|
||||||
logger.debug("Processing PDUs for %s", room_id)
|
logger.debug("Processing PDUs for %s", room_id)
|
||||||
|
|
||||||
|
@ -547,7 +547,7 @@ class FederationServer(FederationBase):
|
||||||
|
|
||||||
async def on_state_ids_request(
|
async def on_state_ids_request(
|
||||||
self, origin: str, room_id: str, event_id: str
|
self, origin: str, room_id: str, event_id: str
|
||||||
) -> Tuple[int, Dict[str, Any]]:
|
) -> Tuple[int, JsonDict]:
|
||||||
if not event_id:
|
if not event_id:
|
||||||
raise NotImplementedError("Specify an event")
|
raise NotImplementedError("Specify an event")
|
||||||
|
|
||||||
|
@ -567,7 +567,9 @@ class FederationServer(FederationBase):
|
||||||
|
|
||||||
return 200, resp
|
return 200, resp
|
||||||
|
|
||||||
async def _on_state_ids_request_compute(self, room_id, event_id):
|
async def _on_state_ids_request_compute(
|
||||||
|
self, room_id: str, event_id: str
|
||||||
|
) -> JsonDict:
|
||||||
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
|
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
|
||||||
auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
|
auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
|
||||||
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
|
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -23,6 +24,7 @@ from typing import Optional, Tuple
|
||||||
|
|
||||||
from synapse.federation.units import Transaction
|
from synapse.federation.units import Transaction
|
||||||
from synapse.logging.utils import log_function
|
from synapse.logging.utils import log_function
|
||||||
|
from synapse.storage.databases.main import DataStore
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -31,7 +33,7 @@ logger = logging.getLogger(__name__)
|
||||||
class TransactionActions:
|
class TransactionActions:
|
||||||
"""Defines persistence actions that relate to handling Transactions."""
|
"""Defines persistence actions that relate to handling Transactions."""
|
||||||
|
|
||||||
def __init__(self, datastore):
|
def __init__(self, datastore: DataStore):
|
||||||
self.store = datastore
|
self.store = datastore
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -350,7 +351,7 @@ class BaseFederationRow:
|
||||||
TypeId = "" # Unique string that ids the type. Must be overridden in sub classes.
|
TypeId = "" # Unique string that ids the type. Must be overridden in sub classes.
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_data(data):
|
def from_data(data: JsonDict) -> "BaseFederationRow":
|
||||||
"""Parse the data from the federation stream into a row.
|
"""Parse the data from the federation stream into a row.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -359,7 +360,7 @@ class BaseFederationRow:
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def to_data(self):
|
def to_data(self) -> JsonDict:
|
||||||
"""Serialize this row to be sent over the federation stream.
|
"""Serialize this row to be sent over the federation stream.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -368,7 +369,7 @@ class BaseFederationRow:
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def add_to_buffer(self, buff):
|
def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
|
||||||
"""Add this row to the appropriate field in the buffer ready for this
|
"""Add this row to the appropriate field in the buffer ready for this
|
||||||
to be sent over federation.
|
to be sent over federation.
|
||||||
|
|
||||||
|
@ -391,15 +392,15 @@ class PresenceDestinationsRow(
|
||||||
TypeId = "pd"
|
TypeId = "pd"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_data(data):
|
def from_data(data: JsonDict) -> "PresenceDestinationsRow":
|
||||||
return PresenceDestinationsRow(
|
return PresenceDestinationsRow(
|
||||||
state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"]
|
state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_data(self):
|
def to_data(self) -> JsonDict:
|
||||||
return {"state": self.state.as_dict(), "dests": self.destinations}
|
return {"state": self.state.as_dict(), "dests": self.destinations}
|
||||||
|
|
||||||
def add_to_buffer(self, buff):
|
def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
|
||||||
buff.presence_destinations.append((self.state, self.destinations))
|
buff.presence_destinations.append((self.state, self.destinations))
|
||||||
|
|
||||||
|
|
||||||
|
@ -417,13 +418,13 @@ class KeyedEduRow(
|
||||||
TypeId = "k"
|
TypeId = "k"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_data(data):
|
def from_data(data: JsonDict) -> "KeyedEduRow":
|
||||||
return KeyedEduRow(key=tuple(data["key"]), edu=Edu(**data["edu"]))
|
return KeyedEduRow(key=tuple(data["key"]), edu=Edu(**data["edu"]))
|
||||||
|
|
||||||
def to_data(self):
|
def to_data(self) -> JsonDict:
|
||||||
return {"key": self.key, "edu": self.edu.get_internal_dict()}
|
return {"key": self.key, "edu": self.edu.get_internal_dict()}
|
||||||
|
|
||||||
def add_to_buffer(self, buff):
|
def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
|
||||||
buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu
|
buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu
|
||||||
|
|
||||||
|
|
||||||
|
@ -433,13 +434,13 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
|
||||||
TypeId = "e"
|
TypeId = "e"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_data(data):
|
def from_data(data: JsonDict) -> "EduRow":
|
||||||
return EduRow(Edu(**data))
|
return EduRow(Edu(**data))
|
||||||
|
|
||||||
def to_data(self):
|
def to_data(self) -> JsonDict:
|
||||||
return self.edu.get_internal_dict()
|
return self.edu.get_internal_dict()
|
||||||
|
|
||||||
def add_to_buffer(self, buff):
|
def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
|
||||||
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
|
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
# Copyright 2019 New Vector Ltd
|
# Copyright 2019 New Vector Ltd
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -14,7 +15,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple
|
from types import TracebackType
|
||||||
|
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
@ -213,7 +215,7 @@ class PerDestinationQueue:
|
||||||
self._pending_edus_keyed[(edu.edu_type, key)] = edu
|
self._pending_edus_keyed[(edu.edu_type, key)] = edu
|
||||||
self.attempt_new_transaction()
|
self.attempt_new_transaction()
|
||||||
|
|
||||||
def send_edu(self, edu) -> None:
|
def send_edu(self, edu: Edu) -> None:
|
||||||
self._pending_edus.append(edu)
|
self._pending_edus.append(edu)
|
||||||
self.attempt_new_transaction()
|
self.attempt_new_transaction()
|
||||||
|
|
||||||
|
@ -701,7 +703,12 @@ class _TransactionQueueManager:
|
||||||
|
|
||||||
return self._pdus, pending_edus
|
return self._pdus, pending_edus
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc, tb):
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[Type[BaseException]],
|
||||||
|
exc: Optional[BaseException],
|
||||||
|
tb: Optional[TracebackType],
|
||||||
|
) -> None:
|
||||||
if exc_type is not None:
|
if exc_type is not None:
|
||||||
# Failed to send transaction, so we bail out.
|
# Failed to send transaction, so we bail out.
|
||||||
return
|
return
|
||||||
|
|
|
@ -21,6 +21,7 @@ from typing import (
|
||||||
Callable,
|
Callable,
|
||||||
Collection,
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
|
Generator,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
|
@ -235,11 +236,16 @@ class TransportLayerClient:
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
async def make_query(
|
async def make_query(
|
||||||
self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False
|
self,
|
||||||
):
|
destination: str,
|
||||||
|
query_type: str,
|
||||||
|
args: dict,
|
||||||
|
retry_on_dns_fail: bool,
|
||||||
|
ignore_backoff: bool = False,
|
||||||
|
) -> JsonDict:
|
||||||
path = _create_v1_path("/query/%s", query_type)
|
path = _create_v1_path("/query/%s", query_type)
|
||||||
|
|
||||||
content = await self.client.get_json(
|
return await self.client.get_json(
|
||||||
destination=destination,
|
destination=destination,
|
||||||
path=path,
|
path=path,
|
||||||
args=args,
|
args=args,
|
||||||
|
@ -248,8 +254,6 @@ class TransportLayerClient:
|
||||||
ignore_backoff=ignore_backoff,
|
ignore_backoff=ignore_backoff,
|
||||||
)
|
)
|
||||||
|
|
||||||
return content
|
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
async def make_membership_event(
|
async def make_membership_event(
|
||||||
self,
|
self,
|
||||||
|
@ -1317,7 +1321,7 @@ class SendJoinResponse:
|
||||||
|
|
||||||
|
|
||||||
@ijson.coroutine
|
@ijson.coroutine
|
||||||
def _event_parser(event_dict: JsonDict):
|
def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
|
||||||
"""Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
|
"""Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
|
||||||
to add them to a given dictionary.
|
to add them to a given dictionary.
|
||||||
"""
|
"""
|
||||||
|
@ -1328,7 +1332,9 @@ def _event_parser(event_dict: JsonDict):
|
||||||
|
|
||||||
|
|
||||||
@ijson.coroutine
|
@ijson.coroutine
|
||||||
def _event_list_parser(room_version: RoomVersion, events: List[EventBase]):
|
def _event_list_parser(
|
||||||
|
room_version: RoomVersion, events: List[EventBase]
|
||||||
|
) -> Generator[None, JsonDict, None]:
|
||||||
"""Helper function for use with `ijson.items_coro` to parse an array of
|
"""Helper function for use with `ijson.items_coro` to parse an array of
|
||||||
events and add them to the given list.
|
events and add them to the given list.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -302,7 +302,7 @@ def register_servlets(
|
||||||
authenticator: Authenticator,
|
authenticator: Authenticator,
|
||||||
ratelimiter: FederationRateLimiter,
|
ratelimiter: FederationRateLimiter,
|
||||||
servlet_groups: Optional[Iterable[str]] = None,
|
servlet_groups: Optional[Iterable[str]] = None,
|
||||||
):
|
) -> None:
|
||||||
"""Initialize and register servlet classes.
|
"""Initialize and register servlet classes.
|
||||||
|
|
||||||
Will by default register all servlets. For custom behaviour, pass in
|
Will by default register all servlets. For custom behaviour, pass in
|
||||||
|
|
|
@ -15,10 +15,13 @@
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
from typing import Any, Awaitable, Callable, Optional, Tuple, cast
|
||||||
|
|
||||||
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
|
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
|
||||||
from synapse.api.urls import FEDERATION_V1_PREFIX
|
from synapse.api.urls import FEDERATION_V1_PREFIX
|
||||||
|
from synapse.http.server import HttpServer, ServletCallback
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging import opentracing
|
from synapse.logging import opentracing
|
||||||
from synapse.logging.context import run_in_background
|
from synapse.logging.context import run_in_background
|
||||||
from synapse.logging.opentracing import (
|
from synapse.logging.opentracing import (
|
||||||
|
@ -29,6 +32,7 @@ from synapse.logging.opentracing import (
|
||||||
whitelisted_homeserver,
|
whitelisted_homeserver,
|
||||||
)
|
)
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
from synapse.types import JsonDict
|
||||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||||
from synapse.util.stringutils import parse_and_validate_server_name
|
from synapse.util.stringutils import parse_and_validate_server_name
|
||||||
|
|
||||||
|
@ -59,9 +63,11 @@ class Authenticator:
|
||||||
self.replication_client = hs.get_tcp_replication()
|
self.replication_client = hs.get_tcp_replication()
|
||||||
|
|
||||||
# A method just so we can pass 'self' as the authenticator to the Servlets
|
# A method just so we can pass 'self' as the authenticator to the Servlets
|
||||||
async def authenticate_request(self, request, content):
|
async def authenticate_request(
|
||||||
|
self, request: SynapseRequest, content: Optional[JsonDict]
|
||||||
|
) -> str:
|
||||||
now = self._clock.time_msec()
|
now = self._clock.time_msec()
|
||||||
json_request = {
|
json_request: JsonDict = {
|
||||||
"method": request.method.decode("ascii"),
|
"method": request.method.decode("ascii"),
|
||||||
"uri": request.uri.decode("ascii"),
|
"uri": request.uri.decode("ascii"),
|
||||||
"destination": self.server_name,
|
"destination": self.server_name,
|
||||||
|
@ -114,7 +120,7 @@ class Authenticator:
|
||||||
|
|
||||||
return origin
|
return origin
|
||||||
|
|
||||||
async def _reset_retry_timings(self, origin):
|
async def _reset_retry_timings(self, origin: str) -> None:
|
||||||
try:
|
try:
|
||||||
logger.info("Marking origin %r as up", origin)
|
logger.info("Marking origin %r as up", origin)
|
||||||
await self.store.set_destination_retry_timings(origin, None, 0, 0)
|
await self.store.set_destination_retry_timings(origin, None, 0, 0)
|
||||||
|
@ -133,14 +139,14 @@ class Authenticator:
|
||||||
logger.exception("Error resetting retry timings on %s", origin)
|
logger.exception("Error resetting retry timings on %s", origin)
|
||||||
|
|
||||||
|
|
||||||
def _parse_auth_header(header_bytes):
|
def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str]:
|
||||||
"""Parse an X-Matrix auth header
|
"""Parse an X-Matrix auth header
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
header_bytes (bytes): header value
|
header_bytes: header value
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[str, str, str]: origin, key id, signature.
|
origin, key id, signature.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AuthenticationError if the header could not be parsed
|
AuthenticationError if the header could not be parsed
|
||||||
|
@ -148,9 +154,9 @@ def _parse_auth_header(header_bytes):
|
||||||
try:
|
try:
|
||||||
header_str = header_bytes.decode("utf-8")
|
header_str = header_bytes.decode("utf-8")
|
||||||
params = header_str.split(" ")[1].split(",")
|
params = header_str.split(" ")[1].split(",")
|
||||||
param_dict = dict(kv.split("=") for kv in params)
|
param_dict = {k: v for k, v in (kv.split("=", maxsplit=1) for kv in params)}
|
||||||
|
|
||||||
def strip_quotes(value):
|
def strip_quotes(value: str) -> str:
|
||||||
if value.startswith('"'):
|
if value.startswith('"'):
|
||||||
return value[1:-1]
|
return value[1:-1]
|
||||||
else:
|
else:
|
||||||
|
@ -233,23 +239,25 @@ class BaseFederationServlet:
|
||||||
self.ratelimiter = ratelimiter
|
self.ratelimiter = ratelimiter
|
||||||
self.server_name = server_name
|
self.server_name = server_name
|
||||||
|
|
||||||
def _wrap(self, func):
|
def _wrap(self, func: Callable[..., Awaitable[Tuple[int, Any]]]) -> ServletCallback:
|
||||||
authenticator = self.authenticator
|
authenticator = self.authenticator
|
||||||
ratelimiter = self.ratelimiter
|
ratelimiter = self.ratelimiter
|
||||||
|
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def new_func(request, *args, **kwargs):
|
async def new_func(
|
||||||
|
request: SynapseRequest, *args: Any, **kwargs: str
|
||||||
|
) -> Optional[Tuple[int, Any]]:
|
||||||
"""A callback which can be passed to HttpServer.RegisterPaths
|
"""A callback which can be passed to HttpServer.RegisterPaths
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request (twisted.web.http.Request):
|
request:
|
||||||
*args: unused?
|
*args: unused?
|
||||||
**kwargs (dict[unicode, unicode]): the dict mapping keys to path
|
**kwargs: the dict mapping keys to path components as specified
|
||||||
components as specified in the path match regexp.
|
in the path match regexp.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[int, object]|None: (response code, response object) as returned by
|
(response code, response object) as returned by the callback method.
|
||||||
the callback method. None if the request has already been handled.
|
None if the request has already been handled.
|
||||||
"""
|
"""
|
||||||
content = None
|
content = None
|
||||||
if request.method in [b"PUT", b"POST"]:
|
if request.method in [b"PUT", b"POST"]:
|
||||||
|
@ -257,7 +265,9 @@ class BaseFederationServlet:
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
origin = await authenticator.authenticate_request(request, content)
|
origin: Optional[str] = await authenticator.authenticate_request(
|
||||||
|
request, content
|
||||||
|
)
|
||||||
except NoAuthenticationError:
|
except NoAuthenticationError:
|
||||||
origin = None
|
origin = None
|
||||||
if self.REQUIRE_AUTH:
|
if self.REQUIRE_AUTH:
|
||||||
|
@ -301,7 +311,7 @@ class BaseFederationServlet:
|
||||||
"client disconnected before we started processing "
|
"client disconnected before we started processing "
|
||||||
"request"
|
"request"
|
||||||
)
|
)
|
||||||
return -1, None
|
return None
|
||||||
response = await func(
|
response = await func(
|
||||||
origin, content, request.args, *args, **kwargs
|
origin, content, request.args, *args, **kwargs
|
||||||
)
|
)
|
||||||
|
@ -312,9 +322,9 @@ class BaseFederationServlet:
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
return new_func
|
return cast(ServletCallback, new_func)
|
||||||
|
|
||||||
def register(self, server):
|
def register(self, server: HttpServer) -> None:
|
||||||
pattern = re.compile("^" + self.PREFIX + self.PATH + "$")
|
pattern = re.compile("^" + self.PREFIX + self.PATH + "$")
|
||||||
|
|
||||||
for method in ("GET", "PUT", "POST"):
|
for method in ("GET", "PUT", "POST"):
|
||||||
|
|
Loading…
Reference in New Issue