mirror of
https://github.com/matrix-org/synapse.git
synced 2025-06-26 21:26:08 +00:00
claim_local_one_time_keys
: pass in result object
Rather than making the caller merge the results together, do it inside `claim_local_one_time_keys`.
This commit is contained in:
parent
dd45ba4d67
commit
27546ac171
@ -1000,18 +1000,12 @@ class FederationServer(FederationBase):
|
||||
self, query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool
|
||||
) -> Dict[str, Any]:
|
||||
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
|
||||
results = await self._e2e_keys_handler.claim_local_one_time_keys(
|
||||
query, always_include_fallback_keys=always_include_fallback_keys
|
||||
)
|
||||
|
||||
json_result: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {}
|
||||
for result in results:
|
||||
for user_id, device_keys in result.items():
|
||||
for device_id, keys in device_keys.items():
|
||||
for key_id, key in keys.items():
|
||||
json_result.setdefault(user_id, {}).setdefault(device_id, {})[
|
||||
key_id
|
||||
] = key
|
||||
await self._e2e_keys_handler.claim_local_one_time_keys(
|
||||
query,
|
||||
always_include_fallback_keys=always_include_fallback_keys,
|
||||
result_dict=json_result,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Claimed one-time-keys: %s",
|
||||
|
@ -561,7 +561,8 @@ class E2eKeysHandler:
|
||||
self,
|
||||
local_query: List[Tuple[str, str, str, int]],
|
||||
always_include_fallback_keys: bool,
|
||||
) -> Iterable[Mapping[str, Mapping[str, Mapping[str, JsonSerializable]]]]:
|
||||
result_dict: Dict[str, Dict[str, Dict[str, JsonSerializable]]],
|
||||
) -> None:
|
||||
"""Claim one time keys for local users.
|
||||
|
||||
1. Attempt to claim OTKs from the database.
|
||||
@ -571,11 +572,20 @@ class E2eKeysHandler:
|
||||
Args:
|
||||
local_query: An iterable of tuples of (user ID, device ID, algorithm).
|
||||
always_include_fallback_keys: True to always include fallback keys.
|
||||
|
||||
Returns:
|
||||
An iterable of maps of user ID -> a map device ID -> a map of key ID -> key.
|
||||
result_dict: A dict to update with the results.
|
||||
{user_id -> { device_id -> { key_id -> key string/object }}}
|
||||
"""
|
||||
|
||||
def update_result_dict(
|
||||
results: Mapping[str, Mapping[str, Mapping[str, JsonSerializable]]]
|
||||
) -> None:
|
||||
"""Stash results from a store query in `result_dict`"""
|
||||
for user_id, device_keys in results.items():
|
||||
user_result_dict = result_dict.setdefault(user_id, {})
|
||||
for device_id, keys in device_keys.items():
|
||||
device_result_dict = user_result_dict.setdefault(device_id, {})
|
||||
device_result_dict.update(keys)
|
||||
|
||||
# Cap the number of OTKs that can be claimed at once to avoid abuse.
|
||||
local_query = [
|
||||
(user_id, device_id, algorithm, min(count, 5))
|
||||
@ -583,6 +593,7 @@ class E2eKeysHandler:
|
||||
]
|
||||
|
||||
otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
|
||||
update_result_dict(otk_results)
|
||||
|
||||
# If the application services have not provided any keys via the C-S
|
||||
# API, query it directly for one-time keys.
|
||||
@ -593,6 +604,7 @@ class E2eKeysHandler:
|
||||
appservice_results,
|
||||
not_found,
|
||||
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
|
||||
update_result_dict(appservice_results)
|
||||
else:
|
||||
appservice_results = {}
|
||||
|
||||
@ -647,10 +659,7 @@ class E2eKeysHandler:
|
||||
# For each user that does not have a one-time keys available, see if
|
||||
# there is a fallback key.
|
||||
fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query)
|
||||
|
||||
# Return the results in order, each item from the input query should
|
||||
# only appear once in the combined list.
|
||||
return (otk_results, appservice_results, fallback_results)
|
||||
update_result_dict(fallback_results)
|
||||
|
||||
@trace
|
||||
async def claim_one_time_keys(
|
||||
@ -676,19 +685,11 @@ class E2eKeysHandler:
|
||||
set_tag("local_key_query", str(local_query))
|
||||
set_tag("remote_key_query", str(remote_queries))
|
||||
|
||||
results = await self.claim_local_one_time_keys(
|
||||
local_query, always_include_fallback_keys
|
||||
)
|
||||
|
||||
# A map of user ID -> device ID -> key ID -> key.
|
||||
json_result: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {}
|
||||
for result in results:
|
||||
for user_id, device_keys in result.items():
|
||||
for device_id, keys in device_keys.items():
|
||||
for key_id, key in keys.items():
|
||||
json_result.setdefault(user_id, {}).setdefault(
|
||||
device_id, {}
|
||||
).update({key_id: key})
|
||||
await self.claim_local_one_time_keys(
|
||||
local_query, always_include_fallback_keys, json_result
|
||||
)
|
||||
|
||||
# Remote failures.
|
||||
failures: Dict[str, JsonDict] = {}
|
||||
|
Loading…
x
Reference in New Issue
Block a user