Add a primitive helper script for listing worker endpoints. (#15243)

Co-authored-by: Patrick Cloke <patrickc@matrix.org>
This commit is contained in:
reivilibre 2023-03-23 12:11:14 +00:00 committed by GitHub
parent 3b0083c92a
commit 98fd558382
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 424 additions and 12 deletions

View File

@ -0,0 +1 @@
Add a primitive helper script for listing worker endpoints.

View File

@ -0,0 +1,302 @@
#!/usr/bin/env python
# Copyright 2022-2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import re
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, Iterable, Optional, Pattern, Set, Tuple
import yaml
from synapse.config.homeserver import HomeServerConfig
from synapse.federation.transport.server import (
TransportLayerServer,
register_servlets as register_federation_servlets,
)
from synapse.http.server import HttpServer, ServletCallback
from synapse.rest import ClientRestResource
from synapse.rest.key.v2 import RemoteKey
from synapse.server import HomeServer
from synapse.storage import DataStore
logger = logging.getLogger("generate_workers_map")
class MockHomeserver(HomeServer):
DATASTORE_CLASS = DataStore # type: ignore
def __init__(self, config: HomeServerConfig, worker_app: Optional[str]) -> None:
super().__init__(config.server.server_name, config=config)
self.config.worker.worker_app = worker_app
GROUP_PATTERN = re.compile(r"\(\?P<[^>]+?>(.+?)\)")
@dataclass
class EndpointDescription:
"""
Describes an endpoint and how it should be routed.
"""
# The servlet class that handles this endpoint
servlet_class: object
# The category of this endpoint. Is read from the `CATEGORY` constant in the servlet
# class.
category: Optional[str]
# TODO:
# - does it need to be routed based on a stream writer config?
# - does it benefit from any optimised, but optional, routing?
# - what 'opinionated synapse worker class' (event_creator, synchrotron, etc) does
# it go in?
class EnumerationResource(HttpServer):
"""
Accepts servlet registrations for the purposes of building up a description of
all endpoints.
"""
def __init__(self, is_worker: bool) -> None:
self.registrations: Dict[Tuple[str, str], EndpointDescription] = {}
self._is_worker = is_worker
def register_paths(
self,
method: str,
path_patterns: Iterable[Pattern],
callback: ServletCallback,
servlet_classname: str,
) -> None:
# federation servlet callbacks are wrapped, so unwrap them.
callback = getattr(callback, "__wrapped__", callback)
# fish out the servlet class
servlet_class = callback.__self__.__class__ # type: ignore
if self._is_worker and method in getattr(
servlet_class, "WORKERS_DENIED_METHODS", ()
):
# This endpoint would cause an error if called on a worker, so pretend it
# was never registered!
return
sd = EndpointDescription(
servlet_class=servlet_class,
category=getattr(servlet_class, "CATEGORY", None),
)
for pat in path_patterns:
self.registrations[(method, pat.pattern)] = sd
def get_registered_paths_for_hs(
hs: HomeServer,
) -> Dict[Tuple[str, str], EndpointDescription]:
"""
Given a homeserver, get all registered endpoints and their descriptions.
"""
enumerator = EnumerationResource(is_worker=hs.config.worker.worker_app is not None)
ClientRestResource.register_servlets(enumerator, hs)
federation_server = TransportLayerServer(hs)
# we can't use `federation_server.register_servlets` but this line does the
# same thing, only it uses this enumerator
register_federation_servlets(
federation_server.hs,
resource=enumerator,
ratelimiter=federation_server.ratelimiter,
authenticator=federation_server.authenticator,
servlet_groups=federation_server.servlet_groups,
)
# the key server endpoints are separate again
RemoteKey(hs).register(enumerator)
return enumerator.registrations
def get_registered_paths_for_default(
worker_app: Optional[str], base_config: HomeServerConfig
) -> Dict[Tuple[str, str], EndpointDescription]:
"""
Given the name of a worker application and a base homeserver configuration,
returns:
Dict from (method, path) to EndpointDescription
TODO Don't require passing in a config
"""
hs = MockHomeserver(base_config, worker_app)
# TODO We only do this to avoid an error, but don't need the database etc
hs.setup()
return get_registered_paths_for_hs(hs)
def elide_http_methods_if_unconflicting(
registrations: Dict[Tuple[str, str], EndpointDescription],
all_possible_registrations: Dict[Tuple[str, str], EndpointDescription],
) -> Dict[Tuple[str, str], EndpointDescription]:
"""
Elides HTTP methods (by replacing them with `*`) if all possible registered methods
can be handled by the worker whose registration map is `registrations`.
i.e. the only endpoints left with methods (other than `*`) should be the ones where
the worker can't handle all possible methods for that path.
"""
def paths_to_methods_dict(
methods_and_paths: Iterable[Tuple[str, str]]
) -> Dict[str, Set[str]]:
"""
Given (method, path) pairs, produces a dict from path to set of methods
available at that path.
"""
result: Dict[str, Set[str]] = {}
for method, path in methods_and_paths:
result.setdefault(path, set()).add(method)
return result
all_possible_reg_methods = paths_to_methods_dict(all_possible_registrations)
reg_methods = paths_to_methods_dict(registrations)
output = {}
for path, handleable_methods in reg_methods.items():
if handleable_methods == all_possible_reg_methods[path]:
any_method = next(iter(handleable_methods))
# TODO This assumes that all methods have the same servlet.
# I suppose that's possibly dubious?
output[("*", path)] = registrations[(any_method, path)]
else:
for method in handleable_methods:
output[(method, path)] = registrations[(method, path)]
return output
def simplify_path_regexes(
registrations: Dict[Tuple[str, str], EndpointDescription]
) -> Dict[Tuple[str, str], EndpointDescription]:
"""
Simplify all the path regexes for the dict of endpoint descriptions,
so that we don't use the Python-specific regex extensions
(and also to remove needlessly specific detail).
"""
def simplify_path_regex(path: str) -> str:
"""
Given a regex pattern, replaces all named capturing groups (e.g. `(?P<blah>xyz)`)
with a simpler version available in more common regex dialects (e.g. `.*`).
"""
# TODO it's hard to choose between these two;
# `.*` is a vague simplification
# return GROUP_PATTERN.sub(r"\1", path)
return GROUP_PATTERN.sub(r".*", path)
return {(m, simplify_path_regex(p)): v for (m, p), v in registrations.items()}
def main() -> None:
parser = argparse.ArgumentParser(
description=(
"Updates a synapse database to the latest schema and optionally runs background updates"
" on it."
)
)
parser.add_argument("-v", action="store_true")
parser.add_argument(
"--config-path",
type=argparse.FileType("r"),
required=True,
help="Synapse configuration file",
)
args = parser.parse_args()
# TODO
# logging.basicConfig(**logging_config)
# Load, process and sanity-check the config.
hs_config = yaml.safe_load(args.config_path)
config = HomeServerConfig()
config.parse_config_dict(hs_config, "", "")
master_paths = get_registered_paths_for_default(None, config)
worker_paths = get_registered_paths_for_default(
"synapse.app.generic_worker", config
)
all_paths = {**master_paths, **worker_paths}
elided_worker_paths = elide_http_methods_if_unconflicting(worker_paths, all_paths)
elide_http_methods_if_unconflicting(master_paths, all_paths)
# TODO SSO endpoints (pick_idp etc) NOT REGISTERED BY THIS SCRIPT
categories_to_methods_and_paths: Dict[
Optional[str], Dict[Tuple[str, str], EndpointDescription]
] = defaultdict(dict)
for (method, path), desc in elided_worker_paths.items():
categories_to_methods_and_paths[desc.category][method, path] = desc
for category, contents in categories_to_methods_and_paths.items():
print_category(category, contents)
def print_category(
category_name: Optional[str],
elided_worker_paths: Dict[Tuple[str, str], EndpointDescription],
) -> None:
"""
Prints out a category, in documentation page style.
Example:
```
# Category name
/path/xyz
GET /path/abc
```
"""
if category_name:
print(f"# {category_name}")
else:
print("# (Uncategorised requests)")
for ln in sorted(
p for m, p in simplify_path_regexes(elided_worker_paths) if m == "*"
):
print(ln)
print()
for ln in sorted(
f"{m:6} {p}" for m, p in simplify_path_regexes(elided_worker_paths) if m != "*"
):
print(ln)
print()
if __name__ == "__main__":
main()

View File

@ -108,6 +108,7 @@ class PublicRoomList(BaseFederationServlet):
""" """
PATH = "/publicRooms" PATH = "/publicRooms"
CATEGORY = "Federation requests"
def __init__( def __init__(
self, self,
@ -212,6 +213,7 @@ class OpenIdUserInfo(BaseFederationServlet):
""" """
PATH = "/openid/userinfo" PATH = "/openid/userinfo"
CATEGORY = "Federation requests"
REQUIRE_AUTH = False REQUIRE_AUTH = False

View File

@ -70,6 +70,7 @@ class BaseFederationServerServlet(BaseFederationServlet):
class FederationSendServlet(BaseFederationServerServlet): class FederationSendServlet(BaseFederationServerServlet):
PATH = "/send/(?P<transaction_id>[^/]*)/?" PATH = "/send/(?P<transaction_id>[^/]*)/?"
CATEGORY = "Inbound federation transaction request"
# We ratelimit manually in the handler as we queue up the requests and we # We ratelimit manually in the handler as we queue up the requests and we
# don't want to fill up the ratelimiter with blocked requests. # don't want to fill up the ratelimiter with blocked requests.
@ -138,6 +139,7 @@ class FederationSendServlet(BaseFederationServerServlet):
class FederationEventServlet(BaseFederationServerServlet): class FederationEventServlet(BaseFederationServerServlet):
PATH = "/event/(?P<event_id>[^/]*)/?" PATH = "/event/(?P<event_id>[^/]*)/?"
CATEGORY = "Federation requests"
# This is when someone asks for a data item for a given server data_id pair. # This is when someone asks for a data item for a given server data_id pair.
async def on_GET( async def on_GET(
@ -152,6 +154,7 @@ class FederationEventServlet(BaseFederationServerServlet):
class FederationStateV1Servlet(BaseFederationServerServlet): class FederationStateV1Servlet(BaseFederationServerServlet):
PATH = "/state/(?P<room_id>[^/]*)/?" PATH = "/state/(?P<room_id>[^/]*)/?"
CATEGORY = "Federation requests"
# This is when someone asks for all data for a given room. # This is when someone asks for all data for a given room.
async def on_GET( async def on_GET(
@ -170,6 +173,7 @@ class FederationStateV1Servlet(BaseFederationServerServlet):
class FederationStateIdsServlet(BaseFederationServerServlet): class FederationStateIdsServlet(BaseFederationServerServlet):
PATH = "/state_ids/(?P<room_id>[^/]*)/?" PATH = "/state_ids/(?P<room_id>[^/]*)/?"
CATEGORY = "Federation requests"
async def on_GET( async def on_GET(
self, self,
@ -187,6 +191,7 @@ class FederationStateIdsServlet(BaseFederationServerServlet):
class FederationBackfillServlet(BaseFederationServerServlet): class FederationBackfillServlet(BaseFederationServerServlet):
PATH = "/backfill/(?P<room_id>[^/]*)/?" PATH = "/backfill/(?P<room_id>[^/]*)/?"
CATEGORY = "Federation requests"
async def on_GET( async def on_GET(
self, self,
@ -225,6 +230,7 @@ class FederationTimestampLookupServlet(BaseFederationServerServlet):
""" """
PATH = "/timestamp_to_event/(?P<room_id>[^/]*)/?" PATH = "/timestamp_to_event/(?P<room_id>[^/]*)/?"
CATEGORY = "Federation requests"
async def on_GET( async def on_GET(
self, self,
@ -246,6 +252,7 @@ class FederationTimestampLookupServlet(BaseFederationServerServlet):
class FederationQueryServlet(BaseFederationServerServlet): class FederationQueryServlet(BaseFederationServerServlet):
PATH = "/query/(?P<query_type>[^/]*)" PATH = "/query/(?P<query_type>[^/]*)"
CATEGORY = "Federation requests"
# This is when we receive a server-server Query # This is when we receive a server-server Query
async def on_GET( async def on_GET(
@ -262,6 +269,7 @@ class FederationQueryServlet(BaseFederationServerServlet):
class FederationMakeJoinServlet(BaseFederationServerServlet): class FederationMakeJoinServlet(BaseFederationServerServlet):
PATH = "/make_join/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)" PATH = "/make_join/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
CATEGORY = "Federation requests"
async def on_GET( async def on_GET(
self, self,
@ -297,6 +305,7 @@ class FederationMakeJoinServlet(BaseFederationServerServlet):
class FederationMakeLeaveServlet(BaseFederationServerServlet): class FederationMakeLeaveServlet(BaseFederationServerServlet):
PATH = "/make_leave/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)" PATH = "/make_leave/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
CATEGORY = "Federation requests"
async def on_GET( async def on_GET(
self, self,
@ -312,6 +321,7 @@ class FederationMakeLeaveServlet(BaseFederationServerServlet):
class FederationV1SendLeaveServlet(BaseFederationServerServlet): class FederationV1SendLeaveServlet(BaseFederationServerServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
CATEGORY = "Federation requests"
async def on_PUT( async def on_PUT(
self, self,
@ -327,6 +337,7 @@ class FederationV1SendLeaveServlet(BaseFederationServerServlet):
class FederationV2SendLeaveServlet(BaseFederationServerServlet): class FederationV2SendLeaveServlet(BaseFederationServerServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
CATEGORY = "Federation requests"
PREFIX = FEDERATION_V2_PREFIX PREFIX = FEDERATION_V2_PREFIX
@ -344,6 +355,7 @@ class FederationV2SendLeaveServlet(BaseFederationServerServlet):
class FederationMakeKnockServlet(BaseFederationServerServlet): class FederationMakeKnockServlet(BaseFederationServerServlet):
PATH = "/make_knock/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)" PATH = "/make_knock/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
CATEGORY = "Federation requests"
async def on_GET( async def on_GET(
self, self,
@ -366,6 +378,7 @@ class FederationMakeKnockServlet(BaseFederationServerServlet):
class FederationV1SendKnockServlet(BaseFederationServerServlet): class FederationV1SendKnockServlet(BaseFederationServerServlet):
PATH = "/send_knock/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" PATH = "/send_knock/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
CATEGORY = "Federation requests"
async def on_PUT( async def on_PUT(
self, self,
@ -381,6 +394,7 @@ class FederationV1SendKnockServlet(BaseFederationServerServlet):
class FederationEventAuthServlet(BaseFederationServerServlet): class FederationEventAuthServlet(BaseFederationServerServlet):
PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
CATEGORY = "Federation requests"
async def on_GET( async def on_GET(
self, self,
@ -395,6 +409,7 @@ class FederationEventAuthServlet(BaseFederationServerServlet):
class FederationV1SendJoinServlet(BaseFederationServerServlet): class FederationV1SendJoinServlet(BaseFederationServerServlet):
PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
CATEGORY = "Federation requests"
async def on_PUT( async def on_PUT(
self, self,
@ -412,6 +427,7 @@ class FederationV1SendJoinServlet(BaseFederationServerServlet):
class FederationV2SendJoinServlet(BaseFederationServerServlet): class FederationV2SendJoinServlet(BaseFederationServerServlet):
PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
CATEGORY = "Federation requests"
PREFIX = FEDERATION_V2_PREFIX PREFIX = FEDERATION_V2_PREFIX
@ -455,6 +471,7 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):
class FederationV1InviteServlet(BaseFederationServerServlet): class FederationV1InviteServlet(BaseFederationServerServlet):
PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
CATEGORY = "Federation requests"
async def on_PUT( async def on_PUT(
self, self,
@ -479,6 +496,7 @@ class FederationV1InviteServlet(BaseFederationServerServlet):
class FederationV2InviteServlet(BaseFederationServerServlet): class FederationV2InviteServlet(BaseFederationServerServlet):
PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
CATEGORY = "Federation requests"
PREFIX = FEDERATION_V2_PREFIX PREFIX = FEDERATION_V2_PREFIX
@ -515,6 +533,7 @@ class FederationV2InviteServlet(BaseFederationServerServlet):
class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet): class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet):
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)" PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
CATEGORY = "Federation requests"
async def on_PUT( async def on_PUT(
self, self,
@ -529,6 +548,7 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet):
class FederationClientKeysQueryServlet(BaseFederationServerServlet): class FederationClientKeysQueryServlet(BaseFederationServerServlet):
PATH = "/user/keys/query" PATH = "/user/keys/query"
CATEGORY = "Federation requests"
async def on_POST( async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
@ -538,6 +558,7 @@ class FederationClientKeysQueryServlet(BaseFederationServerServlet):
class FederationUserDevicesQueryServlet(BaseFederationServerServlet): class FederationUserDevicesQueryServlet(BaseFederationServerServlet):
PATH = "/user/devices/(?P<user_id>[^/]*)" PATH = "/user/devices/(?P<user_id>[^/]*)"
CATEGORY = "Federation requests"
async def on_GET( async def on_GET(
self, self,
@ -551,6 +572,7 @@ class FederationUserDevicesQueryServlet(BaseFederationServerServlet):
class FederationClientKeysClaimServlet(BaseFederationServerServlet): class FederationClientKeysClaimServlet(BaseFederationServerServlet):
PATH = "/user/keys/claim" PATH = "/user/keys/claim"
CATEGORY = "Federation requests"
async def on_POST( async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
@ -561,6 +583,7 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
class FederationGetMissingEventsServlet(BaseFederationServerServlet): class FederationGetMissingEventsServlet(BaseFederationServerServlet):
PATH = "/get_missing_events/(?P<room_id>[^/]*)" PATH = "/get_missing_events/(?P<room_id>[^/]*)"
CATEGORY = "Federation requests"
async def on_POST( async def on_POST(
self, self,
@ -586,6 +609,7 @@ class FederationGetMissingEventsServlet(BaseFederationServerServlet):
class On3pidBindServlet(BaseFederationServerServlet): class On3pidBindServlet(BaseFederationServerServlet):
PATH = "/3pid/onbind" PATH = "/3pid/onbind"
CATEGORY = "Federation requests"
REQUIRE_AUTH = False REQUIRE_AUTH = False
@ -618,6 +642,7 @@ class On3pidBindServlet(BaseFederationServerServlet):
class FederationVersionServlet(BaseFederationServlet): class FederationVersionServlet(BaseFederationServlet):
PATH = "/version" PATH = "/version"
CATEGORY = "Federation requests"
REQUIRE_AUTH = False REQUIRE_AUTH = False
@ -640,6 +665,7 @@ class FederationVersionServlet(BaseFederationServlet):
class FederationRoomHierarchyServlet(BaseFederationServlet): class FederationRoomHierarchyServlet(BaseFederationServlet):
PATH = "/hierarchy/(?P<room_id>[^/]*)" PATH = "/hierarchy/(?P<room_id>[^/]*)"
CATEGORY = "Federation requests"
def __init__( def __init__(
self, self,
@ -672,6 +698,7 @@ class RoomComplexityServlet(BaseFederationServlet):
PATH = "/rooms/(?P<room_id>[^/]*)/complexity" PATH = "/rooms/(?P<room_id>[^/]*)/complexity"
PREFIX = FEDERATION_UNSTABLE_PREFIX PREFIX = FEDERATION_UNSTABLE_PREFIX
CATEGORY = "Federation requests (unstable)"
def __init__( def __init__(
self, self,

View File

@ -43,19 +43,22 @@ def client_patterns(
Returns: Returns:
An iterable of patterns. An iterable of patterns.
""" """
patterns = [] versions = []
if unstable:
unstable_prefix = CLIENT_API_PREFIX + "/unstable"
patterns.append(re.compile("^" + unstable_prefix + path_regex))
if v1: if v1:
v1_prefix = CLIENT_API_PREFIX + "/api/v1" versions.append("api/v1")
patterns.append(re.compile("^" + v1_prefix + path_regex)) versions.extend(releases)
for release in releases: if unstable:
new_prefix = CLIENT_API_PREFIX + f"/{release}" versions.append("unstable")
patterns.append(re.compile("^" + new_prefix + path_regex))
return patterns if len(versions) == 1:
versions_str = versions[0]
elif len(versions) > 1:
versions_str = "(" + "|".join(versions) + ")"
else:
raise RuntimeError("Must have at least one version for a URL")
return [re.compile("^" + CLIENT_API_PREFIX + "/" + versions_str + path_regex)]
def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int) -> None: def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int) -> None:

View File

@ -576,6 +576,9 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
class ThreepidRestServlet(RestServlet): class ThreepidRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid$") PATTERNS = client_patterns("/account/3pid$")
# This is used as a proxy for all the 3pid endpoints.
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -834,6 +837,7 @@ def assert_valid_next_link(hs: "HomeServer", next_link: str) -> None:
class WhoamiRestServlet(RestServlet): class WhoamiRestServlet(RestServlet):
PATTERNS = client_patterns("/account/whoami$") PATTERNS = client_patterns("/account/whoami$")
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@ -38,6 +38,7 @@ class AccountDataServlet(RestServlet):
PATTERNS = client_patterns( PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)" "/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)"
) )
CATEGORY = "Account data requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -136,6 +137,7 @@ class RoomAccountDataServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)" "/rooms/(?P<room_id>[^/]*)"
"/account_data/(?P<account_data_type>[^/]*)" "/account_data/(?P<account_data_type>[^/]*)"
) )
CATEGORY = "Account data requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@ -40,6 +40,7 @@ logger = logging.getLogger(__name__)
class DevicesRestServlet(RestServlet): class DevicesRestServlet(RestServlet):
PATTERNS = client_patterns("/devices$") PATTERNS = client_patterns("/devices$")
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -123,6 +124,7 @@ class DeleteDevicesRestServlet(RestServlet):
class DeviceRestServlet(RestServlet): class DeviceRestServlet(RestServlet):
PATTERNS = client_patterns("/devices/(?P<device_id>[^/]*)$") PATTERNS = client_patterns("/devices/(?P<device_id>[^/]*)$")
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
class EventStreamRestServlet(RestServlet): class EventStreamRestServlet(RestServlet):
PATTERNS = client_patterns("/events$", v1=True) PATTERNS = client_patterns("/events$", v1=True)
CATEGORY = "Sync requests"
DEFAULT_LONGPOLL_TIME_MS = 30000 DEFAULT_LONGPOLL_TIME_MS = 30000
@ -76,6 +77,7 @@ class EventStreamRestServlet(RestServlet):
class EventRestServlet(RestServlet): class EventRestServlet(RestServlet):
PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True) PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
class GetFilterRestServlet(RestServlet): class GetFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)") PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -69,6 +70,7 @@ class GetFilterRestServlet(RestServlet):
class CreateFilterRestServlet(RestServlet): class CreateFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter") PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@ -28,6 +28,7 @@ if TYPE_CHECKING:
# TODO: Needs unit testing # TODO: Needs unit testing
class InitialSyncRestServlet(RestServlet): class InitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/initialSync$", v1=True) PATTERNS = client_patterns("/initialSync$", v1=True)
CATEGORY = "Sync requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@ -89,6 +89,7 @@ class KeyUploadServlet(RestServlet):
""" """
PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -182,6 +183,7 @@ class KeyQueryServlet(RestServlet):
""" """
PATTERNS = client_patterns("/keys/query$") PATTERNS = client_patterns("/keys/query$")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -225,6 +227,7 @@ class KeyChangesServlet(RestServlet):
""" """
PATTERNS = client_patterns("/keys/changes$") PATTERNS = client_patterns("/keys/changes$")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -274,6 +277,7 @@ class OneTimeKeyServlet(RestServlet):
""" """
PATTERNS = client_patterns("/keys/claim$") PATTERNS = client_patterns("/keys/claim$")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@ -40,6 +40,7 @@ class KnockRoomAliasServlet(RestServlet):
""" """
PATTERNS = client_patterns("/knock/(?P<room_identifier>[^/]*)") PATTERNS = client_patterns("/knock/(?P<room_identifier>[^/]*)")
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@ -72,6 +72,8 @@ class LoginResponse(TypedDict, total=False):
class LoginRestServlet(RestServlet): class LoginRestServlet(RestServlet):
PATTERNS = client_patterns("/login$", v1=True) PATTERNS = client_patterns("/login$", v1=True)
CATEGORY = "Registration/login requests"
CAS_TYPE = "m.login.cas" CAS_TYPE = "m.login.cas"
SSO_TYPE = "m.login.sso" SSO_TYPE = "m.login.sso"
TOKEN_TYPE = "m.login.token" TOKEN_TYPE = "m.login.token"
@ -537,6 +539,7 @@ def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
class RefreshTokenServlet(RestServlet): class RefreshTokenServlet(RestServlet):
PATTERNS = client_patterns("/refresh$") PATTERNS = client_patterns("/refresh$")
CATEGORY = "Registration/login requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
@ -590,6 +593,7 @@ class SsoRedirectServlet(RestServlet):
+ "/(r0|v3)/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$" + "/(r0|v3)/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
) )
] ]
CATEGORY = "SSO requests needed for all SSO providers"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they # make sure that the relevant handlers are instantiated, so that they

View File

@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
class PresenceStatusRestServlet(RestServlet): class PresenceStatusRestServlet(RestServlet):
PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status", v1=True) PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status", v1=True)
CATEGORY = "Presence requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@ -29,6 +29,7 @@ if TYPE_CHECKING:
class ProfileDisplaynameRestServlet(RestServlet): class ProfileDisplaynameRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True) PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True)
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -86,6 +87,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
class ProfileAvatarURLRestServlet(RestServlet): class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True) PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -142,6 +144,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
class ProfileRestServlet(RestServlet): class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True) PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@ -44,6 +44,9 @@ class PushRuleRestServlet(RestServlet):
"Unrecognised request: You probably wanted a trailing slash" "Unrecognised request: You probably wanted a trailing slash"
) )
WORKERS_DENIED_METHODS = ["PUT", "DELETE"]
CATEGORY = "Push rule requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()

View File

@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
class ReadMarkerRestServlet(RestServlet): class ReadMarkerRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$") PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
CATEGORY = "Receipts requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@ -36,6 +36,7 @@ class ReceiptRestServlet(RestServlet):
"/receipt/(?P<receipt_type>[^/]*)" "/receipt/(?P<receipt_type>[^/]*)"
"/(?P<event_id>[^/]*)$" "/(?P<event_id>[^/]*)$"
) )
CATEGORY = "Receipts requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@ -367,6 +367,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
f"/register/{LoginType.REGISTRATION_TOKEN}/validity", f"/register/{LoginType.REGISTRATION_TOKEN}/validity",
releases=("v1",), releases=("v1",),
) )
CATEGORY = "Registration/login requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -395,6 +396,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
class RegisterRestServlet(RestServlet): class RegisterRestServlet(RestServlet):
PATTERNS = client_patterns("/register$") PATTERNS = client_patterns("/register$")
CATEGORY = "Registration/login requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@ -42,6 +42,7 @@ class RelationPaginationServlet(RestServlet):
"(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$", "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
releases=("v1",), releases=("v1",),
) )
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -84,6 +85,7 @@ class RelationPaginationServlet(RestServlet):
class ThreadsServlet(RestServlet): class ThreadsServlet(RestServlet):
PATTERNS = (re.compile("^/_matrix/client/v1/rooms/(?P<room_id>[^/]*)/threads"),) PATTERNS = (re.compile("^/_matrix/client/v1/rooms/(?P<room_id>[^/]*)/threads"),)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@ -140,7 +140,7 @@ class TransactionRestServlet(RestServlet):
class RoomCreateRestServlet(TransactionRestServlet): class RoomCreateRestServlet(TransactionRestServlet):
# No PATTERN; we have custom dispatch rules here CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
@ -180,6 +180,8 @@ class RoomCreateRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for generic events # TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(RestServlet): class RoomStateEventRestServlet(RestServlet):
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
@ -323,6 +325,8 @@ class RoomStateEventRestServlet(RestServlet):
# TODO: Needs unit testing for generic events + feedback # TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(TransactionRestServlet): class RoomSendEventRestServlet(TransactionRestServlet):
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
@ -398,6 +402,8 @@ class RoomSendEventRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for room ID + alias joins # TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up
@ -460,6 +466,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class PublicRoomListRestServlet(RestServlet): class PublicRoomListRestServlet(RestServlet):
PATTERNS = client_patterns("/publicRooms$", v1=True) PATTERNS = client_patterns("/publicRooms$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -578,6 +585,7 @@ class PublicRoomListRestServlet(RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomMemberListRestServlet(RestServlet): class RoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True) PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -633,6 +641,7 @@ class RoomMemberListRestServlet(RestServlet):
# except it does custom AS logic and has a simpler return format # except it does custom AS logic and has a simpler return format
class JoinedRoomMemberListRestServlet(RestServlet): class JoinedRoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True) PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -654,6 +663,10 @@ class JoinedRoomMemberListRestServlet(RestServlet):
# TODO: Needs better unit testing # TODO: Needs better unit testing
class RoomMessageListRestServlet(RestServlet): class RoomMessageListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True) PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True)
# TODO The routing information should be exposed programatically.
# I want to do this but for now I felt bad about leaving this without
# at least a visible warning on it.
CATEGORY = "Client API requests (ALL FOR SAME ROOM MUST GO TO SAME WORKER)"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -720,6 +733,7 @@ class RoomMessageListRestServlet(RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomStateRestServlet(RestServlet): class RoomStateRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True) PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -742,6 +756,7 @@ class RoomStateRestServlet(RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomInitialSyncRestServlet(RestServlet): class RoomInitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True) PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True)
CATEGORY = "Sync requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -766,6 +781,7 @@ class RoomEventServlet(RestServlet):
PATTERNS = client_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$", v1=True "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$", v1=True
) )
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -858,6 +874,7 @@ class RoomEventContextServlet(RestServlet):
PATTERNS = client_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$", v1=True "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$", v1=True
) )
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -958,6 +975,8 @@ class RoomForgetRestServlet(TransactionRestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomMembershipRestServlet(TransactionRestServlet): class RoomMembershipRestServlet(TransactionRestServlet):
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
@ -1071,6 +1090,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
class RoomRedactEventRestServlet(TransactionRestServlet): class RoomRedactEventRestServlet(TransactionRestServlet):
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
@ -1164,6 +1185,7 @@ class RoomTypingRestServlet(RestServlet):
PATTERNS = client_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$", v1=True "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$", v1=True
) )
CATEGORY = "The typing stream"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -1195,7 +1217,7 @@ class RoomTypingRestServlet(RestServlet):
# Limit timeout to stop people from setting silly typing timeouts. # Limit timeout to stop people from setting silly typing timeouts.
timeout = min(content.get("timeout", 30000), 120000) timeout = min(content.get("timeout", 30000), 120000)
# Defer getting the typing handler since it will raise on workers. # Defer getting the typing handler since it will raise on WORKER_PATTERNS.
typing_handler = self.hs.get_typing_writer_handler() typing_handler = self.hs.get_typing_writer_handler()
try: try:
@ -1224,6 +1246,7 @@ class RoomAliasListServlet(RestServlet):
r"/rooms/(?P<room_id>[^/]*)/aliases" r"/rooms/(?P<room_id>[^/]*)/aliases"
), ),
] + list(client_patterns("/rooms/(?P<room_id>[^/]*)/aliases$", unstable=False)) ] + list(client_patterns("/rooms/(?P<room_id>[^/]*)/aliases$", unstable=False))
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -1244,6 +1267,7 @@ class RoomAliasListServlet(RestServlet):
class SearchRestServlet(RestServlet): class SearchRestServlet(RestServlet):
PATTERNS = client_patterns("/search$", v1=True) PATTERNS = client_patterns("/search$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -1263,6 +1287,7 @@ class SearchRestServlet(RestServlet):
class JoinedRoomsRestServlet(RestServlet): class JoinedRoomsRestServlet(RestServlet):
PATTERNS = client_patterns("/joined_rooms$", v1=True) PATTERNS = client_patterns("/joined_rooms$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -1334,6 +1359,7 @@ class TimestampLookupRestServlet(RestServlet):
PATTERNS = ( PATTERNS = (
re.compile("^/_matrix/client/v1/rooms/(?P<room_id>[^/]*)/timestamp_to_event$"), re.compile("^/_matrix/client/v1/rooms/(?P<room_id>[^/]*)/timestamp_to_event$"),
) )
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -1365,6 +1391,8 @@ class TimestampLookupRestServlet(RestServlet):
class RoomHierarchyRestServlet(RestServlet): class RoomHierarchyRestServlet(RestServlet):
PATTERNS = (re.compile("^/_matrix/client/v1/rooms/(?P<room_id>[^/]*)/hierarchy$"),) PATTERNS = (re.compile("^/_matrix/client/v1/rooms/(?P<room_id>[^/]*)/hierarchy$"),)
WORKERS = PATTERNS
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -1405,6 +1433,7 @@ class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet):
"/rooms/(?P<room_identifier>[^/]*)/summary$" "/rooms/(?P<room_identifier>[^/]*)/summary$"
), ),
) )
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)

View File

@ -69,6 +69,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)/batch_send$" "/rooms/(?P<room_id>[^/]*)/batch_send$"
), ),
) )
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@ -37,6 +37,7 @@ class RoomKeysServlet(RestServlet):
PATTERNS = client_patterns( PATTERNS = client_patterns(
"/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$" "/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$"
) )
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -253,6 +254,7 @@ class RoomKeysServlet(RestServlet):
class RoomKeysNewVersionServlet(RestServlet): class RoomKeysNewVersionServlet(RestServlet):
PATTERNS = client_patterns("/room_keys/version$") PATTERNS = client_patterns("/room_keys/version$")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -328,6 +330,7 @@ class RoomKeysNewVersionServlet(RestServlet):
class RoomKeysVersionServlet(RestServlet): class RoomKeysVersionServlet(RestServlet):
PATTERNS = client_patterns("/room_keys/version/(?P<version>[^/]+)$") PATTERNS = client_patterns("/room_keys/version/(?P<version>[^/]+)$")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@ -35,6 +35,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
PATTERNS = client_patterns( PATTERNS = client_patterns(
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$" "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$"
) )
CATEGORY = "The to_device stream"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@ -87,6 +87,7 @@ class SyncRestServlet(RestServlet):
PATTERNS = client_patterns("/sync$") PATTERNS = client_patterns("/sync$")
ALLOWED_PRESENCE = {"online", "offline", "unavailable"} ALLOWED_PRESENCE = {"online", "offline", "unavailable"}
CATEGORY = "Sync requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@ -37,6 +37,7 @@ class TagListServlet(RestServlet):
PATTERNS = client_patterns( PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags$" "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags$"
) )
CATEGORY = "Account data requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -64,6 +65,7 @@ class TagServlet(RestServlet):
PATTERNS = client_patterns( PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)" "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
) )
CATEGORY = "Account data requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
class UserDirectorySearchRestServlet(RestServlet): class UserDirectorySearchRestServlet(RestServlet):
PATTERNS = client_patterns("/user_directory/search$") PATTERNS = client_patterns("/user_directory/search$")
CATEGORY = "User directory search requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@ -34,6 +34,7 @@ logger = logging.getLogger(__name__)
class VersionsRestServlet(RestServlet): class VersionsRestServlet(RestServlet):
PATTERNS = [re.compile("^/_matrix/client/versions$")] PATTERNS = [re.compile("^/_matrix/client/versions$")]
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@ -29,6 +29,7 @@ if TYPE_CHECKING:
class VoipRestServlet(RestServlet): class VoipRestServlet(RestServlet):
PATTERNS = client_patterns("/voip/turnServer$", v1=True) PATTERNS = client_patterns("/voip/turnServer$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@ -93,6 +93,8 @@ class RemoteKey(RestServlet):
} }
""" """
CATEGORY = "Federation requests"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.fetcher = ServerKeyFetcher(hs) self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastores().main self.store = hs.get_datastores().main