diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 9f830e7094..568efd641e 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -15,7 +15,6 @@ import enum from typing import TYPE_CHECKING, Any, Optional -import attr import attr.validators from synapse.api.errors import LimitExceededError @@ -419,3 +418,9 @@ class ExperimentalConfig(Config): self.msc4028_push_encrypted_events = experimental.get( "msc4028_push_encrypted_events", False ) + + # MSC4072: Return an empty dict from /keys/claim for unknown devices or those + # with exhausted OTKs + self.msc4072_empty_dict_for_exhausted_devices = experimental.get( + "msc4072_empty_dict_for_exhausted_devices", False + ) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 7d44127ebf..bb628fdb00 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -592,6 +592,12 @@ class E2eKeysHandler: for user_id, device_id, algorithm, count in local_query ] + # prepopulate the response to make sure that all queried users/devices are + # included, even if the user/device is unknown or has run out of OTKs + if self.config.experimental.msc4072_empty_dict_for_exhausted_devices: + for user_id, device_id, _, _ in local_query: + result_dict.setdefault(user_id, {}).setdefault(device_id, {}) + otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query) update_result_dict(otk_results) @@ -669,6 +675,25 @@ class E2eKeysHandler: timeout: Optional[int], always_include_fallback_keys: bool, ) -> JsonDict: + """ + Handle a /keys/claim request. + + Handles requests for local users with a db lookup, and makes federation + requests for remote users. + + Args: + query: map from user ID, to map from device ID, to map from algorithm name + to number of keys needed + (``{user_id: {device_id: {algorithm: number_of keys}}}``) + + user: The user id of the requesting user + + timeout: number of milliseconds to wait for the response from remote servers. + ``config.federation.client_timeout_ms`` by default. + + always_include_fallback_keys: True to always include fallback keys, even + for devices which still have one-time keys. + """ local_query: List[Tuple[str, str, str, int]] = [] remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {} @@ -707,9 +732,18 @@ class E2eKeysHandler: remote_result = await self.federation.claim_client_keys( user, destination, device_keys, timeout=timeout ) - for user_id, keys in remote_result["one_time_keys"].items(): - if user_id in device_keys: - json_result[user_id] = keys + try: + destination_result = filter_remote_claimed_keys( + device_keys, + remote_result, + self.config.experimental.msc4072_empty_dict_for_exhausted_devices, + ) + except Exception as e: + logger.warning( + f"Error parsing /keys/claim response from server {destination}", + e, + ) + raise except Exception as e: failure = _exception_to_failure(e) @@ -717,6 +751,11 @@ class E2eKeysHandler: set_tag("error", True) set_tag("reason", str(failure)) + else: + # only populate json_result once we know there will not be an entry in + # failures for this destination. + json_result.update(destination_result) + await make_deferred_yieldable( defer.gatherResults( [ @@ -1632,3 +1671,51 @@ class SigningKeyEduUpdater: device_ids = device_ids + new_device_ids await self._device_handler.notify_device_update(user_id, device_ids) + + +def filter_remote_claimed_keys( + destination_query: Dict[str, Dict[str, Dict[str, int]]], + remote_response: JsonDict, + msc4072_empty_dict_for_exhausted_devices: bool, +) -> JsonDict: + """ + Process the response from a federation /keys/claim request + + Checks that there are no redundant entries, and that all the entries that + should be there are present. + + Args: + destination_query: user->device->key map that was sent in the request to + this server + remote_response: response from the remote server + msc4072_empty_dict_for_exhausted_devices: true to include an entry in the + result for every queried device + + Returns: + user->device->key map to be merged into the results + """ + remote_otks = remote_response["one_time_keys"] + + destination_result: JsonDict = {} + + if msc4072_empty_dict_for_exhausted_devices: + # We need to make sure there is an entry in destination_result for + # every queried (user, device) even if the remote server did not + # populate it; so we iterate the query and populate + # destination_result based on the federation result. + for user_id, user_query in destination_query.items(): + remote_user_result = remote_otks.get(user_id, {}) + destination_user_result = destination_result[user_id] = {} + for device_id in user_query.keys(): + destination_user_result[device_id] = remote_user_result.get( + device_id, {} + ) + else: + # We need to make sure that remote servers do not poison the + # result with data for users which do not belong to it, so we only + # copy data for users that were queried. + for user_id, keys in remote_otks.items(): + if user_id in destination_query: + destination_result[user_id] = keys + + return destination_result diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index c5556f2844..8a105c5712 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -144,35 +144,81 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): SynapseError, ) - def test_claim_one_time_key(self) -> None: - local_user = "@boris:" + self.hs.hostname - device_id = "xyz" - keys = {"alg1:k1": "key1"} + @parameterized.expand([(True,), (False,)]) + def test_claim_one_time_key(self, msc4072: bool) -> None: + self.hs.config.experimental.msc4072_empty_dict_for_exhausted_devices = msc4072 + local_known_user = "@boris:" + self.hs.hostname + device_id = "xyz" + local_unknown_user = "@charlie:" + self.hs.hostname + + remote_known_user = "@dave:xyz" + remote_unknown_user = "@errol:xyz" + + # upload a key for the local user res = self.get_success( self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + local_known_user, device_id, {"one_time_keys": {"alg1:k1": "key1"}} ) ) self.assertDictEqual( res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}} ) + # mock out the response for remote users. We pretend that the remote server + # hasn't heard of MSC4072 and returns an incomplete result. (Even once + # MSC4072 is stable, we still need to handle incomplete results.) + # + # we also include a spurious result to check it gets filtered out. + self.hs.get_federation_client().claim_client_keys = mock.AsyncMock( # type: ignore[method-assign] + return_value={ + "one_time_keys": { + remote_known_user: {"ghi": {"alg1": "keykey"}}, + "@other:xyz": {"zzz": {"alg1": "dodgykey"}}, + } + } + ) + res2 = self.get_success( self.handler.claim_one_time_keys( - {local_user: {device_id: {"alg1": 1}}}, + { + local_known_user: {device_id: {"alg1": 1}, "abc": {"alg2": 1}}, + local_unknown_user: {"def": {"alg1": 1}}, + remote_known_user: {"ghi": {"alg1": 1}, "jkl": {"alg1": 1}}, + remote_unknown_user: {"mno": {"alg1": 1}}, + }, self.requester, timeout=None, always_include_fallback_keys=False, ) ) - self.assertEqual( - res2, - { - "failures": {}, - "one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}}, - }, - ) + + if msc4072: + # empty result for each unknown device + self.assertEqual( + res2, + { + "failures": {}, + "one_time_keys": { + local_known_user: {device_id: {"alg1:k1": "key1"}, "abc": {}}, + local_unknown_user: {"def": {}}, + remote_known_user: {"ghi": {"alg1": "keykey"}, "jkl": {}}, + remote_unknown_user: {"mno": {}}, + }, + }, + ) + else: + # only known devices + self.assertEqual( + res2, + { + "failures": {}, + "one_time_keys": { + local_known_user: {device_id: {"alg1:k1": "key1"}}, + remote_known_user: {"ghi": {"alg1": "keykey"}}, + }, + }, + ) def test_fallback_key(self) -> None: local_user = "@boris:" + self.hs.hostname