Merge remote-tracking branch 'origin/develop' into clokep/psycopg3
This commit is contained in:
commit
55193f38e5
|
@ -0,0 +1 @@
|
|||
Improve type hints.
|
|
@ -0,0 +1 @@
|
|||
Improve type hints.
|
|
@ -0,0 +1 @@
|
|||
Improve the performance of some operations in multi-worker deployments.
|
|
@ -0,0 +1 @@
|
|||
Improve the performance of some operations in multi-worker deployments.
|
|
@ -0,0 +1 @@
|
|||
Use `dbname` instead of the deprecated `database` connection parameter for psycopg2.
|
|
@ -66,7 +66,7 @@ database:
|
|||
args:
|
||||
user: <user>
|
||||
password: <pass>
|
||||
database: <db>
|
||||
dbname: <db>
|
||||
host: <host>
|
||||
cp_min: 5
|
||||
cp_max: 10
|
||||
|
|
|
@ -1447,7 +1447,7 @@ database:
|
|||
args:
|
||||
user: synapse_user
|
||||
password: secretpassword
|
||||
database: synapse
|
||||
dbname: synapse
|
||||
host: localhost
|
||||
port: 5432
|
||||
cp_min: 5
|
||||
|
@ -1526,7 +1526,7 @@ databases:
|
|||
args:
|
||||
user: synapse_user
|
||||
password: secretpassword
|
||||
database: synapse_main
|
||||
dbname: synapse_main
|
||||
host: localhost
|
||||
port: 5432
|
||||
cp_min: 5
|
||||
|
@ -1539,7 +1539,7 @@ databases:
|
|||
args:
|
||||
user: synapse_user
|
||||
password: secretpassword
|
||||
database: synapse_state
|
||||
dbname: synapse_state
|
||||
host: localhost
|
||||
port: 5432
|
||||
cp_min: 5
|
||||
|
|
|
@ -348,8 +348,7 @@ class Porter:
|
|||
backward_chunk = 0
|
||||
already_ported = 0
|
||||
else:
|
||||
forward_chunk = row["forward_rowid"]
|
||||
backward_chunk = row["backward_rowid"]
|
||||
forward_chunk, backward_chunk = row
|
||||
|
||||
if total_to_port is None:
|
||||
already_ported, total_to_port = await self._get_total_count_to_port(
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import random
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
|
@ -23,6 +23,7 @@ from synapse.api.errors import (
|
|||
StoreError,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
|
||||
from synapse.types import JsonDict, Requester, UserID, create_requester
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.stringutils import parse_and_validate_mxc_uri
|
||||
|
@ -306,7 +307,9 @@ class ProfileHandler:
|
|||
server_name = host
|
||||
|
||||
if self._is_mine_server_name(server_name):
|
||||
media_info = await self.store.get_local_media(media_id)
|
||||
media_info: Optional[
|
||||
Union[LocalMedia, RemoteMedia]
|
||||
] = await self.store.get_local_media(media_id)
|
||||
else:
|
||||
media_info = await self.store.get_cached_remote_media(server_name, media_id)
|
||||
|
||||
|
@ -322,12 +325,12 @@ class ProfileHandler:
|
|||
|
||||
if self.max_avatar_size:
|
||||
# Ensure avatar does not exceed max allowed avatar size
|
||||
if media_info["media_length"] > self.max_avatar_size:
|
||||
if media_info.media_length > self.max_avatar_size:
|
||||
logger.warning(
|
||||
"Forbidding avatar change to %s: %d bytes is above the allowed size "
|
||||
"limit",
|
||||
mxc,
|
||||
media_info["media_length"],
|
||||
media_info.media_length,
|
||||
)
|
||||
return False
|
||||
|
||||
|
@ -335,12 +338,12 @@ class ProfileHandler:
|
|||
# Ensure the avatar's file type is allowed
|
||||
if (
|
||||
self.allowed_avatar_mimetypes
|
||||
and media_info["media_type"] not in self.allowed_avatar_mimetypes
|
||||
and media_info.media_type not in self.allowed_avatar_mimetypes
|
||||
):
|
||||
logger.warning(
|
||||
"Forbidding avatar change to %s: mimetype %s not allowed",
|
||||
mxc,
|
||||
media_info["media_type"],
|
||||
media_info.media_type,
|
||||
)
|
||||
return False
|
||||
|
||||
|
|
|
@ -269,7 +269,7 @@ class RoomCreationHandler:
|
|||
self,
|
||||
requester: Requester,
|
||||
old_room_id: str,
|
||||
old_room: Dict[str, Any],
|
||||
old_room: Tuple[bool, str, bool],
|
||||
new_room_id: str,
|
||||
new_version: RoomVersion,
|
||||
tombstone_event: EventBase,
|
||||
|
@ -279,7 +279,7 @@ class RoomCreationHandler:
|
|||
Args:
|
||||
requester: the user requesting the upgrade
|
||||
old_room_id: the id of the room to be replaced
|
||||
old_room: a dict containing room information for the room to be replaced,
|
||||
old_room: a tuple containing room information for the room to be replaced,
|
||||
as returned by `RoomWorkerStore.get_room`.
|
||||
new_room_id: the id of the replacement room
|
||||
new_version: the version to upgrade the room to
|
||||
|
@ -299,7 +299,7 @@ class RoomCreationHandler:
|
|||
await self.store.store_room(
|
||||
room_id=new_room_id,
|
||||
room_creator_user_id=user_id,
|
||||
is_public=old_room["is_public"],
|
||||
is_public=old_room[0],
|
||||
room_version=new_version,
|
||||
)
|
||||
|
||||
|
|
|
@ -1260,7 +1260,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
# Add new room to the room directory if the old room was there
|
||||
# Remove old room from the room directory
|
||||
old_room = await self.store.get_room(old_room_id)
|
||||
if old_room is not None and old_room["is_public"]:
|
||||
# If the old room exists and is public.
|
||||
if old_room is not None and old_room[0]:
|
||||
await self.store.set_room_is_public(old_room_id, False)
|
||||
await self.store.set_room_is_public(room_id, True)
|
||||
|
||||
|
|
|
@ -806,7 +806,7 @@ class SsoHandler:
|
|||
media_id = profile["avatar_url"].split("/")[-1]
|
||||
if self._is_mine_server_name(server_name):
|
||||
media = await self._media_repo.store.get_local_media(media_id)
|
||||
if media is not None and upload_name == media["upload_name"]:
|
||||
if media is not None and upload_name == media.upload_name:
|
||||
logger.info("skipping saving the user avatar")
|
||||
return True
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ import shutil
|
|||
from io import BytesIO
|
||||
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import attr
|
||||
from matrix_common.types.mxc_uri import MXCUri
|
||||
|
||||
import twisted.internet.error
|
||||
|
@ -50,6 +51,7 @@ from synapse.media.storage_provider import StorageProviderWrapper
|
|||
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
|
||||
from synapse.media.url_previewer import UrlPreviewer
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.databases.main.media_repository import RemoteMedia
|
||||
from synapse.types import UserID
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
@ -245,18 +247,18 @@ class MediaRepository:
|
|||
Resolves once a response has successfully been written to request
|
||||
"""
|
||||
media_info = await self.store.get_local_media(media_id)
|
||||
if not media_info or media_info["quarantined_by"]:
|
||||
if not media_info or media_info.quarantined_by:
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
self.mark_recently_accessed(None, media_id)
|
||||
|
||||
media_type = media_info["media_type"]
|
||||
media_type = media_info.media_type
|
||||
if not media_type:
|
||||
media_type = "application/octet-stream"
|
||||
media_length = media_info["media_length"]
|
||||
upload_name = name if name else media_info["upload_name"]
|
||||
url_cache = media_info["url_cache"]
|
||||
media_length = media_info.media_length
|
||||
upload_name = name if name else media_info.upload_name
|
||||
url_cache = media_info.url_cache
|
||||
|
||||
file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
|
||||
|
||||
|
@ -310,16 +312,20 @@ class MediaRepository:
|
|||
|
||||
# We deliberately stream the file outside the lock
|
||||
if responder:
|
||||
media_type = media_info["media_type"]
|
||||
media_length = media_info["media_length"]
|
||||
upload_name = name if name else media_info["upload_name"]
|
||||
upload_name = name if name else media_info.upload_name
|
||||
await respond_with_responder(
|
||||
request, responder, media_type, media_length, upload_name
|
||||
request,
|
||||
responder,
|
||||
media_info.media_type,
|
||||
media_info.media_length,
|
||||
upload_name,
|
||||
)
|
||||
else:
|
||||
respond_404(request)
|
||||
|
||||
async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
|
||||
async def get_remote_media_info(
|
||||
self, server_name: str, media_id: str
|
||||
) -> RemoteMedia:
|
||||
"""Gets the media info associated with the remote file, downloading
|
||||
if necessary.
|
||||
|
||||
|
@ -353,7 +359,7 @@ class MediaRepository:
|
|||
|
||||
async def _get_remote_media_impl(
|
||||
self, server_name: str, media_id: str
|
||||
) -> Tuple[Optional[Responder], dict]:
|
||||
) -> Tuple[Optional[Responder], RemoteMedia]:
|
||||
"""Looks for media in local cache, if not there then attempt to
|
||||
download from remote server.
|
||||
|
||||
|
@ -373,15 +379,17 @@ class MediaRepository:
|
|||
|
||||
# If we have an entry in the DB, try and look for it
|
||||
if media_info:
|
||||
file_id = media_info["filesystem_id"]
|
||||
file_id = media_info.filesystem_id
|
||||
file_info = FileInfo(server_name, file_id)
|
||||
|
||||
if media_info["quarantined_by"]:
|
||||
if media_info.quarantined_by:
|
||||
logger.info("Media is quarantined")
|
||||
raise NotFoundError()
|
||||
|
||||
if not media_info["media_type"]:
|
||||
media_info["media_type"] = "application/octet-stream"
|
||||
if not media_info.media_type:
|
||||
media_info = attr.evolve(
|
||||
media_info, media_type="application/octet-stream"
|
||||
)
|
||||
|
||||
responder = await self.media_storage.fetch_media(file_info)
|
||||
if responder:
|
||||
|
@ -403,9 +411,9 @@ class MediaRepository:
|
|||
if not media_info:
|
||||
raise e
|
||||
|
||||
file_id = media_info["filesystem_id"]
|
||||
if not media_info["media_type"]:
|
||||
media_info["media_type"] = "application/octet-stream"
|
||||
file_id = media_info.filesystem_id
|
||||
if not media_info.media_type:
|
||||
media_info = attr.evolve(media_info, media_type="application/octet-stream")
|
||||
file_info = FileInfo(server_name, file_id)
|
||||
|
||||
# We generate thumbnails even if another process downloaded the media
|
||||
|
@ -415,7 +423,7 @@ class MediaRepository:
|
|||
# otherwise they'll request thumbnails and get a 404 if they're not
|
||||
# ready yet.
|
||||
await self._generate_thumbnails(
|
||||
server_name, media_id, file_id, media_info["media_type"]
|
||||
server_name, media_id, file_id, media_info.media_type
|
||||
)
|
||||
|
||||
responder = await self.media_storage.fetch_media(file_info)
|
||||
|
@ -425,7 +433,7 @@ class MediaRepository:
|
|||
self,
|
||||
server_name: str,
|
||||
media_id: str,
|
||||
) -> dict:
|
||||
) -> RemoteMedia:
|
||||
"""Attempt to download the remote file from the given server name,
|
||||
using the given file_id as the local id.
|
||||
|
||||
|
@ -518,7 +526,7 @@ class MediaRepository:
|
|||
origin=server_name,
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
time_now_ms=self.clock.time_msec(),
|
||||
time_now_ms=time_now_ms,
|
||||
upload_name=upload_name,
|
||||
media_length=length,
|
||||
filesystem_id=file_id,
|
||||
|
@ -526,15 +534,17 @@ class MediaRepository:
|
|||
|
||||
logger.info("Stored remote media in file %r", fname)
|
||||
|
||||
media_info = {
|
||||
"media_type": media_type,
|
||||
"media_length": length,
|
||||
"upload_name": upload_name,
|
||||
"created_ts": time_now_ms,
|
||||
"filesystem_id": file_id,
|
||||
}
|
||||
|
||||
return media_info
|
||||
return RemoteMedia(
|
||||
media_origin=server_name,
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
media_length=length,
|
||||
upload_name=upload_name,
|
||||
created_ts=time_now_ms,
|
||||
filesystem_id=file_id,
|
||||
last_access_ts=time_now_ms,
|
||||
quarantined_by=None,
|
||||
)
|
||||
|
||||
def _get_thumbnail_requirements(
|
||||
self, media_type: str
|
||||
|
|
|
@ -240,15 +240,14 @@ class UrlPreviewer:
|
|||
cache_result = await self.store.get_url_cache(url, ts)
|
||||
if (
|
||||
cache_result
|
||||
and cache_result["expires_ts"] > ts
|
||||
and cache_result["response_code"] / 100 == 2
|
||||
and cache_result.expires_ts > ts
|
||||
and cache_result.response_code // 100 == 2
|
||||
):
|
||||
# It may be stored as text in the database, not as bytes (such as
|
||||
# PostgreSQL). If so, encode it back before handing it on.
|
||||
og = cache_result["og"]
|
||||
if isinstance(og, str):
|
||||
og = og.encode("utf8")
|
||||
return og
|
||||
if isinstance(cache_result.og, str):
|
||||
return cache_result.og.encode("utf8")
|
||||
return cache_result.og
|
||||
|
||||
# If this URL can be accessed via an allowed oEmbed, use that instead.
|
||||
url_to_download = url
|
||||
|
|
|
@ -1860,7 +1860,8 @@ class PublicRoomListManager:
|
|||
if not room:
|
||||
return False
|
||||
|
||||
return room.get("is_public", False)
|
||||
# The first item is whether the room is public.
|
||||
return room[0]
|
||||
|
||||
async def add_room_to_public_room_list(self, room_id: str) -> None:
|
||||
"""Publishes a room to the public room list.
|
||||
|
|
|
@ -413,8 +413,8 @@ class RoomMembersRestServlet(RestServlet):
|
|||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
ret = await self.store.get_room(room_id)
|
||||
if not ret:
|
||||
room = await self.store.get_room(room_id)
|
||||
if not room:
|
||||
raise NotFoundError("Room not found")
|
||||
|
||||
members = await self.store.get_users_in_room(room_id)
|
||||
|
@ -442,8 +442,8 @@ class RoomStateRestServlet(RestServlet):
|
|||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
ret = await self.store.get_room(room_id)
|
||||
if not ret:
|
||||
room = await self.store.get_room(room_id)
|
||||
if not room:
|
||||
raise NotFoundError("Room not found")
|
||||
|
||||
event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
|
||||
|
|
|
@ -147,7 +147,7 @@ class ClientDirectoryListServer(RestServlet):
|
|||
if room is None:
|
||||
raise NotFoundError("Unknown room")
|
||||
|
||||
return 200, {"visibility": "public" if room["is_public"] else "private"}
|
||||
return 200, {"visibility": "public" if room[0] else "private"}
|
||||
|
||||
class PutBody(RequestBodyModel):
|
||||
visibility: Literal["public", "private"] = "public"
|
||||
|
|
|
@ -119,7 +119,7 @@ class ThumbnailResource(RestServlet):
|
|||
if not media_info:
|
||||
respond_404(request)
|
||||
return
|
||||
if media_info["quarantined_by"]:
|
||||
if media_info.quarantined_by:
|
||||
logger.info("Media is quarantined")
|
||||
respond_404(request)
|
||||
return
|
||||
|
@ -134,7 +134,7 @@ class ThumbnailResource(RestServlet):
|
|||
thumbnail_infos,
|
||||
media_id,
|
||||
media_id,
|
||||
url_cache=bool(media_info["url_cache"]),
|
||||
url_cache=bool(media_info.url_cache),
|
||||
server_name=None,
|
||||
)
|
||||
|
||||
|
@ -152,7 +152,7 @@ class ThumbnailResource(RestServlet):
|
|||
if not media_info:
|
||||
respond_404(request)
|
||||
return
|
||||
if media_info["quarantined_by"]:
|
||||
if media_info.quarantined_by:
|
||||
logger.info("Media is quarantined")
|
||||
respond_404(request)
|
||||
return
|
||||
|
@ -168,7 +168,7 @@ class ThumbnailResource(RestServlet):
|
|||
file_info = FileInfo(
|
||||
server_name=None,
|
||||
file_id=media_id,
|
||||
url_cache=media_info["url_cache"],
|
||||
url_cache=bool(media_info.url_cache),
|
||||
thumbnail=info,
|
||||
)
|
||||
|
||||
|
@ -188,7 +188,7 @@ class ThumbnailResource(RestServlet):
|
|||
desired_height,
|
||||
desired_method,
|
||||
desired_type,
|
||||
url_cache=bool(media_info["url_cache"]),
|
||||
url_cache=bool(media_info.url_cache),
|
||||
)
|
||||
|
||||
if file_path:
|
||||
|
@ -213,7 +213,7 @@ class ThumbnailResource(RestServlet):
|
|||
server_name, media_id
|
||||
)
|
||||
|
||||
file_id = media_info["filesystem_id"]
|
||||
file_id = media_info.filesystem_id
|
||||
|
||||
for info in thumbnail_infos:
|
||||
t_w = info.width == desired_width
|
||||
|
@ -224,7 +224,7 @@ class ThumbnailResource(RestServlet):
|
|||
if t_w and t_h and t_method and t_type:
|
||||
file_info = FileInfo(
|
||||
server_name=server_name,
|
||||
file_id=media_info["filesystem_id"],
|
||||
file_id=file_id,
|
||||
thumbnail=info,
|
||||
)
|
||||
|
||||
|
@ -280,7 +280,7 @@ class ThumbnailResource(RestServlet):
|
|||
m_type,
|
||||
thumbnail_infos,
|
||||
media_id,
|
||||
media_info["filesystem_id"],
|
||||
media_info.filesystem_id,
|
||||
url_cache=False,
|
||||
server_name=server_name,
|
||||
)
|
||||
|
|
|
@ -1660,7 +1660,7 @@ class DatabasePool:
|
|||
retcols: Collection[str],
|
||||
allow_none: Literal[False] = False,
|
||||
desc: str = "simple_select_one",
|
||||
) -> Dict[str, Any]:
|
||||
) -> Tuple[Any, ...]:
|
||||
...
|
||||
|
||||
@overload
|
||||
|
@ -1671,7 +1671,7 @@ class DatabasePool:
|
|||
retcols: Collection[str],
|
||||
allow_none: Literal[True] = True,
|
||||
desc: str = "simple_select_one",
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
) -> Optional[Tuple[Any, ...]]:
|
||||
...
|
||||
|
||||
async def simple_select_one(
|
||||
|
@ -1681,7 +1681,7 @@ class DatabasePool:
|
|||
retcols: Collection[str],
|
||||
allow_none: bool = False,
|
||||
desc: str = "simple_select_one",
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
) -> Optional[Tuple[Any, ...]]:
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning multiple columns from it.
|
||||
|
||||
|
@ -2190,7 +2190,7 @@ class DatabasePool:
|
|||
keyvalues: Dict[str, Any],
|
||||
retcols: Collection[str],
|
||||
allow_none: bool = False,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
) -> Optional[Tuple[Any, ...]]:
|
||||
select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
|
||||
|
||||
if keyvalues:
|
||||
|
@ -2208,7 +2208,7 @@ class DatabasePool:
|
|||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "More than one row matched (%s)" % (table,))
|
||||
|
||||
return dict(zip(retcols, row))
|
||||
return row
|
||||
|
||||
async def simple_delete_one(
|
||||
self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
|
||||
|
|
|
@ -747,8 +747,16 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
|||
)
|
||||
|
||||
# Invalidate the cache for any ignored users which were added or removed.
|
||||
for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
|
||||
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
|
||||
self._invalidate_cache_and_stream_bulk(
|
||||
txn,
|
||||
self.ignored_by,
|
||||
[
|
||||
(ignored_user_id,)
|
||||
for ignored_user_id in (
|
||||
previously_ignored_users ^ currently_ignored_users
|
||||
)
|
||||
],
|
||||
)
|
||||
self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
|
||||
|
||||
async def remove_account_data_for_user(
|
||||
|
@ -824,10 +832,14 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
|||
)
|
||||
|
||||
# Invalidate the cache for ignored users which were removed.
|
||||
for ignored_user_id in previously_ignored_users:
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.ignored_by, (ignored_user_id,)
|
||||
)
|
||||
self._invalidate_cache_and_stream_bulk(
|
||||
txn,
|
||||
self.ignored_by,
|
||||
[
|
||||
(ignored_user_id,)
|
||||
for ignored_user_id in previously_ignored_users
|
||||
],
|
||||
)
|
||||
|
||||
# Invalidate for this user the cache tracking ignored users.
|
||||
self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
|
||||
|
|
|
@ -483,6 +483,30 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
|||
txn.call_after(cache_func.invalidate, keys)
|
||||
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
|
||||
|
||||
def _invalidate_cache_and_stream_bulk(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
cache_func: CachedFunction,
|
||||
key_tuples: Collection[Tuple[Any, ...]],
|
||||
) -> None:
|
||||
"""A bulk version of _invalidate_cache_and_stream.
|
||||
|
||||
Locally invalidate every key-tuple in `key_tuples`, then emit invalidations
|
||||
for each key-tuple over replication.
|
||||
|
||||
This implementation is more efficient than a loop which repeatedly calls the
|
||||
non-bulk version.
|
||||
"""
|
||||
if not key_tuples:
|
||||
return
|
||||
|
||||
for keys in key_tuples:
|
||||
txn.call_after(cache_func.invalidate, keys)
|
||||
|
||||
self._send_invalidation_to_replication_bulk(
|
||||
txn, cache_func.__name__, key_tuples
|
||||
)
|
||||
|
||||
def _invalidate_all_cache_and_stream(
|
||||
self, txn: LoggingTransaction, cache_func: CachedFunction
|
||||
) -> None:
|
||||
|
@ -564,10 +588,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
|||
if isinstance(self.database_engine, PostgresEngine):
|
||||
assert self._cache_id_gen is not None
|
||||
|
||||
# get_next() returns a context manager which is designed to wrap
|
||||
# the transaction. However, we want to only get an ID when we want
|
||||
# to use it, here, so we need to call __enter__ manually, and have
|
||||
# __exit__ called after the transaction finishes.
|
||||
stream_id = self._cache_id_gen.get_next_txn(txn)
|
||||
txn.call_after(self.hs.get_notifier().on_new_replication_data)
|
||||
|
||||
|
@ -586,6 +606,53 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
|||
},
|
||||
)
|
||||
|
||||
def _send_invalidation_to_replication_bulk(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
cache_name: str,
|
||||
key_tuples: Collection[Tuple[Any, ...]],
|
||||
) -> None:
|
||||
"""Announce the invalidation of multiple (but not all) cache entries.
|
||||
|
||||
This is more efficient than repeated calls to the non-bulk version. It should
|
||||
NOT be used to invalidating the entire cache: use
|
||||
`_send_invalidation_to_replication` with keys=None.
|
||||
|
||||
Note that this does *not* invalidate the cache locally.
|
||||
|
||||
Args:
|
||||
txn
|
||||
cache_name
|
||||
key_tuples: Key-tuples to invalidate. Assumed to be non-empty.
|
||||
"""
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
assert self._cache_id_gen is not None
|
||||
|
||||
stream_ids = self._cache_id_gen.get_next_mult_txn(txn, len(key_tuples))
|
||||
ts = self._clock.time_msec()
|
||||
txn.call_after(self.hs.get_notifier().on_new_replication_data)
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="cache_invalidation_stream_by_instance",
|
||||
keys=(
|
||||
"stream_id",
|
||||
"instance_name",
|
||||
"cache_func",
|
||||
"keys",
|
||||
"invalidation_ts",
|
||||
),
|
||||
values=[
|
||||
# We convert key_tuples to a list here because psycopg2 serialises
|
||||
# lists as pq arrrays, but serialises tuples as "composite types".
|
||||
# (We need an array because the `keys` column has type `[]text`.)
|
||||
# See:
|
||||
# https://www.psycopg.org/docs/usage.html#adapt-list
|
||||
# https://www.psycopg.org/docs/usage.html#adapt-tuple
|
||||
(stream_id, self._instance_name, cache_name, list(key_tuple), ts)
|
||||
for stream_id, key_tuple in zip(stream_ids, key_tuples)
|
||||
],
|
||||
)
|
||||
|
||||
def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
|
||||
if self._cache_id_gen:
|
||||
return self._cache_id_gen.get_current_token_for_writer(instance_name)
|
||||
|
|
|
@ -255,33 +255,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
A dict containing the device information, or `None` if the device does not
|
||||
exist.
|
||||
"""
|
||||
return await self.db_pool.simple_select_one(
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
|
||||
retcols=("user_id", "device_id", "display_name"),
|
||||
desc="get_device",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
async def get_device_opt(
|
||||
self, user_id: str, device_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Retrieve a device. Only returns devices that are not marked as
|
||||
hidden.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user which owns the device
|
||||
device_id: The ID of the device to retrieve
|
||||
Returns:
|
||||
A dict containing the device information, or None if the device does not exist.
|
||||
"""
|
||||
return await self.db_pool.simple_select_one(
|
||||
row = await self.db_pool.simple_select_one(
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
|
||||
retcols=("user_id", "device_id", "display_name"),
|
||||
desc="get_device",
|
||||
allow_none=True,
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return {"user_id": row[0], "device_id": row[1], "display_name": row[2]}
|
||||
|
||||
async def get_devices_by_user(
|
||||
self, user_id: str
|
||||
|
@ -1221,9 +1204,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
retcols=["device_id", "device_data"],
|
||||
allow_none=True,
|
||||
)
|
||||
return (
|
||||
(row["device_id"], json_decoder.decode(row["device_data"])) if row else None
|
||||
)
|
||||
return (row[0], json_decoder.decode(row[1])) if row else None
|
||||
|
||||
def _store_dehydrated_device_txn(
|
||||
self,
|
||||
|
@ -2326,13 +2307,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
`FALSE` have not been converted.
|
||||
"""
|
||||
|
||||
row = await self.db_pool.simple_select_one(
|
||||
table="device_lists_changes_converted_stream_position",
|
||||
keyvalues={},
|
||||
retcols=["stream_id", "room_id"],
|
||||
desc="get_device_change_last_converted_pos",
|
||||
return cast(
|
||||
Tuple[int, str],
|
||||
await self.db_pool.simple_select_one(
|
||||
table="device_lists_changes_converted_stream_position",
|
||||
keyvalues={},
|
||||
retcols=["stream_id", "room_id"],
|
||||
desc="get_device_change_last_converted_pos",
|
||||
),
|
||||
)
|
||||
return row["stream_id"], row["room_id"]
|
||||
|
||||
async def set_device_change_last_converted_pos(
|
||||
self,
|
||||
|
|
|
@ -506,19 +506,26 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
|
|||
# it isn't there.
|
||||
raise StoreError(404, "No backup with that version exists")
|
||||
|
||||
result = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
table="e2e_room_keys_versions",
|
||||
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
|
||||
retcols=("version", "algorithm", "auth_data", "etag"),
|
||||
allow_none=False,
|
||||
row = cast(
|
||||
Tuple[int, str, str, Optional[int]],
|
||||
self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
table="e2e_room_keys_versions",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"version": this_version,
|
||||
"deleted": 0,
|
||||
},
|
||||
retcols=("version", "algorithm", "auth_data", "etag"),
|
||||
allow_none=False,
|
||||
),
|
||||
)
|
||||
assert result is not None # see comment on `simple_select_one_txn`
|
||||
result["auth_data"] = db_to_json(result["auth_data"])
|
||||
result["version"] = str(result["version"])
|
||||
if result["etag"] is None:
|
||||
result["etag"] = 0
|
||||
return result
|
||||
return {
|
||||
"auth_data": db_to_json(row[2]),
|
||||
"version": str(row[0]),
|
||||
"algorithm": row[1],
|
||||
"etag": 0 if row[3] is None else row[3],
|
||||
}
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
|
||||
|
|
|
@ -1237,13 +1237,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
for user_id, device_id, algorithm, key_id, key_json in claimed_keys:
|
||||
device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
|
||||
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
|
||||
|
||||
if (user_id, device_id) in seen_user_device:
|
||||
continue
|
||||
seen_user_device.add((user_id, device_id))
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream_bulk(
|
||||
txn, self.get_e2e_unused_fallback_key_types, seen_user_device
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
@ -1268,9 +1266,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
if row is None:
|
||||
continue
|
||||
|
||||
key_id = row["key_id"]
|
||||
key_json = row["key_json"]
|
||||
used = row["used"]
|
||||
key_id, key_json, used = row
|
||||
|
||||
# Mark fallback key as used if not already.
|
||||
if not used and mark_as_used:
|
||||
|
@ -1376,14 +1372,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
|
||||
)
|
||||
|
||||
seen_user_device: Set[Tuple[str, str]] = set()
|
||||
for user_id, device_id, _, _, _ in otk_rows:
|
||||
if (user_id, device_id) in seen_user_device:
|
||||
continue
|
||||
seen_user_device.add((user_id, device_id))
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||
)
|
||||
seen_user_device = {
|
||||
(user_id, device_id) for user_id, device_id, _, _, _ in otk_rows
|
||||
}
|
||||
self._invalidate_cache_and_stream_bulk(
|
||||
txn,
|
||||
self.count_e2e_one_time_keys,
|
||||
seen_user_device,
|
||||
)
|
||||
|
||||
return otk_rows
|
||||
|
||||
|
|
|
@ -193,7 +193,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
# Check if we have indexed the room so we can use the chain cover
|
||||
# algorithm.
|
||||
room = await self.get_room(room_id) # type: ignore[attr-defined]
|
||||
if room["has_auth_chain_index"]:
|
||||
# If the room has an auth chain index.
|
||||
if room[1]:
|
||||
try:
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_auth_chain_ids_chains",
|
||||
|
@ -410,7 +411,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
# Check if we have indexed the room so we can use the chain cover
|
||||
# algorithm.
|
||||
room = await self.get_room(room_id) # type: ignore[attr-defined]
|
||||
if room["has_auth_chain_index"]:
|
||||
# If the room has an auth chain index.
|
||||
if room[1]:
|
||||
try:
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_auth_chain_difference_chains",
|
||||
|
@ -1436,24 +1438,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
)
|
||||
|
||||
if event_lookup_result is not None:
|
||||
event_type, depth, stream_ordering = event_lookup_result
|
||||
logger.debug(
|
||||
"_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s",
|
||||
room_id,
|
||||
seed_event_id,
|
||||
event_lookup_result["depth"],
|
||||
event_lookup_result["stream_ordering"],
|
||||
event_lookup_result["type"],
|
||||
depth,
|
||||
stream_ordering,
|
||||
event_type,
|
||||
)
|
||||
|
||||
if event_lookup_result["depth"]:
|
||||
queue.put(
|
||||
(
|
||||
-event_lookup_result["depth"],
|
||||
-event_lookup_result["stream_ordering"],
|
||||
seed_event_id,
|
||||
event_lookup_result["type"],
|
||||
)
|
||||
)
|
||||
if depth:
|
||||
queue.put((-depth, -stream_ordering, seed_event_id, event_type))
|
||||
|
||||
while not queue.empty() and len(event_id_results) < limit:
|
||||
try:
|
||||
|
|
|
@ -1934,8 +1934,7 @@ class PersistEventsStore:
|
|||
if row is None:
|
||||
return
|
||||
|
||||
redacted_relates_to = row["relates_to_id"]
|
||||
rel_type = row["relation_type"]
|
||||
redacted_relates_to, rel_type = row
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
|
||||
)
|
||||
|
|
|
@ -1222,14 +1222,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
# Iterate the parent IDs and invalidate caches.
|
||||
for parent_id in {r[1] for r in relations_to_insert}:
|
||||
cache_tuple = (parent_id,)
|
||||
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
|
||||
txn, self.get_relations_for_event, cache_tuple # type: ignore[attr-defined]
|
||||
)
|
||||
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
|
||||
txn, self.get_thread_summary, cache_tuple # type: ignore[attr-defined]
|
||||
)
|
||||
cache_tuples = {(r[1],) for r in relations_to_insert}
|
||||
self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined]
|
||||
txn, self.get_relations_for_event, cache_tuples # type: ignore[attr-defined]
|
||||
)
|
||||
self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined]
|
||||
txn, self.get_thread_summary, cache_tuples # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
if results:
|
||||
latest_event_id = results[-1][0]
|
||||
|
|
|
@ -1998,7 +1998,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
if not res:
|
||||
raise SynapseError(404, "Could not find event %s" % (event_id,))
|
||||
|
||||
return int(res["topological_ordering"]), int(res["stream_ordering"])
|
||||
return int(res[0]), int(res[1])
|
||||
|
||||
async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
|
||||
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
|
||||
|
|
|
@ -107,13 +107,16 @@ class KeyStore(CacheInvalidationWorkerStore):
|
|||
# invalidate takes a tuple corresponding to the params of
|
||||
# _get_server_keys_json. _get_server_keys_json only takes one
|
||||
# param, which is itself the 2-tuple (server_name, key_id).
|
||||
for key_id in verify_keys:
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self._get_server_keys_json, ((server_name, key_id),)
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_server_key_json_for_remote, (server_name, key_id)
|
||||
)
|
||||
self._invalidate_cache_and_stream_bulk(
|
||||
txn,
|
||||
self._get_server_keys_json,
|
||||
[((server_name, key_id),) for key_id in verify_keys],
|
||||
)
|
||||
self._invalidate_cache_and_stream_bulk(
|
||||
txn,
|
||||
self.get_server_key_json_for_remote,
|
||||
[(server_name, key_id) for key_id in verify_keys],
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"store_server_keys_response", store_server_keys_response_txn
|
||||
|
|
|
@ -15,9 +15,7 @@
|
|||
from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
|
@ -54,11 +52,32 @@ class LocalMedia:
|
|||
media_length: int
|
||||
upload_name: str
|
||||
created_ts: int
|
||||
url_cache: Optional[str]
|
||||
last_access_ts: int
|
||||
quarantined_by: Optional[str]
|
||||
safe_from_quarantine: bool
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class RemoteMedia:
|
||||
media_origin: str
|
||||
media_id: str
|
||||
media_type: str
|
||||
media_length: int
|
||||
upload_name: Optional[str]
|
||||
filesystem_id: str
|
||||
created_ts: int
|
||||
last_access_ts: int
|
||||
quarantined_by: Optional[str]
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class UrlCache:
|
||||
response_code: int
|
||||
expires_ts: int
|
||||
og: Union[str, bytes]
|
||||
|
||||
|
||||
class MediaSortOrder(Enum):
|
||||
"""
|
||||
Enum to define the sorting method used when returning media with
|
||||
|
@ -165,13 +184,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
super().__init__(database, db_conn, hs)
|
||||
self.server_name: str = hs.hostname
|
||||
|
||||
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
|
||||
async def get_local_media(self, media_id: str) -> Optional[LocalMedia]:
|
||||
"""Get the metadata for a local piece of media
|
||||
|
||||
Returns:
|
||||
None if the media_id doesn't exist.
|
||||
"""
|
||||
return await self.db_pool.simple_select_one(
|
||||
row = await self.db_pool.simple_select_one(
|
||||
"local_media_repository",
|
||||
{"media_id": media_id},
|
||||
(
|
||||
|
@ -181,11 +200,25 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
"created_ts",
|
||||
"quarantined_by",
|
||||
"url_cache",
|
||||
"last_access_ts",
|
||||
"safe_from_quarantine",
|
||||
),
|
||||
allow_none=True,
|
||||
desc="get_local_media",
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return LocalMedia(
|
||||
media_id=media_id,
|
||||
media_type=row[0],
|
||||
media_length=row[1],
|
||||
upload_name=row[2],
|
||||
created_ts=row[3],
|
||||
quarantined_by=row[4],
|
||||
url_cache=row[5],
|
||||
last_access_ts=row[6],
|
||||
safe_from_quarantine=row[7],
|
||||
)
|
||||
|
||||
async def get_local_media_by_user_paginate(
|
||||
self,
|
||||
|
@ -236,6 +269,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
media_length,
|
||||
upload_name,
|
||||
created_ts,
|
||||
url_cache,
|
||||
last_access_ts,
|
||||
quarantined_by,
|
||||
safe_from_quarantine
|
||||
|
@ -257,9 +291,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
media_length=row[2],
|
||||
upload_name=row[3],
|
||||
created_ts=row[4],
|
||||
last_access_ts=row[5],
|
||||
quarantined_by=row[6],
|
||||
safe_from_quarantine=bool(row[7]),
|
||||
url_cache=row[5],
|
||||
last_access_ts=row[6],
|
||||
quarantined_by=row[7],
|
||||
safe_from_quarantine=bool(row[8]),
|
||||
)
|
||||
for row in txn
|
||||
]
|
||||
|
@ -390,51 +425,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
desc="mark_local_media_as_safe",
|
||||
)
|
||||
|
||||
async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
|
||||
async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]:
|
||||
"""Get the media_id and ts for a cached URL as of the given timestamp
|
||||
Returns:
|
||||
None if the URL isn't cached.
|
||||
"""
|
||||
|
||||
def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
|
||||
def get_url_cache_txn(txn: LoggingTransaction) -> Optional[UrlCache]:
|
||||
# get the most recently cached result (relative to the given ts)
|
||||
sql = (
|
||||
"SELECT response_code, etag, expires_ts, og, media_id, download_ts"
|
||||
" FROM local_media_repository_url_cache"
|
||||
" WHERE url = ? AND download_ts <= ?"
|
||||
" ORDER BY download_ts DESC LIMIT 1"
|
||||
)
|
||||
sql = """
|
||||
SELECT response_code, expires_ts, og
|
||||
FROM local_media_repository_url_cache
|
||||
WHERE url = ? AND download_ts <= ?
|
||||
ORDER BY download_ts DESC LIMIT 1
|
||||
"""
|
||||
txn.execute(sql, (url, ts))
|
||||
row = txn.fetchone()
|
||||
|
||||
if not row:
|
||||
# ...or if we've requested a timestamp older than the oldest
|
||||
# copy in the cache, return the oldest copy (if any)
|
||||
sql = (
|
||||
"SELECT response_code, etag, expires_ts, og, media_id, download_ts"
|
||||
" FROM local_media_repository_url_cache"
|
||||
" WHERE url = ? AND download_ts > ?"
|
||||
" ORDER BY download_ts ASC LIMIT 1"
|
||||
)
|
||||
sql = """
|
||||
SELECT response_code, expires_ts, og
|
||||
FROM local_media_repository_url_cache
|
||||
WHERE url = ? AND download_ts > ?
|
||||
ORDER BY download_ts ASC LIMIT 1
|
||||
"""
|
||||
txn.execute(sql, (url, ts))
|
||||
row = txn.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return dict(
|
||||
zip(
|
||||
(
|
||||
"response_code",
|
||||
"etag",
|
||||
"expires_ts",
|
||||
"og",
|
||||
"media_id",
|
||||
"download_ts",
|
||||
),
|
||||
row,
|
||||
)
|
||||
)
|
||||
return UrlCache(response_code=row[0], expires_ts=row[1], og=row[2])
|
||||
|
||||
return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
|
||||
|
||||
|
@ -444,7 +467,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
response_code: int,
|
||||
etag: Optional[str],
|
||||
expires_ts: int,
|
||||
og: Optional[str],
|
||||
og: str,
|
||||
media_id: str,
|
||||
download_ts: int,
|
||||
) -> None:
|
||||
|
@ -510,8 +533,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
|
||||
async def get_cached_remote_media(
|
||||
self, origin: str, media_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
return await self.db_pool.simple_select_one(
|
||||
) -> Optional[RemoteMedia]:
|
||||
row = await self.db_pool.simple_select_one(
|
||||
"remote_media_cache",
|
||||
{"media_origin": origin, "media_id": media_id},
|
||||
(
|
||||
|
@ -520,11 +543,25 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
"upload_name",
|
||||
"created_ts",
|
||||
"filesystem_id",
|
||||
"last_access_ts",
|
||||
"quarantined_by",
|
||||
),
|
||||
allow_none=True,
|
||||
desc="get_cached_remote_media",
|
||||
)
|
||||
if row is None:
|
||||
return row
|
||||
return RemoteMedia(
|
||||
media_origin=origin,
|
||||
media_id=media_id,
|
||||
media_type=row[0],
|
||||
media_length=row[1],
|
||||
upload_name=row[2],
|
||||
created_ts=row[3],
|
||||
filesystem_id=row[4],
|
||||
last_access_ts=row[5],
|
||||
quarantined_by=row[6],
|
||||
)
|
||||
|
||||
async def store_cached_remote_media(
|
||||
self,
|
||||
|
@ -623,10 +660,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
t_width: int,
|
||||
t_height: int,
|
||||
t_type: str,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
) -> Optional[ThumbnailInfo]:
|
||||
"""Fetch the thumbnail info of given width, height and type."""
|
||||
|
||||
return await self.db_pool.simple_select_one(
|
||||
row = await self.db_pool.simple_select_one(
|
||||
table="remote_media_cache_thumbnails",
|
||||
keyvalues={
|
||||
"media_origin": origin,
|
||||
|
@ -641,11 +678,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
"thumbnail_method",
|
||||
"thumbnail_type",
|
||||
"thumbnail_length",
|
||||
"filesystem_id",
|
||||
),
|
||||
allow_none=True,
|
||||
desc="get_remote_media_thumbnail",
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return ThumbnailInfo(
|
||||
width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
|
||||
)
|
||||
|
||||
@trace
|
||||
async def store_remote_media_thumbnail(
|
||||
|
|
|
@ -363,10 +363,11 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
|
|||
# for their user ID.
|
||||
value_values=[(presence_stream_id,) for _ in user_ids],
|
||||
)
|
||||
for user_id in user_ids:
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self._get_full_presence_stream_token_for_user, (user_id,)
|
||||
)
|
||||
self._invalidate_cache_and_stream_bulk(
|
||||
txn,
|
||||
self._get_full_presence_stream_token_for_user,
|
||||
[(user_id,) for user_id in user_ids],
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"add_users_to_send_full_presence_to", _add_users_to_send_full_presence_to
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
|
@ -138,23 +137,18 @@ class ProfileWorkerStore(SQLBaseStore):
|
|||
return 50
|
||||
|
||||
async def get_profileinfo(self, user_id: UserID) -> ProfileInfo:
|
||||
try:
|
||||
profile = await self.db_pool.simple_select_one(
|
||||
table="profiles",
|
||||
keyvalues={"full_user_id": user_id.to_string()},
|
||||
retcols=("displayname", "avatar_url"),
|
||||
desc="get_profileinfo",
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
# no match
|
||||
return ProfileInfo(None, None)
|
||||
else:
|
||||
raise
|
||||
|
||||
return ProfileInfo(
|
||||
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
|
||||
profile = await self.db_pool.simple_select_one(
|
||||
table="profiles",
|
||||
keyvalues={"full_user_id": user_id.to_string()},
|
||||
retcols=("displayname", "avatar_url"),
|
||||
desc="get_profileinfo",
|
||||
allow_none=True,
|
||||
)
|
||||
if profile is None:
|
||||
# no match
|
||||
return ProfileInfo(None, None)
|
||||
|
||||
return ProfileInfo(avatar_url=profile[1], display_name=profile[0])
|
||||
|
||||
async def get_profile_displayname(self, user_id: UserID) -> Optional[str]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
|
|
|
@ -295,19 +295,28 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
|||
# so make sure to keep this actually last.
|
||||
txn.execute("DROP TABLE events_to_purge")
|
||||
|
||||
for event_id, should_delete in event_rows:
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self._get_state_group_for_event, (event_id,)
|
||||
)
|
||||
self._invalidate_cache_and_stream_bulk(
|
||||
txn,
|
||||
self._get_state_group_for_event,
|
||||
[(event_id,) for event_id, _ in event_rows],
|
||||
)
|
||||
|
||||
# XXX: This is racy, since have_seen_events could be called between the
|
||||
# transaction completing and the invalidation running. On the other hand,
|
||||
# that's no different to calling `have_seen_events` just before the
|
||||
# event is deleted from the database.
|
||||
# XXX: This is racy, since have_seen_events could be called between the
|
||||
# transaction completing and the invalidation running. On the other hand,
|
||||
# that's no different to calling `have_seen_events` just before the
|
||||
# event is deleted from the database.
|
||||
self._invalidate_cache_and_stream_bulk(
|
||||
txn,
|
||||
self.have_seen_event,
|
||||
[
|
||||
(room_id, event_id)
|
||||
for event_id, should_delete in event_rows
|
||||
if should_delete
|
||||
],
|
||||
)
|
||||
|
||||
for event_id, should_delete in event_rows:
|
||||
if should_delete:
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.have_seen_event, (room_id, event_id)
|
||||
)
|
||||
self.invalidate_get_event_cache_after_txn(txn, event_id)
|
||||
|
||||
logger.info("[purge] done")
|
||||
|
|
|
@ -468,8 +468,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||
"before/after rule not found: %s" % (relative_to_rule,)
|
||||
)
|
||||
|
||||
base_priority_class = res["priority_class"]
|
||||
base_rule_priority = res["priority"]
|
||||
base_priority_class, base_rule_priority = res
|
||||
|
||||
if base_priority_class != priority_class:
|
||||
raise InconsistentRuleException(
|
||||
|
|
|
@ -701,8 +701,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
allow_none=True,
|
||||
)
|
||||
|
||||
stream_ordering = int(res["stream_ordering"]) if res else None
|
||||
rx_ts = res["received_ts"] if res else 0
|
||||
stream_ordering = int(res[0]) if res else None
|
||||
rx_ts = res[1] if res else 0
|
||||
|
||||
# We don't want to clobber receipts for more recent events, so we
|
||||
# have to compare orderings of existing receipts
|
||||
|
|
|
@ -425,17 +425,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
account timestamp as milliseconds since the epoch. None if the account
|
||||
has not been renewed using the current token yet.
|
||||
"""
|
||||
ret_dict = await self.db_pool.simple_select_one(
|
||||
table="account_validity",
|
||||
keyvalues={"renewal_token": renewal_token},
|
||||
retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
|
||||
desc="get_user_from_renewal_token",
|
||||
)
|
||||
|
||||
return (
|
||||
ret_dict["user_id"],
|
||||
ret_dict["expiration_ts_ms"],
|
||||
ret_dict["token_used_ts_ms"],
|
||||
return cast(
|
||||
Tuple[str, int, Optional[int]],
|
||||
await self.db_pool.simple_select_one(
|
||||
table="account_validity",
|
||||
keyvalues={"renewal_token": renewal_token},
|
||||
retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
|
||||
desc="get_user_from_renewal_token",
|
||||
),
|
||||
)
|
||||
|
||||
async def get_renewal_token_for_user(self, user_id: str) -> str:
|
||||
|
@ -564,16 +561,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
updatevalues={"shadow_banned": shadow_banned},
|
||||
)
|
||||
# In order for this to apply immediately, clear the cache for this user.
|
||||
tokens = self.db_pool.simple_select_onecol_txn(
|
||||
tokens = self.db_pool.simple_select_list_txn(
|
||||
txn,
|
||||
table="access_tokens",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="token",
|
||||
retcols=("token",),
|
||||
)
|
||||
self._invalidate_cache_and_stream_bulk(
|
||||
txn, self.get_user_by_access_token, tokens
|
||||
)
|
||||
for token in tokens:
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_by_access_token, (token,)
|
||||
)
|
||||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||
|
||||
await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
|
||||
|
@ -989,16 +985,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
Returns:
|
||||
user id, or None if no user id/threepid mapping exists
|
||||
"""
|
||||
ret = self.db_pool.simple_select_one_txn(
|
||||
return self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
"user_threepids",
|
||||
{"medium": medium, "address": address},
|
||||
["user_id"],
|
||||
"user_id",
|
||||
True,
|
||||
)
|
||||
if ret:
|
||||
return ret["user_id"]
|
||||
return None
|
||||
|
||||
async def user_add_threepid(
|
||||
self,
|
||||
|
@ -1435,16 +1428,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
if res is None:
|
||||
return False
|
||||
|
||||
uses_allowed, pending, completed, expiry_time = res
|
||||
|
||||
# Check if the token has expired
|
||||
now = self._clock.time_msec()
|
||||
if res["expiry_time"] and res["expiry_time"] < now:
|
||||
if expiry_time and expiry_time < now:
|
||||
return False
|
||||
|
||||
# Check if the token has been used up
|
||||
if (
|
||||
res["uses_allowed"]
|
||||
and res["pending"] + res["completed"] >= res["uses_allowed"]
|
||||
):
|
||||
if uses_allowed and pending + completed >= uses_allowed:
|
||||
return False
|
||||
|
||||
# Otherwise, the token is valid
|
||||
|
@ -1490,8 +1482,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
# Override type because the return type is only optional if
|
||||
# allow_none is True, and we don't want mypy throwing errors
|
||||
# about None not being indexable.
|
||||
res = cast(
|
||||
Dict[str, Any],
|
||||
pending, completed = cast(
|
||||
Tuple[int, int],
|
||||
self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
|
@ -1506,8 +1498,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
updatevalues={
|
||||
"completed": res["completed"] + 1,
|
||||
"pending": res["pending"] - 1,
|
||||
"completed": completed + 1,
|
||||
"pending": pending - 1,
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -1585,13 +1577,22 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
Returns:
|
||||
A dict, or None if token doesn't exist.
|
||||
"""
|
||||
return await self.db_pool.simple_select_one(
|
||||
row = await self.db_pool.simple_select_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
|
||||
allow_none=True,
|
||||
desc="get_one_registration_token",
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return {
|
||||
"token": row[0],
|
||||
"uses_allowed": row[1],
|
||||
"pending": row[2],
|
||||
"completed": row[3],
|
||||
"expiry_time": row[4],
|
||||
}
|
||||
|
||||
async def generate_registration_token(
|
||||
self, length: int, chars: str
|
||||
|
@ -1714,7 +1715,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
return None
|
||||
|
||||
# Get all info about the token so it can be sent in the response
|
||||
return self.db_pool.simple_select_one_txn(
|
||||
result = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
|
@ -1728,6 +1729,17 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
allow_none=True,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return result
|
||||
|
||||
return {
|
||||
"token": result[0],
|
||||
"uses_allowed": result[1],
|
||||
"pending": result[2],
|
||||
"completed": result[3],
|
||||
"expiry_time": result[4],
|
||||
}
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"update_registration_token", _update_registration_token_txn
|
||||
)
|
||||
|
@ -1939,11 +1951,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
keyvalues={"token": token},
|
||||
updatevalues={"used_ts": ts},
|
||||
)
|
||||
user_id = values["user_id"]
|
||||
expiry_ts = values["expiry_ts"]
|
||||
used_ts = values["used_ts"]
|
||||
auth_provider_id = values["auth_provider_id"]
|
||||
auth_provider_session_id = values["auth_provider_session_id"]
|
||||
(
|
||||
user_id,
|
||||
expiry_ts,
|
||||
used_ts,
|
||||
auth_provider_id,
|
||||
auth_provider_session_id,
|
||||
) = values
|
||||
|
||||
# Token was already used
|
||||
if used_ts is not None:
|
||||
|
@ -2668,10 +2682,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
)
|
||||
tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
|
||||
|
||||
for token, _, _ in tokens_and_devices:
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_by_access_token, (token,)
|
||||
)
|
||||
self._invalidate_cache_and_stream_bulk(
|
||||
txn,
|
||||
self.get_user_by_access_token,
|
||||
[(token,) for token, _, _ in tokens_and_devices],
|
||||
)
|
||||
|
||||
txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
|
||||
|
||||
|
@ -2756,12 +2771,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
# reason, the next check is on the client secret, which is NOT NULL,
|
||||
# so we don't have to worry about the client secret matching by
|
||||
# accident.
|
||||
row = {"client_secret": None, "validated_at": None}
|
||||
row = None, None
|
||||
else:
|
||||
raise ThreepidValidationError("Unknown session_id")
|
||||
|
||||
retrieved_client_secret = row["client_secret"]
|
||||
validated_at = row["validated_at"]
|
||||
retrieved_client_secret, validated_at = row
|
||||
|
||||
row = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
|
@ -2775,8 +2789,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
raise ThreepidValidationError(
|
||||
"Validation token not found or has expired"
|
||||
)
|
||||
expires = row["expires"]
|
||||
next_link = row["next_link"]
|
||||
expires, next_link = row
|
||||
|
||||
if retrieved_client_secret != client_secret:
|
||||
raise ThreepidValidationError(
|
||||
|
|
|
@ -213,21 +213,31 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
logger.error("store_room with room_id=%s failed: %s", room_id, e)
|
||||
raise StoreError(500, "Problem creating room.")
|
||||
|
||||
async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]:
|
||||
async def get_room(self, room_id: str) -> Optional[Tuple[bool, bool]]:
|
||||
"""Retrieve a room.
|
||||
|
||||
Args:
|
||||
room_id: The ID of the room to retrieve.
|
||||
Returns:
|
||||
A dict containing the room information, or None if the room is unknown.
|
||||
A tuple containing the room information:
|
||||
* True if the room is public
|
||||
* True if the room has an auth chain index
|
||||
|
||||
or None if the room is unknown.
|
||||
"""
|
||||
return await self.db_pool.simple_select_one(
|
||||
table="rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcols=("room_id", "is_public", "creator", "has_auth_chain_index"),
|
||||
desc="get_room",
|
||||
allow_none=True,
|
||||
row = cast(
|
||||
Optional[Tuple[Optional[Union[int, bool]], Optional[Union[int, bool]]]],
|
||||
await self.db_pool.simple_select_one(
|
||||
table="rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcols=("is_public", "has_auth_chain_index"),
|
||||
desc="get_room",
|
||||
allow_none=True,
|
||||
),
|
||||
)
|
||||
if row is None:
|
||||
return row
|
||||
return bool(row[0]), bool(row[1])
|
||||
|
||||
async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]:
|
||||
"""Retrieve room with statistics.
|
||||
|
@ -794,10 +804,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
)
|
||||
|
||||
if row:
|
||||
return RatelimitOverride(
|
||||
messages_per_second=row["messages_per_second"],
|
||||
burst_count=row["burst_count"],
|
||||
)
|
||||
return RatelimitOverride(messages_per_second=row[0], burst_count=row[1])
|
||||
else:
|
||||
return None
|
||||
|
||||
|
@ -1371,13 +1378,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
join.
|
||||
"""
|
||||
|
||||
result = await self.db_pool.simple_select_one(
|
||||
table="partial_state_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcols=("join_event_id", "device_lists_stream_id"),
|
||||
desc="get_join_event_id_for_partial_state",
|
||||
return cast(
|
||||
Tuple[str, int],
|
||||
await self.db_pool.simple_select_one(
|
||||
table="partial_state_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcols=("join_event_id", "device_lists_stream_id"),
|
||||
desc="get_join_event_id_for_partial_state",
|
||||
),
|
||||
)
|
||||
return result["join_event_id"], result["device_lists_stream_id"]
|
||||
|
||||
def get_un_partial_stated_rooms_token(self, instance_name: str) -> int:
|
||||
return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer(
|
||||
|
|
|
@ -559,17 +559,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
|||
"non-local user %s" % (user_id,),
|
||||
)
|
||||
|
||||
results_dict = await self.db_pool.simple_select_one(
|
||||
"local_current_membership",
|
||||
{"room_id": room_id, "user_id": user_id},
|
||||
("membership", "event_id"),
|
||||
allow_none=True,
|
||||
desc="get_local_current_membership_for_user_in_room",
|
||||
results = cast(
|
||||
Optional[Tuple[str, str]],
|
||||
await self.db_pool.simple_select_one(
|
||||
"local_current_membership",
|
||||
{"room_id": room_id, "user_id": user_id},
|
||||
("membership", "event_id"),
|
||||
allow_none=True,
|
||||
desc="get_local_current_membership_for_user_in_room",
|
||||
),
|
||||
)
|
||||
if not results_dict:
|
||||
if not results:
|
||||
return None, None
|
||||
|
||||
return results_dict.get("membership"), results_dict.get("event_id")
|
||||
return results
|
||||
|
||||
@cached(max_entries=500000, iterable=True)
|
||||
async def get_rooms_for_user_with_stream_ordering(
|
||||
|
|
|
@ -1014,9 +1014,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
desc="get_position_for_event",
|
||||
)
|
||||
|
||||
return PersistedEventPosition(
|
||||
row["instance_name"] or "master", row["stream_ordering"]
|
||||
)
|
||||
return PersistedEventPosition(row[1] or "master", row[0])
|
||||
|
||||
async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken:
|
||||
"""The stream token for an event
|
||||
|
@ -1033,9 +1031,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
retcols=("stream_ordering", "topological_ordering"),
|
||||
desc="get_topological_token_for_event",
|
||||
)
|
||||
return RoomStreamToken(
|
||||
topological=row["topological_ordering"], stream=row["stream_ordering"]
|
||||
)
|
||||
return RoomStreamToken(topological=row[1], stream=row[0])
|
||||
|
||||
async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
|
||||
"""Gets the topological token in a room after or at the given stream
|
||||
|
@ -1180,26 +1176,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
dict
|
||||
"""
|
||||
|
||||
results = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
"events",
|
||||
keyvalues={"event_id": event_id, "room_id": room_id},
|
||||
retcols=["stream_ordering", "topological_ordering"],
|
||||
stream_ordering, topological_ordering = cast(
|
||||
Tuple[int, int],
|
||||
self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
"events",
|
||||
keyvalues={"event_id": event_id, "room_id": room_id},
|
||||
retcols=["stream_ordering", "topological_ordering"],
|
||||
),
|
||||
)
|
||||
|
||||
# This cannot happen as `allow_none=False`.
|
||||
assert results is not None
|
||||
|
||||
# Paginating backwards includes the event at the token, but paginating
|
||||
# forward doesn't.
|
||||
before_token = RoomStreamToken(
|
||||
topological=results["topological_ordering"] - 1,
|
||||
stream=results["stream_ordering"],
|
||||
topological=topological_ordering - 1, stream=stream_ordering
|
||||
)
|
||||
|
||||
after_token = RoomStreamToken(
|
||||
topological=results["topological_ordering"],
|
||||
stream=results["stream_ordering"],
|
||||
topological=topological_ordering, stream=stream_ordering
|
||||
)
|
||||
|
||||
rows, start_token = self._paginate_room_events_txn(
|
||||
|
|
|
@ -183,39 +183,27 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
|
|||
|
||||
Returns: the task if available, `None` otherwise
|
||||
"""
|
||||
row = await self.db_pool.simple_select_one(
|
||||
table="scheduled_tasks",
|
||||
keyvalues={"id": id},
|
||||
retcols=(
|
||||
"id",
|
||||
"action",
|
||||
"status",
|
||||
"timestamp",
|
||||
"resource_id",
|
||||
"params",
|
||||
"result",
|
||||
"error",
|
||||
row = cast(
|
||||
Optional[ScheduledTaskRow],
|
||||
await self.db_pool.simple_select_one(
|
||||
table="scheduled_tasks",
|
||||
keyvalues={"id": id},
|
||||
retcols=(
|
||||
"id",
|
||||
"action",
|
||||
"status",
|
||||
"timestamp",
|
||||
"resource_id",
|
||||
"params",
|
||||
"result",
|
||||
"error",
|
||||
),
|
||||
allow_none=True,
|
||||
desc="get_scheduled_task",
|
||||
),
|
||||
allow_none=True,
|
||||
desc="get_scheduled_task",
|
||||
)
|
||||
|
||||
return (
|
||||
TaskSchedulerWorkerStore._convert_row_to_task(
|
||||
(
|
||||
row["id"],
|
||||
row["action"],
|
||||
row["status"],
|
||||
row["timestamp"],
|
||||
row["resource_id"],
|
||||
row["params"],
|
||||
row["result"],
|
||||
row["error"],
|
||||
)
|
||||
)
|
||||
if row
|
||||
else None
|
||||
)
|
||||
return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
|
||||
|
||||
async def delete_scheduled_task(self, id: str) -> None:
|
||||
"""Delete a specific task from its id.
|
||||
|
|
|
@ -118,19 +118,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
|||
txn,
|
||||
table="received_transactions",
|
||||
keyvalues={"transaction_id": transaction_id, "origin": origin},
|
||||
retcols=(
|
||||
"transaction_id",
|
||||
"origin",
|
||||
"ts",
|
||||
"response_code",
|
||||
"response_json",
|
||||
"has_been_referenced",
|
||||
),
|
||||
retcols=("response_code", "response_json"),
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if result and result["response_code"]:
|
||||
return result["response_code"], db_to_json(result["response_json"])
|
||||
# If the result exists and the response code is non-0.
|
||||
if result and result[0]:
|
||||
return result[0], db_to_json(result[1])
|
||||
|
||||
else:
|
||||
return None
|
||||
|
@ -200,8 +194,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
# check we have a row and retry_last_ts is not null or zero
|
||||
# (retry_last_ts can't be negative)
|
||||
if result and result["retry_last_ts"]:
|
||||
return DestinationRetryTimings(**result)
|
||||
if result and result[1]:
|
||||
return DestinationRetryTimings(
|
||||
failure_ts=result[0], retry_last_ts=result[1], retry_interval=result[2]
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
|
|
@ -122,9 +122,13 @@ class UIAuthWorkerStore(SQLBaseStore):
|
|||
desc="get_ui_auth_session",
|
||||
)
|
||||
|
||||
result["clientdict"] = db_to_json(result["clientdict"])
|
||||
|
||||
return UIAuthSessionData(session_id, **result)
|
||||
return UIAuthSessionData(
|
||||
session_id,
|
||||
clientdict=db_to_json(result[0]),
|
||||
uri=result[1],
|
||||
method=result[2],
|
||||
description=result[3],
|
||||
)
|
||||
|
||||
async def mark_ui_auth_stage_complete(
|
||||
self,
|
||||
|
@ -231,18 +235,15 @@ class UIAuthWorkerStore(SQLBaseStore):
|
|||
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
|
||||
) -> None:
|
||||
# Get the current value.
|
||||
result = cast(
|
||||
Dict[str, Any],
|
||||
self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
table="ui_auth_sessions",
|
||||
keyvalues={"session_id": session_id},
|
||||
retcols=("serverdict",),
|
||||
),
|
||||
result = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="ui_auth_sessions",
|
||||
keyvalues={"session_id": session_id},
|
||||
retcol="serverdict",
|
||||
)
|
||||
|
||||
# Update it and add it back to the database.
|
||||
serverdict = db_to_json(result["serverdict"])
|
||||
serverdict = db_to_json(result)
|
||||
serverdict[key] = value
|
||||
|
||||
self.db_pool.simple_update_one_txn(
|
||||
|
@ -265,14 +266,14 @@ class UIAuthWorkerStore(SQLBaseStore):
|
|||
Raises:
|
||||
StoreError if the session cannot be found.
|
||||
"""
|
||||
result = await self.db_pool.simple_select_one(
|
||||
result = await self.db_pool.simple_select_one_onecol(
|
||||
table="ui_auth_sessions",
|
||||
keyvalues={"session_id": session_id},
|
||||
retcols=("serverdict",),
|
||||
retcol="serverdict",
|
||||
desc="get_ui_auth_session_data",
|
||||
)
|
||||
|
||||
serverdict = db_to_json(result["serverdict"])
|
||||
serverdict = db_to_json(result)
|
||||
|
||||
return serverdict.get(key, default)
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@ from typing import (
|
|||
Collection,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
|
@ -868,13 +867,25 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
|
||||
)
|
||||
|
||||
async def _get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]:
|
||||
return await self.db_pool.simple_select_one(
|
||||
table="user_directory",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("display_name", "avatar_url"),
|
||||
allow_none=True,
|
||||
desc="get_user_in_directory",
|
||||
async def _get_user_in_directory(
|
||||
self, user_id: str
|
||||
) -> Optional[Tuple[Optional[str], Optional[str]]]:
|
||||
"""
|
||||
Fetch the user information in the user directory.
|
||||
|
||||
Returns:
|
||||
None if the user is unknown, otherwise a tuple of display name and
|
||||
avatar URL (both of which may be None).
|
||||
"""
|
||||
return cast(
|
||||
Optional[Tuple[Optional[str], Optional[str]]],
|
||||
await self.db_pool.simple_select_one(
|
||||
table="user_directory",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("display_name", "avatar_url"),
|
||||
allow_none=True,
|
||||
desc="get_user_in_directory",
|
||||
),
|
||||
)
|
||||
|
||||
async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None:
|
||||
|
|
|
@ -650,8 +650,8 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
|||
|
||||
next_id = self._load_next_id_txn(txn)
|
||||
|
||||
txn.call_after(self._mark_id_as_finished, next_id)
|
||||
txn.call_on_exception(self._mark_id_as_finished, next_id)
|
||||
txn.call_after(self._mark_ids_as_finished, [next_id])
|
||||
txn.call_on_exception(self._mark_ids_as_finished, [next_id])
|
||||
txn.call_after(self._notifier.notify_replication)
|
||||
|
||||
# Update the `stream_positions` table with newly updated stream
|
||||
|
@ -671,14 +671,50 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
|||
|
||||
return self._return_factor * next_id
|
||||
|
||||
def _mark_id_as_finished(self, next_id: int) -> None:
|
||||
"""The ID has finished being processed so we should advance the
|
||||
def get_next_mult_txn(self, txn: LoggingTransaction, n: int) -> List[int]:
|
||||
"""
|
||||
Usage:
|
||||
|
||||
stream_id = stream_id_gen.get_next_txn(txn)
|
||||
# ... persist event ...
|
||||
"""
|
||||
|
||||
# If we have a list of instances that are allowed to write to this
|
||||
# stream, make sure we're in it.
|
||||
if self._writers and self._instance_name not in self._writers:
|
||||
raise Exception("Tried to allocate stream ID on non-writer")
|
||||
|
||||
next_ids = self._load_next_mult_id_txn(txn, n)
|
||||
|
||||
txn.call_after(self._mark_ids_as_finished, next_ids)
|
||||
txn.call_on_exception(self._mark_ids_as_finished, next_ids)
|
||||
txn.call_after(self._notifier.notify_replication)
|
||||
|
||||
# Update the `stream_positions` table with newly updated stream
|
||||
# ID (unless self._writers is not set in which case we don't
|
||||
# bother, as nothing will read it).
|
||||
#
|
||||
# We only do this on the success path so that the persisted current
|
||||
# position points to a persisted row with the correct instance name.
|
||||
if self._writers:
|
||||
txn.call_after(
|
||||
run_as_background_process,
|
||||
"MultiWriterIdGenerator._update_table",
|
||||
self._db.runInteraction,
|
||||
"MultiWriterIdGenerator._update_table",
|
||||
self._update_stream_positions_table_txn,
|
||||
)
|
||||
|
||||
return [self._return_factor * next_id for next_id in next_ids]
|
||||
|
||||
def _mark_ids_as_finished(self, next_ids: List[int]) -> None:
|
||||
"""These IDs have finished being processed so we should advance the
|
||||
current position if possible.
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
self._unfinished_ids.discard(next_id)
|
||||
self._finished_ids.add(next_id)
|
||||
self._unfinished_ids.difference_update(next_ids)
|
||||
self._finished_ids.update(next_ids)
|
||||
|
||||
new_cur: Optional[int] = None
|
||||
|
||||
|
@ -727,7 +763,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
|||
curr, new_cur, self._max_position_of_local_instance
|
||||
)
|
||||
|
||||
self._add_persisted_position(next_id)
|
||||
# TODO Can we call this for just the last position or somehow batch
|
||||
# _add_persisted_position.
|
||||
for next_id in next_ids:
|
||||
self._add_persisted_position(next_id)
|
||||
|
||||
def get_current_token(self) -> int:
|
||||
return self.get_persisted_upto_position()
|
||||
|
@ -933,8 +972,7 @@ class _MultiWriterCtxManager:
|
|||
exc: Optional[BaseException],
|
||||
tb: Optional[TracebackType],
|
||||
) -> bool:
|
||||
for i in self.stream_ids:
|
||||
self.id_gen._mark_id_as_finished(i)
|
||||
self.id_gen._mark_ids_as_finished(self.stream_ids)
|
||||
|
||||
self.notifier.notify_replication()
|
||||
|
||||
|
|
|
@ -84,7 +84,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
|
||||
cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type])
|
||||
|
||||
return self.get_success(
|
||||
row = self.get_success(
|
||||
self.store.db_pool.simple_select_one(
|
||||
table + "_current",
|
||||
{id_col: stat_id},
|
||||
|
@ -93,6 +93,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
)
|
||||
)
|
||||
|
||||
return None if row is None else dict(zip(cols, row))
|
||||
|
||||
def _perform_background_initial_update(self) -> None:
|
||||
# Do the initial population of the stats via the background update
|
||||
self._add_background_updates()
|
||||
|
|
|
@ -366,7 +366,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
profile = self.get_success(self.store._get_user_in_directory(regular_user_id))
|
||||
assert profile is not None
|
||||
self.assertTrue(profile["display_name"] == display_name)
|
||||
self.assertTrue(profile[0] == display_name)
|
||||
|
||||
def test_handle_local_profile_change_with_deactivated_user(self) -> None:
|
||||
# create user
|
||||
|
@ -385,7 +385,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
|||
# profile is in directory
|
||||
profile = self.get_success(self.store._get_user_in_directory(r_user_id))
|
||||
assert profile is not None
|
||||
self.assertTrue(profile["display_name"] == display_name)
|
||||
self.assertEqual(profile[0], display_name)
|
||||
|
||||
# deactivate user
|
||||
self.get_success(self.store.set_user_deactivated_status(r_user_id, True))
|
||||
|
|
|
@ -504,7 +504,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
|||
origin, media_id = self.media_id.split("/")
|
||||
info = self.get_success(self.store.get_cached_remote_media(origin, media_id))
|
||||
assert info is not None
|
||||
file_id = info["filesystem_id"]
|
||||
file_id = info.filesystem_id
|
||||
|
||||
thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
|
||||
origin, file_id
|
||||
|
|
|
@ -642,7 +642,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
|
|||
|
||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||
assert media_info is not None
|
||||
self.assertFalse(media_info["quarantined_by"])
|
||||
self.assertFalse(media_info.quarantined_by)
|
||||
|
||||
# quarantining
|
||||
channel = self.make_request(
|
||||
|
@ -656,7 +656,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
|
|||
|
||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||
assert media_info is not None
|
||||
self.assertTrue(media_info["quarantined_by"])
|
||||
self.assertTrue(media_info.quarantined_by)
|
||||
|
||||
# remove from quarantine
|
||||
channel = self.make_request(
|
||||
|
@ -670,7 +670,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
|
|||
|
||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||
assert media_info is not None
|
||||
self.assertFalse(media_info["quarantined_by"])
|
||||
self.assertFalse(media_info.quarantined_by)
|
||||
|
||||
def test_quarantine_protected_media(self) -> None:
|
||||
"""
|
||||
|
@ -683,7 +683,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
|
|||
# verify protection
|
||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||
assert media_info is not None
|
||||
self.assertTrue(media_info["safe_from_quarantine"])
|
||||
self.assertTrue(media_info.safe_from_quarantine)
|
||||
|
||||
# quarantining
|
||||
channel = self.make_request(
|
||||
|
@ -698,7 +698,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
|
|||
# verify that is not in quarantine
|
||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||
assert media_info is not None
|
||||
self.assertFalse(media_info["quarantined_by"])
|
||||
self.assertFalse(media_info.quarantined_by)
|
||||
|
||||
|
||||
class ProtectMediaByIDTestCase(_AdminMediaTests):
|
||||
|
@ -756,7 +756,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
|
|||
|
||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||
assert media_info is not None
|
||||
self.assertFalse(media_info["safe_from_quarantine"])
|
||||
self.assertFalse(media_info.safe_from_quarantine)
|
||||
|
||||
# protect
|
||||
channel = self.make_request(
|
||||
|
@ -770,7 +770,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
|
|||
|
||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||
assert media_info is not None
|
||||
self.assertTrue(media_info["safe_from_quarantine"])
|
||||
self.assertTrue(media_info.safe_from_quarantine)
|
||||
|
||||
# unprotect
|
||||
channel = self.make_request(
|
||||
|
@ -784,7 +784,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
|
|||
|
||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||
assert media_info is not None
|
||||
self.assertFalse(media_info["safe_from_quarantine"])
|
||||
self.assertFalse(media_info.safe_from_quarantine)
|
||||
|
||||
|
||||
class PurgeMediaCacheTestCase(_AdminMediaTests):
|
||||
|
|
|
@ -2706,7 +2706,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
# is in user directory
|
||||
profile = self.get_success(self.store._get_user_in_directory(self.other_user))
|
||||
assert profile is not None
|
||||
self.assertTrue(profile["display_name"] == "User")
|
||||
self.assertEqual(profile[0], "User")
|
||||
|
||||
# Deactivate user
|
||||
channel = self.make_request(
|
||||
|
|
|
@ -139,12 +139,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
|||
#
|
||||
# Note that we don't have the UI Auth session ID, so just pull out the single
|
||||
# row.
|
||||
ui_auth_data = self.get_success(
|
||||
self.store.db_pool.simple_select_one(
|
||||
"ui_auth_sessions", keyvalues={}, retcols=("clientdict",)
|
||||
result = self.get_success(
|
||||
self.store.db_pool.simple_select_one_onecol(
|
||||
"ui_auth_sessions", keyvalues={}, retcol="clientdict"
|
||||
)
|
||||
)
|
||||
client_dict = db_to_json(ui_auth_data["clientdict"])
|
||||
client_dict = db_to_json(result)
|
||||
self.assertNotIn("new_password", client_dict)
|
||||
|
||||
@override_config({"rc_3pid_validation": {"burst_count": 3}})
|
||||
|
|
|
@ -270,15 +270,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
|||
self.assertLessEqual(det_data.items(), channel.json_body.items())
|
||||
|
||||
# Check the `completed` counter has been incremented and pending is 0
|
||||
res = self.get_success(
|
||||
pending, completed = self.get_success(
|
||||
store.db_pool.simple_select_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=["pending", "completed"],
|
||||
)
|
||||
)
|
||||
self.assertEqual(res["completed"], 1)
|
||||
self.assertEqual(res["pending"], 0)
|
||||
self.assertEqual(completed, 1)
|
||||
self.assertEqual(pending, 0)
|
||||
|
||||
@override_config({"registration_requires_token": True})
|
||||
def test_POST_registration_token_invalid(self) -> None:
|
||||
|
@ -372,15 +372,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
|||
params1["auth"]["type"] = LoginType.DUMMY
|
||||
self.make_request(b"POST", self.url, params1)
|
||||
# Check pending=0 and completed=1
|
||||
res = self.get_success(
|
||||
pending, completed = self.get_success(
|
||||
store.db_pool.simple_select_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=["pending", "completed"],
|
||||
)
|
||||
)
|
||||
self.assertEqual(res["pending"], 0)
|
||||
self.assertEqual(res["completed"], 1)
|
||||
self.assertEqual(pending, 0)
|
||||
self.assertEqual(completed, 1)
|
||||
|
||||
# Check auth still fails when using token with session2
|
||||
channel = self.make_request(b"POST", self.url, params2)
|
||||
|
|
|
@ -267,23 +267,23 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
|
|||
def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None:
|
||||
"""Given an MXC URI, assert whether it has been purged or not."""
|
||||
if mxc_uri.server_name == self.hs.config.server.server_name:
|
||||
found_media_dict = self.get_success(
|
||||
self.store.get_local_media(mxc_uri.media_id)
|
||||
found_media = bool(
|
||||
self.get_success(self.store.get_local_media(mxc_uri.media_id))
|
||||
)
|
||||
else:
|
||||
found_media_dict = self.get_success(
|
||||
self.store.get_cached_remote_media(
|
||||
mxc_uri.server_name, mxc_uri.media_id
|
||||
found_media = bool(
|
||||
self.get_success(
|
||||
self.store.get_cached_remote_media(
|
||||
mxc_uri.server_name, mxc_uri.media_id
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if expect_purged:
|
||||
self.assertIsNone(
|
||||
found_media_dict, msg=f"{mxc_uri} unexpectedly not purged"
|
||||
)
|
||||
self.assertFalse(found_media, msg=f"{mxc_uri} unexpectedly not purged")
|
||||
else:
|
||||
self.assertIsNotNone(
|
||||
found_media_dict,
|
||||
self.assertTrue(
|
||||
found_media,
|
||||
msg=f"{mxc_uri} unexpectedly purged",
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
# Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest.mock import Mock, call
|
||||
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
|
||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
|
||||
class CacheInvalidationTestCase(HomeserverTestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.store = self.hs.get_datastores().main
|
||||
|
||||
def test_bulk_invalidation(self) -> None:
|
||||
master_invalidate = Mock()
|
||||
|
||||
self.store._get_cached_user_device.invalidate = master_invalidate
|
||||
|
||||
keys_to_invalidate = [
|
||||
("a", "b"),
|
||||
("c", "d"),
|
||||
("e", "f"),
|
||||
("g", "h"),
|
||||
]
|
||||
|
||||
def test_txn(txn: LoggingTransaction) -> None:
|
||||
self.store._invalidate_cache_and_stream_bulk(
|
||||
txn,
|
||||
# This is an arbitrarily chosen cached store function. It was chosen
|
||||
# because it takes more than one argument. We'll use this later to
|
||||
# check that the invalidation was actioned over replication.
|
||||
cache_func=self.store._get_cached_user_device,
|
||||
key_tuples=keys_to_invalidate,
|
||||
)
|
||||
|
||||
self.get_success(
|
||||
self.store.db_pool.runInteraction(
|
||||
"test_invalidate_cache_and_stream_bulk", test_txn
|
||||
)
|
||||
)
|
||||
|
||||
master_invalidate.assert_has_calls(
|
||||
[call(key_list) for key_list in keys_to_invalidate],
|
||||
any_order=True,
|
||||
)
|
||||
|
||||
|
||||
class CacheInvalidationOverReplicationTestCase(BaseMultiWorkerStreamTestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.store = self.hs.get_datastores().main
|
||||
|
||||
def test_bulk_invalidation_replicates(self) -> None:
|
||||
"""Like test_bulk_invalidation, but also checks the invalidations replicate."""
|
||||
master_invalidate = Mock()
|
||||
worker_invalidate = Mock()
|
||||
|
||||
self.store._get_cached_user_device.invalidate = master_invalidate
|
||||
worker = self.make_worker_hs("synapse.app.generic_worker")
|
||||
worker_ds = worker.get_datastores().main
|
||||
worker_ds._get_cached_user_device.invalidate = worker_invalidate
|
||||
|
||||
keys_to_invalidate = [
|
||||
("a", "b"),
|
||||
("c", "d"),
|
||||
("e", "f"),
|
||||
("g", "h"),
|
||||
]
|
||||
|
||||
def test_txn(txn: LoggingTransaction) -> None:
|
||||
self.store._invalidate_cache_and_stream_bulk(
|
||||
txn,
|
||||
# This is an arbitrarily chosen cached store function. It was chosen
|
||||
# because it takes more than one argument. We'll use this later to
|
||||
# check that the invalidation was actioned over replication.
|
||||
cache_func=self.store._get_cached_user_device,
|
||||
key_tuples=keys_to_invalidate,
|
||||
)
|
||||
|
||||
assert self.store._cache_id_gen is not None
|
||||
initial_token = self.store._cache_id_gen.get_current_token()
|
||||
self.get_success(
|
||||
self.database_pool.runInteraction(
|
||||
"test_invalidate_cache_and_stream_bulk", test_txn
|
||||
)
|
||||
)
|
||||
second_token = self.store._cache_id_gen.get_current_token()
|
||||
|
||||
self.assertGreaterEqual(second_token, initial_token + len(keys_to_invalidate))
|
||||
|
||||
self.get_success(
|
||||
worker.get_replication_data_handler().wait_for_stream_position(
|
||||
"master", "caches", second_token
|
||||
)
|
||||
)
|
||||
|
||||
master_invalidate.assert_has_calls(
|
||||
[call(key_list) for key_list in keys_to_invalidate],
|
||||
any_order=True,
|
||||
)
|
||||
worker_invalidate.assert_has_calls(
|
||||
[call(key_list) for key_list in keys_to_invalidate],
|
||||
any_order=True,
|
||||
)
|
|
@ -222,7 +222,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
|||
)
|
||||
)
|
||||
|
||||
self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret)
|
||||
self.assertEqual((1, 2, 3), ret)
|
||||
self.mock_txn.execute.assert_called_once_with(
|
||||
"SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"]
|
||||
)
|
||||
|
@ -243,7 +243,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
|||
)
|
||||
)
|
||||
|
||||
self.assertFalse(ret)
|
||||
self.assertIsNone(ret)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:
|
||||
|
|
|
@ -42,16 +42,9 @@ class RoomStoreTestCase(HomeserverTestCase):
|
|||
)
|
||||
|
||||
def test_get_room(self) -> None:
|
||||
res = self.get_success(self.store.get_room(self.room.to_string()))
|
||||
assert res is not None
|
||||
self.assertLessEqual(
|
||||
{
|
||||
"room_id": self.room.to_string(),
|
||||
"creator": self.u_creator.to_string(),
|
||||
"is_public": True,
|
||||
}.items(),
|
||||
res.items(),
|
||||
)
|
||||
room = self.get_success(self.store.get_room(self.room.to_string()))
|
||||
assert room is not None
|
||||
self.assertTrue(room[0])
|
||||
|
||||
def test_get_room_unknown_room(self) -> None:
|
||||
self.assertIsNone(self.get_success(self.store.get_room("!uknown:test")))
|
||||
|
|
|
@ -83,11 +83,11 @@ def setupdb() -> None:
|
|||
|
||||
# Set up in the db
|
||||
db_conn = db_engine.module.connect(
|
||||
dbname=POSTGRES_BASE_DB,
|
||||
user=POSTGRES_USER,
|
||||
host=POSTGRES_HOST,
|
||||
port=POSTGRES_PORT,
|
||||
password=POSTGRES_PASSWORD,
|
||||
dbname=POSTGRES_BASE_DB,
|
||||
)
|
||||
logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
|
||||
prepare_database(logging_conn, db_engine, None)
|
||||
|
|
Loading…
Reference in New Issue