Support MSC3814: Dehydrated Devices Part 2 (#16010)

This commit is contained in:
Shay 2023-08-08 12:04:46 -07:00 committed by GitHub
parent 4581809846
commit 0328b56468
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 258 additions and 101 deletions

1
changelog.d/16010.misc Normal file
View File

@ -0,0 +1 @@
Update dehydrated devices implementation.

View File

@ -385,6 +385,7 @@ class DeviceHandler(DeviceWorkerHandler):
self.federation_sender = hs.get_federation_sender() self.federation_sender = hs.get_federation_sender()
self._account_data_handler = hs.get_account_data_handler() self._account_data_handler = hs.get_account_data_handler()
self._storage_controllers = hs.get_storage_controllers() self._storage_controllers = hs.get_storage_controllers()
self.db_pool = hs.get_datastores().main.db_pool
self.device_list_updater = DeviceListUpdater(hs, self) self.device_list_updater = DeviceListUpdater(hs, self)
@ -656,15 +657,17 @@ class DeviceHandler(DeviceWorkerHandler):
device_id: Optional[str], device_id: Optional[str],
device_data: JsonDict, device_data: JsonDict,
initial_device_display_name: Optional[str] = None, initial_device_display_name: Optional[str] = None,
keys_for_device: Optional[JsonDict] = None,
) -> str: ) -> str:
"""Store a dehydrated device for a user. If the user had a previous """Store a dehydrated device for a user, optionally storing the keys associated with
dehydrated device, it is removed. it as well. If the user had a previous dehydrated device, it is removed.
Args: Args:
user_id: the user that we are storing the device for user_id: the user that we are storing the device for
device_id: device id supplied by client device_id: device id supplied by client
device_data: the dehydrated device information device_data: the dehydrated device information
initial_device_display_name: The display name to use for the device initial_device_display_name: The display name to use for the device
keys_for_device: keys for the dehydrated device
Returns: Returns:
device id of the dehydrated device device id of the dehydrated device
""" """
@ -673,11 +676,16 @@ class DeviceHandler(DeviceWorkerHandler):
device_id, device_id,
initial_device_display_name, initial_device_display_name,
) )
time_now = self.clock.time_msec()
old_device_id = await self.store.store_dehydrated_device( old_device_id = await self.store.store_dehydrated_device(
user_id, device_id, device_data user_id, device_id, device_data, time_now, keys_for_device
) )
if old_device_id is not None: if old_device_id is not None:
await self.delete_devices(user_id, [old_device_id]) await self.delete_devices(user_id, [old_device_id])
return device_id return device_id
async def rehydrate_device( async def rehydrate_device(

View File

@ -367,19 +367,6 @@ class DeviceMessageHandler:
errcode=Codes.INVALID_PARAM, errcode=Codes.INVALID_PARAM,
) )
# if we have a since token, delete any to-device messages before that token
# (since we now know that the device has received them)
deleted = await self.store.delete_messages_for_device(
user_id, device_id, since_stream_id
)
logger.debug(
"Deleted %d to-device messages up to %d for user_id %s device_id %s",
deleted,
since_stream_id,
user_id,
device_id,
)
to_token = self.event_sources.get_current_token().to_device_key to_token = self.event_sources.get_current_token().to_device_key
messages, stream_id = await self.store.get_messages_for_device( messages, stream_id = await self.store.get_messages_for_device(

View File

@ -29,7 +29,6 @@ from synapse.http.servlet import (
parse_integer, parse_integer,
) )
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
from synapse.rest.client._base import client_patterns, interactive_auth_handler from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.rest.client.models import AuthenticationData from synapse.rest.client.models import AuthenticationData
from synapse.rest.models import RequestBodyModel from synapse.rest.models import RequestBodyModel
@ -480,13 +479,6 @@ class DehydratedDeviceV2Servlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = handler self.device_handler = handler
if hs.config.worker.worker_app is None:
# if main process
self.key_uploader = self.e2e_keys_handler.upload_keys_for_user
else:
# then a worker
self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs)
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
@ -549,18 +541,12 @@ class DehydratedDeviceV2Servlet(RestServlet):
"Device key(s) not found, these must be provided.", "Device key(s) not found, these must be provided.",
) )
# TODO: Those two operations, creating a device and storing the
# device's keys should be atomic.
device_id = await self.device_handler.store_dehydrated_device( device_id = await self.device_handler.store_dehydrated_device(
requester.user.to_string(), requester.user.to_string(),
submission.device_id, submission.device_id,
submission.device_data.dict(), submission.device_data.dict(),
submission.initial_device_display_name, submission.initial_device_display_name,
) device_info,
# TODO: Do we need to do something with the result here?
await self.key_uploader(
user_id=user_id, device_id=submission.device_id, keys=submission.dict()
) )
return 200, {"device_id": device_id} return 200, {"device_id": device_id}

View File

@ -28,6 +28,7 @@ from typing import (
cast, cast,
) )
from canonicaljson import encode_canonical_json
from typing_extensions import Literal from typing_extensions import Literal
from synapse.api.constants import EduTypes from synapse.api.constants import EduTypes
@ -1188,8 +1189,42 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
) )
def _store_dehydrated_device_txn( def _store_dehydrated_device_txn(
self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str self,
txn: LoggingTransaction,
user_id: str,
device_id: str,
device_data: str,
time: int,
keys: Optional[JsonDict] = None,
) -> Optional[str]: ) -> Optional[str]:
# TODO: make keys non-optional once support for msc2697 is dropped
if keys:
device_keys = keys.get("device_keys", None)
if device_keys:
# Type ignore - this function is defined on EndToEndKeyStore which we do
# have access to due to hs.get_datastore() "magic"
self._set_e2e_device_keys_txn( # type: ignore[attr-defined]
txn, user_id, device_id, time, device_keys
)
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
key_list = []
for key_id, key_obj in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append(
(
algorithm,
key_id,
encode_canonical_json(key_obj).decode("ascii"),
)
)
self._add_e2e_one_time_keys_txn(txn, user_id, device_id, time, key_list)
fallback_keys = keys.get("fallback_keys", None)
if fallback_keys:
self._set_e2e_fallback_keys_txn(txn, user_id, device_id, fallback_keys)
old_device_id = self.db_pool.simple_select_one_onecol_txn( old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="dehydrated_devices", table="dehydrated_devices",
@ -1203,10 +1238,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
values={"device_id": device_id, "device_data": device_data}, values={"device_id": device_id, "device_data": device_data},
) )
return old_device_id return old_device_id
async def store_dehydrated_device( async def store_dehydrated_device(
self, user_id: str, device_id: str, device_data: JsonDict self,
user_id: str,
device_id: str,
device_data: JsonDict,
time_now: int,
keys: Optional[dict] = None,
) -> Optional[str]: ) -> Optional[str]:
"""Store a dehydrated device for a user. """Store a dehydrated device for a user.
@ -1214,15 +1255,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
user_id: the user that we are storing the device for user_id: the user that we are storing the device for
device_id: the ID of the dehydrated device device_id: the ID of the dehydrated device
device_data: the dehydrated device information device_data: the dehydrated device information
time_now: current time at the request in milliseconds
keys: keys for the dehydrated device
Returns: Returns:
device id of the user's previous dehydrated device, if any device id of the user's previous dehydrated device, if any
""" """
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"store_dehydrated_device_txn", "store_dehydrated_device_txn",
self._store_dehydrated_device_txn, self._store_dehydrated_device_txn,
user_id, user_id,
device_id, device_id,
json_encoder.encode(device_data), json_encoder.encode(device_data),
time_now,
keys,
) )
async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool: async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:

View File

@ -522,36 +522,57 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
new_keys: keys to add - each a tuple of (algorithm, key_id, key json) new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
""" """
def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("new_keys", str(new_keys))
# We are protected from race between lookup and insertion due to
# a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only
# insert one set.
self.db_pool.simple_insert_many_txn(
txn,
table="e2e_one_time_keys_json",
keys=(
"user_id",
"device_id",
"algorithm",
"key_id",
"ts_added_ms",
"key_json",
),
values=[
(user_id, device_id, algorithm, key_id, time_now, json_bytes)
for algorithm, key_id, json_bytes in new_keys
],
)
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys "add_e2e_one_time_keys_insert",
self._add_e2e_one_time_keys_txn,
user_id,
device_id,
time_now,
new_keys,
)
def _add_e2e_one_time_keys_txn(
self,
txn: LoggingTransaction,
user_id: str,
device_id: str,
time_now: int,
new_keys: Iterable[Tuple[str, str, str]],
) -> None:
"""Insert some new one time keys for a device. Errors if any of the keys already exist.
Args:
user_id: id of user to get keys for
device_id: id of device to get keys for
time_now: insertion time to record (ms since epoch)
new_keys: keys to add - each a tuple of (algorithm, key_id, key json) - note
that the key JSON must be in canonical JSON form
"""
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("new_keys", str(new_keys))
# We are protected from race between lookup and insertion due to
# a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only
# insert one set.
self.db_pool.simple_insert_many_txn(
txn,
table="e2e_one_time_keys_json",
keys=(
"user_id",
"device_id",
"algorithm",
"key_id",
"ts_added_ms",
"key_json",
),
values=[
(user_id, device_id, algorithm, key_id, time_now, json_bytes)
for algorithm, key_id, json_bytes in new_keys
],
)
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
) )
@cached(max_entries=10000) @cached(max_entries=10000)
@ -723,6 +744,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
device_id: str, device_id: str,
fallback_keys: JsonDict, fallback_keys: JsonDict,
) -> None: ) -> None:
"""Set the user's e2e fallback keys.
Args:
user_id: the user whose keys are being set
device_id: the device whose keys are being set
fallback_keys: the keys to set. This is a map from key ID (which is
of the form "algorithm:id") to key data.
"""
# fallback_keys will usually only have one item in it, so using a for # fallback_keys will usually only have one item in it, so using a for
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
# FIXME: make sure that only one key per algorithm is uploaded # FIXME: make sure that only one key per algorithm is uploaded
@ -1304,43 +1333,70 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
) -> bool: ) -> bool:
"""Stores device keys for a device. Returns whether there was a change """Stores device keys for a device. Returns whether there was a change
or the keys were already in the database. or the keys were already in the database.
Args:
user_id: user_id of the user to store keys for
device_id: device_id of the device to store keys for
time_now: time at the request to store the keys
device_keys: the keys to store
""" """
def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("time_now", time_now)
set_tag("device_keys", str(device_keys))
old_key_json = self.db_pool.simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="key_json",
allow_none=True,
)
# In py3 we need old_key_json to match new_key_json type. The DB
# returns unicode while encode_canonical_json returns bytes.
new_key_json = encode_canonical_json(device_keys).decode("utf-8")
if old_key_json == new_key_json:
log_kv({"Message": "Device key already stored."})
return False
self.db_pool.simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
values={"ts_added_ms": time_now, "key_json": new_key_json},
)
log_kv({"message": "Device keys stored."})
return True
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn "set_e2e_device_keys",
self._set_e2e_device_keys_txn,
user_id,
device_id,
time_now,
device_keys,
) )
def _set_e2e_device_keys_txn(
self,
txn: LoggingTransaction,
user_id: str,
device_id: str,
time_now: int,
device_keys: JsonDict,
) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
Args:
user_id: user_id of the user to store keys for
device_id: device_id of the device to store keys for
time_now: time at the request to store the keys
device_keys: the keys to store
"""
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("time_now", time_now)
set_tag("device_keys", str(device_keys))
old_key_json = self.db_pool.simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="key_json",
allow_none=True,
)
# In py3 we need old_key_json to match new_key_json type. The DB
# returns unicode while encode_canonical_json returns bytes.
new_key_json = encode_canonical_json(device_keys).decode("utf-8")
if old_key_json == new_key_json:
log_kv({"Message": "Device key already stored."})
return False
self.db_pool.simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
values={"ts_added_ms": time_now, "key_json": new_key_json},
)
log_kv({"message": "Device keys stored."})
return True
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None: async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None: def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
log_kv( log_kv(

View File

@ -566,15 +566,16 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(res["events"]), 1) self.assertEqual(len(res["events"]), 1)
self.assertEqual(res["events"][0]["content"]["body"], "foo") self.assertEqual(res["events"][0]["content"]["body"], "foo")
# Fetch the message of the dehydrated device again, which should return nothing # Fetch the message of the dehydrated device again, which should return
# and delete the old messages # the same message as it has not been deleted
res = self.get_success( res = self.get_success(
self.message_handler.get_events_for_dehydrated_device( self.message_handler.get_events_for_dehydrated_device(
requester=requester, requester=requester,
device_id=stored_dehydrated_device_id, device_id=stored_dehydrated_device_id,
since_token=res["next_batch"], since_token=None,
limit=10, limit=10,
) )
) )
self.assertTrue(len(res["next_batch"]) > 1) self.assertTrue(len(res["next_batch"]) > 1)
self.assertEqual(len(res["events"]), 0) self.assertEqual(len(res["events"]), 1)
self.assertEqual(res["events"][0]["content"]["body"], "foo")

View File

@ -20,7 +20,7 @@ from synapse.api.errors import NotFoundError
from synapse.rest import admin, devices, room, sync from synapse.rest import admin, devices, room, sync
from synapse.rest.client import account, keys, login, register from synapse.rest.client import account, keys, login, register
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict, create_requester from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -282,6 +282,17 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
"<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"} "<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"}
}, },
}, },
"fallback_keys": {
"alg1:device1": "f4llb4ckk3y",
"signed_<algorithm>:<device_id>": {
"fallback": "true",
"key": "f4llb4ckk3y",
"signatures": {
"<user_id>": {"<algorithm>:<device_id>": "<key_base64>"}
},
},
},
"one_time_keys": {"alg1:k1": "0net1m3k3y"},
} }
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
@ -312,6 +323,55 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
} }
self.assertEqual(device_data, expected_device_data) self.assertEqual(device_data, expected_device_data)
# test that the keys are correctly uploaded
channel = self.make_request(
"POST",
"/_matrix/client/r0/keys/query",
{
"device_keys": {
user: ["device1"],
},
},
token,
)
self.assertEqual(channel.code, 200)
self.assertEqual(
channel.json_body["device_keys"][user][device_id]["keys"],
content["device_keys"]["keys"],
)
# first claim should return the onetime key we uploaded
res = self.get_success(
self.hs.get_e2e_keys_handler().claim_one_time_keys(
{user: {device_id: {"alg1": 1}}},
UserID.from_string(user),
timeout=None,
always_include_fallback_keys=False,
)
)
self.assertEqual(
res,
{
"failures": {},
"one_time_keys": {user: {device_id: {"alg1:k1": "0net1m3k3y"}}},
},
)
# second claim should return fallback key
res2 = self.get_success(
self.hs.get_e2e_keys_handler().claim_one_time_keys(
{user: {device_id: {"alg1": 1}}},
UserID.from_string(user),
timeout=None,
always_include_fallback_keys=False,
)
)
self.assertEqual(
res2,
{
"failures": {},
"one_time_keys": {user: {device_id: {"alg1:device1": "f4llb4ckk3y"}}},
},
)
# create another device for the user # create another device for the user
( (
new_device_id, new_device_id,
@ -348,10 +408,21 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
expected_content = {"body": "test_message"} expected_content = {"body": "test_message"}
self.assertEqual(channel.json_body["events"][0]["content"], expected_content) self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
# fetch messages again and make sure that the message was not deleted
channel = self.make_request(
"POST",
f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events",
content={},
access_token=token,
shorthand=False,
)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
next_batch_token = channel.json_body.get("next_batch") next_batch_token = channel.json_body.get("next_batch")
# fetch messages again and make sure that the message was deleted and we are returned an # make sure fetching messages with next batch token works - there are no unfetched
# empty array # messages so we should receive an empty array
content = {"next_batch": next_batch_token} content = {"next_batch": next_batch_token}
channel = self.make_request( channel = self.make_request(
"POST", "POST",