Implement MSC4072: return result for all /keys/claims
This commit is contained in:
parent
4c586567f6
commit
6dbad83998
|
@ -15,7 +15,6 @@
|
||||||
import enum
|
import enum
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import attr
|
|
||||||
import attr.validators
|
import attr.validators
|
||||||
|
|
||||||
from synapse.api.errors import LimitExceededError
|
from synapse.api.errors import LimitExceededError
|
||||||
|
@ -419,3 +418,9 @@ class ExperimentalConfig(Config):
|
||||||
self.msc4028_push_encrypted_events = experimental.get(
|
self.msc4028_push_encrypted_events = experimental.get(
|
||||||
"msc4028_push_encrypted_events", False
|
"msc4028_push_encrypted_events", False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# MSC4072: Return an empty dict from /keys/claim for unknown devices or those
|
||||||
|
# with exhausted OTKs
|
||||||
|
self.msc4072_empty_dict_for_exhausted_devices = experimental.get(
|
||||||
|
"msc4072_empty_dict_for_exhausted_devices", False
|
||||||
|
)
|
||||||
|
|
|
@ -592,6 +592,12 @@ class E2eKeysHandler:
|
||||||
for user_id, device_id, algorithm, count in local_query
|
for user_id, device_id, algorithm, count in local_query
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# prepopulate the response to make sure that all queried users/devices are
|
||||||
|
# included, even if the user/device is unknown or has run out of OTKs
|
||||||
|
if self.config.experimental.msc4072_empty_dict_for_exhausted_devices:
|
||||||
|
for user_id, device_id, _, _ in local_query:
|
||||||
|
result_dict.setdefault(user_id, {}).setdefault(device_id, {})
|
||||||
|
|
||||||
otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
|
otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
|
||||||
update_result_dict(otk_results)
|
update_result_dict(otk_results)
|
||||||
|
|
||||||
|
@ -669,6 +675,25 @@ class E2eKeysHandler:
|
||||||
timeout: Optional[int],
|
timeout: Optional[int],
|
||||||
always_include_fallback_keys: bool,
|
always_include_fallback_keys: bool,
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
|
"""
|
||||||
|
Handle a /keys/claim request.
|
||||||
|
|
||||||
|
Handles requests for local users with a db lookup, and makes federation
|
||||||
|
requests for remote users.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: map from user ID, to map from device ID, to map from algorithm name
|
||||||
|
to number of keys needed
|
||||||
|
(``{user_id: {device_id: {algorithm: number_of keys}}}``)
|
||||||
|
|
||||||
|
user: The user id of the requesting user
|
||||||
|
|
||||||
|
timeout: number of milliseconds to wait for the response from remote servers.
|
||||||
|
``config.federation.client_timeout_ms`` by default.
|
||||||
|
|
||||||
|
always_include_fallback_keys: True to always include fallback keys, even
|
||||||
|
for devices which still have one-time keys.
|
||||||
|
"""
|
||||||
local_query: List[Tuple[str, str, str, int]] = []
|
local_query: List[Tuple[str, str, str, int]] = []
|
||||||
remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}
|
remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}
|
||||||
|
|
||||||
|
@ -707,9 +732,18 @@ class E2eKeysHandler:
|
||||||
remote_result = await self.federation.claim_client_keys(
|
remote_result = await self.federation.claim_client_keys(
|
||||||
user, destination, device_keys, timeout=timeout
|
user, destination, device_keys, timeout=timeout
|
||||||
)
|
)
|
||||||
for user_id, keys in remote_result["one_time_keys"].items():
|
try:
|
||||||
if user_id in device_keys:
|
destination_result = filter_remote_claimed_keys(
|
||||||
json_result[user_id] = keys
|
device_keys,
|
||||||
|
remote_result,
|
||||||
|
self.config.experimental.msc4072_empty_dict_for_exhausted_devices,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Error parsing /keys/claim response from server {destination}",
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
failure = _exception_to_failure(e)
|
failure = _exception_to_failure(e)
|
||||||
|
@ -717,6 +751,11 @@ class E2eKeysHandler:
|
||||||
set_tag("error", True)
|
set_tag("error", True)
|
||||||
set_tag("reason", str(failure))
|
set_tag("reason", str(failure))
|
||||||
|
|
||||||
|
else:
|
||||||
|
# only populate json_result once we know there will not be an entry in
|
||||||
|
# failures for this destination.
|
||||||
|
json_result.update(destination_result)
|
||||||
|
|
||||||
await make_deferred_yieldable(
|
await make_deferred_yieldable(
|
||||||
defer.gatherResults(
|
defer.gatherResults(
|
||||||
[
|
[
|
||||||
|
@ -1632,3 +1671,51 @@ class SigningKeyEduUpdater:
|
||||||
device_ids = device_ids + new_device_ids
|
device_ids = device_ids + new_device_ids
|
||||||
|
|
||||||
await self._device_handler.notify_device_update(user_id, device_ids)
|
await self._device_handler.notify_device_update(user_id, device_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_remote_claimed_keys(
|
||||||
|
destination_query: Dict[str, Dict[str, Dict[str, int]]],
|
||||||
|
remote_response: JsonDict,
|
||||||
|
msc4072_empty_dict_for_exhausted_devices: bool,
|
||||||
|
) -> JsonDict:
|
||||||
|
"""
|
||||||
|
Process the response from a federation /keys/claim request
|
||||||
|
|
||||||
|
Checks that there are no redundant entries, and that all the entries that
|
||||||
|
should be there are present.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination_query: user->device->key map that was sent in the request to
|
||||||
|
this server
|
||||||
|
remote_response: response from the remote server
|
||||||
|
msc4072_empty_dict_for_exhausted_devices: true to include an entry in the
|
||||||
|
result for every queried device
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
user->device->key map to be merged into the results
|
||||||
|
"""
|
||||||
|
remote_otks = remote_response["one_time_keys"]
|
||||||
|
|
||||||
|
destination_result: JsonDict = {}
|
||||||
|
|
||||||
|
if msc4072_empty_dict_for_exhausted_devices:
|
||||||
|
# We need to make sure there is an entry in destination_result for
|
||||||
|
# every queried (user, device) even if the remote server did not
|
||||||
|
# populate it; so we iterate the query and populate
|
||||||
|
# destination_result based on the federation result.
|
||||||
|
for user_id, user_query in destination_query.items():
|
||||||
|
remote_user_result = remote_otks.get(user_id, {})
|
||||||
|
destination_user_result = destination_result[user_id] = {}
|
||||||
|
for device_id in user_query.keys():
|
||||||
|
destination_user_result[device_id] = remote_user_result.get(
|
||||||
|
device_id, {}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# We need to make sure that remote servers do not poison the
|
||||||
|
# result with data for users which do not belong to it, so we only
|
||||||
|
# copy data for users that were queried.
|
||||||
|
for user_id, keys in remote_otks.items():
|
||||||
|
if user_id in destination_query:
|
||||||
|
destination_result[user_id] = keys
|
||||||
|
|
||||||
|
return destination_result
|
||||||
|
|
|
@ -144,33 +144,79 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_claim_one_time_key(self) -> None:
|
@parameterized.expand([(True,), (False,)])
|
||||||
local_user = "@boris:" + self.hs.hostname
|
def test_claim_one_time_key(self, msc4072: bool) -> None:
|
||||||
device_id = "xyz"
|
self.hs.config.experimental.msc4072_empty_dict_for_exhausted_devices = msc4072
|
||||||
keys = {"alg1:k1": "key1"}
|
|
||||||
|
|
||||||
|
local_known_user = "@boris:" + self.hs.hostname
|
||||||
|
device_id = "xyz"
|
||||||
|
local_unknown_user = "@charlie:" + self.hs.hostname
|
||||||
|
|
||||||
|
remote_known_user = "@dave:xyz"
|
||||||
|
remote_unknown_user = "@errol:xyz"
|
||||||
|
|
||||||
|
# upload a key for the local user
|
||||||
res = self.get_success(
|
res = self.get_success(
|
||||||
self.handler.upload_keys_for_user(
|
self.handler.upload_keys_for_user(
|
||||||
local_user, device_id, {"one_time_keys": keys}
|
local_known_user, device_id, {"one_time_keys": {"alg1:k1": "key1"}}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertDictEqual(
|
self.assertDictEqual(
|
||||||
res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}}
|
res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# mock out the response for remote users. We pretend that the remote server
|
||||||
|
# hasn't heard of MSC4072 and returns an incomplete result. (Even once
|
||||||
|
# MSC4072 is stable, we still need to handle incomplete results.)
|
||||||
|
#
|
||||||
|
# we also include a spurious result to check it gets filtered out.
|
||||||
|
self.hs.get_federation_client().claim_client_keys = mock.AsyncMock( # type: ignore[method-assign]
|
||||||
|
return_value={
|
||||||
|
"one_time_keys": {
|
||||||
|
remote_known_user: {"ghi": {"alg1": "keykey"}},
|
||||||
|
"@other:xyz": {"zzz": {"alg1": "dodgykey"}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
res2 = self.get_success(
|
res2 = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{local_user: {device_id: {"alg1": 1}}},
|
{
|
||||||
|
local_known_user: {device_id: {"alg1": 1}, "abc": {"alg2": 1}},
|
||||||
|
local_unknown_user: {"def": {"alg1": 1}},
|
||||||
|
remote_known_user: {"ghi": {"alg1": 1}, "jkl": {"alg1": 1}},
|
||||||
|
remote_unknown_user: {"mno": {"alg1": 1}},
|
||||||
|
},
|
||||||
self.requester,
|
self.requester,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
always_include_fallback_keys=False,
|
always_include_fallback_keys=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if msc4072:
|
||||||
|
# empty result for each unknown device
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
res2,
|
res2,
|
||||||
{
|
{
|
||||||
"failures": {},
|
"failures": {},
|
||||||
"one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}},
|
"one_time_keys": {
|
||||||
|
local_known_user: {device_id: {"alg1:k1": "key1"}, "abc": {}},
|
||||||
|
local_unknown_user: {"def": {}},
|
||||||
|
remote_known_user: {"ghi": {"alg1": "keykey"}, "jkl": {}},
|
||||||
|
remote_unknown_user: {"mno": {}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# only known devices
|
||||||
|
self.assertEqual(
|
||||||
|
res2,
|
||||||
|
{
|
||||||
|
"failures": {},
|
||||||
|
"one_time_keys": {
|
||||||
|
local_known_user: {device_id: {"alg1:k1": "key1"}},
|
||||||
|
remote_known_user: {"ghi": {"alg1": "keykey"}},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue