Convert simple_select_one_txn and simple_select_one to return tuples. (#16612)
This commit is contained in:
parent
ff716b483b
commit
ab3f1b3b53
|
@ -0,0 +1 @@
|
||||||
|
Improve type hints.
|
|
@ -348,8 +348,7 @@ class Porter:
|
||||||
backward_chunk = 0
|
backward_chunk = 0
|
||||||
already_ported = 0
|
already_ported = 0
|
||||||
else:
|
else:
|
||||||
forward_chunk = row["forward_rowid"]
|
forward_chunk, backward_chunk = row
|
||||||
backward_chunk = row["backward_rowid"]
|
|
||||||
|
|
||||||
if total_to_port is None:
|
if total_to_port is None:
|
||||||
already_ported, total_to_port = await self._get_total_count_to_port(
|
already_ported, total_to_port = await self._get_total_count_to_port(
|
||||||
|
|
|
@ -269,7 +269,7 @@ class RoomCreationHandler:
|
||||||
self,
|
self,
|
||||||
requester: Requester,
|
requester: Requester,
|
||||||
old_room_id: str,
|
old_room_id: str,
|
||||||
old_room: Dict[str, Any],
|
old_room: Tuple[bool, str, bool],
|
||||||
new_room_id: str,
|
new_room_id: str,
|
||||||
new_version: RoomVersion,
|
new_version: RoomVersion,
|
||||||
tombstone_event: EventBase,
|
tombstone_event: EventBase,
|
||||||
|
@ -279,7 +279,7 @@ class RoomCreationHandler:
|
||||||
Args:
|
Args:
|
||||||
requester: the user requesting the upgrade
|
requester: the user requesting the upgrade
|
||||||
old_room_id: the id of the room to be replaced
|
old_room_id: the id of the room to be replaced
|
||||||
old_room: a dict containing room information for the room to be replaced,
|
old_room: a tuple containing room information for the room to be replaced,
|
||||||
as returned by `RoomWorkerStore.get_room`.
|
as returned by `RoomWorkerStore.get_room`.
|
||||||
new_room_id: the id of the replacement room
|
new_room_id: the id of the replacement room
|
||||||
new_version: the version to upgrade the room to
|
new_version: the version to upgrade the room to
|
||||||
|
@ -299,7 +299,7 @@ class RoomCreationHandler:
|
||||||
await self.store.store_room(
|
await self.store.store_room(
|
||||||
room_id=new_room_id,
|
room_id=new_room_id,
|
||||||
room_creator_user_id=user_id,
|
room_creator_user_id=user_id,
|
||||||
is_public=old_room["is_public"],
|
is_public=old_room[0],
|
||||||
room_version=new_version,
|
room_version=new_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1260,7 +1260,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
# Add new room to the room directory if the old room was there
|
# Add new room to the room directory if the old room was there
|
||||||
# Remove old room from the room directory
|
# Remove old room from the room directory
|
||||||
old_room = await self.store.get_room(old_room_id)
|
old_room = await self.store.get_room(old_room_id)
|
||||||
if old_room is not None and old_room["is_public"]:
|
# If the old room exists and is public.
|
||||||
|
if old_room is not None and old_room[0]:
|
||||||
await self.store.set_room_is_public(old_room_id, False)
|
await self.store.set_room_is_public(old_room_id, False)
|
||||||
await self.store.set_room_is_public(room_id, True)
|
await self.store.set_room_is_public(room_id, True)
|
||||||
|
|
||||||
|
|
|
@ -1860,7 +1860,8 @@ class PublicRoomListManager:
|
||||||
if not room:
|
if not room:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return room.get("is_public", False)
|
# The first item is whether the room is public.
|
||||||
|
return room[0]
|
||||||
|
|
||||||
async def add_room_to_public_room_list(self, room_id: str) -> None:
|
async def add_room_to_public_room_list(self, room_id: str) -> None:
|
||||||
"""Publishes a room to the public room list.
|
"""Publishes a room to the public room list.
|
||||||
|
|
|
@ -413,8 +413,8 @@ class RoomMembersRestServlet(RestServlet):
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
ret = await self.store.get_room(room_id)
|
room = await self.store.get_room(room_id)
|
||||||
if not ret:
|
if not room:
|
||||||
raise NotFoundError("Room not found")
|
raise NotFoundError("Room not found")
|
||||||
|
|
||||||
members = await self.store.get_users_in_room(room_id)
|
members = await self.store.get_users_in_room(room_id)
|
||||||
|
@ -442,8 +442,8 @@ class RoomStateRestServlet(RestServlet):
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
ret = await self.store.get_room(room_id)
|
room = await self.store.get_room(room_id)
|
||||||
if not ret:
|
if not room:
|
||||||
raise NotFoundError("Room not found")
|
raise NotFoundError("Room not found")
|
||||||
|
|
||||||
event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
|
event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
|
||||||
|
|
|
@ -147,7 +147,7 @@ class ClientDirectoryListServer(RestServlet):
|
||||||
if room is None:
|
if room is None:
|
||||||
raise NotFoundError("Unknown room")
|
raise NotFoundError("Unknown room")
|
||||||
|
|
||||||
return 200, {"visibility": "public" if room["is_public"] else "private"}
|
return 200, {"visibility": "public" if room[0] else "private"}
|
||||||
|
|
||||||
class PutBody(RequestBodyModel):
|
class PutBody(RequestBodyModel):
|
||||||
visibility: Literal["public", "private"] = "public"
|
visibility: Literal["public", "private"] = "public"
|
||||||
|
|
|
@ -1597,7 +1597,7 @@ class DatabasePool:
|
||||||
retcols: Collection[str],
|
retcols: Collection[str],
|
||||||
allow_none: Literal[False] = False,
|
allow_none: Literal[False] = False,
|
||||||
desc: str = "simple_select_one",
|
desc: str = "simple_select_one",
|
||||||
) -> Dict[str, Any]:
|
) -> Tuple[Any, ...]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
@ -1608,7 +1608,7 @@ class DatabasePool:
|
||||||
retcols: Collection[str],
|
retcols: Collection[str],
|
||||||
allow_none: Literal[True] = True,
|
allow_none: Literal[True] = True,
|
||||||
desc: str = "simple_select_one",
|
desc: str = "simple_select_one",
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Tuple[Any, ...]]:
|
||||||
...
|
...
|
||||||
|
|
||||||
async def simple_select_one(
|
async def simple_select_one(
|
||||||
|
@ -1618,7 +1618,7 @@ class DatabasePool:
|
||||||
retcols: Collection[str],
|
retcols: Collection[str],
|
||||||
allow_none: bool = False,
|
allow_none: bool = False,
|
||||||
desc: str = "simple_select_one",
|
desc: str = "simple_select_one",
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Tuple[Any, ...]]:
|
||||||
"""Executes a SELECT query on the named table, which is expected to
|
"""Executes a SELECT query on the named table, which is expected to
|
||||||
return a single row, returning multiple columns from it.
|
return a single row, returning multiple columns from it.
|
||||||
|
|
||||||
|
@ -2127,7 +2127,7 @@ class DatabasePool:
|
||||||
keyvalues: Dict[str, Any],
|
keyvalues: Dict[str, Any],
|
||||||
retcols: Collection[str],
|
retcols: Collection[str],
|
||||||
allow_none: bool = False,
|
allow_none: bool = False,
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Tuple[Any, ...]]:
|
||||||
select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
|
select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
|
||||||
|
|
||||||
if keyvalues:
|
if keyvalues:
|
||||||
|
@ -2145,7 +2145,7 @@ class DatabasePool:
|
||||||
if txn.rowcount > 1:
|
if txn.rowcount > 1:
|
||||||
raise StoreError(500, "More than one row matched (%s)" % (table,))
|
raise StoreError(500, "More than one row matched (%s)" % (table,))
|
||||||
|
|
||||||
return dict(zip(retcols, row))
|
return row
|
||||||
|
|
||||||
async def simple_delete_one(
|
async def simple_delete_one(
|
||||||
self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
|
self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
|
||||||
|
|
|
@ -255,33 +255,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
A dict containing the device information, or `None` if the device does not
|
A dict containing the device information, or `None` if the device does not
|
||||||
exist.
|
exist.
|
||||||
"""
|
"""
|
||||||
return await self.db_pool.simple_select_one(
|
row = await self.db_pool.simple_select_one(
|
||||||
table="devices",
|
|
||||||
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
|
|
||||||
retcols=("user_id", "device_id", "display_name"),
|
|
||||||
desc="get_device",
|
|
||||||
allow_none=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_device_opt(
|
|
||||||
self, user_id: str, device_id: str
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
"""Retrieve a device. Only returns devices that are not marked as
|
|
||||||
hidden.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: The ID of the user which owns the device
|
|
||||||
device_id: The ID of the device to retrieve
|
|
||||||
Returns:
|
|
||||||
A dict containing the device information, or None if the device does not exist.
|
|
||||||
"""
|
|
||||||
return await self.db_pool.simple_select_one(
|
|
||||||
table="devices",
|
table="devices",
|
||||||
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
|
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
|
||||||
retcols=("user_id", "device_id", "display_name"),
|
retcols=("user_id", "device_id", "display_name"),
|
||||||
desc="get_device",
|
desc="get_device",
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return {"user_id": row[0], "device_id": row[1], "display_name": row[2]}
|
||||||
|
|
||||||
async def get_devices_by_user(
|
async def get_devices_by_user(
|
||||||
self, user_id: str
|
self, user_id: str
|
||||||
|
@ -1221,9 +1204,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
retcols=["device_id", "device_data"],
|
retcols=["device_id", "device_data"],
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
return (
|
return (row[0], json_decoder.decode(row[1])) if row else None
|
||||||
(row["device_id"], json_decoder.decode(row["device_data"])) if row else None
|
|
||||||
)
|
|
||||||
|
|
||||||
def _store_dehydrated_device_txn(
|
def _store_dehydrated_device_txn(
|
||||||
self,
|
self,
|
||||||
|
@ -2326,13 +2307,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
`FALSE` have not been converted.
|
`FALSE` have not been converted.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
row = await self.db_pool.simple_select_one(
|
return cast(
|
||||||
|
Tuple[int, str],
|
||||||
|
await self.db_pool.simple_select_one(
|
||||||
table="device_lists_changes_converted_stream_position",
|
table="device_lists_changes_converted_stream_position",
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
retcols=["stream_id", "room_id"],
|
retcols=["stream_id", "room_id"],
|
||||||
desc="get_device_change_last_converted_pos",
|
desc="get_device_change_last_converted_pos",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return row["stream_id"], row["room_id"]
|
|
||||||
|
|
||||||
async def set_device_change_last_converted_pos(
|
async def set_device_change_last_converted_pos(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -506,19 +506,26 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
|
||||||
# it isn't there.
|
# it isn't there.
|
||||||
raise StoreError(404, "No backup with that version exists")
|
raise StoreError(404, "No backup with that version exists")
|
||||||
|
|
||||||
result = self.db_pool.simple_select_one_txn(
|
row = cast(
|
||||||
|
Tuple[int, str, str, Optional[int]],
|
||||||
|
self.db_pool.simple_select_one_txn(
|
||||||
txn,
|
txn,
|
||||||
table="e2e_room_keys_versions",
|
table="e2e_room_keys_versions",
|
||||||
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
"version": this_version,
|
||||||
|
"deleted": 0,
|
||||||
|
},
|
||||||
retcols=("version", "algorithm", "auth_data", "etag"),
|
retcols=("version", "algorithm", "auth_data", "etag"),
|
||||||
allow_none=False,
|
allow_none=False,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
assert result is not None # see comment on `simple_select_one_txn`
|
return {
|
||||||
result["auth_data"] = db_to_json(result["auth_data"])
|
"auth_data": db_to_json(row[2]),
|
||||||
result["version"] = str(result["version"])
|
"version": str(row[0]),
|
||||||
if result["etag"] is None:
|
"algorithm": row[1],
|
||||||
result["etag"] = 0
|
"etag": 0 if row[3] is None else row[3],
|
||||||
return result
|
}
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
|
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
|
||||||
|
|
|
@ -1266,9 +1266,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||||
if row is None:
|
if row is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
key_id = row["key_id"]
|
key_id, key_json, used = row
|
||||||
key_json = row["key_json"]
|
|
||||||
used = row["used"]
|
|
||||||
|
|
||||||
# Mark fallback key as used if not already.
|
# Mark fallback key as used if not already.
|
||||||
if not used and mark_as_used:
|
if not used and mark_as_used:
|
||||||
|
|
|
@ -193,7 +193,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
# Check if we have indexed the room so we can use the chain cover
|
# Check if we have indexed the room so we can use the chain cover
|
||||||
# algorithm.
|
# algorithm.
|
||||||
room = await self.get_room(room_id) # type: ignore[attr-defined]
|
room = await self.get_room(room_id) # type: ignore[attr-defined]
|
||||||
if room["has_auth_chain_index"]:
|
# If the room has an auth chain index.
|
||||||
|
if room[1]:
|
||||||
try:
|
try:
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_auth_chain_ids_chains",
|
"get_auth_chain_ids_chains",
|
||||||
|
@ -411,7 +412,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
# Check if we have indexed the room so we can use the chain cover
|
# Check if we have indexed the room so we can use the chain cover
|
||||||
# algorithm.
|
# algorithm.
|
||||||
room = await self.get_room(room_id) # type: ignore[attr-defined]
|
room = await self.get_room(room_id) # type: ignore[attr-defined]
|
||||||
if room["has_auth_chain_index"]:
|
# If the room has an auth chain index.
|
||||||
|
if room[1]:
|
||||||
try:
|
try:
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_auth_chain_difference_chains",
|
"get_auth_chain_difference_chains",
|
||||||
|
@ -1437,24 +1439,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
)
|
)
|
||||||
|
|
||||||
if event_lookup_result is not None:
|
if event_lookup_result is not None:
|
||||||
|
event_type, depth, stream_ordering = event_lookup_result
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s",
|
"_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s",
|
||||||
room_id,
|
room_id,
|
||||||
seed_event_id,
|
seed_event_id,
|
||||||
event_lookup_result["depth"],
|
depth,
|
||||||
event_lookup_result["stream_ordering"],
|
stream_ordering,
|
||||||
event_lookup_result["type"],
|
event_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
if event_lookup_result["depth"]:
|
if depth:
|
||||||
queue.put(
|
queue.put((-depth, -stream_ordering, seed_event_id, event_type))
|
||||||
(
|
|
||||||
-event_lookup_result["depth"],
|
|
||||||
-event_lookup_result["stream_ordering"],
|
|
||||||
seed_event_id,
|
|
||||||
event_lookup_result["type"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
while not queue.empty() and len(event_id_results) < limit:
|
while not queue.empty() and len(event_id_results) < limit:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1934,8 +1934,7 @@ class PersistEventsStore:
|
||||||
if row is None:
|
if row is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
redacted_relates_to = row["relates_to_id"]
|
redacted_relates_to, rel_type = row
|
||||||
rel_type = row["relation_type"]
|
|
||||||
self.db_pool.simple_delete_txn(
|
self.db_pool.simple_delete_txn(
|
||||||
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
|
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
|
||||||
)
|
)
|
||||||
|
|
|
@ -1998,7 +1998,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
if not res:
|
if not res:
|
||||||
raise SynapseError(404, "Could not find event %s" % (event_id,))
|
raise SynapseError(404, "Could not find event %s" % (event_id,))
|
||||||
|
|
||||||
return int(res["topological_ordering"]), int(res["stream_ordering"])
|
return int(res[0]), int(res[1])
|
||||||
|
|
||||||
async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
|
async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
|
||||||
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
|
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
|
||||||
|
|
|
@ -208,7 +208,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
)
|
)
|
||||||
if row is None:
|
if row is None:
|
||||||
return None
|
return None
|
||||||
return LocalMedia(media_id=media_id, **row)
|
return LocalMedia(
|
||||||
|
media_id=media_id,
|
||||||
|
media_type=row[0],
|
||||||
|
media_length=row[1],
|
||||||
|
upload_name=row[2],
|
||||||
|
created_ts=row[3],
|
||||||
|
quarantined_by=row[4],
|
||||||
|
url_cache=row[5],
|
||||||
|
last_access_ts=row[6],
|
||||||
|
safe_from_quarantine=row[7],
|
||||||
|
)
|
||||||
|
|
||||||
async def get_local_media_by_user_paginate(
|
async def get_local_media_by_user_paginate(
|
||||||
self,
|
self,
|
||||||
|
@ -541,7 +551,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
)
|
)
|
||||||
if row is None:
|
if row is None:
|
||||||
return row
|
return row
|
||||||
return RemoteMedia(media_origin=origin, media_id=media_id, **row)
|
return RemoteMedia(
|
||||||
|
media_origin=origin,
|
||||||
|
media_id=media_id,
|
||||||
|
media_type=row[0],
|
||||||
|
media_length=row[1],
|
||||||
|
upload_name=row[2],
|
||||||
|
created_ts=row[3],
|
||||||
|
filesystem_id=row[4],
|
||||||
|
last_access_ts=row[5],
|
||||||
|
quarantined_by=row[6],
|
||||||
|
)
|
||||||
|
|
||||||
async def store_cached_remote_media(
|
async def store_cached_remote_media(
|
||||||
self,
|
self,
|
||||||
|
@ -665,11 +685,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
if row is None:
|
if row is None:
|
||||||
return None
|
return None
|
||||||
return ThumbnailInfo(
|
return ThumbnailInfo(
|
||||||
width=row["thumbnail_width"],
|
width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
|
||||||
height=row["thumbnail_height"],
|
|
||||||
method=row["thumbnail_method"],
|
|
||||||
type=row["thumbnail_type"],
|
|
||||||
length=row["thumbnail_length"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
|
|
|
@ -13,7 +13,6 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.database import (
|
from synapse.storage.database import (
|
||||||
DatabasePool,
|
DatabasePool,
|
||||||
|
@ -138,23 +137,18 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||||
return 50
|
return 50
|
||||||
|
|
||||||
async def get_profileinfo(self, user_id: UserID) -> ProfileInfo:
|
async def get_profileinfo(self, user_id: UserID) -> ProfileInfo:
|
||||||
try:
|
|
||||||
profile = await self.db_pool.simple_select_one(
|
profile = await self.db_pool.simple_select_one(
|
||||||
table="profiles",
|
table="profiles",
|
||||||
keyvalues={"full_user_id": user_id.to_string()},
|
keyvalues={"full_user_id": user_id.to_string()},
|
||||||
retcols=("displayname", "avatar_url"),
|
retcols=("displayname", "avatar_url"),
|
||||||
desc="get_profileinfo",
|
desc="get_profileinfo",
|
||||||
|
allow_none=True,
|
||||||
)
|
)
|
||||||
except StoreError as e:
|
if profile is None:
|
||||||
if e.code == 404:
|
|
||||||
# no match
|
# no match
|
||||||
return ProfileInfo(None, None)
|
return ProfileInfo(None, None)
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
return ProfileInfo(
|
return ProfileInfo(avatar_url=profile[1], display_name=profile[0])
|
||||||
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_profile_displayname(self, user_id: UserID) -> Optional[str]:
|
async def get_profile_displayname(self, user_id: UserID) -> Optional[str]:
|
||||||
return await self.db_pool.simple_select_one_onecol(
|
return await self.db_pool.simple_select_one_onecol(
|
||||||
|
|
|
@ -468,8 +468,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||||
"before/after rule not found: %s" % (relative_to_rule,)
|
"before/after rule not found: %s" % (relative_to_rule,)
|
||||||
)
|
)
|
||||||
|
|
||||||
base_priority_class = res["priority_class"]
|
base_priority_class, base_rule_priority = res
|
||||||
base_rule_priority = res["priority"]
|
|
||||||
|
|
||||||
if base_priority_class != priority_class:
|
if base_priority_class != priority_class:
|
||||||
raise InconsistentRuleException(
|
raise InconsistentRuleException(
|
||||||
|
|
|
@ -701,8 +701,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
stream_ordering = int(res["stream_ordering"]) if res else None
|
stream_ordering = int(res[0]) if res else None
|
||||||
rx_ts = res["received_ts"] if res else 0
|
rx_ts = res[1] if res else 0
|
||||||
|
|
||||||
# We don't want to clobber receipts for more recent events, so we
|
# We don't want to clobber receipts for more recent events, so we
|
||||||
# have to compare orderings of existing receipts
|
# have to compare orderings of existing receipts
|
||||||
|
|
|
@ -425,17 +425,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
account timestamp as milliseconds since the epoch. None if the account
|
account timestamp as milliseconds since the epoch. None if the account
|
||||||
has not been renewed using the current token yet.
|
has not been renewed using the current token yet.
|
||||||
"""
|
"""
|
||||||
ret_dict = await self.db_pool.simple_select_one(
|
return cast(
|
||||||
|
Tuple[str, int, Optional[int]],
|
||||||
|
await self.db_pool.simple_select_one(
|
||||||
table="account_validity",
|
table="account_validity",
|
||||||
keyvalues={"renewal_token": renewal_token},
|
keyvalues={"renewal_token": renewal_token},
|
||||||
retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
|
retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
|
||||||
desc="get_user_from_renewal_token",
|
desc="get_user_from_renewal_token",
|
||||||
)
|
),
|
||||||
|
|
||||||
return (
|
|
||||||
ret_dict["user_id"],
|
|
||||||
ret_dict["expiration_ts_ms"],
|
|
||||||
ret_dict["token_used_ts_ms"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_renewal_token_for_user(self, user_id: str) -> str:
|
async def get_renewal_token_for_user(self, user_id: str) -> str:
|
||||||
|
@ -989,16 +986,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
Returns:
|
Returns:
|
||||||
user id, or None if no user id/threepid mapping exists
|
user id, or None if no user id/threepid mapping exists
|
||||||
"""
|
"""
|
||||||
ret = self.db_pool.simple_select_one_txn(
|
return self.db_pool.simple_select_one_onecol_txn(
|
||||||
txn,
|
txn,
|
||||||
"user_threepids",
|
"user_threepids",
|
||||||
{"medium": medium, "address": address},
|
{"medium": medium, "address": address},
|
||||||
["user_id"],
|
"user_id",
|
||||||
True,
|
True,
|
||||||
)
|
)
|
||||||
if ret:
|
|
||||||
return ret["user_id"]
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def user_add_threepid(
|
async def user_add_threepid(
|
||||||
self,
|
self,
|
||||||
|
@ -1435,16 +1429,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
if res is None:
|
if res is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
uses_allowed, pending, completed, expiry_time = res
|
||||||
|
|
||||||
# Check if the token has expired
|
# Check if the token has expired
|
||||||
now = self._clock.time_msec()
|
now = self._clock.time_msec()
|
||||||
if res["expiry_time"] and res["expiry_time"] < now:
|
if expiry_time and expiry_time < now:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check if the token has been used up
|
# Check if the token has been used up
|
||||||
if (
|
if uses_allowed and pending + completed >= uses_allowed:
|
||||||
res["uses_allowed"]
|
|
||||||
and res["pending"] + res["completed"] >= res["uses_allowed"]
|
|
||||||
):
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Otherwise, the token is valid
|
# Otherwise, the token is valid
|
||||||
|
@ -1490,8 +1483,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
# Override type because the return type is only optional if
|
# Override type because the return type is only optional if
|
||||||
# allow_none is True, and we don't want mypy throwing errors
|
# allow_none is True, and we don't want mypy throwing errors
|
||||||
# about None not being indexable.
|
# about None not being indexable.
|
||||||
res = cast(
|
pending, completed = cast(
|
||||||
Dict[str, Any],
|
Tuple[int, int],
|
||||||
self.db_pool.simple_select_one_txn(
|
self.db_pool.simple_select_one_txn(
|
||||||
txn,
|
txn,
|
||||||
"registration_tokens",
|
"registration_tokens",
|
||||||
|
@ -1506,8 +1499,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
"registration_tokens",
|
"registration_tokens",
|
||||||
keyvalues={"token": token},
|
keyvalues={"token": token},
|
||||||
updatevalues={
|
updatevalues={
|
||||||
"completed": res["completed"] + 1,
|
"completed": completed + 1,
|
||||||
"pending": res["pending"] - 1,
|
"pending": pending - 1,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1585,13 +1578,22 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
Returns:
|
Returns:
|
||||||
A dict, or None if token doesn't exist.
|
A dict, or None if token doesn't exist.
|
||||||
"""
|
"""
|
||||||
return await self.db_pool.simple_select_one(
|
row = await self.db_pool.simple_select_one(
|
||||||
"registration_tokens",
|
"registration_tokens",
|
||||||
keyvalues={"token": token},
|
keyvalues={"token": token},
|
||||||
retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
|
retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
desc="get_one_registration_token",
|
desc="get_one_registration_token",
|
||||||
)
|
)
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"token": row[0],
|
||||||
|
"uses_allowed": row[1],
|
||||||
|
"pending": row[2],
|
||||||
|
"completed": row[3],
|
||||||
|
"expiry_time": row[4],
|
||||||
|
}
|
||||||
|
|
||||||
async def generate_registration_token(
|
async def generate_registration_token(
|
||||||
self, length: int, chars: str
|
self, length: int, chars: str
|
||||||
|
@ -1714,7 +1716,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Get all info about the token so it can be sent in the response
|
# Get all info about the token so it can be sent in the response
|
||||||
return self.db_pool.simple_select_one_txn(
|
result = self.db_pool.simple_select_one_txn(
|
||||||
txn,
|
txn,
|
||||||
"registration_tokens",
|
"registration_tokens",
|
||||||
keyvalues={"token": token},
|
keyvalues={"token": token},
|
||||||
|
@ -1728,6 +1730,17 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
return {
|
||||||
|
"token": result[0],
|
||||||
|
"uses_allowed": result[1],
|
||||||
|
"pending": result[2],
|
||||||
|
"completed": result[3],
|
||||||
|
"expiry_time": result[4],
|
||||||
|
}
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"update_registration_token", _update_registration_token_txn
|
"update_registration_token", _update_registration_token_txn
|
||||||
)
|
)
|
||||||
|
@ -1939,11 +1952,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
keyvalues={"token": token},
|
keyvalues={"token": token},
|
||||||
updatevalues={"used_ts": ts},
|
updatevalues={"used_ts": ts},
|
||||||
)
|
)
|
||||||
user_id = values["user_id"]
|
(
|
||||||
expiry_ts = values["expiry_ts"]
|
user_id,
|
||||||
used_ts = values["used_ts"]
|
expiry_ts,
|
||||||
auth_provider_id = values["auth_provider_id"]
|
used_ts,
|
||||||
auth_provider_session_id = values["auth_provider_session_id"]
|
auth_provider_id,
|
||||||
|
auth_provider_session_id,
|
||||||
|
) = values
|
||||||
|
|
||||||
# Token was already used
|
# Token was already used
|
||||||
if used_ts is not None:
|
if used_ts is not None:
|
||||||
|
@ -2756,12 +2771,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
# reason, the next check is on the client secret, which is NOT NULL,
|
# reason, the next check is on the client secret, which is NOT NULL,
|
||||||
# so we don't have to worry about the client secret matching by
|
# so we don't have to worry about the client secret matching by
|
||||||
# accident.
|
# accident.
|
||||||
row = {"client_secret": None, "validated_at": None}
|
row = None, None
|
||||||
else:
|
else:
|
||||||
raise ThreepidValidationError("Unknown session_id")
|
raise ThreepidValidationError("Unknown session_id")
|
||||||
|
|
||||||
retrieved_client_secret = row["client_secret"]
|
retrieved_client_secret, validated_at = row
|
||||||
validated_at = row["validated_at"]
|
|
||||||
|
|
||||||
row = self.db_pool.simple_select_one_txn(
|
row = self.db_pool.simple_select_one_txn(
|
||||||
txn,
|
txn,
|
||||||
|
@ -2775,8 +2789,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
raise ThreepidValidationError(
|
raise ThreepidValidationError(
|
||||||
"Validation token not found or has expired"
|
"Validation token not found or has expired"
|
||||||
)
|
)
|
||||||
expires = row["expires"]
|
expires, next_link = row
|
||||||
next_link = row["next_link"]
|
|
||||||
|
|
||||||
if retrieved_client_secret != client_secret:
|
if retrieved_client_secret != client_secret:
|
||||||
raise ThreepidValidationError(
|
raise ThreepidValidationError(
|
||||||
|
|
|
@ -213,21 +213,31 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||||
logger.error("store_room with room_id=%s failed: %s", room_id, e)
|
logger.error("store_room with room_id=%s failed: %s", room_id, e)
|
||||||
raise StoreError(500, "Problem creating room.")
|
raise StoreError(500, "Problem creating room.")
|
||||||
|
|
||||||
async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]:
|
async def get_room(self, room_id: str) -> Optional[Tuple[bool, bool]]:
|
||||||
"""Retrieve a room.
|
"""Retrieve a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id: The ID of the room to retrieve.
|
room_id: The ID of the room to retrieve.
|
||||||
Returns:
|
Returns:
|
||||||
A dict containing the room information, or None if the room is unknown.
|
A tuple containing the room information:
|
||||||
|
* True if the room is public
|
||||||
|
* True if the room has an auth chain index
|
||||||
|
|
||||||
|
or None if the room is unknown.
|
||||||
"""
|
"""
|
||||||
return await self.db_pool.simple_select_one(
|
row = cast(
|
||||||
|
Optional[Tuple[Optional[Union[int, bool]], Optional[Union[int, bool]]]],
|
||||||
|
await self.db_pool.simple_select_one(
|
||||||
table="rooms",
|
table="rooms",
|
||||||
keyvalues={"room_id": room_id},
|
keyvalues={"room_id": room_id},
|
||||||
retcols=("room_id", "is_public", "creator", "has_auth_chain_index"),
|
retcols=("is_public", "has_auth_chain_index"),
|
||||||
desc="get_room",
|
desc="get_room",
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
if row is None:
|
||||||
|
return row
|
||||||
|
return bool(row[0]), bool(row[1])
|
||||||
|
|
||||||
async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]:
|
async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]:
|
||||||
"""Retrieve room with statistics.
|
"""Retrieve room with statistics.
|
||||||
|
@ -794,10 +804,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
if row:
|
if row:
|
||||||
return RatelimitOverride(
|
return RatelimitOverride(messages_per_second=row[0], burst_count=row[1])
|
||||||
messages_per_second=row["messages_per_second"],
|
|
||||||
burst_count=row["burst_count"],
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -1371,13 +1378,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||||
join.
|
join.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await self.db_pool.simple_select_one(
|
return cast(
|
||||||
|
Tuple[str, int],
|
||||||
|
await self.db_pool.simple_select_one(
|
||||||
table="partial_state_rooms",
|
table="partial_state_rooms",
|
||||||
keyvalues={"room_id": room_id},
|
keyvalues={"room_id": room_id},
|
||||||
retcols=("join_event_id", "device_lists_stream_id"),
|
retcols=("join_event_id", "device_lists_stream_id"),
|
||||||
desc="get_join_event_id_for_partial_state",
|
desc="get_join_event_id_for_partial_state",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return result["join_event_id"], result["device_lists_stream_id"]
|
|
||||||
|
|
||||||
def get_un_partial_stated_rooms_token(self, instance_name: str) -> int:
|
def get_un_partial_stated_rooms_token(self, instance_name: str) -> int:
|
||||||
return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer(
|
return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer(
|
||||||
|
|
|
@ -559,17 +559,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||||
"non-local user %s" % (user_id,),
|
"non-local user %s" % (user_id,),
|
||||||
)
|
)
|
||||||
|
|
||||||
results_dict = await self.db_pool.simple_select_one(
|
results = cast(
|
||||||
|
Optional[Tuple[str, str]],
|
||||||
|
await self.db_pool.simple_select_one(
|
||||||
"local_current_membership",
|
"local_current_membership",
|
||||||
{"room_id": room_id, "user_id": user_id},
|
{"room_id": room_id, "user_id": user_id},
|
||||||
("membership", "event_id"),
|
("membership", "event_id"),
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
desc="get_local_current_membership_for_user_in_room",
|
desc="get_local_current_membership_for_user_in_room",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if not results_dict:
|
if not results:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
return results_dict.get("membership"), results_dict.get("event_id")
|
return results
|
||||||
|
|
||||||
@cached(max_entries=500000, iterable=True)
|
@cached(max_entries=500000, iterable=True)
|
||||||
async def get_rooms_for_user_with_stream_ordering(
|
async def get_rooms_for_user_with_stream_ordering(
|
||||||
|
|
|
@ -1014,9 +1014,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
desc="get_position_for_event",
|
desc="get_position_for_event",
|
||||||
)
|
)
|
||||||
|
|
||||||
return PersistedEventPosition(
|
return PersistedEventPosition(row[1] or "master", row[0])
|
||||||
row["instance_name"] or "master", row["stream_ordering"]
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken:
|
async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken:
|
||||||
"""The stream token for an event
|
"""The stream token for an event
|
||||||
|
@ -1033,9 +1031,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
retcols=("stream_ordering", "topological_ordering"),
|
retcols=("stream_ordering", "topological_ordering"),
|
||||||
desc="get_topological_token_for_event",
|
desc="get_topological_token_for_event",
|
||||||
)
|
)
|
||||||
return RoomStreamToken(
|
return RoomStreamToken(topological=row[1], stream=row[0])
|
||||||
topological=row["topological_ordering"], stream=row["stream_ordering"]
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
|
async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
|
||||||
"""Gets the topological token in a room after or at the given stream
|
"""Gets the topological token in a room after or at the given stream
|
||||||
|
@ -1180,26 +1176,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
dict
|
dict
|
||||||
"""
|
"""
|
||||||
|
|
||||||
results = self.db_pool.simple_select_one_txn(
|
stream_ordering, topological_ordering = cast(
|
||||||
|
Tuple[int, int],
|
||||||
|
self.db_pool.simple_select_one_txn(
|
||||||
txn,
|
txn,
|
||||||
"events",
|
"events",
|
||||||
keyvalues={"event_id": event_id, "room_id": room_id},
|
keyvalues={"event_id": event_id, "room_id": room_id},
|
||||||
retcols=["stream_ordering", "topological_ordering"],
|
retcols=["stream_ordering", "topological_ordering"],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# This cannot happen as `allow_none=False`.
|
|
||||||
assert results is not None
|
|
||||||
|
|
||||||
# Paginating backwards includes the event at the token, but paginating
|
# Paginating backwards includes the event at the token, but paginating
|
||||||
# forward doesn't.
|
# forward doesn't.
|
||||||
before_token = RoomStreamToken(
|
before_token = RoomStreamToken(
|
||||||
topological=results["topological_ordering"] - 1,
|
topological=topological_ordering - 1, stream=stream_ordering
|
||||||
stream=results["stream_ordering"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
after_token = RoomStreamToken(
|
after_token = RoomStreamToken(
|
||||||
topological=results["topological_ordering"],
|
topological=topological_ordering, stream=stream_ordering
|
||||||
stream=results["stream_ordering"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
rows, start_token = self._paginate_room_events_txn(
|
rows, start_token = self._paginate_room_events_txn(
|
||||||
|
|
|
@ -183,7 +183,9 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
Returns: the task if available, `None` otherwise
|
Returns: the task if available, `None` otherwise
|
||||||
"""
|
"""
|
||||||
row = await self.db_pool.simple_select_one(
|
row = cast(
|
||||||
|
Optional[ScheduledTaskRow],
|
||||||
|
await self.db_pool.simple_select_one(
|
||||||
table="scheduled_tasks",
|
table="scheduled_tasks",
|
||||||
keyvalues={"id": id},
|
keyvalues={"id": id},
|
||||||
retcols=(
|
retcols=(
|
||||||
|
@ -198,24 +200,10 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
|
||||||
),
|
),
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
desc="get_scheduled_task",
|
desc="get_scheduled_task",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
|
||||||
TaskSchedulerWorkerStore._convert_row_to_task(
|
|
||||||
(
|
|
||||||
row["id"],
|
|
||||||
row["action"],
|
|
||||||
row["status"],
|
|
||||||
row["timestamp"],
|
|
||||||
row["resource_id"],
|
|
||||||
row["params"],
|
|
||||||
row["result"],
|
|
||||||
row["error"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if row
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
async def delete_scheduled_task(self, id: str) -> None:
|
async def delete_scheduled_task(self, id: str) -> None:
|
||||||
"""Delete a specific task from its id.
|
"""Delete a specific task from its id.
|
||||||
|
|
|
@ -118,19 +118,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||||
txn,
|
txn,
|
||||||
table="received_transactions",
|
table="received_transactions",
|
||||||
keyvalues={"transaction_id": transaction_id, "origin": origin},
|
keyvalues={"transaction_id": transaction_id, "origin": origin},
|
||||||
retcols=(
|
retcols=("response_code", "response_json"),
|
||||||
"transaction_id",
|
|
||||||
"origin",
|
|
||||||
"ts",
|
|
||||||
"response_code",
|
|
||||||
"response_json",
|
|
||||||
"has_been_referenced",
|
|
||||||
),
|
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if result and result["response_code"]:
|
# If the result exists and the response code is non-0.
|
||||||
return result["response_code"], db_to_json(result["response_json"])
|
if result and result[0]:
|
||||||
|
return result[0], db_to_json(result[1])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
@ -200,8 +194,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||||
|
|
||||||
# check we have a row and retry_last_ts is not null or zero
|
# check we have a row and retry_last_ts is not null or zero
|
||||||
# (retry_last_ts can't be negative)
|
# (retry_last_ts can't be negative)
|
||||||
if result and result["retry_last_ts"]:
|
if result and result[1]:
|
||||||
return DestinationRetryTimings(**result)
|
return DestinationRetryTimings(
|
||||||
|
failure_ts=result[0], retry_last_ts=result[1], retry_interval=result[2]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -122,9 +122,13 @@ class UIAuthWorkerStore(SQLBaseStore):
|
||||||
desc="get_ui_auth_session",
|
desc="get_ui_auth_session",
|
||||||
)
|
)
|
||||||
|
|
||||||
result["clientdict"] = db_to_json(result["clientdict"])
|
return UIAuthSessionData(
|
||||||
|
session_id,
|
||||||
return UIAuthSessionData(session_id, **result)
|
clientdict=db_to_json(result[0]),
|
||||||
|
uri=result[1],
|
||||||
|
method=result[2],
|
||||||
|
description=result[3],
|
||||||
|
)
|
||||||
|
|
||||||
async def mark_ui_auth_stage_complete(
|
async def mark_ui_auth_stage_complete(
|
||||||
self,
|
self,
|
||||||
|
@ -231,18 +235,15 @@ class UIAuthWorkerStore(SQLBaseStore):
|
||||||
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
|
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
# Get the current value.
|
# Get the current value.
|
||||||
result = cast(
|
result = self.db_pool.simple_select_one_onecol_txn(
|
||||||
Dict[str, Any],
|
|
||||||
self.db_pool.simple_select_one_txn(
|
|
||||||
txn,
|
txn,
|
||||||
table="ui_auth_sessions",
|
table="ui_auth_sessions",
|
||||||
keyvalues={"session_id": session_id},
|
keyvalues={"session_id": session_id},
|
||||||
retcols=("serverdict",),
|
retcol="serverdict",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update it and add it back to the database.
|
# Update it and add it back to the database.
|
||||||
serverdict = db_to_json(result["serverdict"])
|
serverdict = db_to_json(result)
|
||||||
serverdict[key] = value
|
serverdict[key] = value
|
||||||
|
|
||||||
self.db_pool.simple_update_one_txn(
|
self.db_pool.simple_update_one_txn(
|
||||||
|
@ -265,14 +266,14 @@ class UIAuthWorkerStore(SQLBaseStore):
|
||||||
Raises:
|
Raises:
|
||||||
StoreError if the session cannot be found.
|
StoreError if the session cannot be found.
|
||||||
"""
|
"""
|
||||||
result = await self.db_pool.simple_select_one(
|
result = await self.db_pool.simple_select_one_onecol(
|
||||||
table="ui_auth_sessions",
|
table="ui_auth_sessions",
|
||||||
keyvalues={"session_id": session_id},
|
keyvalues={"session_id": session_id},
|
||||||
retcols=("serverdict",),
|
retcol="serverdict",
|
||||||
desc="get_ui_auth_session_data",
|
desc="get_ui_auth_session_data",
|
||||||
)
|
)
|
||||||
|
|
||||||
serverdict = db_to_json(result["serverdict"])
|
serverdict = db_to_json(result)
|
||||||
|
|
||||||
return serverdict.get(key, default)
|
return serverdict.get(key, default)
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,6 @@ from typing import (
|
||||||
Collection,
|
Collection,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Mapping,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Set,
|
Set,
|
||||||
|
@ -833,13 +832,25 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||||
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
|
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]:
|
async def _get_user_in_directory(
|
||||||
return await self.db_pool.simple_select_one(
|
self, user_id: str
|
||||||
|
) -> Optional[Tuple[Optional[str], Optional[str]]]:
|
||||||
|
"""
|
||||||
|
Fetch the user information in the user directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None if the user is unknown, otherwise a tuple of display name and
|
||||||
|
avatar URL (both of which may be None).
|
||||||
|
"""
|
||||||
|
return cast(
|
||||||
|
Optional[Tuple[Optional[str], Optional[str]]],
|
||||||
|
await self.db_pool.simple_select_one(
|
||||||
table="user_directory",
|
table="user_directory",
|
||||||
keyvalues={"user_id": user_id},
|
keyvalues={"user_id": user_id},
|
||||||
retcols=("display_name", "avatar_url"),
|
retcols=("display_name", "avatar_url"),
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
desc="get_user_in_directory",
|
desc="get_user_in_directory",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None:
|
async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None:
|
||||||
|
|
|
@ -84,7 +84,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type])
|
cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type])
|
||||||
|
|
||||||
return self.get_success(
|
row = self.get_success(
|
||||||
self.store.db_pool.simple_select_one(
|
self.store.db_pool.simple_select_one(
|
||||||
table + "_current",
|
table + "_current",
|
||||||
{id_col: stat_id},
|
{id_col: stat_id},
|
||||||
|
@ -93,6 +93,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return None if row is None else dict(zip(cols, row))
|
||||||
|
|
||||||
def _perform_background_initial_update(self) -> None:
|
def _perform_background_initial_update(self) -> None:
|
||||||
# Do the initial population of the stats via the background update
|
# Do the initial population of the stats via the background update
|
||||||
self._add_background_updates()
|
self._add_background_updates()
|
||||||
|
|
|
@ -366,7 +366,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
profile = self.get_success(self.store._get_user_in_directory(regular_user_id))
|
profile = self.get_success(self.store._get_user_in_directory(regular_user_id))
|
||||||
assert profile is not None
|
assert profile is not None
|
||||||
self.assertTrue(profile["display_name"] == display_name)
|
self.assertTrue(profile[0] == display_name)
|
||||||
|
|
||||||
def test_handle_local_profile_change_with_deactivated_user(self) -> None:
|
def test_handle_local_profile_change_with_deactivated_user(self) -> None:
|
||||||
# create user
|
# create user
|
||||||
|
@ -385,7 +385,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
# profile is in directory
|
# profile is in directory
|
||||||
profile = self.get_success(self.store._get_user_in_directory(r_user_id))
|
profile = self.get_success(self.store._get_user_in_directory(r_user_id))
|
||||||
assert profile is not None
|
assert profile is not None
|
||||||
self.assertTrue(profile["display_name"] == display_name)
|
self.assertEqual(profile[0], display_name)
|
||||||
|
|
||||||
# deactivate user
|
# deactivate user
|
||||||
self.get_success(self.store.set_user_deactivated_status(r_user_id, True))
|
self.get_success(self.store.set_user_deactivated_status(r_user_id, True))
|
||||||
|
|
|
@ -2706,7 +2706,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
# is in user directory
|
# is in user directory
|
||||||
profile = self.get_success(self.store._get_user_in_directory(self.other_user))
|
profile = self.get_success(self.store._get_user_in_directory(self.other_user))
|
||||||
assert profile is not None
|
assert profile is not None
|
||||||
self.assertTrue(profile["display_name"] == "User")
|
self.assertEqual(profile[0], "User")
|
||||||
|
|
||||||
# Deactivate user
|
# Deactivate user
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
|
|
|
@ -139,12 +139,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||||
#
|
#
|
||||||
# Note that we don't have the UI Auth session ID, so just pull out the single
|
# Note that we don't have the UI Auth session ID, so just pull out the single
|
||||||
# row.
|
# row.
|
||||||
ui_auth_data = self.get_success(
|
result = self.get_success(
|
||||||
self.store.db_pool.simple_select_one(
|
self.store.db_pool.simple_select_one_onecol(
|
||||||
"ui_auth_sessions", keyvalues={}, retcols=("clientdict",)
|
"ui_auth_sessions", keyvalues={}, retcol="clientdict"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
client_dict = db_to_json(ui_auth_data["clientdict"])
|
client_dict = db_to_json(result)
|
||||||
self.assertNotIn("new_password", client_dict)
|
self.assertNotIn("new_password", client_dict)
|
||||||
|
|
||||||
@override_config({"rc_3pid_validation": {"burst_count": 3}})
|
@override_config({"rc_3pid_validation": {"burst_count": 3}})
|
||||||
|
|
|
@ -270,15 +270,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertLessEqual(det_data.items(), channel.json_body.items())
|
self.assertLessEqual(det_data.items(), channel.json_body.items())
|
||||||
|
|
||||||
# Check the `completed` counter has been incremented and pending is 0
|
# Check the `completed` counter has been incremented and pending is 0
|
||||||
res = self.get_success(
|
pending, completed = self.get_success(
|
||||||
store.db_pool.simple_select_one(
|
store.db_pool.simple_select_one(
|
||||||
"registration_tokens",
|
"registration_tokens",
|
||||||
keyvalues={"token": token},
|
keyvalues={"token": token},
|
||||||
retcols=["pending", "completed"],
|
retcols=["pending", "completed"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(res["completed"], 1)
|
self.assertEqual(completed, 1)
|
||||||
self.assertEqual(res["pending"], 0)
|
self.assertEqual(pending, 0)
|
||||||
|
|
||||||
@override_config({"registration_requires_token": True})
|
@override_config({"registration_requires_token": True})
|
||||||
def test_POST_registration_token_invalid(self) -> None:
|
def test_POST_registration_token_invalid(self) -> None:
|
||||||
|
@ -372,15 +372,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
params1["auth"]["type"] = LoginType.DUMMY
|
params1["auth"]["type"] = LoginType.DUMMY
|
||||||
self.make_request(b"POST", self.url, params1)
|
self.make_request(b"POST", self.url, params1)
|
||||||
# Check pending=0 and completed=1
|
# Check pending=0 and completed=1
|
||||||
res = self.get_success(
|
pending, completed = self.get_success(
|
||||||
store.db_pool.simple_select_one(
|
store.db_pool.simple_select_one(
|
||||||
"registration_tokens",
|
"registration_tokens",
|
||||||
keyvalues={"token": token},
|
keyvalues={"token": token},
|
||||||
retcols=["pending", "completed"],
|
retcols=["pending", "completed"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(res["pending"], 0)
|
self.assertEqual(pending, 0)
|
||||||
self.assertEqual(res["completed"], 1)
|
self.assertEqual(completed, 1)
|
||||||
|
|
||||||
# Check auth still fails when using token with session2
|
# Check auth still fails when using token with session2
|
||||||
channel = self.make_request(b"POST", self.url, params2)
|
channel = self.make_request(b"POST", self.url, params2)
|
||||||
|
|
|
@ -222,7 +222,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret)
|
self.assertEqual((1, 2, 3), ret)
|
||||||
self.mock_txn.execute.assert_called_once_with(
|
self.mock_txn.execute.assert_called_once_with(
|
||||||
"SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"]
|
"SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"]
|
||||||
)
|
)
|
||||||
|
@ -243,7 +243,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertFalse(ret)
|
self.assertIsNone(ret)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:
|
def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:
|
||||||
|
|
|
@ -42,16 +42,9 @@ class RoomStoreTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_room(self) -> None:
|
def test_get_room(self) -> None:
|
||||||
res = self.get_success(self.store.get_room(self.room.to_string()))
|
room = self.get_success(self.store.get_room(self.room.to_string()))
|
||||||
assert res is not None
|
assert room is not None
|
||||||
self.assertLessEqual(
|
self.assertTrue(room[0])
|
||||||
{
|
|
||||||
"room_id": self.room.to_string(),
|
|
||||||
"creator": self.u_creator.to_string(),
|
|
||||||
"is_public": True,
|
|
||||||
}.items(),
|
|
||||||
res.items(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_get_room_unknown_room(self) -> None:
|
def test_get_room_unknown_room(self) -> None:
|
||||||
self.assertIsNone(self.get_success(self.store.get_room("!uknown:test")))
|
self.assertIsNone(self.get_success(self.store.get_room("!uknown:test")))
|
||||||
|
|
Loading…
Reference in New Issue