Additional type hints for client REST servlets (part 4) (#10728)

This commit is contained in:
Patrick Cloke 2021-09-01 11:59:32 -04:00 committed by GitHub
parent dc75fb7f05
commit d1f1b46c2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 145 additions and 100 deletions

1
changelog.d/10728.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to REST servlets.

View File

@ -16,9 +16,11 @@
import logging import logging
import random import random
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
from twisted.web.server import Request
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import ( from synapse.api.errors import (
Codes, Codes,
@ -28,15 +30,17 @@ from synapse.api.errors import (
) )
from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.emailconfig import ThreepidBehaviour
from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import finish_request, respond_with_html from synapse.http.server import HttpServer, finish_request, respond_with_html
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_dict, assert_params_in_dict,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
from synapse.http.site import SynapseRequest
from synapse.metrics import threepid_send_requests from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer from synapse.push.mailer import Mailer
from synapse.types import JsonDict
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import check_3pid_allowed, validate_email from synapse.util.threepids import check_3pid_allowed, validate_email
@ -68,7 +72,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
template_text=self.config.email_password_reset_template_text, template_text=self.config.email_password_reset_template_text,
) )
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config: if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning( logger.warning(
@ -159,7 +163,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
class PasswordRestServlet(RestServlet): class PasswordRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password$") PATTERNS = client_patterns("/account/password$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -169,7 +173,7 @@ class PasswordRestServlet(RestServlet):
self._set_password_handler = hs.get_set_password_handler() self._set_password_handler = hs.get_set_password_handler()
@interactive_auth_handler @interactive_auth_handler
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
# we do basic sanity checks here because the auth layer will store these # we do basic sanity checks here because the auth layer will store these
@ -190,6 +194,7 @@ class PasswordRestServlet(RestServlet):
# #
# In the second case, we require a password to confirm their identity. # In the second case, we require a password to confirm their identity.
requester = None
if self.auth.has_access_token(request): if self.auth.has_access_token(request):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
try: try:
@ -206,16 +211,15 @@ class PasswordRestServlet(RestServlet):
# If a password is available now, hash the provided password and # If a password is available now, hash the provided password and
# store it for later. # store it for later.
if new_password: if new_password:
password_hash = await self.auth_handler.hash(new_password) new_password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data( await self.auth_handler.set_session_data(
e.session_id, e.session_id,
UIAuthSessionDataConstants.PASSWORD_HASH, UIAuthSessionDataConstants.PASSWORD_HASH,
password_hash, new_password_hash,
) )
raise raise
user_id = requester.user.to_string() user_id = requester.user.to_string()
else: else:
requester = None
try: try:
result, params, session_id = await self.auth_handler.check_ui_auth( result, params, session_id = await self.auth_handler.check_ui_auth(
[[LoginType.EMAIL_IDENTITY]], [[LoginType.EMAIL_IDENTITY]],
@ -230,11 +234,11 @@ class PasswordRestServlet(RestServlet):
# If a password is available now, hash the provided password and # If a password is available now, hash the provided password and
# store it for later. # store it for later.
if new_password: if new_password:
password_hash = await self.auth_handler.hash(new_password) new_password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data( await self.auth_handler.set_session_data(
e.session_id, e.session_id,
UIAuthSessionDataConstants.PASSWORD_HASH, UIAuthSessionDataConstants.PASSWORD_HASH,
password_hash, new_password_hash,
) )
raise raise
@ -264,7 +268,7 @@ class PasswordRestServlet(RestServlet):
# If we have a password in this request, prefer it. Otherwise, use the # If we have a password in this request, prefer it. Otherwise, use the
# password hash from an earlier request. # password hash from an earlier request.
if new_password: if new_password:
password_hash = await self.auth_handler.hash(new_password) password_hash: Optional[str] = await self.auth_handler.hash(new_password)
elif session_id is not None: elif session_id is not None:
password_hash = await self.auth_handler.get_session_data( password_hash = await self.auth_handler.get_session_data(
session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
@ -288,7 +292,7 @@ class PasswordRestServlet(RestServlet):
class DeactivateAccountRestServlet(RestServlet): class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_patterns("/account/deactivate$") PATTERNS = client_patterns("/account/deactivate$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -296,7 +300,7 @@ class DeactivateAccountRestServlet(RestServlet):
self._deactivate_account_handler = hs.get_deactivate_account_handler() self._deactivate_account_handler = hs.get_deactivate_account_handler()
@interactive_auth_handler @interactive_auth_handler
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
erase = body.get("erase", False) erase = body.get("erase", False)
if not isinstance(erase, bool): if not isinstance(erase, bool):
@ -338,7 +342,7 @@ class DeactivateAccountRestServlet(RestServlet):
class EmailThreepidRequestTokenRestServlet(RestServlet): class EmailThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/email/requestToken$") PATTERNS = client_patterns("/account/3pid/email/requestToken$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.config = hs.config self.config = hs.config
@ -353,7 +357,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
template_text=self.config.email_add_threepid_template_text, template_text=self.config.email_add_threepid_template_text,
) )
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config: if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning( logger.warning(
@ -449,7 +453,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.identity_handler = hs.get_identity_handler() self.identity_handler = hs.get_identity_handler()
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict( assert_params_in_dict(
body, ["client_secret", "country", "phone_number", "send_attempt"] body, ["client_secret", "country", "phone_number", "send_attempt"]
@ -525,11 +529,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
"/add_threepid/email/submit_token$", releases=(), unstable=True "/add_threepid/email/submit_token$", releases=(), unstable=True
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.config = hs.config self.config = hs.config
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -539,7 +539,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
self.config.email_add_threepid_template_failure_html self.config.email_add_threepid_template_failure_html
) )
async def on_GET(self, request): async def on_GET(self, request: Request) -> None:
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config: if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning( logger.warning(
@ -596,18 +596,14 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
"/add_threepid/msisdn/submit_token$", releases=(), unstable=True "/add_threepid/msisdn/submit_token$", releases=(), unstable=True
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.config = hs.config self.config = hs.config
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.identity_handler = hs.get_identity_handler() self.identity_handler = hs.get_identity_handler()
async def on_POST(self, request): async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
if not self.config.account_threepid_delegate_msisdn: if not self.config.account_threepid_delegate_msisdn:
raise SynapseError( raise SynapseError(
400, 400,
@ -632,7 +628,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
class ThreepidRestServlet(RestServlet): class ThreepidRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid$") PATTERNS = client_patterns("/account/3pid$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.identity_handler = hs.get_identity_handler() self.identity_handler = hs.get_identity_handler()
@ -640,14 +636,14 @@ class ThreepidRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore() self.datastore = self.hs.get_datastore()
async def on_GET(self, request): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
threepids = await self.datastore.user_get_threepids(requester.user.to_string()) threepids = await self.datastore.user_get_threepids(requester.user.to_string())
return 200, {"threepids": threepids} return 200, {"threepids": threepids}
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.enable_3pid_changes: if not self.hs.config.enable_3pid_changes:
raise SynapseError( raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN 400, "3PID changes are disabled on this server", Codes.FORBIDDEN
@ -688,7 +684,7 @@ class ThreepidRestServlet(RestServlet):
class ThreepidAddRestServlet(RestServlet): class ThreepidAddRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/add$") PATTERNS = client_patterns("/account/3pid/add$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.identity_handler = hs.get_identity_handler() self.identity_handler = hs.get_identity_handler()
@ -696,7 +692,7 @@ class ThreepidAddRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler @interactive_auth_handler
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.enable_3pid_changes: if not self.hs.config.enable_3pid_changes:
raise SynapseError( raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN 400, "3PID changes are disabled on this server", Codes.FORBIDDEN
@ -738,13 +734,13 @@ class ThreepidAddRestServlet(RestServlet):
class ThreepidBindRestServlet(RestServlet): class ThreepidBindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/bind$") PATTERNS = client_patterns("/account/3pid/bind$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.identity_handler = hs.get_identity_handler() self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["id_server", "sid", "client_secret"]) assert_params_in_dict(body, ["id_server", "sid", "client_secret"])
@ -767,14 +763,14 @@ class ThreepidBindRestServlet(RestServlet):
class ThreepidUnbindRestServlet(RestServlet): class ThreepidUnbindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/unbind$") PATTERNS = client_patterns("/account/3pid/unbind$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.identity_handler = hs.get_identity_handler() self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.datastore = self.hs.get_datastore() self.datastore = self.hs.get_datastore()
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""Unbind the given 3pid from a specific identity server, or identity servers that are """Unbind the given 3pid from a specific identity server, or identity servers that are
known to have this 3pid bound known to have this 3pid bound
""" """
@ -798,13 +794,13 @@ class ThreepidUnbindRestServlet(RestServlet):
class ThreepidDeleteRestServlet(RestServlet): class ThreepidDeleteRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/delete$") PATTERNS = client_patterns("/account/3pid/delete$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.enable_3pid_changes: if not self.hs.config.enable_3pid_changes:
raise SynapseError( raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN 400, "3PID changes are disabled on this server", Codes.FORBIDDEN
@ -835,7 +831,7 @@ class ThreepidDeleteRestServlet(RestServlet):
return 200, {"id_server_unbind_result": id_server_unbind_result} return 200, {"id_server_unbind_result": id_server_unbind_result}
def assert_valid_next_link(hs: "HomeServer", next_link: str): def assert_valid_next_link(hs: "HomeServer", next_link: str) -> None:
""" """
Raises a SynapseError if a given next_link value is invalid Raises a SynapseError if a given next_link value is invalid
@ -877,11 +873,11 @@ def assert_valid_next_link(hs: "HomeServer", next_link: str):
class WhoamiRestServlet(RestServlet): class WhoamiRestServlet(RestServlet):
PATTERNS = client_patterns("/account/whoami$") PATTERNS = client_patterns("/account/whoami$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
response = {"user_id": requester.user.to_string()} response = {"user_id": requester.user.to_string()}
@ -894,7 +890,7 @@ class WhoamiRestServlet(RestServlet):
return 200, response return 200, response
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
EmailPasswordRequestTokenRestServlet(hs).register(http_server) EmailPasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server)

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple
from twisted.web.server import Request from twisted.web.server import Request
@ -96,7 +96,9 @@ class KnockRoomAliasServlet(RestServlet):
return 200, {"room_id": room_id} return 200, {"room_id": room_id}
def on_PUT(self, request: Request, room_identifier: str, txn_id: str): def on_PUT(
self, request: Request, room_identifier: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id) set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request( return self.txns.fetch_or_execute_request(

View File

@ -375,11 +375,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
unstable=True, unstable=True,
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -390,7 +386,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count, burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
) )
async def on_GET(self, request): async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
await self.ratelimiter.ratelimit(None, (request.getClientIP(),)) await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
if not self.hs.config.enable_registration: if not self.hs.config.enable_registration:
@ -730,7 +726,11 @@ class RegisterRestServlet(RestServlet):
return 200, return_dict return 200, return_dict
async def _do_appservice_registration( async def _do_appservice_registration(
self, username, as_token, body, should_issue_refresh_token: bool = False self,
username: str,
as_token: str,
body: JsonDict,
should_issue_refresh_token: bool = False,
) -> JsonDict: ) -> JsonDict:
user_id = await self.registration_handler.appservice_register( user_id = await self.registration_handler.appservice_register(
username, as_token username, as_token

View File

@ -14,26 +14,35 @@
import logging import logging
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ReportEventRestServlet(RestServlet): class ReportEventRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$") PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_POST(self, request, room_id, event_id): async def on_POST(
self, request: SynapseRequest, room_id: str, event_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
@ -64,5 +73,5 @@ class ReportEventRestServlet(RestServlet):
return 200, {} return 200, {}
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReportEventRestServlet(hs).register(http_server) ReportEventRestServlet(hs).register(http_server)

View File

@ -14,10 +14,14 @@
import logging import logging
import re import re
from typing import TYPE_CHECKING, Awaitable, List, Tuple
from twisted.web.server import Request
from synapse.api.constants import EventContentFields, EventTypes from synapse.api.constants import EventContentFields, EventTypes
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.http.server import HttpServer
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_dict, assert_params_in_dict,
@ -25,10 +29,14 @@ from synapse.http.servlet import (
parse_string, parse_string,
parse_strings_from_args, parse_strings_from_args,
) )
from synapse.http.site import SynapseRequest
from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import Requester, UserID, create_requester from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -66,7 +74,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
), ),
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -76,7 +84,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs) self.txns = HttpTransactionCache(hs)
async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int: async def _inherit_depth_from_prev_ids(self, prev_event_ids: List[str]) -> int:
( (
most_recent_prev_event_id, most_recent_prev_event_id,
most_recent_prev_event_depth, most_recent_prev_event_depth,
@ -118,7 +126,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
def _create_insertion_event_dict( def _create_insertion_event_dict(
self, sender: str, room_id: str, origin_server_ts: int self, sender: str, room_id: str, origin_server_ts: int
): ) -> JsonDict:
"""Creates an event dict for an "insertion" event with the proper fields """Creates an event dict for an "insertion" event with the proper fields
and a random chunk ID. and a random chunk ID.
@ -128,7 +136,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
origin_server_ts: Timestamp when the event was sent origin_server_ts: Timestamp when the event was sent
Returns: Returns:
Tuple of event ID and stream ordering position The new event dictionary to insert.
""" """
next_chunk_id = random_string(8) next_chunk_id = random_string(8)
@ -164,7 +172,9 @@ class RoomBatchSendEventRestServlet(RestServlet):
return create_requester(user_id, app_service=app_service) return create_requester(user_id, app_service=app_service)
async def on_POST(self, request, room_id): async def on_POST(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
if not requester.app_service: if not requester.app_service:
@ -176,6 +186,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["state_events_at_start", "events"]) assert_params_in_dict(body, ["state_events_at_start", "events"])
assert request.args is not None
prev_events_from_query = parse_strings_from_args(request.args, "prev_event") prev_events_from_query = parse_strings_from_args(request.args, "prev_event")
chunk_id_from_query = parse_string(request, "chunk_id") chunk_id_from_query = parse_string(request, "chunk_id")
@ -425,16 +436,18 @@ class RoomBatchSendEventRestServlet(RestServlet):
], ],
} }
def on_GET(self, request, room_id): def on_GET(self, request: Request, room_id: str) -> Tuple[int, str]:
return 501, "Not implemented" return 501, "Not implemented"
def on_PUT(self, request, room_id): def on_PUT(
self, request: SynapseRequest, room_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
return self.txns.fetch_or_execute_request( return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id request, self.on_POST, request, room_id
) )
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
msc2716_enabled = hs.config.experimental.msc2716_enabled msc2716_enabled = hs.config.experimental.msc2716_enabled
if msc2716_enabled: if msc2716_enabled:

View File

@ -13,16 +13,23 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Optional, Tuple
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,16 +38,14 @@ class RoomKeysServlet(RestServlet):
"/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$" "/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
async def on_PUT(self, request, room_id, session_id): async def on_PUT(
self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str]
) -> Tuple[int, JsonDict]:
""" """
Uploads one or more encrypted E2E room keys for backup purposes. Uploads one or more encrypted E2E room keys for backup purposes.
room_id: the ID of the room the keys are for (optional) room_id: the ID of the room the keys are for (optional)
@ -133,7 +138,9 @@ class RoomKeysServlet(RestServlet):
ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body)
return 200, ret return 200, ret
async def on_GET(self, request, room_id, session_id): async def on_GET(
self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str]
) -> Tuple[int, JsonDict]:
""" """
Retrieves one or more encrypted E2E room keys for backup purposes. Retrieves one or more encrypted E2E room keys for backup purposes.
Symmetric with the PUT version of the API. Symmetric with the PUT version of the API.
@ -215,7 +222,9 @@ class RoomKeysServlet(RestServlet):
return 200, room_keys return 200, room_keys
async def on_DELETE(self, request, room_id, session_id): async def on_DELETE(
self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str]
) -> Tuple[int, JsonDict]:
""" """
Deletes one or more encrypted E2E room keys for a user for backup purposes. Deletes one or more encrypted E2E room keys for a user for backup purposes.
@ -242,16 +251,12 @@ class RoomKeysServlet(RestServlet):
class RoomKeysNewVersionServlet(RestServlet): class RoomKeysNewVersionServlet(RestServlet):
PATTERNS = client_patterns("/room_keys/version$") PATTERNS = client_patterns("/room_keys/version$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
""" """
Create a new backup version for this user's room_keys with the given Create a new backup version for this user's room_keys with the given
info. The version is allocated by the server and returned to the user info. The version is allocated by the server and returned to the user
@ -295,16 +300,14 @@ class RoomKeysNewVersionServlet(RestServlet):
class RoomKeysVersionServlet(RestServlet): class RoomKeysVersionServlet(RestServlet):
PATTERNS = client_patterns("/room_keys/version(/(?P<version>[^/]+))?$") PATTERNS = client_patterns("/room_keys/version(/(?P<version>[^/]+))?$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
async def on_GET(self, request, version): async def on_GET(
self, request: SynapseRequest, version: Optional[str]
) -> Tuple[int, JsonDict]:
""" """
Retrieve the version information about a given version of the user's Retrieve the version information about a given version of the user's
room_keys backup. If the version part is missing, returns info about the room_keys backup. If the version part is missing, returns info about the
@ -332,7 +335,9 @@ class RoomKeysVersionServlet(RestServlet):
raise SynapseError(404, "No backup found", Codes.NOT_FOUND) raise SynapseError(404, "No backup found", Codes.NOT_FOUND)
return 200, info return 200, info
async def on_DELETE(self, request, version): async def on_DELETE(
self, request: SynapseRequest, version: Optional[str]
) -> Tuple[int, JsonDict]:
""" """
Delete the information about a given version of the user's Delete the information about a given version of the user's
room_keys backup. If the version part is missing, deletes the most room_keys backup. If the version part is missing, deletes the most
@ -351,7 +356,9 @@ class RoomKeysVersionServlet(RestServlet):
await self.e2e_room_keys_handler.delete_version(user_id, version) await self.e2e_room_keys_handler.delete_version(user_id, version)
return 200, {} return 200, {}
async def on_PUT(self, request, version): async def on_PUT(
self, request: SynapseRequest, version: Optional[str]
) -> Tuple[int, JsonDict]:
""" """
Update the information about a given version of the user's room_keys backup. Update the information about a given version of the user's room_keys backup.
@ -385,7 +392,7 @@ class RoomKeysVersionServlet(RestServlet):
return 200, {} return 200, {}
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomKeysServlet(hs).register(http_server) RoomKeysServlet(hs).register(http_server)
RoomKeysVersionServlet(hs).register(http_server) RoomKeysVersionServlet(hs).register(http_server)
RoomKeysNewVersionServlet(hs).register(http_server) RoomKeysNewVersionServlet(hs).register(http_server)

View File

@ -13,15 +13,21 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Tuple from typing import TYPE_CHECKING, Awaitable, Tuple
from synapse.http import servlet from synapse.http import servlet
from synapse.http.server import HttpServer
from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import set_tag, trace from synapse.logging.opentracing import set_tag, trace
from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import JsonDict
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,11 +36,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$" "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -42,14 +44,18 @@ class SendToDeviceRestServlet(servlet.RestServlet):
self.device_message_handler = hs.get_device_message_handler() self.device_message_handler = hs.get_device_message_handler()
@trace(opname="sendToDevice") @trace(opname="sendToDevice")
def on_PUT(self, request, message_type, txn_id): def on_PUT(
self, request: SynapseRequest, message_type: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("message_type", message_type) set_tag("message_type", message_type)
set_tag("txn_id", txn_id) set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request( return self.txns.fetch_or_execute_request(
request, self._put, request, message_type, txn_id request, self._put, request, message_type, txn_id
) )
async def _put(self, request, message_type, txn_id): async def _put(
self, request: SynapseRequest, message_type: str, txn_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -59,9 +65,8 @@ class SendToDeviceRestServlet(servlet.RestServlet):
requester, message_type, content["messages"] requester, message_type, content["messages"]
) )
response: Tuple[int, dict] = (200, {}) return 200, {}
return response
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SendToDeviceRestServlet(hs).register(http_server) SendToDeviceRestServlet(hs).register(http_server)

View File

@ -14,12 +14,24 @@
import itertools import itertools
import logging import logging
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
)
from synapse.api.constants import Membership, PresenceState from synapse.api.constants import Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.api.presence import UserPresenceState from synapse.api.presence import UserPresenceState
from synapse.events import EventBase
from synapse.events.utils import ( from synapse.events.utils import (
format_event_for_client_v2_without_room_id, format_event_for_client_v2_without_room_id,
format_event_raw, format_event_raw,
@ -504,7 +516,7 @@ class SyncRestServlet(RestServlet):
The room, encoded in our response format The room, encoded in our response format
""" """
def serialize(events): def serialize(events: Iterable[EventBase]) -> Awaitable[List[JsonDict]]:
return self._event_serializer.serialize_events( return self._event_serializer.serialize_events(
events, events,
time_now=time_now, time_now=time_now,