diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index d4dbfd3dca..54c25dfe3e 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1000,18 +1000,12 @@ class FederationServer(FederationBase): self, query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool ) -> Dict[str, Any]: log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) - results = await self._e2e_keys_handler.claim_local_one_time_keys( - query, always_include_fallback_keys=always_include_fallback_keys - ) - json_result: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {} - for result in results: - for user_id, device_keys in result.items(): - for device_id, keys in device_keys.items(): - for key_id, key in keys.items(): - json_result.setdefault(user_id, {}).setdefault(device_id, {})[ - key_id - ] = key + await self._e2e_keys_handler.claim_local_one_time_keys( + query, + always_include_fallback_keys=always_include_fallback_keys, + result_dict=json_result, + ) logger.info( "Claimed one-time-keys: %s", diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 9e51a34d70..5c356b0cff 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -561,7 +561,8 @@ class E2eKeysHandler: self, local_query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool, - ) -> Iterable[Mapping[str, Mapping[str, Mapping[str, JsonSerializable]]]]: + result_dict: Dict[str, Dict[str, Dict[str, JsonSerializable]]], + ) -> None: """Claim one time keys for local users. 1. Attempt to claim OTKs from the database. @@ -571,11 +572,20 @@ class E2eKeysHandler: Args: local_query: An iterable of tuples of (user ID, device ID, algorithm). always_include_fallback_keys: True to always include fallback keys. - - Returns: - An iterable of maps of user ID -> a map device ID -> a map of key ID -> key. + result_dict: A dict to update with the results. + {user_id -> { device_id -> { key_id -> key string/object }}} """ + def update_result_dict( + results: Mapping[str, Mapping[str, Mapping[str, JsonSerializable]]] + ) -> None: + """Stash results from a store query in `result_dict`""" + for user_id, device_keys in results.items(): + user_result_dict = result_dict.setdefault(user_id, {}) + for device_id, keys in device_keys.items(): + device_result_dict = user_result_dict.setdefault(device_id, {}) + device_result_dict.update(keys) + # Cap the number of OTKs that can be claimed at once to avoid abuse. local_query = [ (user_id, device_id, algorithm, min(count, 5)) @@ -583,6 +593,7 @@ class E2eKeysHandler: ] otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query) + update_result_dict(otk_results) # If the application services have not provided any keys via the C-S # API, query it directly for one-time keys. @@ -593,6 +604,7 @@ class E2eKeysHandler: appservice_results, not_found, ) = await self._appservice_handler.claim_e2e_one_time_keys(not_found) + update_result_dict(appservice_results) else: appservice_results = {} @@ -647,10 +659,7 @@ class E2eKeysHandler: # For each user that does not have a one-time keys available, see if # there is a fallback key. fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query) - - # Return the results in order, each item from the input query should - # only appear once in the combined list. - return (otk_results, appservice_results, fallback_results) + update_result_dict(fallback_results) @trace async def claim_one_time_keys( @@ -676,19 +685,11 @@ class E2eKeysHandler: set_tag("local_key_query", str(local_query)) set_tag("remote_key_query", str(remote_queries)) - results = await self.claim_local_one_time_keys( - local_query, always_include_fallback_keys - ) - # A map of user ID -> device ID -> key ID -> key. json_result: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {} - for result in results: - for user_id, device_keys in result.items(): - for device_id, keys in device_keys.items(): - for key_id, key in keys.items(): - json_result.setdefault(user_id, {}).setdefault( - device_id, {} - ).update({key_id: key}) + await self.claim_local_one_time_keys( + local_query, always_include_fallback_keys, json_result + ) # Remote failures. failures: Dict[str, JsonDict] = {}