mirror of
https://github.com/matrix-org/synapse.git
synced 2025-01-26 18:16:01 +00:00
Convert simple_select_one_txn and simple_select_one to return tuples. (#16612)
This commit is contained in:
parent
ff716b483b
commit
ab3f1b3b53
1
changelog.d/16612.misc
Normal file
1
changelog.d/16612.misc
Normal file
@ -0,0 +1 @@
|
||||
Improve type hints.
|
@ -348,8 +348,7 @@ class Porter:
|
||||
backward_chunk = 0
|
||||
already_ported = 0
|
||||
else:
|
||||
forward_chunk = row["forward_rowid"]
|
||||
backward_chunk = row["backward_rowid"]
|
||||
forward_chunk, backward_chunk = row
|
||||
|
||||
if total_to_port is None:
|
||||
already_ported, total_to_port = await self._get_total_count_to_port(
|
||||
|
@ -269,7 +269,7 @@ class RoomCreationHandler:
|
||||
self,
|
||||
requester: Requester,
|
||||
old_room_id: str,
|
||||
old_room: Dict[str, Any],
|
||||
old_room: Tuple[bool, str, bool],
|
||||
new_room_id: str,
|
||||
new_version: RoomVersion,
|
||||
tombstone_event: EventBase,
|
||||
@ -279,7 +279,7 @@ class RoomCreationHandler:
|
||||
Args:
|
||||
requester: the user requesting the upgrade
|
||||
old_room_id: the id of the room to be replaced
|
||||
old_room: a dict containing room information for the room to be replaced,
|
||||
old_room: a tuple containing room information for the room to be replaced,
|
||||
as returned by `RoomWorkerStore.get_room`.
|
||||
new_room_id: the id of the replacement room
|
||||
new_version: the version to upgrade the room to
|
||||
@ -299,7 +299,7 @@ class RoomCreationHandler:
|
||||
await self.store.store_room(
|
||||
room_id=new_room_id,
|
||||
room_creator_user_id=user_id,
|
||||
is_public=old_room["is_public"],
|
||||
is_public=old_room[0],
|
||||
room_version=new_version,
|
||||
)
|
||||
|
||||
|
@ -1260,7 +1260,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||
# Add new room to the room directory if the old room was there
|
||||
# Remove old room from the room directory
|
||||
old_room = await self.store.get_room(old_room_id)
|
||||
if old_room is not None and old_room["is_public"]:
|
||||
# If the old room exists and is public.
|
||||
if old_room is not None and old_room[0]:
|
||||
await self.store.set_room_is_public(old_room_id, False)
|
||||
await self.store.set_room_is_public(room_id, True)
|
||||
|
||||
|
@ -1860,7 +1860,8 @@ class PublicRoomListManager:
|
||||
if not room:
|
||||
return False
|
||||
|
||||
return room.get("is_public", False)
|
||||
# The first item is whether the room is public.
|
||||
return room[0]
|
||||
|
||||
async def add_room_to_public_room_list(self, room_id: str) -> None:
|
||||
"""Publishes a room to the public room list.
|
||||
|
@ -413,8 +413,8 @@ class RoomMembersRestServlet(RestServlet):
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
ret = await self.store.get_room(room_id)
|
||||
if not ret:
|
||||
room = await self.store.get_room(room_id)
|
||||
if not room:
|
||||
raise NotFoundError("Room not found")
|
||||
|
||||
members = await self.store.get_users_in_room(room_id)
|
||||
@ -442,8 +442,8 @@ class RoomStateRestServlet(RestServlet):
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
ret = await self.store.get_room(room_id)
|
||||
if not ret:
|
||||
room = await self.store.get_room(room_id)
|
||||
if not room:
|
||||
raise NotFoundError("Room not found")
|
||||
|
||||
event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
|
||||
|
@ -147,7 +147,7 @@ class ClientDirectoryListServer(RestServlet):
|
||||
if room is None:
|
||||
raise NotFoundError("Unknown room")
|
||||
|
||||
return 200, {"visibility": "public" if room["is_public"] else "private"}
|
||||
return 200, {"visibility": "public" if room[0] else "private"}
|
||||
|
||||
class PutBody(RequestBodyModel):
|
||||
visibility: Literal["public", "private"] = "public"
|
||||
|
@ -1597,7 +1597,7 @@ class DatabasePool:
|
||||
retcols: Collection[str],
|
||||
allow_none: Literal[False] = False,
|
||||
desc: str = "simple_select_one",
|
||||
) -> Dict[str, Any]:
|
||||
) -> Tuple[Any, ...]:
|
||||
...
|
||||
|
||||
@overload
|
||||
@ -1608,7 +1608,7 @@ class DatabasePool:
|
||||
retcols: Collection[str],
|
||||
allow_none: Literal[True] = True,
|
||||
desc: str = "simple_select_one",
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
) -> Optional[Tuple[Any, ...]]:
|
||||
...
|
||||
|
||||
async def simple_select_one(
|
||||
@ -1618,7 +1618,7 @@ class DatabasePool:
|
||||
retcols: Collection[str],
|
||||
allow_none: bool = False,
|
||||
desc: str = "simple_select_one",
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
) -> Optional[Tuple[Any, ...]]:
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning multiple columns from it.
|
||||
|
||||
@ -2127,7 +2127,7 @@ class DatabasePool:
|
||||
keyvalues: Dict[str, Any],
|
||||
retcols: Collection[str],
|
||||
allow_none: bool = False,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
) -> Optional[Tuple[Any, ...]]:
|
||||
select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
|
||||
|
||||
if keyvalues:
|
||||
@ -2145,7 +2145,7 @@ class DatabasePool:
|
||||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "More than one row matched (%s)" % (table,))
|
||||
|
||||
return dict(zip(retcols, row))
|
||||
return row
|
||||
|
||||
async def simple_delete_one(
|
||||
self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
|
||||
|
@ -255,33 +255,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||
A dict containing the device information, or `None` if the device does not
|
||||
exist.
|
||||
"""
|
||||
return await self.db_pool.simple_select_one(
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
|
||||
retcols=("user_id", "device_id", "display_name"),
|
||||
desc="get_device",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
async def get_device_opt(
|
||||
self, user_id: str, device_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Retrieve a device. Only returns devices that are not marked as
|
||||
hidden.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user which owns the device
|
||||
device_id: The ID of the device to retrieve
|
||||
Returns:
|
||||
A dict containing the device information, or None if the device does not exist.
|
||||
"""
|
||||
return await self.db_pool.simple_select_one(
|
||||
row = await self.db_pool.simple_select_one(
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
|
||||
retcols=("user_id", "device_id", "display_name"),
|
||||
desc="get_device",
|
||||
allow_none=True,
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return {"user_id": row[0], "device_id": row[1], "display_name": row[2]}
|
||||
|
||||
async def get_devices_by_user(
|
||||
self, user_id: str
|
||||
@ -1221,9 +1204,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||
retcols=["device_id", "device_data"],
|
||||
allow_none=True,
|
||||
)
|
||||
return (
|
||||
(row["device_id"], json_decoder.decode(row["device_data"])) if row else None
|
||||
)
|
||||
return (row[0], json_decoder.decode(row[1])) if row else None
|
||||
|
||||
def _store_dehydrated_device_txn(
|
||||
self,
|
||||
@ -2326,13 +2307,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
`FALSE` have not been converted.
|
||||
"""
|
||||
|
||||
row = await self.db_pool.simple_select_one(
|
||||
return cast(
|
||||
Tuple[int, str],
|
||||
await self.db_pool.simple_select_one(
|
||||
table="device_lists_changes_converted_stream_position",
|
||||
keyvalues={},
|
||||
retcols=["stream_id", "room_id"],
|
||||
desc="get_device_change_last_converted_pos",
|
||||
),
|
||||
)
|
||||
return row["stream_id"], row["room_id"]
|
||||
|
||||
async def set_device_change_last_converted_pos(
|
||||
self,
|
||||
|
@ -506,19 +506,26 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
|
||||
# it isn't there.
|
||||
raise StoreError(404, "No backup with that version exists")
|
||||
|
||||
result = self.db_pool.simple_select_one_txn(
|
||||
row = cast(
|
||||
Tuple[int, str, str, Optional[int]],
|
||||
self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
table="e2e_room_keys_versions",
|
||||
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"version": this_version,
|
||||
"deleted": 0,
|
||||
},
|
||||
retcols=("version", "algorithm", "auth_data", "etag"),
|
||||
allow_none=False,
|
||||
),
|
||||
)
|
||||
assert result is not None # see comment on `simple_select_one_txn`
|
||||
result["auth_data"] = db_to_json(result["auth_data"])
|
||||
result["version"] = str(result["version"])
|
||||
if result["etag"] is None:
|
||||
result["etag"] = 0
|
||||
return result
|
||||
return {
|
||||
"auth_data": db_to_json(row[2]),
|
||||
"version": str(row[0]),
|
||||
"algorithm": row[1],
|
||||
"etag": 0 if row[3] is None else row[3],
|
||||
}
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
|
||||
|
@ -1266,9 +1266,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
if row is None:
|
||||
continue
|
||||
|
||||
key_id = row["key_id"]
|
||||
key_json = row["key_json"]
|
||||
used = row["used"]
|
||||
key_id, key_json, used = row
|
||||
|
||||
# Mark fallback key as used if not already.
|
||||
if not used and mark_as_used:
|
||||
|
@ -193,7 +193,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
# Check if we have indexed the room so we can use the chain cover
|
||||
# algorithm.
|
||||
room = await self.get_room(room_id) # type: ignore[attr-defined]
|
||||
if room["has_auth_chain_index"]:
|
||||
# If the room has an auth chain index.
|
||||
if room[1]:
|
||||
try:
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_auth_chain_ids_chains",
|
||||
@ -411,7 +412,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
# Check if we have indexed the room so we can use the chain cover
|
||||
# algorithm.
|
||||
room = await self.get_room(room_id) # type: ignore[attr-defined]
|
||||
if room["has_auth_chain_index"]:
|
||||
# If the room has an auth chain index.
|
||||
if room[1]:
|
||||
try:
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_auth_chain_difference_chains",
|
||||
@ -1437,24 +1439,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
)
|
||||
|
||||
if event_lookup_result is not None:
|
||||
event_type, depth, stream_ordering = event_lookup_result
|
||||
logger.debug(
|
||||
"_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s",
|
||||
room_id,
|
||||
seed_event_id,
|
||||
event_lookup_result["depth"],
|
||||
event_lookup_result["stream_ordering"],
|
||||
event_lookup_result["type"],
|
||||
depth,
|
||||
stream_ordering,
|
||||
event_type,
|
||||
)
|
||||
|
||||
if event_lookup_result["depth"]:
|
||||
queue.put(
|
||||
(
|
||||
-event_lookup_result["depth"],
|
||||
-event_lookup_result["stream_ordering"],
|
||||
seed_event_id,
|
||||
event_lookup_result["type"],
|
||||
)
|
||||
)
|
||||
if depth:
|
||||
queue.put((-depth, -stream_ordering, seed_event_id, event_type))
|
||||
|
||||
while not queue.empty() and len(event_id_results) < limit:
|
||||
try:
|
||||
|
@ -1934,8 +1934,7 @@ class PersistEventsStore:
|
||||
if row is None:
|
||||
return
|
||||
|
||||
redacted_relates_to = row["relates_to_id"]
|
||||
rel_type = row["relation_type"]
|
||||
redacted_relates_to, rel_type = row
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
|
||||
)
|
||||
|
@ -1998,7 +1998,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
if not res:
|
||||
raise SynapseError(404, "Could not find event %s" % (event_id,))
|
||||
|
||||
return int(res["topological_ordering"]), int(res["stream_ordering"])
|
||||
return int(res[0]), int(res[1])
|
||||
|
||||
async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
|
||||
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
|
||||
|
@ -208,7 +208,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return LocalMedia(media_id=media_id, **row)
|
||||
return LocalMedia(
|
||||
media_id=media_id,
|
||||
media_type=row[0],
|
||||
media_length=row[1],
|
||||
upload_name=row[2],
|
||||
created_ts=row[3],
|
||||
quarantined_by=row[4],
|
||||
url_cache=row[5],
|
||||
last_access_ts=row[6],
|
||||
safe_from_quarantine=row[7],
|
||||
)
|
||||
|
||||
async def get_local_media_by_user_paginate(
|
||||
self,
|
||||
@ -541,7 +551,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
)
|
||||
if row is None:
|
||||
return row
|
||||
return RemoteMedia(media_origin=origin, media_id=media_id, **row)
|
||||
return RemoteMedia(
|
||||
media_origin=origin,
|
||||
media_id=media_id,
|
||||
media_type=row[0],
|
||||
media_length=row[1],
|
||||
upload_name=row[2],
|
||||
created_ts=row[3],
|
||||
filesystem_id=row[4],
|
||||
last_access_ts=row[5],
|
||||
quarantined_by=row[6],
|
||||
)
|
||||
|
||||
async def store_cached_remote_media(
|
||||
self,
|
||||
@ -665,11 +685,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
if row is None:
|
||||
return None
|
||||
return ThumbnailInfo(
|
||||
width=row["thumbnail_width"],
|
||||
height=row["thumbnail_height"],
|
||||
method=row["thumbnail_method"],
|
||||
type=row["thumbnail_type"],
|
||||
length=row["thumbnail_length"],
|
||||
width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
|
||||
)
|
||||
|
||||
@trace
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
@ -138,23 +137,18 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||
return 50
|
||||
|
||||
async def get_profileinfo(self, user_id: UserID) -> ProfileInfo:
|
||||
try:
|
||||
profile = await self.db_pool.simple_select_one(
|
||||
table="profiles",
|
||||
keyvalues={"full_user_id": user_id.to_string()},
|
||||
retcols=("displayname", "avatar_url"),
|
||||
desc="get_profileinfo",
|
||||
allow_none=True,
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
if profile is None:
|
||||
# no match
|
||||
return ProfileInfo(None, None)
|
||||
else:
|
||||
raise
|
||||
|
||||
return ProfileInfo(
|
||||
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
|
||||
)
|
||||
return ProfileInfo(avatar_url=profile[1], display_name=profile[0])
|
||||
|
||||
async def get_profile_displayname(self, user_id: UserID) -> Optional[str]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
|
@ -468,8 +468,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
"before/after rule not found: %s" % (relative_to_rule,)
|
||||
)
|
||||
|
||||
base_priority_class = res["priority_class"]
|
||||
base_rule_priority = res["priority"]
|
||||
base_priority_class, base_rule_priority = res
|
||||
|
||||
if base_priority_class != priority_class:
|
||||
raise InconsistentRuleException(
|
||||
|
@ -701,8 +701,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
stream_ordering = int(res["stream_ordering"]) if res else None
|
||||
rx_ts = res["received_ts"] if res else 0
|
||||
stream_ordering = int(res[0]) if res else None
|
||||
rx_ts = res[1] if res else 0
|
||||
|
||||
# We don't want to clobber receipts for more recent events, so we
|
||||
# have to compare orderings of existing receipts
|
||||
|
@ -425,17 +425,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
account timestamp as milliseconds since the epoch. None if the account
|
||||
has not been renewed using the current token yet.
|
||||
"""
|
||||
ret_dict = await self.db_pool.simple_select_one(
|
||||
return cast(
|
||||
Tuple[str, int, Optional[int]],
|
||||
await self.db_pool.simple_select_one(
|
||||
table="account_validity",
|
||||
keyvalues={"renewal_token": renewal_token},
|
||||
retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
|
||||
desc="get_user_from_renewal_token",
|
||||
)
|
||||
|
||||
return (
|
||||
ret_dict["user_id"],
|
||||
ret_dict["expiration_ts_ms"],
|
||||
ret_dict["token_used_ts_ms"],
|
||||
),
|
||||
)
|
||||
|
||||
async def get_renewal_token_for_user(self, user_id: str) -> str:
|
||||
@ -989,16 +986,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
Returns:
|
||||
user id, or None if no user id/threepid mapping exists
|
||||
"""
|
||||
ret = self.db_pool.simple_select_one_txn(
|
||||
return self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
"user_threepids",
|
||||
{"medium": medium, "address": address},
|
||||
["user_id"],
|
||||
"user_id",
|
||||
True,
|
||||
)
|
||||
if ret:
|
||||
return ret["user_id"]
|
||||
return None
|
||||
|
||||
async def user_add_threepid(
|
||||
self,
|
||||
@ -1435,16 +1429,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
if res is None:
|
||||
return False
|
||||
|
||||
uses_allowed, pending, completed, expiry_time = res
|
||||
|
||||
# Check if the token has expired
|
||||
now = self._clock.time_msec()
|
||||
if res["expiry_time"] and res["expiry_time"] < now:
|
||||
if expiry_time and expiry_time < now:
|
||||
return False
|
||||
|
||||
# Check if the token has been used up
|
||||
if (
|
||||
res["uses_allowed"]
|
||||
and res["pending"] + res["completed"] >= res["uses_allowed"]
|
||||
):
|
||||
if uses_allowed and pending + completed >= uses_allowed:
|
||||
return False
|
||||
|
||||
# Otherwise, the token is valid
|
||||
@ -1490,8 +1483,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
# Override type because the return type is only optional if
|
||||
# allow_none is True, and we don't want mypy throwing errors
|
||||
# about None not being indexable.
|
||||
res = cast(
|
||||
Dict[str, Any],
|
||||
pending, completed = cast(
|
||||
Tuple[int, int],
|
||||
self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
@ -1506,8 +1499,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
updatevalues={
|
||||
"completed": res["completed"] + 1,
|
||||
"pending": res["pending"] - 1,
|
||||
"completed": completed + 1,
|
||||
"pending": pending - 1,
|
||||
},
|
||||
)
|
||||
|
||||
@ -1585,13 +1578,22 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
Returns:
|
||||
A dict, or None if token doesn't exist.
|
||||
"""
|
||||
return await self.db_pool.simple_select_one(
|
||||
row = await self.db_pool.simple_select_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
|
||||
allow_none=True,
|
||||
desc="get_one_registration_token",
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return {
|
||||
"token": row[0],
|
||||
"uses_allowed": row[1],
|
||||
"pending": row[2],
|
||||
"completed": row[3],
|
||||
"expiry_time": row[4],
|
||||
}
|
||||
|
||||
async def generate_registration_token(
|
||||
self, length: int, chars: str
|
||||
@ -1714,7 +1716,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
return None
|
||||
|
||||
# Get all info about the token so it can be sent in the response
|
||||
return self.db_pool.simple_select_one_txn(
|
||||
result = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
@ -1728,6 +1730,17 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return result
|
||||
|
||||
return {
|
||||
"token": result[0],
|
||||
"uses_allowed": result[1],
|
||||
"pending": result[2],
|
||||
"completed": result[3],
|
||||
"expiry_time": result[4],
|
||||
}
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"update_registration_token", _update_registration_token_txn
|
||||
)
|
||||
@ -1939,11 +1952,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
keyvalues={"token": token},
|
||||
updatevalues={"used_ts": ts},
|
||||
)
|
||||
user_id = values["user_id"]
|
||||
expiry_ts = values["expiry_ts"]
|
||||
used_ts = values["used_ts"]
|
||||
auth_provider_id = values["auth_provider_id"]
|
||||
auth_provider_session_id = values["auth_provider_session_id"]
|
||||
(
|
||||
user_id,
|
||||
expiry_ts,
|
||||
used_ts,
|
||||
auth_provider_id,
|
||||
auth_provider_session_id,
|
||||
) = values
|
||||
|
||||
# Token was already used
|
||||
if used_ts is not None:
|
||||
@ -2756,12 +2771,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||
# reason, the next check is on the client secret, which is NOT NULL,
|
||||
# so we don't have to worry about the client secret matching by
|
||||
# accident.
|
||||
row = {"client_secret": None, "validated_at": None}
|
||||
row = None, None
|
||||
else:
|
||||
raise ThreepidValidationError("Unknown session_id")
|
||||
|
||||
retrieved_client_secret = row["client_secret"]
|
||||
validated_at = row["validated_at"]
|
||||
retrieved_client_secret, validated_at = row
|
||||
|
||||
row = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
@ -2775,8 +2789,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||
raise ThreepidValidationError(
|
||||
"Validation token not found or has expired"
|
||||
)
|
||||
expires = row["expires"]
|
||||
next_link = row["next_link"]
|
||||
expires, next_link = row
|
||||
|
||||
if retrieved_client_secret != client_secret:
|
||||
raise ThreepidValidationError(
|
||||
|
@ -213,21 +213,31 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||
logger.error("store_room with room_id=%s failed: %s", room_id, e)
|
||||
raise StoreError(500, "Problem creating room.")
|
||||
|
||||
async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]:
|
||||
async def get_room(self, room_id: str) -> Optional[Tuple[bool, bool]]:
|
||||
"""Retrieve a room.
|
||||
|
||||
Args:
|
||||
room_id: The ID of the room to retrieve.
|
||||
Returns:
|
||||
A dict containing the room information, or None if the room is unknown.
|
||||
A tuple containing the room information:
|
||||
* True if the room is public
|
||||
* True if the room has an auth chain index
|
||||
|
||||
or None if the room is unknown.
|
||||
"""
|
||||
return await self.db_pool.simple_select_one(
|
||||
row = cast(
|
||||
Optional[Tuple[Optional[Union[int, bool]], Optional[Union[int, bool]]]],
|
||||
await self.db_pool.simple_select_one(
|
||||
table="rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcols=("room_id", "is_public", "creator", "has_auth_chain_index"),
|
||||
retcols=("is_public", "has_auth_chain_index"),
|
||||
desc="get_room",
|
||||
allow_none=True,
|
||||
),
|
||||
)
|
||||
if row is None:
|
||||
return row
|
||||
return bool(row[0]), bool(row[1])
|
||||
|
||||
async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]:
|
||||
"""Retrieve room with statistics.
|
||||
@ -794,10 +804,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||
)
|
||||
|
||||
if row:
|
||||
return RatelimitOverride(
|
||||
messages_per_second=row["messages_per_second"],
|
||||
burst_count=row["burst_count"],
|
||||
)
|
||||
return RatelimitOverride(messages_per_second=row[0], burst_count=row[1])
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -1371,13 +1378,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||
join.
|
||||
"""
|
||||
|
||||
result = await self.db_pool.simple_select_one(
|
||||
return cast(
|
||||
Tuple[str, int],
|
||||
await self.db_pool.simple_select_one(
|
||||
table="partial_state_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcols=("join_event_id", "device_lists_stream_id"),
|
||||
desc="get_join_event_id_for_partial_state",
|
||||
),
|
||||
)
|
||||
return result["join_event_id"], result["device_lists_stream_id"]
|
||||
|
||||
def get_un_partial_stated_rooms_token(self, instance_name: str) -> int:
|
||||
return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer(
|
||||
|
@ -559,17 +559,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||
"non-local user %s" % (user_id,),
|
||||
)
|
||||
|
||||
results_dict = await self.db_pool.simple_select_one(
|
||||
results = cast(
|
||||
Optional[Tuple[str, str]],
|
||||
await self.db_pool.simple_select_one(
|
||||
"local_current_membership",
|
||||
{"room_id": room_id, "user_id": user_id},
|
||||
("membership", "event_id"),
|
||||
allow_none=True,
|
||||
desc="get_local_current_membership_for_user_in_room",
|
||||
),
|
||||
)
|
||||
if not results_dict:
|
||||
if not results:
|
||||
return None, None
|
||||
|
||||
return results_dict.get("membership"), results_dict.get("event_id")
|
||||
return results
|
||||
|
||||
@cached(max_entries=500000, iterable=True)
|
||||
async def get_rooms_for_user_with_stream_ordering(
|
||||
|
@ -1014,9 +1014,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
desc="get_position_for_event",
|
||||
)
|
||||
|
||||
return PersistedEventPosition(
|
||||
row["instance_name"] or "master", row["stream_ordering"]
|
||||
)
|
||||
return PersistedEventPosition(row[1] or "master", row[0])
|
||||
|
||||
async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken:
|
||||
"""The stream token for an event
|
||||
@ -1033,9 +1031,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
retcols=("stream_ordering", "topological_ordering"),
|
||||
desc="get_topological_token_for_event",
|
||||
)
|
||||
return RoomStreamToken(
|
||||
topological=row["topological_ordering"], stream=row["stream_ordering"]
|
||||
)
|
||||
return RoomStreamToken(topological=row[1], stream=row[0])
|
||||
|
||||
async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
|
||||
"""Gets the topological token in a room after or at the given stream
|
||||
@ -1180,26 +1176,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
dict
|
||||
"""
|
||||
|
||||
results = self.db_pool.simple_select_one_txn(
|
||||
stream_ordering, topological_ordering = cast(
|
||||
Tuple[int, int],
|
||||
self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
"events",
|
||||
keyvalues={"event_id": event_id, "room_id": room_id},
|
||||
retcols=["stream_ordering", "topological_ordering"],
|
||||
),
|
||||
)
|
||||
|
||||
# This cannot happen as `allow_none=False`.
|
||||
assert results is not None
|
||||
|
||||
# Paginating backwards includes the event at the token, but paginating
|
||||
# forward doesn't.
|
||||
before_token = RoomStreamToken(
|
||||
topological=results["topological_ordering"] - 1,
|
||||
stream=results["stream_ordering"],
|
||||
topological=topological_ordering - 1, stream=stream_ordering
|
||||
)
|
||||
|
||||
after_token = RoomStreamToken(
|
||||
topological=results["topological_ordering"],
|
||||
stream=results["stream_ordering"],
|
||||
topological=topological_ordering, stream=stream_ordering
|
||||
)
|
||||
|
||||
rows, start_token = self._paginate_room_events_txn(
|
||||
|
@ -183,7 +183,9 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
|
||||
|
||||
Returns: the task if available, `None` otherwise
|
||||
"""
|
||||
row = await self.db_pool.simple_select_one(
|
||||
row = cast(
|
||||
Optional[ScheduledTaskRow],
|
||||
await self.db_pool.simple_select_one(
|
||||
table="scheduled_tasks",
|
||||
keyvalues={"id": id},
|
||||
retcols=(
|
||||
@ -198,24 +200,10 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
|
||||
),
|
||||
allow_none=True,
|
||||
desc="get_scheduled_task",
|
||||
),
|
||||
)
|
||||
|
||||
return (
|
||||
TaskSchedulerWorkerStore._convert_row_to_task(
|
||||
(
|
||||
row["id"],
|
||||
row["action"],
|
||||
row["status"],
|
||||
row["timestamp"],
|
||||
row["resource_id"],
|
||||
row["params"],
|
||||
row["result"],
|
||||
row["error"],
|
||||
)
|
||||
)
|
||||
if row
|
||||
else None
|
||||
)
|
||||
return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
|
||||
|
||||
async def delete_scheduled_task(self, id: str) -> None:
|
||||
"""Delete a specific task from its id.
|
||||
|
@ -118,19 +118,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||
txn,
|
||||
table="received_transactions",
|
||||
keyvalues={"transaction_id": transaction_id, "origin": origin},
|
||||
retcols=(
|
||||
"transaction_id",
|
||||
"origin",
|
||||
"ts",
|
||||
"response_code",
|
||||
"response_json",
|
||||
"has_been_referenced",
|
||||
),
|
||||
retcols=("response_code", "response_json"),
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if result and result["response_code"]:
|
||||
return result["response_code"], db_to_json(result["response_json"])
|
||||
# If the result exists and the response code is non-0.
|
||||
if result and result[0]:
|
||||
return result[0], db_to_json(result[1])
|
||||
|
||||
else:
|
||||
return None
|
||||
@ -200,8 +194,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||
|
||||
# check we have a row and retry_last_ts is not null or zero
|
||||
# (retry_last_ts can't be negative)
|
||||
if result and result["retry_last_ts"]:
|
||||
return DestinationRetryTimings(**result)
|
||||
if result and result[1]:
|
||||
return DestinationRetryTimings(
|
||||
failure_ts=result[0], retry_last_ts=result[1], retry_interval=result[2]
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
@ -122,9 +122,13 @@ class UIAuthWorkerStore(SQLBaseStore):
|
||||
desc="get_ui_auth_session",
|
||||
)
|
||||
|
||||
result["clientdict"] = db_to_json(result["clientdict"])
|
||||
|
||||
return UIAuthSessionData(session_id, **result)
|
||||
return UIAuthSessionData(
|
||||
session_id,
|
||||
clientdict=db_to_json(result[0]),
|
||||
uri=result[1],
|
||||
method=result[2],
|
||||
description=result[3],
|
||||
)
|
||||
|
||||
async def mark_ui_auth_stage_complete(
|
||||
self,
|
||||
@ -231,18 +235,15 @@ class UIAuthWorkerStore(SQLBaseStore):
|
||||
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
|
||||
) -> None:
|
||||
# Get the current value.
|
||||
result = cast(
|
||||
Dict[str, Any],
|
||||
self.db_pool.simple_select_one_txn(
|
||||
result = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="ui_auth_sessions",
|
||||
keyvalues={"session_id": session_id},
|
||||
retcols=("serverdict",),
|
||||
),
|
||||
retcol="serverdict",
|
||||
)
|
||||
|
||||
# Update it and add it back to the database.
|
||||
serverdict = db_to_json(result["serverdict"])
|
||||
serverdict = db_to_json(result)
|
||||
serverdict[key] = value
|
||||
|
||||
self.db_pool.simple_update_one_txn(
|
||||
@ -265,14 +266,14 @@ class UIAuthWorkerStore(SQLBaseStore):
|
||||
Raises:
|
||||
StoreError if the session cannot be found.
|
||||
"""
|
||||
result = await self.db_pool.simple_select_one(
|
||||
result = await self.db_pool.simple_select_one_onecol(
|
||||
table="ui_auth_sessions",
|
||||
keyvalues={"session_id": session_id},
|
||||
retcols=("serverdict",),
|
||||
retcol="serverdict",
|
||||
desc="get_ui_auth_session_data",
|
||||
)
|
||||
|
||||
serverdict = db_to_json(result["serverdict"])
|
||||
serverdict = db_to_json(result)
|
||||
|
||||
return serverdict.get(key, default)
|
||||
|
||||
|
@ -20,7 +20,6 @@ from typing import (
|
||||
Collection,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
@ -833,13 +832,25 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
|
||||
)
|
||||
|
||||
async def _get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]:
|
||||
return await self.db_pool.simple_select_one(
|
||||
async def _get_user_in_directory(
|
||||
self, user_id: str
|
||||
) -> Optional[Tuple[Optional[str], Optional[str]]]:
|
||||
"""
|
||||
Fetch the user information in the user directory.
|
||||
|
||||
Returns:
|
||||
None if the user is unknown, otherwise a tuple of display name and
|
||||
avatar URL (both of which may be None).
|
||||
"""
|
||||
return cast(
|
||||
Optional[Tuple[Optional[str], Optional[str]]],
|
||||
await self.db_pool.simple_select_one(
|
||||
table="user_directory",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("display_name", "avatar_url"),
|
||||
allow_none=True,
|
||||
desc="get_user_in_directory",
|
||||
),
|
||||
)
|
||||
|
||||
async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None:
|
||||
|
@ -84,7 +84,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||
|
||||
cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type])
|
||||
|
||||
return self.get_success(
|
||||
row = self.get_success(
|
||||
self.store.db_pool.simple_select_one(
|
||||
table + "_current",
|
||||
{id_col: stat_id},
|
||||
@ -93,6 +93,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
return None if row is None else dict(zip(cols, row))
|
||||
|
||||
def _perform_background_initial_update(self) -> None:
|
||||
# Do the initial population of the stats via the background update
|
||||
self._add_background_updates()
|
||||
|
@ -366,7 +366,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
profile = self.get_success(self.store._get_user_in_directory(regular_user_id))
|
||||
assert profile is not None
|
||||
self.assertTrue(profile["display_name"] == display_name)
|
||||
self.assertTrue(profile[0] == display_name)
|
||||
|
||||
def test_handle_local_profile_change_with_deactivated_user(self) -> None:
|
||||
# create user
|
||||
@ -385,7 +385,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||
# profile is in directory
|
||||
profile = self.get_success(self.store._get_user_in_directory(r_user_id))
|
||||
assert profile is not None
|
||||
self.assertTrue(profile["display_name"] == display_name)
|
||||
self.assertEqual(profile[0], display_name)
|
||||
|
||||
# deactivate user
|
||||
self.get_success(self.store.set_user_deactivated_status(r_user_id, True))
|
||||
|
@ -2706,7 +2706,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||
# is in user directory
|
||||
profile = self.get_success(self.store._get_user_in_directory(self.other_user))
|
||||
assert profile is not None
|
||||
self.assertTrue(profile["display_name"] == "User")
|
||||
self.assertEqual(profile[0], "User")
|
||||
|
||||
# Deactivate user
|
||||
channel = self.make_request(
|
||||
|
@ -139,12 +139,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||
#
|
||||
# Note that we don't have the UI Auth session ID, so just pull out the single
|
||||
# row.
|
||||
ui_auth_data = self.get_success(
|
||||
self.store.db_pool.simple_select_one(
|
||||
"ui_auth_sessions", keyvalues={}, retcols=("clientdict",)
|
||||
result = self.get_success(
|
||||
self.store.db_pool.simple_select_one_onecol(
|
||||
"ui_auth_sessions", keyvalues={}, retcol="clientdict"
|
||||
)
|
||||
)
|
||||
client_dict = db_to_json(ui_auth_data["clientdict"])
|
||||
client_dict = db_to_json(result)
|
||||
self.assertNotIn("new_password", client_dict)
|
||||
|
||||
@override_config({"rc_3pid_validation": {"burst_count": 3}})
|
||||
|
@ -270,15 +270,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
self.assertLessEqual(det_data.items(), channel.json_body.items())
|
||||
|
||||
# Check the `completed` counter has been incremented and pending is 0
|
||||
res = self.get_success(
|
||||
pending, completed = self.get_success(
|
||||
store.db_pool.simple_select_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=["pending", "completed"],
|
||||
)
|
||||
)
|
||||
self.assertEqual(res["completed"], 1)
|
||||
self.assertEqual(res["pending"], 0)
|
||||
self.assertEqual(completed, 1)
|
||||
self.assertEqual(pending, 0)
|
||||
|
||||
@override_config({"registration_requires_token": True})
|
||||
def test_POST_registration_token_invalid(self) -> None:
|
||||
@ -372,15 +372,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
params1["auth"]["type"] = LoginType.DUMMY
|
||||
self.make_request(b"POST", self.url, params1)
|
||||
# Check pending=0 and completed=1
|
||||
res = self.get_success(
|
||||
pending, completed = self.get_success(
|
||||
store.db_pool.simple_select_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=["pending", "completed"],
|
||||
)
|
||||
)
|
||||
self.assertEqual(res["pending"], 0)
|
||||
self.assertEqual(res["completed"], 1)
|
||||
self.assertEqual(pending, 0)
|
||||
self.assertEqual(completed, 1)
|
||||
|
||||
# Check auth still fails when using token with session2
|
||||
channel = self.make_request(b"POST", self.url, params2)
|
||||
|
@ -222,7 +222,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret)
|
||||
self.assertEqual((1, 2, 3), ret)
|
||||
self.mock_txn.execute.assert_called_once_with(
|
||||
"SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"]
|
||||
)
|
||||
@ -243,7 +243,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
self.assertFalse(ret)
|
||||
self.assertIsNone(ret)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:
|
||||
|
@ -42,16 +42,9 @@ class RoomStoreTestCase(HomeserverTestCase):
|
||||
)
|
||||
|
||||
def test_get_room(self) -> None:
|
||||
res = self.get_success(self.store.get_room(self.room.to_string()))
|
||||
assert res is not None
|
||||
self.assertLessEqual(
|
||||
{
|
||||
"room_id": self.room.to_string(),
|
||||
"creator": self.u_creator.to_string(),
|
||||
"is_public": True,
|
||||
}.items(),
|
||||
res.items(),
|
||||
)
|
||||
room = self.get_success(self.store.get_room(self.room.to_string()))
|
||||
assert room is not None
|
||||
self.assertTrue(room[0])
|
||||
|
||||
def test_get_room_unknown_room(self) -> None:
|
||||
self.assertIsNone(self.get_success(self.store.get_room("!uknown:test")))
|
||||
|
Loading…
Reference in New Issue
Block a user