claim_local_one_time_keys: pass in result object

Rather than making the caller merge the results together, do it inside
`claim_local_one_time_keys`.
This commit is contained in:
Richard van der Hoff 2023-10-27 15:10:55 +01:00
parent dd45ba4d67
commit 27546ac171
2 changed files with 25 additions and 30 deletions

View File

@ -1000,18 +1000,12 @@ class FederationServer(FederationBase):
self, query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool self, query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool
) -> Dict[str, Any]: ) -> Dict[str, Any]:
log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
results = await self._e2e_keys_handler.claim_local_one_time_keys(
query, always_include_fallback_keys=always_include_fallback_keys
)
json_result: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {} json_result: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {}
for result in results: await self._e2e_keys_handler.claim_local_one_time_keys(
for user_id, device_keys in result.items(): query,
for device_id, keys in device_keys.items(): always_include_fallback_keys=always_include_fallback_keys,
for key_id, key in keys.items(): result_dict=json_result,
json_result.setdefault(user_id, {}).setdefault(device_id, {})[ )
key_id
] = key
logger.info( logger.info(
"Claimed one-time-keys: %s", "Claimed one-time-keys: %s",

View File

@ -561,7 +561,8 @@ class E2eKeysHandler:
self, self,
local_query: List[Tuple[str, str, str, int]], local_query: List[Tuple[str, str, str, int]],
always_include_fallback_keys: bool, always_include_fallback_keys: bool,
) -> Iterable[Mapping[str, Mapping[str, Mapping[str, JsonSerializable]]]]: result_dict: Dict[str, Dict[str, Dict[str, JsonSerializable]]],
) -> None:
"""Claim one time keys for local users. """Claim one time keys for local users.
1. Attempt to claim OTKs from the database. 1. Attempt to claim OTKs from the database.
@ -571,11 +572,20 @@ class E2eKeysHandler:
Args: Args:
local_query: An iterable of tuples of (user ID, device ID, algorithm). local_query: An iterable of tuples of (user ID, device ID, algorithm).
always_include_fallback_keys: True to always include fallback keys. always_include_fallback_keys: True to always include fallback keys.
result_dict: A dict to update with the results.
Returns: {user_id -> { device_id -> { key_id -> key string/object }}}
An iterable of maps of user ID -> a map device ID -> a map of key ID -> key.
""" """
def update_result_dict(
results: Mapping[str, Mapping[str, Mapping[str, JsonSerializable]]]
) -> None:
"""Stash results from a store query in `result_dict`"""
for user_id, device_keys in results.items():
user_result_dict = result_dict.setdefault(user_id, {})
for device_id, keys in device_keys.items():
device_result_dict = user_result_dict.setdefault(device_id, {})
device_result_dict.update(keys)
# Cap the number of OTKs that can be claimed at once to avoid abuse. # Cap the number of OTKs that can be claimed at once to avoid abuse.
local_query = [ local_query = [
(user_id, device_id, algorithm, min(count, 5)) (user_id, device_id, algorithm, min(count, 5))
@ -583,6 +593,7 @@ class E2eKeysHandler:
] ]
otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query) otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
update_result_dict(otk_results)
# If the application services have not provided any keys via the C-S # If the application services have not provided any keys via the C-S
# API, query it directly for one-time keys. # API, query it directly for one-time keys.
@ -593,6 +604,7 @@ class E2eKeysHandler:
appservice_results, appservice_results,
not_found, not_found,
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found) ) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
update_result_dict(appservice_results)
else: else:
appservice_results = {} appservice_results = {}
@ -647,10 +659,7 @@ class E2eKeysHandler:
# For each user that does not have a one-time keys available, see if # For each user that does not have a one-time keys available, see if
# there is a fallback key. # there is a fallback key.
fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query) fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query)
update_result_dict(fallback_results)
# Return the results in order, each item from the input query should
# only appear once in the combined list.
return (otk_results, appservice_results, fallback_results)
@trace @trace
async def claim_one_time_keys( async def claim_one_time_keys(
@ -676,19 +685,11 @@ class E2eKeysHandler:
set_tag("local_key_query", str(local_query)) set_tag("local_key_query", str(local_query))
set_tag("remote_key_query", str(remote_queries)) set_tag("remote_key_query", str(remote_queries))
results = await self.claim_local_one_time_keys(
local_query, always_include_fallback_keys
)
# A map of user ID -> device ID -> key ID -> key. # A map of user ID -> device ID -> key ID -> key.
json_result: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {} json_result: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {}
for result in results: await self.claim_local_one_time_keys(
for user_id, device_keys in result.items(): local_query, always_include_fallback_keys, json_result
for device_id, keys in device_keys.items(): )
for key_id, key in keys.items():
json_result.setdefault(user_id, {}).setdefault(
device_id, {}
).update({key_id: key})
# Remote failures. # Remote failures.
failures: Dict[str, JsonDict] = {} failures: Dict[str, JsonDict] = {}