Asynchronous Uploads (#15503)

Support asynchronous uploads as defined in MSC2246.
This commit is contained in:
Sumner Evans 2023-11-15 07:19:24 -07:00 committed by GitHub
parent 80922dc46e
commit 999bd77d3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 568 additions and 59 deletions

View File

@ -0,0 +1 @@
Add support for asynchronous uploads as defined by [MSC2246](https://github.com/matrix-org/matrix-spec-proposals/pull/2246). Contributed by @sumnerevans at @beeper.

View File

@ -1753,6 +1753,19 @@ rc_third_party_invite:
burst_count: 10 burst_count: 10
``` ```
--- ---
### `rc_media_create`
This option ratelimits creation of MXC URIs via the `/_matrix/media/v1/create`
endpoint based on the account that's creating the media. Defaults to
`per_second: 10`, `burst_count: 50`.
Example configuration:
```yaml
rc_media_create:
per_second: 10
burst_count: 50
```
---
### `rc_federation` ### `rc_federation`
Defines limits on federation requests. Defines limits on federation requests.
@ -1814,6 +1827,27 @@ Example configuration:
media_store_path: "DATADIR/media_store" media_store_path: "DATADIR/media_store"
``` ```
--- ---
### `max_pending_media_uploads`
How many *pending media uploads* can a given user have? A pending media upload
is a created MXC URI that (a) is not expired (the `unused_expires_at` timestamp
has not passed) and (b) the media has not yet been uploaded for. Defaults to 5.
Example configuration:
```yaml
max_pending_media_uploads: 5
```
---
### `unused_expiration_time`
How long to wait in milliseconds before expiring created media IDs. Defaults to
"24h"
Example configuration:
```yaml
unused_expiration_time: "1h"
```
---
### `media_storage_providers` ### `media_storage_providers`
Media storage providers allow media to be stored in different Media storage providers allow media to be stored in different

View File

@ -83,6 +83,8 @@ class Codes(str, Enum):
USER_DEACTIVATED = "M_USER_DEACTIVATED" USER_DEACTIVATED = "M_USER_DEACTIVATED"
# USER_LOCKED = "M_USER_LOCKED" # USER_LOCKED = "M_USER_LOCKED"
USER_LOCKED = "ORG_MATRIX_MSC3939_USER_LOCKED" USER_LOCKED = "ORG_MATRIX_MSC3939_USER_LOCKED"
NOT_YET_UPLOADED = "M_NOT_YET_UPLOADED"
CANNOT_OVERWRITE_MEDIA = "M_CANNOT_OVERWRITE_MEDIA"
# Part of MSC3848 # Part of MSC3848
# https://github.com/matrix-org/matrix-spec-proposals/pull/3848 # https://github.com/matrix-org/matrix-spec-proposals/pull/3848

View File

@ -204,3 +204,10 @@ class RatelimitConfig(Config):
"rc_third_party_invite", "rc_third_party_invite",
defaults={"per_second": 0.0025, "burst_count": 5}, defaults={"per_second": 0.0025, "burst_count": 5},
) )
# Ratelimit create media requests:
self.rc_media_create = RatelimitSettings.parse(
config,
"rc_media_create",
defaults={"per_second": 10, "burst_count": 50},
)

View File

@ -141,6 +141,12 @@ class ContentRepositoryConfig(Config):
"prevent_media_downloads_from", [] "prevent_media_downloads_from", []
) )
self.unused_expiration_time = self.parse_duration(
config.get("unused_expiration_time", "24h")
)
self.max_pending_media_uploads = config.get("max_pending_media_uploads", 5)
self.media_store_path = self.ensure_directory( self.media_store_path = self.ensure_directory(
config.get("media_store_path", "media_store") config.get("media_store_path", "media_store")
) )

View File

@ -83,6 +83,12 @@ INLINE_CONTENT_TYPES = [
"audio/x-flac", "audio/x-flac",
] ]
# Default timeout_ms for download and thumbnail requests
DEFAULT_MAX_TIMEOUT_MS = 20_000
# Maximum allowed timeout_ms for download and thumbnail requests
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS = 60_000
def respond_404(request: SynapseRequest) -> None: def respond_404(request: SynapseRequest) -> None:
assert request.path is not None assert request.path is not None

View File

@ -27,13 +27,16 @@ import twisted.web.http
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from synapse.api.errors import ( from synapse.api.errors import (
Codes,
FederationDeniedError, FederationDeniedError,
HttpResponseException, HttpResponseException,
NotFoundError, NotFoundError,
RequestSendFailed, RequestSendFailed,
SynapseError, SynapseError,
cs_error,
) )
from synapse.config.repository import ThumbnailRequirement from synapse.config.repository import ThumbnailRequirement
from synapse.http.server import respond_with_json
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread from synapse.logging.context import defer_to_thread
from synapse.logging.opentracing import trace from synapse.logging.opentracing import trace
@ -51,7 +54,7 @@ from synapse.media.storage_provider import StorageProviderWrapper
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
from synapse.media.url_previewer import UrlPreviewer from synapse.media.url_previewer import UrlPreviewer
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main.media_repository import RemoteMedia from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
from synapse.types import UserID from synapse.types import UserID
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
@ -80,6 +83,8 @@ class MediaRepository:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.max_upload_size = hs.config.media.max_upload_size self.max_upload_size = hs.config.media.max_upload_size
self.max_image_pixels = hs.config.media.max_image_pixels self.max_image_pixels = hs.config.media.max_image_pixels
self.unused_expiration_time = hs.config.media.unused_expiration_time
self.max_pending_media_uploads = hs.config.media.max_pending_media_uploads
Thumbnailer.set_limits(self.max_image_pixels) Thumbnailer.set_limits(self.max_image_pixels)
@ -185,6 +190,117 @@ class MediaRepository:
else: else:
self.recently_accessed_locals.add(media_id) self.recently_accessed_locals.add(media_id)
@trace
async def create_media_id(self, auth_user: UserID) -> Tuple[str, int]:
"""Create and store a media ID for a local user and return the MXC URI and its
expiration.
Args:
auth_user: The user_id of the uploader
Returns:
A tuple containing the MXC URI of the stored content and the timestamp at
which the MXC URI expires.
"""
media_id = random_string(24)
now = self.clock.time_msec()
await self.store.store_local_media_id(
media_id=media_id,
time_now_ms=now,
user_id=auth_user,
)
return f"mxc://{self.server_name}/{media_id}", now + self.unused_expiration_time
@trace
async def reached_pending_media_limit(self, auth_user: UserID) -> Tuple[bool, int]:
"""Check if the user is over the limit for pending media uploads.
Args:
auth_user: The user_id of the uploader
Returns:
A tuple with a boolean and an integer indicating whether the user has too
many pending media uploads and the timestamp at which the first pending
media will expire, respectively.
"""
pending, first_expiration_ts = await self.store.count_pending_media(
user_id=auth_user
)
return pending >= self.max_pending_media_uploads, first_expiration_ts
@trace
async def verify_can_upload(self, media_id: str, auth_user: UserID) -> None:
"""Verify that the media ID can be uploaded to by the given user. This
function checks that:
* the media ID exists
* the media ID does not already have content
* the user uploading is the same as the one who created the media ID
* the media ID has not expired
Args:
media_id: The media ID to verify
auth_user: The user_id of the uploader
"""
media = await self.store.get_local_media(media_id)
if media is None:
raise SynapseError(404, "Unknow media ID", errcode=Codes.NOT_FOUND)
if media.user_id != auth_user.to_string():
raise SynapseError(
403,
"Only the creator of the media ID can upload to it",
errcode=Codes.FORBIDDEN,
)
if media.media_length is not None:
raise SynapseError(
409,
"Media ID already has content",
errcode=Codes.CANNOT_OVERWRITE_MEDIA,
)
expired_time_ms = self.clock.time_msec() - self.unused_expiration_time
if media.created_ts < expired_time_ms:
raise NotFoundError("Media ID has expired")
@trace
async def update_content(
self,
media_id: str,
media_type: str,
upload_name: Optional[str],
content: IO,
content_length: int,
auth_user: UserID,
) -> None:
"""Update the content of the given media ID.
Args:
media_id: The media ID to replace.
media_type: The content type of the file.
upload_name: The name of the file, if provided.
content: A file like object that is the content to store
content_length: The length of the content
auth_user: The user_id of the uploader
"""
file_info = FileInfo(server_name=None, file_id=media_id)
fname = await self.media_storage.store_file(content, file_info)
logger.info("Stored local media in file %r", fname)
await self.store.update_local_media(
media_id=media_id,
media_type=media_type,
upload_name=upload_name,
media_length=content_length,
user_id=auth_user,
)
try:
await self._generate_thumbnails(None, media_id, media_id, media_type)
except Exception as e:
logger.info("Failed to generate thumbnails: %s", e)
@trace @trace
async def create_content( async def create_content(
self, self,
@ -231,8 +347,74 @@ class MediaRepository:
return MXCUri(self.server_name, media_id) return MXCUri(self.server_name, media_id)
def respond_not_yet_uploaded(self, request: SynapseRequest) -> None:
respond_with_json(
request,
504,
cs_error("Media has not been uploaded yet", code=Codes.NOT_YET_UPLOADED),
send_cors=True,
)
async def get_local_media_info(
self, request: SynapseRequest, media_id: str, max_timeout_ms: int
) -> Optional[LocalMedia]:
"""Gets the info dictionary for given local media ID. If the media has
not been uploaded yet, this function will wait up to ``max_timeout_ms``
milliseconds for the media to be uploaded.
Args:
request: The incoming request.
media_id: The media ID of the content. (This is the same as
the file_id for local content.)
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
Returns:
Either the info dictionary for the given local media ID or
``None``. If ``None``, then no further processing is necessary as
this function will send the necessary JSON response.
"""
wait_until = self.clock.time_msec() + max_timeout_ms
while True:
# Get the info for the media
media_info = await self.store.get_local_media(media_id)
if not media_info:
logger.info("Media %s is unknown", media_id)
respond_404(request)
return None
if media_info.quarantined_by:
logger.info("Media %s is quarantined", media_id)
respond_404(request)
return None
# The file has been uploaded, so stop looping
if media_info.media_length is not None:
return media_info
# Check if the media ID has expired and still hasn't been uploaded to.
now = self.clock.time_msec()
expired_time_ms = now - self.unused_expiration_time
if media_info.created_ts < expired_time_ms:
logger.info("Media %s has expired without being uploaded", media_id)
respond_404(request)
return None
if now >= wait_until:
break
await self.clock.sleep(0.5)
logger.info("Media %s has not yet been uploaded", media_id)
self.respond_not_yet_uploaded(request)
return None
async def get_local_media( async def get_local_media(
self, request: SynapseRequest, media_id: str, name: Optional[str] self,
request: SynapseRequest,
media_id: str,
name: Optional[str],
max_timeout_ms: int,
) -> None: ) -> None:
"""Responds to requests for local media, if exists, or returns 404. """Responds to requests for local media, if exists, or returns 404.
@ -242,13 +424,14 @@ class MediaRepository:
the file_id for local content.) the file_id for local content.)
name: Optional name that, if specified, will be used as name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response. the filename in the Content-Disposition header of the response.
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
Returns: Returns:
Resolves once a response has successfully been written to request Resolves once a response has successfully been written to request
""" """
media_info = await self.store.get_local_media(media_id) media_info = await self.get_local_media_info(request, media_id, max_timeout_ms)
if not media_info or media_info.quarantined_by: if not media_info:
respond_404(request)
return return
self.mark_recently_accessed(None, media_id) self.mark_recently_accessed(None, media_id)
@ -273,6 +456,7 @@ class MediaRepository:
server_name: str, server_name: str,
media_id: str, media_id: str,
name: Optional[str], name: Optional[str],
max_timeout_ms: int,
) -> None: ) -> None:
"""Respond to requests for remote media. """Respond to requests for remote media.
@ -282,6 +466,8 @@ class MediaRepository:
media_id: The media ID of the content (as defined by the remote server). media_id: The media ID of the content (as defined by the remote server).
name: Optional name that, if specified, will be used as name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response. the filename in the Content-Disposition header of the response.
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
Returns: Returns:
Resolves once a response has successfully been written to request Resolves once a response has successfully been written to request
@ -307,11 +493,11 @@ class MediaRepository:
key = (server_name, media_id) key = (server_name, media_id)
async with self.remote_media_linearizer.queue(key): async with self.remote_media_linearizer.queue(key):
responder, media_info = await self._get_remote_media_impl( responder, media_info = await self._get_remote_media_impl(
server_name, media_id server_name, media_id, max_timeout_ms
) )
# We deliberately stream the file outside the lock # We deliberately stream the file outside the lock
if responder: if responder and media_info:
upload_name = name if name else media_info.upload_name upload_name = name if name else media_info.upload_name
await respond_with_responder( await respond_with_responder(
request, request,
@ -324,7 +510,7 @@ class MediaRepository:
respond_404(request) respond_404(request)
async def get_remote_media_info( async def get_remote_media_info(
self, server_name: str, media_id: str self, server_name: str, media_id: str, max_timeout_ms: int
) -> RemoteMedia: ) -> RemoteMedia:
"""Gets the media info associated with the remote file, downloading """Gets the media info associated with the remote file, downloading
if necessary. if necessary.
@ -332,6 +518,8 @@ class MediaRepository:
Args: Args:
server_name: Remote server_name where the media originated. server_name: Remote server_name where the media originated.
media_id: The media ID of the content (as defined by the remote server). media_id: The media ID of the content (as defined by the remote server).
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
Returns: Returns:
The media info of the file The media info of the file
@ -347,7 +535,7 @@ class MediaRepository:
key = (server_name, media_id) key = (server_name, media_id)
async with self.remote_media_linearizer.queue(key): async with self.remote_media_linearizer.queue(key):
responder, media_info = await self._get_remote_media_impl( responder, media_info = await self._get_remote_media_impl(
server_name, media_id server_name, media_id, max_timeout_ms
) )
# Ensure we actually use the responder so that it releases resources # Ensure we actually use the responder so that it releases resources
@ -358,7 +546,7 @@ class MediaRepository:
return media_info return media_info
async def _get_remote_media_impl( async def _get_remote_media_impl(
self, server_name: str, media_id: str self, server_name: str, media_id: str, max_timeout_ms: int
) -> Tuple[Optional[Responder], RemoteMedia]: ) -> Tuple[Optional[Responder], RemoteMedia]:
"""Looks for media in local cache, if not there then attempt to """Looks for media in local cache, if not there then attempt to
download from remote server. download from remote server.
@ -367,6 +555,8 @@ class MediaRepository:
server_name: Remote server_name where the media originated. server_name: Remote server_name where the media originated.
media_id: The media ID of the content (as defined by the media_id: The media ID of the content (as defined by the
remote server). remote server).
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
Returns: Returns:
A tuple of responder and the media info of the file. A tuple of responder and the media info of the file.
@ -399,8 +589,7 @@ class MediaRepository:
try: try:
media_info = await self._download_remote_file( media_info = await self._download_remote_file(
server_name, server_name, media_id, max_timeout_ms
media_id,
) )
except SynapseError: except SynapseError:
raise raise
@ -433,6 +622,7 @@ class MediaRepository:
self, self,
server_name: str, server_name: str,
media_id: str, media_id: str,
max_timeout_ms: int,
) -> RemoteMedia: ) -> RemoteMedia:
"""Attempt to download the remote file from the given server name, """Attempt to download the remote file from the given server name,
using the given file_id as the local id. using the given file_id as the local id.
@ -442,7 +632,8 @@ class MediaRepository:
media_id: The media ID of the content (as defined by the media_id: The media ID of the content (as defined by the
remote server). This is different than the file_id, which is remote server). This is different than the file_id, which is
locally generated. locally generated.
file_id: Local file ID max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
Returns: Returns:
The media info of the file. The media info of the file.
@ -466,7 +657,8 @@ class MediaRepository:
# tell the remote server to 404 if it doesn't # tell the remote server to 404 if it doesn't
# recognise the server_name, to make sure we don't # recognise the server_name, to make sure we don't
# end up with a routing loop. # end up with a routing loop.
"allow_remote": "false" "allow_remote": "false",
"timeout_ms": str(max_timeout_ms),
}, },
) )
except RequestSendFailed as e: except RequestSendFailed as e:

View File

@ -0,0 +1,83 @@
# Copyright 2023 Beeper Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from typing import TYPE_CHECKING
from synapse.api.errors import LimitExceededError
from synapse.api.ratelimiting import Ratelimiter
from synapse.http.server import respond_with_json
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
if TYPE_CHECKING:
from synapse.media.media_repository import MediaRepository
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class CreateResource(RestServlet):
PATTERNS = [re.compile("/_matrix/media/v1/create")]
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
self.media_repo = media_repo
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.max_pending_media_uploads = hs.config.media.max_pending_media_uploads
# A rate limiter for creating new media IDs.
self._create_media_rate_limiter = Ratelimiter(
store=hs.get_datastores().main,
clock=self.clock,
cfg=hs.config.ratelimiting.rc_media_create,
)
async def on_POST(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request)
# If the create media requests for the user are over the limit, drop them.
await self._create_media_rate_limiter.ratelimit(requester)
(
reached_pending_limit,
first_expiration_ts,
) = await self.media_repo.reached_pending_media_limit(requester.user)
if reached_pending_limit:
raise LimitExceededError(
limiter_name="max_pending_media_uploads",
retry_after_ms=first_expiration_ts - self.clock.time_msec(),
)
content_uri, unused_expires_at = await self.media_repo.create_media_id(
requester.user
)
logger.info(
"Created Media URI %r that if unused will expire at %d",
content_uri,
unused_expires_at,
)
respond_with_json(
request,
200,
{
"content_uri": content_uri,
"unused_expires_at": unused_expires_at,
},
send_cors=True,
)

View File

@ -17,9 +17,13 @@ import re
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from synapse.http.server import set_corp_headers, set_cors_headers from synapse.http.server import set_corp_headers, set_cors_headers
from synapse.http.servlet import RestServlet, parse_boolean from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.media._base import respond_404 from synapse.media._base import (
DEFAULT_MAX_TIMEOUT_MS,
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
respond_404,
)
from synapse.util.stringutils import parse_and_validate_server_name from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING: if TYPE_CHECKING:
@ -65,12 +69,16 @@ class DownloadResource(RestServlet):
) )
# Limited non-standard form of CSP for IE11 # Limited non-standard form of CSP for IE11
request.setHeader(b"X-Content-Security-Policy", b"sandbox;") request.setHeader(b"X-Content-Security-Policy", b"sandbox;")
request.setHeader( request.setHeader(b"Referrer-Policy", b"no-referrer")
b"Referrer-Policy", max_timeout_ms = parse_integer(
b"no-referrer", request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
) )
max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
if self._is_mine_server_name(server_name): if self._is_mine_server_name(server_name):
await self.media_repo.get_local_media(request, media_id, file_name) await self.media_repo.get_local_media(
request, media_id, file_name, max_timeout_ms
)
else: else:
allow_remote = parse_boolean(request, "allow_remote", default=True) allow_remote = parse_boolean(request, "allow_remote", default=True)
if not allow_remote: if not allow_remote:
@ -83,5 +91,5 @@ class DownloadResource(RestServlet):
return return
await self.media_repo.get_remote_media( await self.media_repo.get_remote_media(
request, server_name, media_id, file_name request, server_name, media_id, file_name, max_timeout_ms
) )

View File

@ -18,10 +18,11 @@ from synapse.config._base import ConfigError
from synapse.http.server import HttpServer, JsonResource from synapse.http.server import HttpServer, JsonResource
from .config_resource import MediaConfigResource from .config_resource import MediaConfigResource
from .create_resource import CreateResource
from .download_resource import DownloadResource from .download_resource import DownloadResource
from .preview_url_resource import PreviewUrlResource from .preview_url_resource import PreviewUrlResource
from .thumbnail_resource import ThumbnailResource from .thumbnail_resource import ThumbnailResource
from .upload_resource import UploadResource from .upload_resource import AsyncUploadServlet, UploadServlet
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -91,8 +92,9 @@ class MediaRepositoryResource(JsonResource):
# Note that many of these should not exist as v1 endpoints, but empirically # Note that many of these should not exist as v1 endpoints, but empirically
# a lot of traffic still goes to them. # a lot of traffic still goes to them.
CreateResource(hs, media_repo).register(http_server)
UploadResource(hs, media_repo).register(http_server) UploadServlet(hs, media_repo).register(http_server)
AsyncUploadServlet(hs, media_repo).register(http_server)
DownloadResource(hs, media_repo).register(http_server) DownloadResource(hs, media_repo).register(http_server)
ThumbnailResource(hs, media_repo, media_repo.media_storage).register( ThumbnailResource(hs, media_repo, media_repo.media_storage).register(
http_server http_server

View File

@ -23,6 +23,8 @@ from synapse.http.server import respond_with_json, set_corp_headers, set_cors_he
from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.media._base import ( from synapse.media._base import (
DEFAULT_MAX_TIMEOUT_MS,
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
FileInfo, FileInfo,
ThumbnailInfo, ThumbnailInfo,
respond_404, respond_404,
@ -75,15 +77,19 @@ class ThumbnailResource(RestServlet):
method = parse_string(request, "method", "scale") method = parse_string(request, "method", "scale")
# TODO Parse the Accept header to get an prioritised list of thumbnail types. # TODO Parse the Accept header to get an prioritised list of thumbnail types.
m_type = "image/png" m_type = "image/png"
max_timeout_ms = parse_integer(
request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
)
max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
if self._is_mine_server_name(server_name): if self._is_mine_server_name(server_name):
if self.dynamic_thumbnails: if self.dynamic_thumbnails:
await self._select_or_generate_local_thumbnail( await self._select_or_generate_local_thumbnail(
request, media_id, width, height, method, m_type request, media_id, width, height, method, m_type, max_timeout_ms
) )
else: else:
await self._respond_local_thumbnail( await self._respond_local_thumbnail(
request, media_id, width, height, method, m_type request, media_id, width, height, method, m_type, max_timeout_ms
) )
self.media_repo.mark_recently_accessed(None, media_id) self.media_repo.mark_recently_accessed(None, media_id)
else: else:
@ -95,14 +101,21 @@ class ThumbnailResource(RestServlet):
respond_404(request) respond_404(request)
return return
if self.dynamic_thumbnails: remote_resp_function = (
await self._select_or_generate_remote_thumbnail( self._select_or_generate_remote_thumbnail
request, server_name, media_id, width, height, method, m_type if self.dynamic_thumbnails
) else self._respond_remote_thumbnail
else: )
await self._respond_remote_thumbnail( await remote_resp_function(
request, server_name, media_id, width, height, method, m_type request,
) server_name,
media_id,
width,
height,
method,
m_type,
max_timeout_ms,
)
self.media_repo.mark_recently_accessed(server_name, media_id) self.media_repo.mark_recently_accessed(server_name, media_id)
async def _respond_local_thumbnail( async def _respond_local_thumbnail(
@ -113,15 +126,12 @@ class ThumbnailResource(RestServlet):
height: int, height: int,
method: str, method: str,
m_type: str, m_type: str,
max_timeout_ms: int,
) -> None: ) -> None:
media_info = await self.store.get_local_media(media_id) media_info = await self.media_repo.get_local_media_info(
request, media_id, max_timeout_ms
)
if not media_info: if not media_info:
respond_404(request)
return
if media_info.quarantined_by:
logger.info("Media is quarantined")
respond_404(request)
return return
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
@ -146,15 +156,13 @@ class ThumbnailResource(RestServlet):
desired_height: int, desired_height: int,
desired_method: str, desired_method: str,
desired_type: str, desired_type: str,
max_timeout_ms: int,
) -> None: ) -> None:
media_info = await self.store.get_local_media(media_id) media_info = await self.media_repo.get_local_media_info(
request, media_id, max_timeout_ms
)
if not media_info: if not media_info:
respond_404(request)
return
if media_info.quarantined_by:
logger.info("Media is quarantined")
respond_404(request)
return return
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
@ -206,8 +214,14 @@ class ThumbnailResource(RestServlet):
desired_height: int, desired_height: int,
desired_method: str, desired_method: str,
desired_type: str, desired_type: str,
max_timeout_ms: int,
) -> None: ) -> None:
media_info = await self.media_repo.get_remote_media_info(server_name, media_id) media_info = await self.media_repo.get_remote_media_info(
server_name, media_id, max_timeout_ms
)
if not media_info:
respond_404(request)
return
thumbnail_infos = await self.store.get_remote_media_thumbnails( thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id server_name, media_id
@ -263,11 +277,16 @@ class ThumbnailResource(RestServlet):
height: int, height: int,
method: str, method: str,
m_type: str, m_type: str,
max_timeout_ms: int,
) -> None: ) -> None:
# TODO: Don't download the whole remote file # TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of # We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails. # downloading the remote file and generating our own thumbnails.
media_info = await self.media_repo.get_remote_media_info(server_name, media_id) media_info = await self.media_repo.get_remote_media_info(
server_name, media_id, max_timeout_ms
)
if not media_info:
return
thumbnail_infos = await self.store.get_remote_media_thumbnails( thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id server_name, media_id

View File

@ -15,7 +15,7 @@
import logging import logging
import re import re
from typing import IO, TYPE_CHECKING, Dict, List, Optional from typing import IO, TYPE_CHECKING, Dict, List, Optional, Tuple
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import respond_with_json from synapse.http.server import respond_with_json
@ -29,23 +29,24 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# The name of the lock to use when uploading media.
_UPLOAD_MEDIA_LOCK_NAME = "upload_media"
class UploadResource(RestServlet):
PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/upload")]
class BaseUploadServlet(RestServlet):
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__() super().__init__()
self.media_repo = media_repo self.media_repo = media_repo
self.filepaths = media_repo.filepaths self.filepaths = media_repo.filepaths
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.clock = hs.get_clock() self.server_name = hs.hostname
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.max_upload_size = hs.config.media.max_upload_size self.max_upload_size = hs.config.media.max_upload_size
self.clock = hs.get_clock()
async def on_POST(self, request: SynapseRequest) -> None: def _get_file_metadata(
requester = await self.auth.get_user_by_req(request) self, request: SynapseRequest
) -> Tuple[int, Optional[str], str]:
raw_content_length = request.getHeader("Content-Length") raw_content_length = request.getHeader("Content-Length")
if raw_content_length is None: if raw_content_length is None:
raise SynapseError(msg="Request must specify a Content-Length", code=400) raise SynapseError(msg="Request must specify a Content-Length", code=400)
@ -88,6 +89,16 @@ class UploadResource(RestServlet):
# disposition = headers.getRawHeaders(b"Content-Disposition")[0] # disposition = headers.getRawHeaders(b"Content-Disposition")[0]
# TODO(markjh): parse content-dispostion # TODO(markjh): parse content-dispostion
return content_length, upload_name, media_type
class UploadServlet(BaseUploadServlet):
PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/upload$")]
async def on_POST(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request)
content_length, upload_name, media_type = self._get_file_metadata(request)
try: try:
content: IO = request.content # type: ignore content: IO = request.content # type: ignore
content_uri = await self.media_repo.create_content( content_uri = await self.media_repo.create_content(
@ -103,3 +114,53 @@ class UploadResource(RestServlet):
respond_with_json( respond_with_json(
request, 200, {"content_uri": str(content_uri)}, send_cors=True request, 200, {"content_uri": str(content_uri)}, send_cors=True
) )
class AsyncUploadServlet(BaseUploadServlet):
PATTERNS = [
re.compile(
"/_matrix/media/v3/upload/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
)
]
async def on_PUT(
self, request: SynapseRequest, server_name: str, media_id: str
) -> None:
requester = await self.auth.get_user_by_req(request)
if server_name != self.server_name:
raise SynapseError(
404,
"Non-local server name specified",
errcode=Codes.NOT_FOUND,
)
lock = await self.store.try_acquire_lock(_UPLOAD_MEDIA_LOCK_NAME, media_id)
if not lock:
raise SynapseError(
409,
"Media ID cannot be overwritten",
errcode=Codes.CANNOT_OVERWRITE_MEDIA,
)
async with lock:
await self.media_repo.verify_can_upload(media_id, requester.user)
content_length, upload_name, media_type = self._get_file_metadata(request)
try:
content: IO = request.content # type: ignore
await self.media_repo.update_content(
media_id,
media_type,
upload_name,
content,
content_length,
requester.user,
)
except SpamMediaException:
# For uploading of media we want to respond with a 400, instead of
# the default 404, as that would just be confusing.
raise SynapseError(400, "Bad content")
logger.info("Uploaded content for media ID %r", media_id)
respond_with_json(request, 200, {}, send_cors=True)

View File

@ -49,13 +49,14 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
class LocalMedia: class LocalMedia:
media_id: str media_id: str
media_type: str media_type: str
media_length: int media_length: Optional[int]
upload_name: str upload_name: str
created_ts: int created_ts: int
url_cache: Optional[str] url_cache: Optional[str]
last_access_ts: int last_access_ts: int
quarantined_by: Optional[str] quarantined_by: Optional[str]
safe_from_quarantine: bool safe_from_quarantine: bool
user_id: Optional[str]
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
@ -149,6 +150,13 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
self._drop_media_index_without_method, self._drop_media_index_without_method,
) )
if hs.config.media.can_load_media_repo:
self.unused_expiration_time: Optional[
int
] = hs.config.media.unused_expiration_time
else:
self.unused_expiration_time = None
async def _drop_media_index_without_method( async def _drop_media_index_without_method(
self, progress: JsonDict, batch_size: int self, progress: JsonDict, batch_size: int
) -> int: ) -> int:
@ -202,6 +210,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"url_cache", "url_cache",
"last_access_ts", "last_access_ts",
"safe_from_quarantine", "safe_from_quarantine",
"user_id",
), ),
allow_none=True, allow_none=True,
desc="get_local_media", desc="get_local_media",
@ -218,6 +227,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
url_cache=row[5], url_cache=row[5],
last_access_ts=row[6], last_access_ts=row[6],
safe_from_quarantine=row[7], safe_from_quarantine=row[7],
user_id=row[8],
) )
async def get_local_media_by_user_paginate( async def get_local_media_by_user_paginate(
@ -272,7 +282,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
url_cache, url_cache,
last_access_ts, last_access_ts,
quarantined_by, quarantined_by,
safe_from_quarantine safe_from_quarantine,
user_id
FROM local_media_repository FROM local_media_repository
WHERE user_id = ? WHERE user_id = ?
ORDER BY {order_by_column} {order}, media_id ASC ORDER BY {order_by_column} {order}, media_id ASC
@ -295,6 +306,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
last_access_ts=row[6], last_access_ts=row[6],
quarantined_by=row[7], quarantined_by=row[7],
safe_from_quarantine=bool(row[8]), safe_from_quarantine=bool(row[8]),
user_id=row[9],
) )
for row in txn for row in txn
] ]
@ -391,6 +403,23 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"get_local_media_ids", _get_local_media_ids_txn "get_local_media_ids", _get_local_media_ids_txn
) )
@trace
async def store_local_media_id(
self,
media_id: str,
time_now_ms: int,
user_id: UserID,
) -> None:
await self.db_pool.simple_insert(
"local_media_repository",
{
"media_id": media_id,
"created_ts": time_now_ms,
"user_id": user_id.to_string(),
},
desc="store_local_media_id",
)
@trace @trace
async def store_local_media( async def store_local_media(
self, self,
@ -416,6 +445,30 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_local_media", desc="store_local_media",
) )
async def update_local_media(
self,
media_id: str,
media_type: str,
upload_name: Optional[str],
media_length: int,
user_id: UserID,
url_cache: Optional[str] = None,
) -> None:
await self.db_pool.simple_update_one(
"local_media_repository",
keyvalues={
"user_id": user_id.to_string(),
"media_id": media_id,
},
updatevalues={
"media_type": media_type,
"upload_name": upload_name,
"media_length": media_length,
"url_cache": url_cache,
},
desc="update_local_media",
)
async def mark_local_media_as_safe(self, media_id: str, safe: bool = True) -> None: async def mark_local_media_as_safe(self, media_id: str, safe: bool = True) -> None:
"""Mark a local media as safe or unsafe from quarantining.""" """Mark a local media as safe or unsafe from quarantining."""
await self.db_pool.simple_update_one( await self.db_pool.simple_update_one(
@ -425,6 +478,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="mark_local_media_as_safe", desc="mark_local_media_as_safe",
) )
async def count_pending_media(self, user_id: UserID) -> Tuple[int, int]:
"""Count the number of pending media for a user.
Returns:
A tuple of two integers: the total pending media requests and the earliest
expiration timestamp.
"""
def get_pending_media_txn(txn: LoggingTransaction) -> Tuple[int, int]:
sql = """
SELECT COUNT(*), MIN(created_ts)
FROM local_media_repository
WHERE user_id = ?
AND created_ts > ?
AND media_length IS NULL
"""
assert self.unused_expiration_time is not None
txn.execute(
sql,
(
user_id.to_string(),
self._clock.time_msec() - self.unused_expiration_time,
),
)
row = txn.fetchone()
if not row:
return 0, 0
return row[0], (row[1] + self.unused_expiration_time if row[1] else 0)
return await self.db_pool.runInteraction(
"get_pending_media", get_pending_media_txn
)
async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]: async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]:
"""Get the media_id and ts for a cached URL as of the given timestamp """Get the media_id and ts for a cached URL as of the given timestamp
Returns: Returns:

View File

@ -318,7 +318,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.assertEqual( self.assertEqual(
self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id
) )
self.assertEqual(self.fetches[0][3], {"allow_remote": "false"}) self.assertEqual(
self.fetches[0][3], {"allow_remote": "false", "timeout_ms": "20000"}
)
headers = { headers = {
b"Content-Length": [b"%d" % (len(self.test_image.data))], b"Content-Length": [b"%d" % (len(self.test_image.data))],