mirror of
				https://github.com/matrix-org/synapse.git
				synced 2025-10-31 20:28:16 +00:00 
			
		
		
		
	Add unstable /keys/claim endpoint which always returns fallback keys. (#15462)
It can be useful to always return the fallback key when attempting to claim keys. This adds an unstable endpoint for `/keys/claim` which always returns fallback keys in addition to one-time-keys. The fallback key(s) are not marked as "used" unless there are no corresponding OTKs. This is currently defined in MSC3983 (although likely to be split out to a separate MSC). The endpoint shape may change or be requested differently (i.e. a keyword parameter on the current endpoint), but the core logic should be reasonable.
This commit is contained in:
		
							parent
							
								
									b39b02c26e
								
							
						
					
					
						commit
						8e9739449d
					
				
							
								
								
									
										1
									
								
								changelog.d/15462.misc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								changelog.d/15462.misc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1 @@ | |||||||
|  | Update support for [MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983) to allow always returning fallback-keys in a `/keys/claim` request. | ||||||
| @ -1005,7 +1005,7 @@ class FederationServer(FederationBase): | |||||||
| 
 | 
 | ||||||
|     @trace |     @trace | ||||||
|     async def on_claim_client_keys( |     async def on_claim_client_keys( | ||||||
|         self, origin: str, content: JsonDict |         self, origin: str, content: JsonDict, always_include_fallback_keys: bool | ||||||
|     ) -> Dict[str, Any]: |     ) -> Dict[str, Any]: | ||||||
|         query = [] |         query = [] | ||||||
|         for user_id, device_keys in content.get("one_time_keys", {}).items(): |         for user_id, device_keys in content.get("one_time_keys", {}).items(): | ||||||
| @ -1013,7 +1013,9 @@ class FederationServer(FederationBase): | |||||||
|                 query.append((user_id, device_id, algorithm)) |                 query.append((user_id, device_id, algorithm)) | ||||||
| 
 | 
 | ||||||
|         log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) |         log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) | ||||||
|         results = await self._e2e_keys_handler.claim_local_one_time_keys(query) |         results = await self._e2e_keys_handler.claim_local_one_time_keys( | ||||||
|  |             query, always_include_fallback_keys=always_include_fallback_keys | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} |         json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} | ||||||
|         for result in results: |         for result in results: | ||||||
|  | |||||||
| @ -25,6 +25,7 @@ from synapse.federation.transport.server._base import ( | |||||||
| from synapse.federation.transport.server.federation import ( | from synapse.federation.transport.server.federation import ( | ||||||
|     FEDERATION_SERVLET_CLASSES, |     FEDERATION_SERVLET_CLASSES, | ||||||
|     FederationAccountStatusServlet, |     FederationAccountStatusServlet, | ||||||
|  |     FederationUnstableClientKeysClaimServlet, | ||||||
| ) | ) | ||||||
| from synapse.http.server import HttpServer, JsonResource | from synapse.http.server import HttpServer, JsonResource | ||||||
| from synapse.http.servlet import ( | from synapse.http.servlet import ( | ||||||
| @ -298,6 +299,11 @@ def register_servlets( | |||||||
|                 and not hs.config.experimental.msc3720_enabled |                 and not hs.config.experimental.msc3720_enabled | ||||||
|             ): |             ): | ||||||
|                 continue |                 continue | ||||||
|  |             if ( | ||||||
|  |                 servletclass == FederationUnstableClientKeysClaimServlet | ||||||
|  |                 and not hs.config.experimental.msc3983_appservice_otk_claims | ||||||
|  |             ): | ||||||
|  |                 continue | ||||||
| 
 | 
 | ||||||
|             servletclass( |             servletclass( | ||||||
|                 hs=hs, |                 hs=hs, | ||||||
|  | |||||||
| @ -577,7 +577,28 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet): | |||||||
|     async def on_POST( |     async def on_POST( | ||||||
|         self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] |         self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] | ||||||
|     ) -> Tuple[int, JsonDict]: |     ) -> Tuple[int, JsonDict]: | ||||||
|         response = await self.handler.on_claim_client_keys(origin, content) |         response = await self.handler.on_claim_client_keys( | ||||||
|  |             origin, content, always_include_fallback_keys=False | ||||||
|  |         ) | ||||||
|  |         return 200, response | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet): | ||||||
|  |     """ | ||||||
|  |     Identical to the stable endpoint (FederationClientKeysClaimServlet) except it | ||||||
|  |     always includes fallback keys in the response. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     PREFIX = FEDERATION_UNSTABLE_PREFIX | ||||||
|  |     PATH = "/user/keys/claim" | ||||||
|  |     CATEGORY = "Federation requests" | ||||||
|  | 
 | ||||||
|  |     async def on_POST( | ||||||
|  |         self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] | ||||||
|  |     ) -> Tuple[int, JsonDict]: | ||||||
|  |         response = await self.handler.on_claim_client_keys( | ||||||
|  |             origin, content, always_include_fallback_keys=True | ||||||
|  |         ) | ||||||
|         return 200, response |         return 200, response | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -842,9 +842,7 @@ class ApplicationServicesHandler: | |||||||
| 
 | 
 | ||||||
|     async def claim_e2e_one_time_keys( |     async def claim_e2e_one_time_keys( | ||||||
|         self, query: Iterable[Tuple[str, str, str]] |         self, query: Iterable[Tuple[str, str, str]] | ||||||
|     ) -> Tuple[ |     ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]: | ||||||
|         Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]] |  | ||||||
|     ]: |  | ||||||
|         """Claim one time keys from application services. |         """Claim one time keys from application services. | ||||||
| 
 | 
 | ||||||
|         Users which are exclusively owned by an application service are sent a |         Users which are exclusively owned by an application service are sent a | ||||||
| @ -856,7 +854,7 @@ class ApplicationServicesHandler: | |||||||
| 
 | 
 | ||||||
|         Returns: |         Returns: | ||||||
|             A tuple of: |             A tuple of: | ||||||
|                 An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes. |                 A map of user ID -> a map device ID -> a map of key ID -> JSON. | ||||||
| 
 | 
 | ||||||
|                 A copy of the input which has not been fulfilled (either because |                 A copy of the input which has not been fulfilled (either because | ||||||
|                 they are not appservice users or the appservice does not support |                 they are not appservice users or the appservice does not support | ||||||
| @ -897,12 +895,11 @@ class ApplicationServicesHandler: | |||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         # Patch together the results -- they are all independent (since they |         # Patch together the results -- they are all independent (since they | ||||||
|         # require exclusive control over the users). They get returned as a list |         # require exclusive control over the users, which is the outermost key). | ||||||
|         # and the caller combines them. |         claimed_keys: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} | ||||||
|         claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = [] |  | ||||||
|         for success, result in results: |         for success, result in results: | ||||||
|             if success: |             if success: | ||||||
|                 claimed_keys.append(result[0]) |                 claimed_keys.update(result[0]) | ||||||
|                 missing.extend(result[1]) |                 missing.extend(result[1]) | ||||||
| 
 | 
 | ||||||
|         return claimed_keys, missing |         return claimed_keys, missing | ||||||
|  | |||||||
| @ -563,7 +563,9 @@ class E2eKeysHandler: | |||||||
|         return ret |         return ret | ||||||
| 
 | 
 | ||||||
|     async def claim_local_one_time_keys( |     async def claim_local_one_time_keys( | ||||||
|         self, local_query: List[Tuple[str, str, str]] |         self, | ||||||
|  |         local_query: List[Tuple[str, str, str]], | ||||||
|  |         always_include_fallback_keys: bool, | ||||||
|     ) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]: |     ) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]: | ||||||
|         """Claim one time keys for local users. |         """Claim one time keys for local users. | ||||||
| 
 | 
 | ||||||
| @ -573,6 +575,7 @@ class E2eKeysHandler: | |||||||
| 
 | 
 | ||||||
|         Args: |         Args: | ||||||
|             local_query: An iterable of tuples of (user ID, device ID, algorithm). |             local_query: An iterable of tuples of (user ID, device ID, algorithm). | ||||||
|  |             always_include_fallback_keys: True to always include fallback keys. | ||||||
| 
 | 
 | ||||||
|         Returns: |         Returns: | ||||||
|             An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes. |             An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes. | ||||||
| @ -583,24 +586,73 @@ class E2eKeysHandler: | |||||||
|         # If the application services have not provided any keys via the C-S |         # If the application services have not provided any keys via the C-S | ||||||
|         # API, query it directly for one-time keys. |         # API, query it directly for one-time keys. | ||||||
|         if self._query_appservices_for_otks: |         if self._query_appservices_for_otks: | ||||||
|  |             # TODO Should this query for fallback keys of uploaded OTKs if | ||||||
|  |             #      always_include_fallback_keys is True? The MSC is ambiguous. | ||||||
|             ( |             ( | ||||||
|                 appservice_results, |                 appservice_results, | ||||||
|                 not_found, |                 not_found, | ||||||
|             ) = await self._appservice_handler.claim_e2e_one_time_keys(not_found) |             ) = await self._appservice_handler.claim_e2e_one_time_keys(not_found) | ||||||
|         else: |         else: | ||||||
|             appservice_results = [] |             appservice_results = {} | ||||||
|  | 
 | ||||||
|  |         # Calculate which user ID / device ID / algorithm tuples to get fallback | ||||||
|  |         # keys for. This can be either only missing results *or* all results | ||||||
|  |         # (which don't already have a fallback key). | ||||||
|  |         if always_include_fallback_keys: | ||||||
|  |             # Build the fallback query as any part of the original query where | ||||||
|  |             # the appservice didn't respond with a fallback key. | ||||||
|  |             fallback_query = [] | ||||||
|  | 
 | ||||||
|  |             # Iterate each item in the original query and search the results | ||||||
|  |             # from the appservice for that user ID / device ID. If it is found, | ||||||
|  |             # check if any of the keys match the requested algorithm & are a | ||||||
|  |             # fallback key. | ||||||
|  |             for user_id, device_id, algorithm in local_query: | ||||||
|  |                 # Check if the appservice responded for this query. | ||||||
|  |                 as_result = appservice_results.get(user_id, {}).get(device_id, {}) | ||||||
|  |                 found_otk = False | ||||||
|  |                 for key_id, key_json in as_result.items(): | ||||||
|  |                     if key_id.startswith(f"{algorithm}:"): | ||||||
|  |                         # A OTK or fallback key was found for this query. | ||||||
|  |                         found_otk = True | ||||||
|  |                         # A fallback key was found for this query, no need to | ||||||
|  |                         # query further. | ||||||
|  |                         if key_json.get("fallback", False): | ||||||
|  |                             break | ||||||
|  | 
 | ||||||
|  |                 else: | ||||||
|  |                     # No fallback key was found from appservices, query for it. | ||||||
|  |                     # Only mark the fallback key as used if no OTK was found | ||||||
|  |                     # (from either the database or appservices). | ||||||
|  |                     mark_as_used = not found_otk and not any( | ||||||
|  |                         key_id.startswith(f"{algorithm}:") | ||||||
|  |                         for key_id in otk_results.get(user_id, {}) | ||||||
|  |                         .get(device_id, {}) | ||||||
|  |                         .keys() | ||||||
|  |                     ) | ||||||
|  |                     fallback_query.append((user_id, device_id, algorithm, mark_as_used)) | ||||||
|  | 
 | ||||||
|  |         else: | ||||||
|  |             # All fallback keys get marked as used. | ||||||
|  |             fallback_query = [ | ||||||
|  |                 (user_id, device_id, algorithm, True) | ||||||
|  |                 for user_id, device_id, algorithm in not_found | ||||||
|  |             ] | ||||||
| 
 | 
 | ||||||
|         # For each user that does not have a one-time keys available, see if |         # For each user that does not have a one-time keys available, see if | ||||||
|         # there is a fallback key. |         # there is a fallback key. | ||||||
|         fallback_results = await self.store.claim_e2e_fallback_keys(not_found) |         fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query) | ||||||
| 
 | 
 | ||||||
|         # Return the results in order, each item from the input query should |         # Return the results in order, each item from the input query should | ||||||
|         # only appear once in the combined list. |         # only appear once in the combined list. | ||||||
|         return (otk_results, *appservice_results, fallback_results) |         return (otk_results, appservice_results, fallback_results) | ||||||
| 
 | 
 | ||||||
|     @trace |     @trace | ||||||
|     async def claim_one_time_keys( |     async def claim_one_time_keys( | ||||||
|         self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int] |         self, | ||||||
|  |         query: Dict[str, Dict[str, Dict[str, str]]], | ||||||
|  |         timeout: Optional[int], | ||||||
|  |         always_include_fallback_keys: bool, | ||||||
|     ) -> JsonDict: |     ) -> JsonDict: | ||||||
|         local_query: List[Tuple[str, str, str]] = [] |         local_query: List[Tuple[str, str, str]] = [] | ||||||
|         remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {} |         remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {} | ||||||
| @ -617,7 +669,9 @@ class E2eKeysHandler: | |||||||
|         set_tag("local_key_query", str(local_query)) |         set_tag("local_key_query", str(local_query)) | ||||||
|         set_tag("remote_key_query", str(remote_queries)) |         set_tag("remote_key_query", str(remote_queries)) | ||||||
| 
 | 
 | ||||||
|         results = await self.claim_local_one_time_keys(local_query) |         results = await self.claim_local_one_time_keys( | ||||||
|  |             local_query, always_include_fallback_keys | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         # A map of user ID -> device ID -> key ID -> key. |         # A map of user ID -> device ID -> key ID -> key. | ||||||
|         json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} |         json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} | ||||||
| @ -625,7 +679,9 @@ class E2eKeysHandler: | |||||||
|             for user_id, device_keys in result.items(): |             for user_id, device_keys in result.items(): | ||||||
|                 for device_id, keys in device_keys.items(): |                 for device_id, keys in device_keys.items(): | ||||||
|                     for key_id, key in keys.items(): |                     for key_id, key in keys.items(): | ||||||
|                         json_result.setdefault(user_id, {})[device_id] = {key_id: key} |                         json_result.setdefault(user_id, {}).setdefault( | ||||||
|  |                             device_id, {} | ||||||
|  |                         ).update({key_id: key}) | ||||||
| 
 | 
 | ||||||
|         # Remote failures. |         # Remote failures. | ||||||
|         failures: Dict[str, JsonDict] = {} |         failures: Dict[str, JsonDict] = {} | ||||||
|  | |||||||
| @ -15,6 +15,7 @@ | |||||||
| # limitations under the License. | # limitations under the License. | ||||||
| 
 | 
 | ||||||
| import logging | import logging | ||||||
|  | import re | ||||||
| from typing import TYPE_CHECKING, Any, Optional, Tuple | from typing import TYPE_CHECKING, Any, Optional, Tuple | ||||||
| 
 | 
 | ||||||
| from synapse.api.errors import InvalidAPICallError, SynapseError | from synapse.api.errors import InvalidAPICallError, SynapseError | ||||||
| @ -288,7 +289,33 @@ class OneTimeKeyServlet(RestServlet): | |||||||
|         await self.auth.get_user_by_req(request, allow_guest=True) |         await self.auth.get_user_by_req(request, allow_guest=True) | ||||||
|         timeout = parse_integer(request, "timeout", 10 * 1000) |         timeout = parse_integer(request, "timeout", 10 * 1000) | ||||||
|         body = parse_json_object_from_request(request) |         body = parse_json_object_from_request(request) | ||||||
|         result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout) |         result = await self.e2e_keys_handler.claim_one_time_keys( | ||||||
|  |             body, timeout, always_include_fallback_keys=False | ||||||
|  |         ) | ||||||
|  |         return 200, result | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class UnstableOneTimeKeyServlet(RestServlet): | ||||||
|  |     """ | ||||||
|  |     Identical to the stable endpoint (OneTimeKeyServlet) except it always includes | ||||||
|  |     fallback keys in the response. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")] | ||||||
|  |     CATEGORY = "Encryption requests" | ||||||
|  | 
 | ||||||
|  |     def __init__(self, hs: "HomeServer"): | ||||||
|  |         super().__init__() | ||||||
|  |         self.auth = hs.get_auth() | ||||||
|  |         self.e2e_keys_handler = hs.get_e2e_keys_handler() | ||||||
|  | 
 | ||||||
|  |     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: | ||||||
|  |         await self.auth.get_user_by_req(request, allow_guest=True) | ||||||
|  |         timeout = parse_integer(request, "timeout", 10 * 1000) | ||||||
|  |         body = parse_json_object_from_request(request) | ||||||
|  |         result = await self.e2e_keys_handler.claim_one_time_keys( | ||||||
|  |             body, timeout, always_include_fallback_keys=True | ||||||
|  |         ) | ||||||
|         return 200, result |         return 200, result | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -394,6 +421,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: | |||||||
|     KeyQueryServlet(hs).register(http_server) |     KeyQueryServlet(hs).register(http_server) | ||||||
|     KeyChangesServlet(hs).register(http_server) |     KeyChangesServlet(hs).register(http_server) | ||||||
|     OneTimeKeyServlet(hs).register(http_server) |     OneTimeKeyServlet(hs).register(http_server) | ||||||
|  |     if hs.config.experimental.msc3983_appservice_otk_claims: | ||||||
|  |         UnstableOneTimeKeyServlet(hs).register(http_server) | ||||||
|     if hs.config.worker.worker_app is None: |     if hs.config.worker.worker_app is None: | ||||||
|         SigningKeyUploadServlet(hs).register(http_server) |         SigningKeyUploadServlet(hs).register(http_server) | ||||||
|         SignaturesUploadServlet(hs).register(http_server) |         SignaturesUploadServlet(hs).register(http_server) | ||||||
|  | |||||||
| @ -1149,18 +1149,19 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||||||
|         return results, missing |         return results, missing | ||||||
| 
 | 
 | ||||||
|     async def claim_e2e_fallback_keys( |     async def claim_e2e_fallback_keys( | ||||||
|         self, query_list: Iterable[Tuple[str, str, str]] |         self, query_list: Iterable[Tuple[str, str, str, bool]] | ||||||
|     ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]: |     ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]: | ||||||
|         """Take a list of fallback keys out of the database. |         """Take a list of fallback keys out of the database. | ||||||
| 
 | 
 | ||||||
|         Args: |         Args: | ||||||
|             query_list: An iterable of tuples of (user ID, device ID, algorithm). |             query_list: An iterable of tuples of | ||||||
|  |                 (user ID, device ID, algorithm, whether the key should be marked as used). | ||||||
| 
 | 
 | ||||||
|         Returns: |         Returns: | ||||||
|             A map of user ID -> a map device ID -> a map of key ID -> JSON. |             A map of user ID -> a map device ID -> a map of key ID -> JSON. | ||||||
|         """ |         """ | ||||||
|         results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} |         results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} | ||||||
|         for user_id, device_id, algorithm in query_list: |         for user_id, device_id, algorithm, mark_as_used in query_list: | ||||||
|             row = await self.db_pool.simple_select_one( |             row = await self.db_pool.simple_select_one( | ||||||
|                 table="e2e_fallback_keys_json", |                 table="e2e_fallback_keys_json", | ||||||
|                 keyvalues={ |                 keyvalues={ | ||||||
| @ -1180,7 +1181,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||||||
|             used = row["used"] |             used = row["used"] | ||||||
| 
 | 
 | ||||||
|             # Mark fallback key as used if not already. |             # Mark fallback key as used if not already. | ||||||
|             if not used: |             if not used and mark_as_used: | ||||||
|                 await self.db_pool.simple_update_one( |                 await self.db_pool.simple_update_one( | ||||||
|                     table="e2e_fallback_keys_json", |                     table="e2e_fallback_keys_json", | ||||||
|                     keyvalues={ |                     keyvalues={ | ||||||
|  | |||||||
| @ -160,7 +160,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||||||
| 
 | 
 | ||||||
|         res2 = self.get_success( |         res2 = self.get_success( | ||||||
|             self.handler.claim_one_time_keys( |             self.handler.claim_one_time_keys( | ||||||
|                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None |                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, | ||||||
|  |                 timeout=None, | ||||||
|  |                 always_include_fallback_keys=False, | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
| @ -203,7 +205,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||||||
|         # key |         # key | ||||||
|         claim_res = self.get_success( |         claim_res = self.get_success( | ||||||
|             self.handler.claim_one_time_keys( |             self.handler.claim_one_time_keys( | ||||||
|                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None |                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, | ||||||
|  |                 timeout=None, | ||||||
|  |                 always_include_fallback_keys=False, | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
| @ -220,7 +224,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||||||
|         # claiming an OTK again should return the same fallback key |         # claiming an OTK again should return the same fallback key | ||||||
|         claim_res = self.get_success( |         claim_res = self.get_success( | ||||||
|             self.handler.claim_one_time_keys( |             self.handler.claim_one_time_keys( | ||||||
|                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None |                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, | ||||||
|  |                 timeout=None, | ||||||
|  |                 always_include_fallback_keys=False, | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
| @ -267,7 +273,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||||||
| 
 | 
 | ||||||
|         claim_res = self.get_success( |         claim_res = self.get_success( | ||||||
|             self.handler.claim_one_time_keys( |             self.handler.claim_one_time_keys( | ||||||
|                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None |                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, | ||||||
|  |                 timeout=None, | ||||||
|  |                 always_include_fallback_keys=False, | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
| @ -277,7 +285,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||||||
| 
 | 
 | ||||||
|         claim_res = self.get_success( |         claim_res = self.get_success( | ||||||
|             self.handler.claim_one_time_keys( |             self.handler.claim_one_time_keys( | ||||||
|                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None |                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, | ||||||
|  |                 timeout=None, | ||||||
|  |                 always_include_fallback_keys=False, | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
| @ -296,7 +306,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||||||
| 
 | 
 | ||||||
|         claim_res = self.get_success( |         claim_res = self.get_success( | ||||||
|             self.handler.claim_one_time_keys( |             self.handler.claim_one_time_keys( | ||||||
|                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None |                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, | ||||||
|  |                 timeout=None, | ||||||
|  |                 always_include_fallback_keys=False, | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
| @ -304,6 +316,75 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||||||
|             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}}, |             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}}, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  |     def test_fallback_key_always_returned(self) -> None: | ||||||
|  |         local_user = "@boris:" + self.hs.hostname | ||||||
|  |         device_id = "xyz" | ||||||
|  |         fallback_key = {"alg1:k1": "fallback_key1"} | ||||||
|  |         otk = {"alg1:k2": "key2"} | ||||||
|  | 
 | ||||||
|  |         # we shouldn't have any unused fallback keys yet | ||||||
|  |         res = self.get_success( | ||||||
|  |             self.store.get_e2e_unused_fallback_key_types(local_user, device_id) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(res, []) | ||||||
|  | 
 | ||||||
|  |         # Upload a OTK & fallback key. | ||||||
|  |         self.get_success( | ||||||
|  |             self.handler.upload_keys_for_user( | ||||||
|  |                 local_user, | ||||||
|  |                 device_id, | ||||||
|  |                 {"one_time_keys": otk, "fallback_keys": fallback_key}, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # we should now have an unused alg1 key | ||||||
|  |         fallback_res = self.get_success( | ||||||
|  |             self.store.get_e2e_unused_fallback_key_types(local_user, device_id) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(fallback_res, ["alg1"]) | ||||||
|  | 
 | ||||||
|  |         # Claiming an OTK and requesting to always return the fallback key should | ||||||
|  |         # return both. | ||||||
|  |         claim_res = self.get_success( | ||||||
|  |             self.handler.claim_one_time_keys( | ||||||
|  |                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, | ||||||
|  |                 timeout=None, | ||||||
|  |                 always_include_fallback_keys=True, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual( | ||||||
|  |             claim_res, | ||||||
|  |             { | ||||||
|  |                 "failures": {}, | ||||||
|  |                 "one_time_keys": {local_user: {device_id: {**fallback_key, **otk}}}, | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # This should not mark the key as used. | ||||||
|  |         fallback_res = self.get_success( | ||||||
|  |             self.store.get_e2e_unused_fallback_key_types(local_user, device_id) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(fallback_res, ["alg1"]) | ||||||
|  | 
 | ||||||
|  |         # Claiming an OTK again should return only the fallback key. | ||||||
|  |         claim_res = self.get_success( | ||||||
|  |             self.handler.claim_one_time_keys( | ||||||
|  |                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, | ||||||
|  |                 timeout=None, | ||||||
|  |                 always_include_fallback_keys=True, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual( | ||||||
|  |             claim_res, | ||||||
|  |             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # And mark it as used. | ||||||
|  |         fallback_res = self.get_success( | ||||||
|  |             self.store.get_e2e_unused_fallback_key_types(local_user, device_id) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(fallback_res, []) | ||||||
|  | 
 | ||||||
|     def test_replace_master_key(self) -> None: |     def test_replace_master_key(self) -> None: | ||||||
|         """uploading a new signing key should make the old signing key unavailable""" |         """uploading a new signing key should make the old signing key unavailable""" | ||||||
|         local_user = "@boris:" + self.hs.hostname |         local_user = "@boris:" + self.hs.hostname | ||||||
| @ -1004,6 +1085,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||||||
|                     } |                     } | ||||||
|                 }, |                 }, | ||||||
|                 timeout=None, |                 timeout=None, | ||||||
|  |                 always_include_fallback_keys=False, | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
| @ -1016,6 +1098,153 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  |     @override_config({"experimental_features": {"msc3983_appservice_otk_claims": True}}) | ||||||
|  |     def test_query_appservice_with_fallback(self) -> None: | ||||||
|  |         local_user = "@boris:" + self.hs.hostname | ||||||
|  |         device_id_1 = "xyz" | ||||||
|  |         fallback_key = {"alg1:k1": {"desc": "fallback_key1", "fallback": True}} | ||||||
|  |         otk = {"alg1:k2": {"desc": "key2"}} | ||||||
|  |         as_fallback_key = {"alg1:k3": {"desc": "fallback_key3", "fallback": True}} | ||||||
|  |         as_otk = {"alg1:k4": {"desc": "key4"}} | ||||||
|  | 
 | ||||||
|  |         # Inject an appservice interested in this user. | ||||||
|  |         appservice = ApplicationService( | ||||||
|  |             token="i_am_an_app_service", | ||||||
|  |             id="1234", | ||||||
|  |             namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]}, | ||||||
|  |             # Note: this user does not have to match the regex above | ||||||
|  |             sender="@as_main:test", | ||||||
|  |         ) | ||||||
|  |         self.hs.get_datastores().main.services_cache = [appservice] | ||||||
|  |         self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex( | ||||||
|  |             [appservice] | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # Setup a response. | ||||||
|  |         self.appservice_api.claim_client_keys.return_value = make_awaitable( | ||||||
|  |             ({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, []) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # Claim OTKs, which will ask the appservice and do nothing else. | ||||||
|  |         claim_res = self.get_success( | ||||||
|  |             self.handler.claim_one_time_keys( | ||||||
|  |                 {"one_time_keys": {local_user: {device_id_1: "alg1"}}}, | ||||||
|  |                 timeout=None, | ||||||
|  |                 always_include_fallback_keys=True, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual( | ||||||
|  |             claim_res, | ||||||
|  |             { | ||||||
|  |                 "failures": {}, | ||||||
|  |                 "one_time_keys": { | ||||||
|  |                     local_user: {device_id_1: {**as_otk, **as_fallback_key}} | ||||||
|  |                 }, | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # Now upload a fallback key. | ||||||
|  |         res = self.get_success( | ||||||
|  |             self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(res, []) | ||||||
|  | 
 | ||||||
|  |         self.get_success( | ||||||
|  |             self.handler.upload_keys_for_user( | ||||||
|  |                 local_user, | ||||||
|  |                 device_id_1, | ||||||
|  |                 {"fallback_keys": fallback_key}, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # we should now have an unused alg1 key | ||||||
|  |         fallback_res = self.get_success( | ||||||
|  |             self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(fallback_res, ["alg1"]) | ||||||
|  | 
 | ||||||
|  |         # The appservice will return only the OTK. | ||||||
|  |         self.appservice_api.claim_client_keys.return_value = make_awaitable( | ||||||
|  |             ({local_user: {device_id_1: as_otk}}, []) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # Claim OTKs, which should return the OTK from the appservice and the | ||||||
|  |         # uploaded fallback key. | ||||||
|  |         claim_res = self.get_success( | ||||||
|  |             self.handler.claim_one_time_keys( | ||||||
|  |                 {"one_time_keys": {local_user: {device_id_1: "alg1"}}}, | ||||||
|  |                 timeout=None, | ||||||
|  |                 always_include_fallback_keys=True, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual( | ||||||
|  |             claim_res, | ||||||
|  |             { | ||||||
|  |                 "failures": {}, | ||||||
|  |                 "one_time_keys": { | ||||||
|  |                     local_user: {device_id_1: {**as_otk, **fallback_key}} | ||||||
|  |                 }, | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # But the fallback key should not be marked as used. | ||||||
|  |         fallback_res = self.get_success( | ||||||
|  |             self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(fallback_res, ["alg1"]) | ||||||
|  | 
 | ||||||
|  |         # Now upload a OTK. | ||||||
|  |         self.get_success( | ||||||
|  |             self.handler.upload_keys_for_user( | ||||||
|  |                 local_user, | ||||||
|  |                 device_id_1, | ||||||
|  |                 {"one_time_keys": otk}, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # Claim OTKs, which will return information only from the database. | ||||||
|  |         claim_res = self.get_success( | ||||||
|  |             self.handler.claim_one_time_keys( | ||||||
|  |                 {"one_time_keys": {local_user: {device_id_1: "alg1"}}}, | ||||||
|  |                 timeout=None, | ||||||
|  |                 always_include_fallback_keys=True, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual( | ||||||
|  |             claim_res, | ||||||
|  |             { | ||||||
|  |                 "failures": {}, | ||||||
|  |                 "one_time_keys": {local_user: {device_id_1: {**otk, **fallback_key}}}, | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # But the fallback key should not be marked as used. | ||||||
|  |         fallback_res = self.get_success( | ||||||
|  |             self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(fallback_res, ["alg1"]) | ||||||
|  | 
 | ||||||
|  |         # Finally, return only the fallback key from the appservice. | ||||||
|  |         self.appservice_api.claim_client_keys.return_value = make_awaitable( | ||||||
|  |             ({local_user: {device_id_1: as_fallback_key}}, []) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # Claim OTKs, which will return only the fallback key from the database. | ||||||
|  |         claim_res = self.get_success( | ||||||
|  |             self.handler.claim_one_time_keys( | ||||||
|  |                 {"one_time_keys": {local_user: {device_id_1: "alg1"}}}, | ||||||
|  |                 timeout=None, | ||||||
|  |                 always_include_fallback_keys=True, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual( | ||||||
|  |             claim_res, | ||||||
|  |             { | ||||||
|  |                 "failures": {}, | ||||||
|  |                 "one_time_keys": {local_user: {device_id_1: as_fallback_key}}, | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|     @override_config({"experimental_features": {"msc3984_appservice_key_query": True}}) |     @override_config({"experimental_features": {"msc3984_appservice_key_query": True}}) | ||||||
|     def test_query_local_devices_appservice(self) -> None: |     def test_query_local_devices_appservice(self) -> None: | ||||||
|         """Test that querying of appservices for keys overrides responses from the database.""" |         """Test that querying of appservices for keys overrides responses from the database.""" | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user