Claim local one-time-keys in bulk (#16565)
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
parent
91aa52c911
commit
de981ae567
|
@ -0,0 +1 @@
|
|||
Improve the performance of claiming encryption keys.
|
|
@ -753,6 +753,16 @@ class E2eKeysHandler:
|
|||
async def upload_keys_for_user(
|
||||
self, user_id: str, device_id: str, keys: JsonDict
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Args:
|
||||
user_id: user whose keys are being uploaded.
|
||||
device_id: device whose keys are being uploaded.
|
||||
keys: the body of a /keys/upload request.
|
||||
|
||||
Returns a dictionary with one field:
|
||||
"one_time_keys": A mapping from algorithm to number of keys for that
|
||||
algorithm, including those previously persisted.
|
||||
"""
|
||||
# This can only be called from the main process.
|
||||
assert isinstance(self.device_handler, DeviceHandler)
|
||||
|
||||
|
|
|
@ -1111,7 +1111,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
...
|
||||
|
||||
async def claim_e2e_one_time_keys(
|
||||
self, query_list: Iterable[Tuple[str, str, str, int]]
|
||||
self, query_list: Collection[Tuple[str, str, str, int]]
|
||||
) -> Tuple[
|
||||
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
|
||||
]:
|
||||
|
@ -1121,131 +1121,63 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
query_list: An iterable of tuples of (user ID, device ID, algorithm).
|
||||
|
||||
Returns:
|
||||
A tuple pf:
|
||||
A tuple (results, missing) of:
|
||||
A map of user ID -> a map device ID -> a map of key ID -> JSON.
|
||||
|
||||
A copy of the input which has not been fulfilled.
|
||||
A copy of the input which has not been fulfilled. The returned counts
|
||||
may be less than the input counts. In this case, the returned counts
|
||||
are the number of claims that were not fulfilled.
|
||||
"""
|
||||
|
||||
@trace
|
||||
def _claim_e2e_one_time_key_simple(
|
||||
txn: LoggingTransaction,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
algorithm: str,
|
||||
count: int,
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""Claim OTK for device for DBs that don't support RETURNING.
|
||||
|
||||
Returns:
|
||||
A tuple of key name (algorithm + key ID) and key JSON, if an
|
||||
OTK was found.
|
||||
"""
|
||||
|
||||
sql = """
|
||||
SELECT key_id, key_json FROM e2e_one_time_keys_json
|
||||
WHERE user_id = ? AND device_id = ? AND algorithm = ?
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (user_id, device_id, algorithm, count))
|
||||
otk_rows = list(txn)
|
||||
if not otk_rows:
|
||||
return []
|
||||
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="e2e_one_time_keys_json",
|
||||
column="key_id",
|
||||
values=[otk_row[0] for otk_row in otk_rows],
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"algorithm": algorithm,
|
||||
},
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||
)
|
||||
|
||||
return [
|
||||
(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
|
||||
]
|
||||
|
||||
@trace
|
||||
def _claim_e2e_one_time_key_returning(
|
||||
txn: LoggingTransaction,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
algorithm: str,
|
||||
count: int,
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""Claim OTK for device for DBs that support RETURNING.
|
||||
|
||||
Returns:
|
||||
A tuple of key name (algorithm + key ID) and key JSON, if an
|
||||
OTK was found.
|
||||
"""
|
||||
|
||||
# We can use RETURNING to do the fetch and DELETE in once step.
|
||||
sql = """
|
||||
DELETE FROM e2e_one_time_keys_json
|
||||
WHERE user_id = ? AND device_id = ? AND algorithm = ?
|
||||
AND key_id IN (
|
||||
SELECT key_id FROM e2e_one_time_keys_json
|
||||
WHERE user_id = ? AND device_id = ? AND algorithm = ?
|
||||
LIMIT ?
|
||||
)
|
||||
RETURNING key_id, key_json
|
||||
"""
|
||||
|
||||
txn.execute(
|
||||
sql,
|
||||
(user_id, device_id, algorithm, user_id, device_id, algorithm, count),
|
||||
)
|
||||
otk_rows = list(txn)
|
||||
if not otk_rows:
|
||||
return []
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||
)
|
||||
|
||||
return [
|
||||
(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
|
||||
]
|
||||
|
||||
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||
missing: List[Tuple[str, str, str, int]] = []
|
||||
for user_id, device_id, algorithm, count in query_list:
|
||||
if self.database_engine.supports_returning:
|
||||
# If we support RETURNING clause we can use a single query that
|
||||
# allows us to use autocommit mode.
|
||||
_claim_e2e_one_time_key = _claim_e2e_one_time_key_returning
|
||||
db_autocommit = True
|
||||
else:
|
||||
_claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
|
||||
db_autocommit = False
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
# If we can use execute_values we can use a single batch query
|
||||
# in autocommit mode.
|
||||
unfulfilled_claim_counts: Dict[Tuple[str, str, str], int] = {}
|
||||
for user_id, device_id, algorithm, count in query_list:
|
||||
unfulfilled_claim_counts[user_id, device_id, algorithm] = count
|
||||
|
||||
claim_rows = await self.db_pool.runInteraction(
|
||||
bulk_claims = await self.db_pool.runInteraction(
|
||||
"claim_e2e_one_time_keys",
|
||||
_claim_e2e_one_time_key,
|
||||
user_id,
|
||||
device_id,
|
||||
algorithm,
|
||||
count,
|
||||
db_autocommit=db_autocommit,
|
||||
self._claim_e2e_one_time_keys_bulk,
|
||||
query_list,
|
||||
db_autocommit=True,
|
||||
)
|
||||
if claim_rows:
|
||||
|
||||
for user_id, device_id, algorithm, key_id, key_json in bulk_claims:
|
||||
device_results = results.setdefault(user_id, {}).setdefault(
|
||||
device_id, {}
|
||||
)
|
||||
for claim_row in claim_rows:
|
||||
device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
|
||||
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
|
||||
unfulfilled_claim_counts[(user_id, device_id, algorithm)] -= 1
|
||||
|
||||
# Did we get enough OTKs?
|
||||
count -= len(claim_rows)
|
||||
if count:
|
||||
missing.append((user_id, device_id, algorithm, count))
|
||||
missing = [
|
||||
(user, device, alg, count)
|
||||
for (user, device, alg), count in unfulfilled_claim_counts.items()
|
||||
if count > 0
|
||||
]
|
||||
else:
|
||||
for user_id, device_id, algorithm, count in query_list:
|
||||
claim_rows = await self.db_pool.runInteraction(
|
||||
"claim_e2e_one_time_keys",
|
||||
self._claim_e2e_one_time_key_simple,
|
||||
user_id,
|
||||
device_id,
|
||||
algorithm,
|
||||
count,
|
||||
db_autocommit=False,
|
||||
)
|
||||
if claim_rows:
|
||||
device_results = results.setdefault(user_id, {}).setdefault(
|
||||
device_id, {}
|
||||
)
|
||||
for claim_row in claim_rows:
|
||||
device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
|
||||
# Did we get enough OTKs?
|
||||
count -= len(claim_rows)
|
||||
if count:
|
||||
missing.append((user_id, device_id, algorithm, count))
|
||||
|
||||
return results, missing
|
||||
|
||||
|
@ -1362,6 +1294,99 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
|
||||
return results
|
||||
|
||||
@trace
|
||||
def _claim_e2e_one_time_key_simple(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
algorithm: str,
|
||||
count: int,
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""Claim OTK for device for DBs that don't support RETURNING.
|
||||
|
||||
Returns:
|
||||
A tuple of key name (algorithm + key ID) and key JSON, if an
|
||||
OTK was found.
|
||||
"""
|
||||
|
||||
sql = """
|
||||
SELECT key_id, key_json FROM e2e_one_time_keys_json
|
||||
WHERE user_id = ? AND device_id = ? AND algorithm = ?
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (user_id, device_id, algorithm, count))
|
||||
otk_rows = list(txn)
|
||||
if not otk_rows:
|
||||
return []
|
||||
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="e2e_one_time_keys_json",
|
||||
column="key_id",
|
||||
values=[otk_row[0] for otk_row in otk_rows],
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"algorithm": algorithm,
|
||||
},
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||
)
|
||||
|
||||
return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows]
|
||||
|
||||
@trace
|
||||
def _claim_e2e_one_time_keys_bulk(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
query_list: Iterable[Tuple[str, str, str, int]],
|
||||
) -> List[Tuple[str, str, str, str, str]]:
|
||||
"""Bulk claim OTKs, for DBs that support DELETE FROM... RETURNING.
|
||||
|
||||
Args:
|
||||
query_list: Collection of tuples (user_id, device_id, algorithm, count)
|
||||
as passed to claim_e2e_one_time_keys.
|
||||
|
||||
Returns:
|
||||
A list of tuples (user_id, device_id, algorithm, key_id, key_json)
|
||||
for each OTK claimed.
|
||||
"""
|
||||
sql = """
|
||||
WITH claims(user_id, device_id, algorithm, claim_count) AS (
|
||||
VALUES ?
|
||||
), ranked_keys AS (
|
||||
SELECT
|
||||
user_id, device_id, algorithm, key_id, claim_count,
|
||||
ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r
|
||||
FROM e2e_one_time_keys_json
|
||||
JOIN claims USING (user_id, device_id, algorithm)
|
||||
)
|
||||
DELETE FROM e2e_one_time_keys_json k
|
||||
WHERE (user_id, device_id, algorithm, key_id) IN (
|
||||
SELECT user_id, device_id, algorithm, key_id
|
||||
FROM ranked_keys
|
||||
WHERE r <= claim_count
|
||||
)
|
||||
RETURNING user_id, device_id, algorithm, key_id, key_json;
|
||||
"""
|
||||
otk_rows = cast(
|
||||
List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
|
||||
)
|
||||
|
||||
seen_user_device: Set[Tuple[str, str]] = set()
|
||||
for user_id, device_id, _, _, _ in otk_rows:
|
||||
if (user_id, device_id) in seen_user_device:
|
||||
continue
|
||||
seen_user_device.add((user_id, device_id))
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||
)
|
||||
|
||||
return otk_rows
|
||||
|
||||
|
||||
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
def __init__(
|
||||
|
|
|
@ -174,6 +174,164 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
},
|
||||
)
|
||||
|
||||
def test_claim_one_time_key_bulk(self) -> None:
|
||||
"""Like test_claim_one_time_key but claims multiple keys in one handler call."""
|
||||
# Apologies to the reader. This test is a little too verbose. It is particularly
|
||||
# tricky to make assertions neatly with all these nested dictionaries in play.
|
||||
|
||||
# Three users with two devices each. Each device uses two algorithms.
|
||||
# Each algorithm is invoked with two keys.
|
||||
alice = f"@alice:{self.hs.hostname}"
|
||||
brian = f"@brian:{self.hs.hostname}"
|
||||
chris = f"@chris:{self.hs.hostname}"
|
||||
one_time_keys = {
|
||||
alice: {
|
||||
"alice_dev_1": {
|
||||
"alg1:k1": {"dummy_id": 1},
|
||||
"alg1:k2": {"dummy_id": 2},
|
||||
"alg2:k3": {"dummy_id": 3},
|
||||
"alg2:k4": {"dummy_id": 4},
|
||||
},
|
||||
"alice_dev_2": {
|
||||
"alg1:k5": {"dummy_id": 5},
|
||||
"alg1:k6": {"dummy_id": 6},
|
||||
"alg2:k7": {"dummy_id": 7},
|
||||
"alg2:k8": {"dummy_id": 8},
|
||||
},
|
||||
},
|
||||
brian: {
|
||||
"brian_dev_1": {
|
||||
"alg1:k9": {"dummy_id": 9},
|
||||
"alg1:k10": {"dummy_id": 10},
|
||||
"alg2:k11": {"dummy_id": 11},
|
||||
"alg2:k12": {"dummy_id": 12},
|
||||
},
|
||||
"brian_dev_2": {
|
||||
"alg1:k13": {"dummy_id": 13},
|
||||
"alg1:k14": {"dummy_id": 14},
|
||||
"alg2:k15": {"dummy_id": 15},
|
||||
"alg2:k16": {"dummy_id": 16},
|
||||
},
|
||||
},
|
||||
chris: {
|
||||
"chris_dev_1": {
|
||||
"alg1:k17": {"dummy_id": 17},
|
||||
"alg1:k18": {"dummy_id": 18},
|
||||
"alg2:k19": {"dummy_id": 19},
|
||||
"alg2:k20": {"dummy_id": 20},
|
||||
},
|
||||
"chris_dev_2": {
|
||||
"alg1:k21": {"dummy_id": 21},
|
||||
"alg1:k22": {"dummy_id": 22},
|
||||
"alg2:k23": {"dummy_id": 23},
|
||||
"alg2:k24": {"dummy_id": 24},
|
||||
},
|
||||
},
|
||||
}
|
||||
for user_id, devices in one_time_keys.items():
|
||||
for device_id, keys_dict in devices.items():
|
||||
counts = self.get_success(
|
||||
self.handler.upload_keys_for_user(
|
||||
user_id,
|
||||
device_id,
|
||||
{"one_time_keys": keys_dict},
|
||||
)
|
||||
)
|
||||
# The upload should report 2 keys per algorithm.
|
||||
expected_counts = {
|
||||
"one_time_key_counts": {
|
||||
# See count_e2e_one_time_keys for why this is hardcoded.
|
||||
"signed_curve25519": 0,
|
||||
"alg1": 2,
|
||||
"alg2": 2,
|
||||
},
|
||||
}
|
||||
self.assertEqual(counts, expected_counts)
|
||||
|
||||
# Claim a variety of keys.
|
||||
# Raw format, easier to make test assertions about.
|
||||
claims_to_make = {
|
||||
(alice, "alice_dev_1", "alg1"): 1,
|
||||
(alice, "alice_dev_1", "alg2"): 2,
|
||||
(alice, "alice_dev_2", "alg2"): 1,
|
||||
(brian, "brian_dev_1", "alg1"): 2,
|
||||
(brian, "brian_dev_2", "alg2"): 9001,
|
||||
(chris, "chris_dev_2", "alg2"): 1,
|
||||
}
|
||||
# Convert to the format the handler wants.
|
||||
query: Dict[str, Dict[str, Dict[str, int]]] = {}
|
||||
for (user_id, device_id, algorithm), count in claims_to_make.items():
|
||||
query.setdefault(user_id, {}).setdefault(device_id, {})[algorithm] = count
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
query,
|
||||
self.requester,
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
)
|
||||
|
||||
# No failures, please!
|
||||
self.assertEqual(claim_res["failures"], {})
|
||||
|
||||
# Check that we get exactly the (user, device, algorithm)s we asked for.
|
||||
got_otks = claim_res["one_time_keys"]
|
||||
claimed_user_device_algorithms = {
|
||||
(user_id, device_id, alg_key_id.split(":")[0])
|
||||
for user_id, devices in got_otks.items()
|
||||
for device_id, key_dict in devices.items()
|
||||
for alg_key_id in key_dict
|
||||
}
|
||||
self.assertEqual(claimed_user_device_algorithms, set(claims_to_make))
|
||||
|
||||
# Now check the keys we got are what we expected.
|
||||
def assertExactlyOneOtk(
|
||||
user_id: str, device_id: str, *alg_key_pairs: str
|
||||
) -> None:
|
||||
key_dict = got_otks[user_id][device_id]
|
||||
found = 0
|
||||
for alg_key in alg_key_pairs:
|
||||
if alg_key in key_dict:
|
||||
expected_key_json = one_time_keys[user_id][device_id][alg_key]
|
||||
self.assertEqual(key_dict[alg_key], expected_key_json)
|
||||
found += 1
|
||||
self.assertEqual(found, 1)
|
||||
|
||||
def assertAllOtks(user_id: str, device_id: str, *alg_key_pairs: str) -> None:
|
||||
key_dict = got_otks[user_id][device_id]
|
||||
for alg_key in alg_key_pairs:
|
||||
expected_key_json = one_time_keys[user_id][device_id][alg_key]
|
||||
self.assertEqual(key_dict[alg_key], expected_key_json)
|
||||
|
||||
# Expect a single arbitrary key to be returned.
|
||||
assertExactlyOneOtk(alice, "alice_dev_1", "alg1:k1", "alg1:k2")
|
||||
assertExactlyOneOtk(alice, "alice_dev_2", "alg2:k7", "alg2:k8")
|
||||
assertExactlyOneOtk(chris, "chris_dev_2", "alg2:k23", "alg2:k24")
|
||||
|
||||
assertAllOtks(alice, "alice_dev_1", "alg2:k3", "alg2:k4")
|
||||
assertAllOtks(brian, "brian_dev_1", "alg1:k9", "alg1:k10")
|
||||
assertAllOtks(brian, "brian_dev_2", "alg2:k15", "alg2:k16")
|
||||
|
||||
# Now check the unused key counts.
|
||||
for user_id, devices in one_time_keys.items():
|
||||
for device_id in devices:
|
||||
counts_by_alg = self.get_success(
|
||||
self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||
)
|
||||
# Somewhat fiddley to compute the expected count dict.
|
||||
expected_counts_by_alg = {
|
||||
"signed_curve25519": 0,
|
||||
}
|
||||
for alg in ["alg1", "alg2"]:
|
||||
claim_count = claims_to_make.get((user_id, device_id, alg), 0)
|
||||
remaining_count = max(0, 2 - claim_count)
|
||||
if remaining_count > 0:
|
||||
expected_counts_by_alg[alg] = remaining_count
|
||||
|
||||
self.assertEqual(
|
||||
counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}"
|
||||
)
|
||||
|
||||
def test_fallback_key(self) -> None:
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
device_id = "xyz"
|
||||
|
|
Loading…
Reference in New Issue