Convert simple_select_one_txn and simple_select_one to return tuples. (#16612)

This commit is contained in:
Patrick Cloke 2023-11-09 11:13:31 -05:00 committed by GitHub
parent ff716b483b
commit ab3f1b3b53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 283 additions and 279 deletions

1
changelog.d/16612.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -348,8 +348,7 @@ class Porter:
backward_chunk = 0 backward_chunk = 0
already_ported = 0 already_ported = 0
else: else:
forward_chunk = row["forward_rowid"] forward_chunk, backward_chunk = row
backward_chunk = row["backward_rowid"]
if total_to_port is None: if total_to_port is None:
already_ported, total_to_port = await self._get_total_count_to_port( already_ported, total_to_port = await self._get_total_count_to_port(

View File

@ -269,7 +269,7 @@ class RoomCreationHandler:
self, self,
requester: Requester, requester: Requester,
old_room_id: str, old_room_id: str,
old_room: Dict[str, Any], old_room: Tuple[bool, str, bool],
new_room_id: str, new_room_id: str,
new_version: RoomVersion, new_version: RoomVersion,
tombstone_event: EventBase, tombstone_event: EventBase,
@ -279,7 +279,7 @@ class RoomCreationHandler:
Args: Args:
requester: the user requesting the upgrade requester: the user requesting the upgrade
old_room_id: the id of the room to be replaced old_room_id: the id of the room to be replaced
old_room: a dict containing room information for the room to be replaced, old_room: a tuple containing room information for the room to be replaced,
as returned by `RoomWorkerStore.get_room`. as returned by `RoomWorkerStore.get_room`.
new_room_id: the id of the replacement room new_room_id: the id of the replacement room
new_version: the version to upgrade the room to new_version: the version to upgrade the room to
@ -299,7 +299,7 @@ class RoomCreationHandler:
await self.store.store_room( await self.store.store_room(
room_id=new_room_id, room_id=new_room_id,
room_creator_user_id=user_id, room_creator_user_id=user_id,
is_public=old_room["is_public"], is_public=old_room[0],
room_version=new_version, room_version=new_version,
) )

View File

@ -1260,7 +1260,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Add new room to the room directory if the old room was there # Add new room to the room directory if the old room was there
# Remove old room from the room directory # Remove old room from the room directory
old_room = await self.store.get_room(old_room_id) old_room = await self.store.get_room(old_room_id)
if old_room is not None and old_room["is_public"]: # If the old room exists and is public.
if old_room is not None and old_room[0]:
await self.store.set_room_is_public(old_room_id, False) await self.store.set_room_is_public(old_room_id, False)
await self.store.set_room_is_public(room_id, True) await self.store.set_room_is_public(room_id, True)

View File

@ -1860,7 +1860,8 @@ class PublicRoomListManager:
if not room: if not room:
return False return False
return room.get("is_public", False) # The first item is whether the room is public.
return room[0]
async def add_room_to_public_room_list(self, room_id: str) -> None: async def add_room_to_public_room_list(self, room_id: str) -> None:
"""Publishes a room to the public room list. """Publishes a room to the public room list.

View File

@ -413,8 +413,8 @@ class RoomMembersRestServlet(RestServlet):
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room(room_id) room = await self.store.get_room(room_id)
if not ret: if not room:
raise NotFoundError("Room not found") raise NotFoundError("Room not found")
members = await self.store.get_users_in_room(room_id) members = await self.store.get_users_in_room(room_id)
@ -442,8 +442,8 @@ class RoomStateRestServlet(RestServlet):
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room(room_id) room = await self.store.get_room(room_id)
if not ret: if not room:
raise NotFoundError("Room not found") raise NotFoundError("Room not found")
event_ids = await self._storage_controllers.state.get_current_state_ids(room_id) event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)

View File

@ -147,7 +147,7 @@ class ClientDirectoryListServer(RestServlet):
if room is None: if room is None:
raise NotFoundError("Unknown room") raise NotFoundError("Unknown room")
return 200, {"visibility": "public" if room["is_public"] else "private"} return 200, {"visibility": "public" if room[0] else "private"}
class PutBody(RequestBodyModel): class PutBody(RequestBodyModel):
visibility: Literal["public", "private"] = "public" visibility: Literal["public", "private"] = "public"

View File

@ -1597,7 +1597,7 @@ class DatabasePool:
retcols: Collection[str], retcols: Collection[str],
allow_none: Literal[False] = False, allow_none: Literal[False] = False,
desc: str = "simple_select_one", desc: str = "simple_select_one",
) -> Dict[str, Any]: ) -> Tuple[Any, ...]:
... ...
@overload @overload
@ -1608,7 +1608,7 @@ class DatabasePool:
retcols: Collection[str], retcols: Collection[str],
allow_none: Literal[True] = True, allow_none: Literal[True] = True,
desc: str = "simple_select_one", desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]: ) -> Optional[Tuple[Any, ...]]:
... ...
async def simple_select_one( async def simple_select_one(
@ -1618,7 +1618,7 @@ class DatabasePool:
retcols: Collection[str], retcols: Collection[str],
allow_none: bool = False, allow_none: bool = False,
desc: str = "simple_select_one", desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]: ) -> Optional[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which is expected to """Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it. return a single row, returning multiple columns from it.
@ -2127,7 +2127,7 @@ class DatabasePool:
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
retcols: Collection[str], retcols: Collection[str],
allow_none: bool = False, allow_none: bool = False,
) -> Optional[Dict[str, Any]]: ) -> Optional[Tuple[Any, ...]]:
select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table) select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
if keyvalues: if keyvalues:
@ -2145,7 +2145,7 @@ class DatabasePool:
if txn.rowcount > 1: if txn.rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,)) raise StoreError(500, "More than one row matched (%s)" % (table,))
return dict(zip(retcols, row)) return row
async def simple_delete_one( async def simple_delete_one(
self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one" self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"

View File

@ -255,33 +255,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
A dict containing the device information, or `None` if the device does not A dict containing the device information, or `None` if the device does not
exist. exist.
""" """
return await self.db_pool.simple_select_one( row = await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
allow_none=True,
)
async def get_device_opt(
self, user_id: str, device_id: str
) -> Optional[Dict[str, Any]]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
Args:
user_id: The ID of the user which owns the device
device_id: The ID of the device to retrieve
Returns:
A dict containing the device information, or None if the device does not exist.
"""
return await self.db_pool.simple_select_one(
table="devices", table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"), retcols=("user_id", "device_id", "display_name"),
desc="get_device", desc="get_device",
allow_none=True, allow_none=True,
) )
if row is None:
return None
return {"user_id": row[0], "device_id": row[1], "display_name": row[2]}
async def get_devices_by_user( async def get_devices_by_user(
self, user_id: str self, user_id: str
@ -1221,9 +1204,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
retcols=["device_id", "device_data"], retcols=["device_id", "device_data"],
allow_none=True, allow_none=True,
) )
return ( return (row[0], json_decoder.decode(row[1])) if row else None
(row["device_id"], json_decoder.decode(row["device_data"])) if row else None
)
def _store_dehydrated_device_txn( def _store_dehydrated_device_txn(
self, self,
@ -2326,13 +2307,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
`FALSE` have not been converted. `FALSE` have not been converted.
""" """
row = await self.db_pool.simple_select_one( return cast(
Tuple[int, str],
await self.db_pool.simple_select_one(
table="device_lists_changes_converted_stream_position", table="device_lists_changes_converted_stream_position",
keyvalues={}, keyvalues={},
retcols=["stream_id", "room_id"], retcols=["stream_id", "room_id"],
desc="get_device_change_last_converted_pos", desc="get_device_change_last_converted_pos",
),
) )
return row["stream_id"], row["room_id"]
async def set_device_change_last_converted_pos( async def set_device_change_last_converted_pos(
self, self,

View File

@ -506,19 +506,26 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
# it isn't there. # it isn't there.
raise StoreError(404, "No backup with that version exists") raise StoreError(404, "No backup with that version exists")
result = self.db_pool.simple_select_one_txn( row = cast(
Tuple[int, str, str, Optional[int]],
self.db_pool.simple_select_one_txn(
txn, txn,
table="e2e_room_keys_versions", table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, keyvalues={
"user_id": user_id,
"version": this_version,
"deleted": 0,
},
retcols=("version", "algorithm", "auth_data", "etag"), retcols=("version", "algorithm", "auth_data", "etag"),
allow_none=False, allow_none=False,
),
) )
assert result is not None # see comment on `simple_select_one_txn` return {
result["auth_data"] = db_to_json(result["auth_data"]) "auth_data": db_to_json(row[2]),
result["version"] = str(result["version"]) "version": str(row[0]),
if result["etag"] is None: "algorithm": row[1],
result["etag"] = 0 "etag": 0 if row[3] is None else row[3],
return result }
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn

View File

@ -1266,9 +1266,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
if row is None: if row is None:
continue continue
key_id = row["key_id"] key_id, key_json, used = row
key_json = row["key_json"]
used = row["used"]
# Mark fallback key as used if not already. # Mark fallback key as used if not already.
if not used and mark_as_used: if not used and mark_as_used:

View File

@ -193,7 +193,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover # Check if we have indexed the room so we can use the chain cover
# algorithm. # algorithm.
room = await self.get_room(room_id) # type: ignore[attr-defined] room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]: # If the room has an auth chain index.
if room[1]:
try: try:
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_auth_chain_ids_chains", "get_auth_chain_ids_chains",
@ -411,7 +412,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover # Check if we have indexed the room so we can use the chain cover
# algorithm. # algorithm.
room = await self.get_room(room_id) # type: ignore[attr-defined] room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]: # If the room has an auth chain index.
if room[1]:
try: try:
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_auth_chain_difference_chains", "get_auth_chain_difference_chains",
@ -1437,24 +1439,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) )
if event_lookup_result is not None: if event_lookup_result is not None:
event_type, depth, stream_ordering = event_lookup_result
logger.debug( logger.debug(
"_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s", "_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s",
room_id, room_id,
seed_event_id, seed_event_id,
event_lookup_result["depth"], depth,
event_lookup_result["stream_ordering"], stream_ordering,
event_lookup_result["type"], event_type,
) )
if event_lookup_result["depth"]: if depth:
queue.put( queue.put((-depth, -stream_ordering, seed_event_id, event_type))
(
-event_lookup_result["depth"],
-event_lookup_result["stream_ordering"],
seed_event_id,
event_lookup_result["type"],
)
)
while not queue.empty() and len(event_id_results) < limit: while not queue.empty() and len(event_id_results) < limit:
try: try:

View File

@ -1934,8 +1934,7 @@ class PersistEventsStore:
if row is None: if row is None:
return return
redacted_relates_to = row["relates_to_id"] redacted_relates_to, rel_type = row
rel_type = row["relation_type"]
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, table="event_relations", keyvalues={"event_id": redacted_event_id} txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
) )

View File

@ -1998,7 +1998,7 @@ class EventsWorkerStore(SQLBaseStore):
if not res: if not res:
raise SynapseError(404, "Could not find event %s" % (event_id,)) raise SynapseError(404, "Could not find event %s" % (event_id,))
return int(res["topological_ordering"]), int(res["stream_ordering"]) return int(res[0]), int(res[1])
async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]: async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry """Retrieve the entry with the lowest expiry timestamp in the event_expiry

View File

@ -208,7 +208,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
) )
if row is None: if row is None:
return None return None
return LocalMedia(media_id=media_id, **row) return LocalMedia(
media_id=media_id,
media_type=row[0],
media_length=row[1],
upload_name=row[2],
created_ts=row[3],
quarantined_by=row[4],
url_cache=row[5],
last_access_ts=row[6],
safe_from_quarantine=row[7],
)
async def get_local_media_by_user_paginate( async def get_local_media_by_user_paginate(
self, self,
@ -541,7 +551,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
) )
if row is None: if row is None:
return row return row
return RemoteMedia(media_origin=origin, media_id=media_id, **row) return RemoteMedia(
media_origin=origin,
media_id=media_id,
media_type=row[0],
media_length=row[1],
upload_name=row[2],
created_ts=row[3],
filesystem_id=row[4],
last_access_ts=row[5],
quarantined_by=row[6],
)
async def store_cached_remote_media( async def store_cached_remote_media(
self, self,
@ -665,11 +685,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
if row is None: if row is None:
return None return None
return ThumbnailInfo( return ThumbnailInfo(
width=row["thumbnail_width"], width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
height=row["thumbnail_height"],
method=row["thumbnail_method"],
type=row["thumbnail_type"],
length=row["thumbnail_length"],
) )
@trace @trace

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import ( from synapse.storage.database import (
DatabasePool, DatabasePool,
@ -138,23 +137,18 @@ class ProfileWorkerStore(SQLBaseStore):
return 50 return 50
async def get_profileinfo(self, user_id: UserID) -> ProfileInfo: async def get_profileinfo(self, user_id: UserID) -> ProfileInfo:
try:
profile = await self.db_pool.simple_select_one( profile = await self.db_pool.simple_select_one(
table="profiles", table="profiles",
keyvalues={"full_user_id": user_id.to_string()}, keyvalues={"full_user_id": user_id.to_string()},
retcols=("displayname", "avatar_url"), retcols=("displayname", "avatar_url"),
desc="get_profileinfo", desc="get_profileinfo",
allow_none=True,
) )
except StoreError as e: if profile is None:
if e.code == 404:
# no match # no match
return ProfileInfo(None, None) return ProfileInfo(None, None)
else:
raise
return ProfileInfo( return ProfileInfo(avatar_url=profile[1], display_name=profile[0])
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
async def get_profile_displayname(self, user_id: UserID) -> Optional[str]: async def get_profile_displayname(self, user_id: UserID) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(

View File

@ -468,8 +468,7 @@ class PushRuleStore(PushRulesWorkerStore):
"before/after rule not found: %s" % (relative_to_rule,) "before/after rule not found: %s" % (relative_to_rule,)
) )
base_priority_class = res["priority_class"] base_priority_class, base_rule_priority = res
base_rule_priority = res["priority"]
if base_priority_class != priority_class: if base_priority_class != priority_class:
raise InconsistentRuleException( raise InconsistentRuleException(

View File

@ -701,8 +701,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
stream_ordering = int(res["stream_ordering"]) if res else None stream_ordering = int(res[0]) if res else None
rx_ts = res["received_ts"] if res else 0 rx_ts = res[1] if res else 0
# We don't want to clobber receipts for more recent events, so we # We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts # have to compare orderings of existing receipts

View File

@ -425,17 +425,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
account timestamp as milliseconds since the epoch. None if the account account timestamp as milliseconds since the epoch. None if the account
has not been renewed using the current token yet. has not been renewed using the current token yet.
""" """
ret_dict = await self.db_pool.simple_select_one( return cast(
Tuple[str, int, Optional[int]],
await self.db_pool.simple_select_one(
table="account_validity", table="account_validity",
keyvalues={"renewal_token": renewal_token}, keyvalues={"renewal_token": renewal_token},
retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"], retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
desc="get_user_from_renewal_token", desc="get_user_from_renewal_token",
) ),
return (
ret_dict["user_id"],
ret_dict["expiration_ts_ms"],
ret_dict["token_used_ts_ms"],
) )
async def get_renewal_token_for_user(self, user_id: str) -> str: async def get_renewal_token_for_user(self, user_id: str) -> str:
@ -989,16 +986,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Returns: Returns:
user id, or None if no user id/threepid mapping exists user id, or None if no user id/threepid mapping exists
""" """
ret = self.db_pool.simple_select_one_txn( return self.db_pool.simple_select_one_onecol_txn(
txn, txn,
"user_threepids", "user_threepids",
{"medium": medium, "address": address}, {"medium": medium, "address": address},
["user_id"], "user_id",
True, True,
) )
if ret:
return ret["user_id"]
return None
async def user_add_threepid( async def user_add_threepid(
self, self,
@ -1435,16 +1429,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
if res is None: if res is None:
return False return False
uses_allowed, pending, completed, expiry_time = res
# Check if the token has expired # Check if the token has expired
now = self._clock.time_msec() now = self._clock.time_msec()
if res["expiry_time"] and res["expiry_time"] < now: if expiry_time and expiry_time < now:
return False return False
# Check if the token has been used up # Check if the token has been used up
if ( if uses_allowed and pending + completed >= uses_allowed:
res["uses_allowed"]
and res["pending"] + res["completed"] >= res["uses_allowed"]
):
return False return False
# Otherwise, the token is valid # Otherwise, the token is valid
@ -1490,8 +1483,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# Override type because the return type is only optional if # Override type because the return type is only optional if
# allow_none is True, and we don't want mypy throwing errors # allow_none is True, and we don't want mypy throwing errors
# about None not being indexable. # about None not being indexable.
res = cast( pending, completed = cast(
Dict[str, Any], Tuple[int, int],
self.db_pool.simple_select_one_txn( self.db_pool.simple_select_one_txn(
txn, txn,
"registration_tokens", "registration_tokens",
@ -1506,8 +1499,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"registration_tokens", "registration_tokens",
keyvalues={"token": token}, keyvalues={"token": token},
updatevalues={ updatevalues={
"completed": res["completed"] + 1, "completed": completed + 1,
"pending": res["pending"] - 1, "pending": pending - 1,
}, },
) )
@ -1585,13 +1578,22 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Returns: Returns:
A dict, or None if token doesn't exist. A dict, or None if token doesn't exist.
""" """
return await self.db_pool.simple_select_one( row = await self.db_pool.simple_select_one(
"registration_tokens", "registration_tokens",
keyvalues={"token": token}, keyvalues={"token": token},
retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"], retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
allow_none=True, allow_none=True,
desc="get_one_registration_token", desc="get_one_registration_token",
) )
if row is None:
return None
return {
"token": row[0],
"uses_allowed": row[1],
"pending": row[2],
"completed": row[3],
"expiry_time": row[4],
}
async def generate_registration_token( async def generate_registration_token(
self, length: int, chars: str self, length: int, chars: str
@ -1714,7 +1716,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return None return None
# Get all info about the token so it can be sent in the response # Get all info about the token so it can be sent in the response
return self.db_pool.simple_select_one_txn( result = self.db_pool.simple_select_one_txn(
txn, txn,
"registration_tokens", "registration_tokens",
keyvalues={"token": token}, keyvalues={"token": token},
@ -1728,6 +1730,17 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
allow_none=True, allow_none=True,
) )
if result is None:
return result
return {
"token": result[0],
"uses_allowed": result[1],
"pending": result[2],
"completed": result[3],
"expiry_time": result[4],
}
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"update_registration_token", _update_registration_token_txn "update_registration_token", _update_registration_token_txn
) )
@ -1939,11 +1952,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
keyvalues={"token": token}, keyvalues={"token": token},
updatevalues={"used_ts": ts}, updatevalues={"used_ts": ts},
) )
user_id = values["user_id"] (
expiry_ts = values["expiry_ts"] user_id,
used_ts = values["used_ts"] expiry_ts,
auth_provider_id = values["auth_provider_id"] used_ts,
auth_provider_session_id = values["auth_provider_session_id"] auth_provider_id,
auth_provider_session_id,
) = values
# Token was already used # Token was already used
if used_ts is not None: if used_ts is not None:
@ -2756,12 +2771,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
# reason, the next check is on the client secret, which is NOT NULL, # reason, the next check is on the client secret, which is NOT NULL,
# so we don't have to worry about the client secret matching by # so we don't have to worry about the client secret matching by
# accident. # accident.
row = {"client_secret": None, "validated_at": None} row = None, None
else: else:
raise ThreepidValidationError("Unknown session_id") raise ThreepidValidationError("Unknown session_id")
retrieved_client_secret = row["client_secret"] retrieved_client_secret, validated_at = row
validated_at = row["validated_at"]
row = self.db_pool.simple_select_one_txn( row = self.db_pool.simple_select_one_txn(
txn, txn,
@ -2775,8 +2789,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
raise ThreepidValidationError( raise ThreepidValidationError(
"Validation token not found or has expired" "Validation token not found or has expired"
) )
expires = row["expires"] expires, next_link = row
next_link = row["next_link"]
if retrieved_client_secret != client_secret: if retrieved_client_secret != client_secret:
raise ThreepidValidationError( raise ThreepidValidationError(

View File

@ -213,21 +213,31 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
logger.error("store_room with room_id=%s failed: %s", room_id, e) logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.") raise StoreError(500, "Problem creating room.")
async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]: async def get_room(self, room_id: str) -> Optional[Tuple[bool, bool]]:
"""Retrieve a room. """Retrieve a room.
Args: Args:
room_id: The ID of the room to retrieve. room_id: The ID of the room to retrieve.
Returns: Returns:
A dict containing the room information, or None if the room is unknown. A tuple containing the room information:
* True if the room is public
* True if the room has an auth chain index
or None if the room is unknown.
""" """
return await self.db_pool.simple_select_one( row = cast(
Optional[Tuple[Optional[Union[int, bool]], Optional[Union[int, bool]]]],
await self.db_pool.simple_select_one(
table="rooms", table="rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator", "has_auth_chain_index"), retcols=("is_public", "has_auth_chain_index"),
desc="get_room", desc="get_room",
allow_none=True, allow_none=True,
),
) )
if row is None:
return row
return bool(row[0]), bool(row[1])
async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]: async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]:
"""Retrieve room with statistics. """Retrieve room with statistics.
@ -794,10 +804,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
) )
if row: if row:
return RatelimitOverride( return RatelimitOverride(messages_per_second=row[0], burst_count=row[1])
messages_per_second=row["messages_per_second"],
burst_count=row["burst_count"],
)
else: else:
return None return None
@ -1371,13 +1378,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
join. join.
""" """
result = await self.db_pool.simple_select_one( return cast(
Tuple[str, int],
await self.db_pool.simple_select_one(
table="partial_state_rooms", table="partial_state_rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
retcols=("join_event_id", "device_lists_stream_id"), retcols=("join_event_id", "device_lists_stream_id"),
desc="get_join_event_id_for_partial_state", desc="get_join_event_id_for_partial_state",
),
) )
return result["join_event_id"], result["device_lists_stream_id"]
def get_un_partial_stated_rooms_token(self, instance_name: str) -> int: def get_un_partial_stated_rooms_token(self, instance_name: str) -> int:
return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer( return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer(

View File

@ -559,17 +559,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
"non-local user %s" % (user_id,), "non-local user %s" % (user_id,),
) )
results_dict = await self.db_pool.simple_select_one( results = cast(
Optional[Tuple[str, str]],
await self.db_pool.simple_select_one(
"local_current_membership", "local_current_membership",
{"room_id": room_id, "user_id": user_id}, {"room_id": room_id, "user_id": user_id},
("membership", "event_id"), ("membership", "event_id"),
allow_none=True, allow_none=True,
desc="get_local_current_membership_for_user_in_room", desc="get_local_current_membership_for_user_in_room",
),
) )
if not results_dict: if not results:
return None, None return None, None
return results_dict.get("membership"), results_dict.get("event_id") return results
@cached(max_entries=500000, iterable=True) @cached(max_entries=500000, iterable=True)
async def get_rooms_for_user_with_stream_ordering( async def get_rooms_for_user_with_stream_ordering(

View File

@ -1014,9 +1014,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="get_position_for_event", desc="get_position_for_event",
) )
return PersistedEventPosition( return PersistedEventPosition(row[1] or "master", row[0])
row["instance_name"] or "master", row["stream_ordering"]
)
async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken: async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken:
"""The stream token for an event """The stream token for an event
@ -1033,9 +1031,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcols=("stream_ordering", "topological_ordering"), retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event", desc="get_topological_token_for_event",
) )
return RoomStreamToken( return RoomStreamToken(topological=row[1], stream=row[0])
topological=row["topological_ordering"], stream=row["stream_ordering"]
)
async def get_current_topological_token(self, room_id: str, stream_key: int) -> int: async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
"""Gets the topological token in a room after or at the given stream """Gets the topological token in a room after or at the given stream
@ -1180,26 +1176,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
dict dict
""" """
results = self.db_pool.simple_select_one_txn( stream_ordering, topological_ordering = cast(
Tuple[int, int],
self.db_pool.simple_select_one_txn(
txn, txn,
"events", "events",
keyvalues={"event_id": event_id, "room_id": room_id}, keyvalues={"event_id": event_id, "room_id": room_id},
retcols=["stream_ordering", "topological_ordering"], retcols=["stream_ordering", "topological_ordering"],
),
) )
# This cannot happen as `allow_none=False`.
assert results is not None
# Paginating backwards includes the event at the token, but paginating # Paginating backwards includes the event at the token, but paginating
# forward doesn't. # forward doesn't.
before_token = RoomStreamToken( before_token = RoomStreamToken(
topological=results["topological_ordering"] - 1, topological=topological_ordering - 1, stream=stream_ordering
stream=results["stream_ordering"],
) )
after_token = RoomStreamToken( after_token = RoomStreamToken(
topological=results["topological_ordering"], topological=topological_ordering, stream=stream_ordering
stream=results["stream_ordering"],
) )
rows, start_token = self._paginate_room_events_txn( rows, start_token = self._paginate_room_events_txn(

View File

@ -183,7 +183,9 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
Returns: the task if available, `None` otherwise Returns: the task if available, `None` otherwise
""" """
row = await self.db_pool.simple_select_one( row = cast(
Optional[ScheduledTaskRow],
await self.db_pool.simple_select_one(
table="scheduled_tasks", table="scheduled_tasks",
keyvalues={"id": id}, keyvalues={"id": id},
retcols=( retcols=(
@ -198,24 +200,10 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
), ),
allow_none=True, allow_none=True,
desc="get_scheduled_task", desc="get_scheduled_task",
),
) )
return ( return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
TaskSchedulerWorkerStore._convert_row_to_task(
(
row["id"],
row["action"],
row["status"],
row["timestamp"],
row["resource_id"],
row["params"],
row["result"],
row["error"],
)
)
if row
else None
)
async def delete_scheduled_task(self, id: str) -> None: async def delete_scheduled_task(self, id: str) -> None:
"""Delete a specific task from its id. """Delete a specific task from its id.

View File

@ -118,19 +118,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
txn, txn,
table="received_transactions", table="received_transactions",
keyvalues={"transaction_id": transaction_id, "origin": origin}, keyvalues={"transaction_id": transaction_id, "origin": origin},
retcols=( retcols=("response_code", "response_json"),
"transaction_id",
"origin",
"ts",
"response_code",
"response_json",
"has_been_referenced",
),
allow_none=True, allow_none=True,
) )
if result and result["response_code"]: # If the result exists and the response code is non-0.
return result["response_code"], db_to_json(result["response_json"]) if result and result[0]:
return result[0], db_to_json(result[1])
else: else:
return None return None
@ -200,8 +194,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
# check we have a row and retry_last_ts is not null or zero # check we have a row and retry_last_ts is not null or zero
# (retry_last_ts can't be negative) # (retry_last_ts can't be negative)
if result and result["retry_last_ts"]: if result and result[1]:
return DestinationRetryTimings(**result) return DestinationRetryTimings(
failure_ts=result[0], retry_last_ts=result[1], retry_interval=result[2]
)
else: else:
return None return None

View File

@ -122,9 +122,13 @@ class UIAuthWorkerStore(SQLBaseStore):
desc="get_ui_auth_session", desc="get_ui_auth_session",
) )
result["clientdict"] = db_to_json(result["clientdict"]) return UIAuthSessionData(
session_id,
return UIAuthSessionData(session_id, **result) clientdict=db_to_json(result[0]),
uri=result[1],
method=result[2],
description=result[3],
)
async def mark_ui_auth_stage_complete( async def mark_ui_auth_stage_complete(
self, self,
@ -231,18 +235,15 @@ class UIAuthWorkerStore(SQLBaseStore):
self, txn: LoggingTransaction, session_id: str, key: str, value: Any self, txn: LoggingTransaction, session_id: str, key: str, value: Any
) -> None: ) -> None:
# Get the current value. # Get the current value.
result = cast( result = self.db_pool.simple_select_one_onecol_txn(
Dict[str, Any],
self.db_pool.simple_select_one_txn(
txn, txn,
table="ui_auth_sessions", table="ui_auth_sessions",
keyvalues={"session_id": session_id}, keyvalues={"session_id": session_id},
retcols=("serverdict",), retcol="serverdict",
),
) )
# Update it and add it back to the database. # Update it and add it back to the database.
serverdict = db_to_json(result["serverdict"]) serverdict = db_to_json(result)
serverdict[key] = value serverdict[key] = value
self.db_pool.simple_update_one_txn( self.db_pool.simple_update_one_txn(
@ -265,14 +266,14 @@ class UIAuthWorkerStore(SQLBaseStore):
Raises: Raises:
StoreError if the session cannot be found. StoreError if the session cannot be found.
""" """
result = await self.db_pool.simple_select_one( result = await self.db_pool.simple_select_one_onecol(
table="ui_auth_sessions", table="ui_auth_sessions",
keyvalues={"session_id": session_id}, keyvalues={"session_id": session_id},
retcols=("serverdict",), retcol="serverdict",
desc="get_ui_auth_session_data", desc="get_ui_auth_session_data",
) )
serverdict = db_to_json(result["serverdict"]) serverdict = db_to_json(result)
return serverdict.get(key, default) return serverdict.get(key, default)

View File

@ -20,7 +20,6 @@ from typing import (
Collection, Collection,
Iterable, Iterable,
List, List,
Mapping,
Optional, Optional,
Sequence, Sequence,
Set, Set,
@ -833,13 +832,25 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
"delete_all_from_user_dir", _delete_all_from_user_dir_txn "delete_all_from_user_dir", _delete_all_from_user_dir_txn
) )
async def _get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]: async def _get_user_in_directory(
return await self.db_pool.simple_select_one( self, user_id: str
) -> Optional[Tuple[Optional[str], Optional[str]]]:
"""
Fetch the user information in the user directory.
Returns:
None if the user is unknown, otherwise a tuple of display name and
avatar URL (both of which may be None).
"""
return cast(
Optional[Tuple[Optional[str], Optional[str]]],
await self.db_pool.simple_select_one(
table="user_directory", table="user_directory",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"), retcols=("display_name", "avatar_url"),
allow_none=True, allow_none=True,
desc="get_user_in_directory", desc="get_user_in_directory",
),
) )
async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None: async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None:

View File

@ -84,7 +84,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type])
return self.get_success( row = self.get_success(
self.store.db_pool.simple_select_one( self.store.db_pool.simple_select_one(
table + "_current", table + "_current",
{id_col: stat_id}, {id_col: stat_id},
@ -93,6 +93,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
) )
) )
return None if row is None else dict(zip(cols, row))
def _perform_background_initial_update(self) -> None: def _perform_background_initial_update(self) -> None:
# Do the initial population of the stats via the background update # Do the initial population of the stats via the background update
self._add_background_updates() self._add_background_updates()

View File

@ -366,7 +366,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
) )
profile = self.get_success(self.store._get_user_in_directory(regular_user_id)) profile = self.get_success(self.store._get_user_in_directory(regular_user_id))
assert profile is not None assert profile is not None
self.assertTrue(profile["display_name"] == display_name) self.assertTrue(profile[0] == display_name)
def test_handle_local_profile_change_with_deactivated_user(self) -> None: def test_handle_local_profile_change_with_deactivated_user(self) -> None:
# create user # create user
@ -385,7 +385,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# profile is in directory # profile is in directory
profile = self.get_success(self.store._get_user_in_directory(r_user_id)) profile = self.get_success(self.store._get_user_in_directory(r_user_id))
assert profile is not None assert profile is not None
self.assertTrue(profile["display_name"] == display_name) self.assertEqual(profile[0], display_name)
# deactivate user # deactivate user
self.get_success(self.store.set_user_deactivated_status(r_user_id, True)) self.get_success(self.store.set_user_deactivated_status(r_user_id, True))

View File

@ -2706,7 +2706,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# is in user directory # is in user directory
profile = self.get_success(self.store._get_user_in_directory(self.other_user)) profile = self.get_success(self.store._get_user_in_directory(self.other_user))
assert profile is not None assert profile is not None
self.assertTrue(profile["display_name"] == "User") self.assertEqual(profile[0], "User")
# Deactivate user # Deactivate user
channel = self.make_request( channel = self.make_request(

View File

@ -139,12 +139,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# #
# Note that we don't have the UI Auth session ID, so just pull out the single # Note that we don't have the UI Auth session ID, so just pull out the single
# row. # row.
ui_auth_data = self.get_success( result = self.get_success(
self.store.db_pool.simple_select_one( self.store.db_pool.simple_select_one_onecol(
"ui_auth_sessions", keyvalues={}, retcols=("clientdict",) "ui_auth_sessions", keyvalues={}, retcol="clientdict"
) )
) )
client_dict = db_to_json(ui_auth_data["clientdict"]) client_dict = db_to_json(result)
self.assertNotIn("new_password", client_dict) self.assertNotIn("new_password", client_dict)
@override_config({"rc_3pid_validation": {"burst_count": 3}}) @override_config({"rc_3pid_validation": {"burst_count": 3}})

View File

@ -270,15 +270,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertLessEqual(det_data.items(), channel.json_body.items()) self.assertLessEqual(det_data.items(), channel.json_body.items())
# Check the `completed` counter has been incremented and pending is 0 # Check the `completed` counter has been incremented and pending is 0
res = self.get_success( pending, completed = self.get_success(
store.db_pool.simple_select_one( store.db_pool.simple_select_one(
"registration_tokens", "registration_tokens",
keyvalues={"token": token}, keyvalues={"token": token},
retcols=["pending", "completed"], retcols=["pending", "completed"],
) )
) )
self.assertEqual(res["completed"], 1) self.assertEqual(completed, 1)
self.assertEqual(res["pending"], 0) self.assertEqual(pending, 0)
@override_config({"registration_requires_token": True}) @override_config({"registration_requires_token": True})
def test_POST_registration_token_invalid(self) -> None: def test_POST_registration_token_invalid(self) -> None:
@ -372,15 +372,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
params1["auth"]["type"] = LoginType.DUMMY params1["auth"]["type"] = LoginType.DUMMY
self.make_request(b"POST", self.url, params1) self.make_request(b"POST", self.url, params1)
# Check pending=0 and completed=1 # Check pending=0 and completed=1
res = self.get_success( pending, completed = self.get_success(
store.db_pool.simple_select_one( store.db_pool.simple_select_one(
"registration_tokens", "registration_tokens",
keyvalues={"token": token}, keyvalues={"token": token},
retcols=["pending", "completed"], retcols=["pending", "completed"],
) )
) )
self.assertEqual(res["pending"], 0) self.assertEqual(pending, 0)
self.assertEqual(res["completed"], 1) self.assertEqual(completed, 1)
# Check auth still fails when using token with session2 # Check auth still fails when using token with session2
channel = self.make_request(b"POST", self.url, params2) channel = self.make_request(b"POST", self.url, params2)

View File

@ -222,7 +222,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
) )
) )
self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret) self.assertEqual((1, 2, 3), ret)
self.mock_txn.execute.assert_called_once_with( self.mock_txn.execute.assert_called_once_with(
"SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"] "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"]
) )
@ -243,7 +243,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
) )
) )
self.assertFalse(ret) self.assertIsNone(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]: def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:

View File

@ -42,16 +42,9 @@ class RoomStoreTestCase(HomeserverTestCase):
) )
def test_get_room(self) -> None: def test_get_room(self) -> None:
res = self.get_success(self.store.get_room(self.room.to_string())) room = self.get_success(self.store.get_room(self.room.to_string()))
assert res is not None assert room is not None
self.assertLessEqual( self.assertTrue(room[0])
{
"room_id": self.room.to_string(),
"creator": self.u_creator.to_string(),
"is_public": True,
}.items(),
res.items(),
)
def test_get_room_unknown_room(self) -> None: def test_get_room_unknown_room(self) -> None:
self.assertIsNone(self.get_success(self.store.get_room("!uknown:test"))) self.assertIsNone(self.get_success(self.store.get_room("!uknown:test")))