mirror of
				https://github.com/matrix-org/synapse.git
				synced 2025-10-31 20:28:16 +00:00 
			
		
		
		
	Add experimental support for MSC3391: deleting account data (#14714)
This commit is contained in:
		
							parent
							
								
									044fa1a1de
								
							
						
					
					
						commit
						c4456114e1
					
				
							
								
								
									
										1
									
								
								changelog.d/14714.feature
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								changelog.d/14714.feature
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1 @@ | ||||
| Add experimental support for [MSC3391](https://github.com/matrix-org/matrix-spec-proposals/pull/3391) (removing account data). | ||||
| @ -102,6 +102,8 @@ experimental_features: | ||||
|   {% endif %} | ||||
|   # Filtering /messages by relation type. | ||||
|   msc3874_enabled: true | ||||
|   # Enable removing account data support | ||||
|   msc3391_enabled: true | ||||
| 
 | ||||
| server_notices: | ||||
|   system_mxid_localpart: _server | ||||
|  | ||||
| @ -190,7 +190,7 @@ fi | ||||
| 
 | ||||
| extra_test_args=() | ||||
| 
 | ||||
| test_tags="synapse_blacklist,msc3787,msc3874" | ||||
| test_tags="synapse_blacklist,msc3787,msc3874,msc3391" | ||||
| 
 | ||||
| # All environment variables starting with PASS_ will be shared. | ||||
| # (The prefix is stripped off before reaching the container.) | ||||
|  | ||||
| @ -136,3 +136,6 @@ class ExperimentalConfig(Config): | ||||
|             # Enable room version (and thus applicable push rules from MSC3931/3932) | ||||
|             version_id = RoomVersions.MSC1767v10.identifier | ||||
|             KNOWN_ROOM_VERSIONS[version_id] = RoomVersions.MSC1767v10 | ||||
| 
 | ||||
|         # MSC3391: Removing account data. | ||||
|         self.msc3391_enabled = experimental.get("msc3391_enabled", False) | ||||
|  | ||||
| @ -17,10 +17,12 @@ import random | ||||
| from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple | ||||
| 
 | ||||
| from synapse.replication.http.account_data import ( | ||||
|     ReplicationAddRoomAccountDataRestServlet, | ||||
|     ReplicationAddTagRestServlet, | ||||
|     ReplicationAddUserAccountDataRestServlet, | ||||
|     ReplicationRemoveRoomAccountDataRestServlet, | ||||
|     ReplicationRemoveTagRestServlet, | ||||
|     ReplicationRoomAccountDataRestServlet, | ||||
|     ReplicationUserAccountDataRestServlet, | ||||
|     ReplicationRemoveUserAccountDataRestServlet, | ||||
| ) | ||||
| from synapse.streams import EventSource | ||||
| from synapse.types import JsonDict, StreamKeyType, UserID | ||||
| @ -41,8 +43,18 @@ class AccountDataHandler: | ||||
|         self._instance_name = hs.get_instance_name() | ||||
|         self._notifier = hs.get_notifier() | ||||
| 
 | ||||
|         self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs) | ||||
|         self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs) | ||||
|         self._add_user_data_client = ( | ||||
|             ReplicationAddUserAccountDataRestServlet.make_client(hs) | ||||
|         ) | ||||
|         self._remove_user_data_client = ( | ||||
|             ReplicationRemoveUserAccountDataRestServlet.make_client(hs) | ||||
|         ) | ||||
|         self._add_room_data_client = ( | ||||
|             ReplicationAddRoomAccountDataRestServlet.make_client(hs) | ||||
|         ) | ||||
|         self._remove_room_data_client = ( | ||||
|             ReplicationRemoveRoomAccountDataRestServlet.make_client(hs) | ||||
|         ) | ||||
|         self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs) | ||||
|         self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs) | ||||
|         self._account_data_writers = hs.config.worker.writers.account_data | ||||
| @ -112,7 +124,7 @@ class AccountDataHandler: | ||||
| 
 | ||||
|             return max_stream_id | ||||
|         else: | ||||
|             response = await self._room_data_client( | ||||
|             response = await self._add_room_data_client( | ||||
|                 instance_name=random.choice(self._account_data_writers), | ||||
|                 user_id=user_id, | ||||
|                 room_id=room_id, | ||||
| @ -121,15 +133,59 @@ class AccountDataHandler: | ||||
|             ) | ||||
|             return response["max_stream_id"] | ||||
| 
 | ||||
|     async def remove_account_data_for_room( | ||||
|         self, user_id: str, room_id: str, account_data_type: str | ||||
|     ) -> Optional[int]: | ||||
|         """ | ||||
|         Deletes the room account data for the given user and account data type. | ||||
| 
 | ||||
|         "Deleting" account data merely means setting the content of the account data | ||||
|         to an empty JSON object: {}. | ||||
| 
 | ||||
|         Args: | ||||
|             user_id: The user ID to remove room account data for. | ||||
|             room_id: The room ID to target. | ||||
|             account_data_type: The account data type to remove. | ||||
| 
 | ||||
|         Returns: | ||||
|             The maximum stream ID, or None if the room account data item did not exist. | ||||
|         """ | ||||
|         if self._instance_name in self._account_data_writers: | ||||
|             max_stream_id = await self._store.remove_account_data_for_room( | ||||
|                 user_id, room_id, account_data_type | ||||
|             ) | ||||
|             if max_stream_id is None: | ||||
|                 # The referenced account data did not exist, so no delete occurred. | ||||
|                 return None | ||||
| 
 | ||||
|             self._notifier.on_new_event( | ||||
|                 StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id] | ||||
|             ) | ||||
| 
 | ||||
|             # Notify Synapse modules that the content of the type has changed to an | ||||
|             # empty dictionary. | ||||
|             await self._notify_modules(user_id, room_id, account_data_type, {}) | ||||
| 
 | ||||
|             return max_stream_id | ||||
|         else: | ||||
|             response = await self._remove_room_data_client( | ||||
|                 instance_name=random.choice(self._account_data_writers), | ||||
|                 user_id=user_id, | ||||
|                 room_id=room_id, | ||||
|                 account_data_type=account_data_type, | ||||
|                 content={}, | ||||
|             ) | ||||
|             return response["max_stream_id"] | ||||
| 
 | ||||
|     async def add_account_data_for_user( | ||||
|         self, user_id: str, account_data_type: str, content: JsonDict | ||||
|     ) -> int: | ||||
|         """Add some global account_data for a user. | ||||
| 
 | ||||
|         Args: | ||||
|             user_id: The user to add a tag for. | ||||
|             user_id: The user to add some account data for. | ||||
|             account_data_type: The type of account_data to add. | ||||
|             content: A json object to associate with the tag. | ||||
|             content: The content json dictionary. | ||||
| 
 | ||||
|         Returns: | ||||
|             The maximum stream ID. | ||||
| @ -148,7 +204,7 @@ class AccountDataHandler: | ||||
| 
 | ||||
|             return max_stream_id | ||||
|         else: | ||||
|             response = await self._user_data_client( | ||||
|             response = await self._add_user_data_client( | ||||
|                 instance_name=random.choice(self._account_data_writers), | ||||
|                 user_id=user_id, | ||||
|                 account_data_type=account_data_type, | ||||
| @ -156,6 +212,45 @@ class AccountDataHandler: | ||||
|             ) | ||||
|             return response["max_stream_id"] | ||||
| 
 | ||||
|     async def remove_account_data_for_user( | ||||
|         self, user_id: str, account_data_type: str | ||||
|     ) -> Optional[int]: | ||||
|         """Removes a piece of global account_data for a user. | ||||
| 
 | ||||
|         Args: | ||||
|             user_id: The user to remove account data for. | ||||
|             account_data_type: The type of account_data to remove. | ||||
| 
 | ||||
|         Returns: | ||||
|             The maximum stream ID, or None if the room account data item did not exist. | ||||
|         """ | ||||
| 
 | ||||
|         if self._instance_name in self._account_data_writers: | ||||
|             max_stream_id = await self._store.remove_account_data_for_user( | ||||
|                 user_id, account_data_type | ||||
|             ) | ||||
|             if max_stream_id is None: | ||||
|                 # The referenced account data did not exist, so no delete occurred. | ||||
|                 return None | ||||
| 
 | ||||
|             self._notifier.on_new_event( | ||||
|                 StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id] | ||||
|             ) | ||||
| 
 | ||||
|             # Notify Synapse modules that the content of the type has changed to an | ||||
|             # empty dictionary. | ||||
|             await self._notify_modules(user_id, None, account_data_type, {}) | ||||
| 
 | ||||
|             return max_stream_id | ||||
|         else: | ||||
|             response = await self._remove_user_data_client( | ||||
|                 instance_name=random.choice(self._account_data_writers), | ||||
|                 user_id=user_id, | ||||
|                 account_data_type=account_data_type, | ||||
|                 content={}, | ||||
|             ) | ||||
|             return response["max_stream_id"] | ||||
| 
 | ||||
|     async def add_tag_to_room( | ||||
|         self, user_id: str, room_id: str, tag: str, content: JsonDict | ||||
|     ) -> int: | ||||
|  | ||||
| @ -28,7 +28,7 @@ if TYPE_CHECKING: | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): | ||||
| class ReplicationAddUserAccountDataRestServlet(ReplicationEndpoint): | ||||
|     """Add user account data on the appropriate account data worker. | ||||
| 
 | ||||
|     Request format: | ||||
| @ -49,7 +49,6 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): | ||||
|         super().__init__(hs) | ||||
| 
 | ||||
|         self.handler = hs.get_account_data_handler() | ||||
|         self.clock = hs.get_clock() | ||||
| 
 | ||||
|     @staticmethod | ||||
|     async def _serialize_payload(  # type: ignore[override] | ||||
| @ -73,7 +72,45 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): | ||||
|         return 200, {"max_stream_id": max_stream_id} | ||||
| 
 | ||||
| 
 | ||||
| class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): | ||||
| class ReplicationRemoveUserAccountDataRestServlet(ReplicationEndpoint): | ||||
|     """Remove user account data on the appropriate account data worker. | ||||
| 
 | ||||
|     Request format: | ||||
| 
 | ||||
|         POST /_synapse/replication/remove_user_account_data/:user_id/:type | ||||
| 
 | ||||
|         { | ||||
|             "content": { ... }, | ||||
|         } | ||||
| 
 | ||||
|     """ | ||||
| 
 | ||||
|     NAME = "remove_user_account_data" | ||||
|     PATH_ARGS = ("user_id", "account_data_type") | ||||
|     CACHE = False | ||||
| 
 | ||||
|     def __init__(self, hs: "HomeServer"): | ||||
|         super().__init__(hs) | ||||
| 
 | ||||
|         self.handler = hs.get_account_data_handler() | ||||
| 
 | ||||
|     @staticmethod | ||||
|     async def _serialize_payload(  # type: ignore[override] | ||||
|         user_id: str, account_data_type: str | ||||
|     ) -> JsonDict: | ||||
|         return {} | ||||
| 
 | ||||
|     async def _handle_request(  # type: ignore[override] | ||||
|         self, request: Request, user_id: str, account_data_type: str | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         max_stream_id = await self.handler.remove_account_data_for_user( | ||||
|             user_id, account_data_type | ||||
|         ) | ||||
| 
 | ||||
|         return 200, {"max_stream_id": max_stream_id} | ||||
| 
 | ||||
| 
 | ||||
| class ReplicationAddRoomAccountDataRestServlet(ReplicationEndpoint): | ||||
|     """Add room account data on the appropriate account data worker. | ||||
| 
 | ||||
|     Request format: | ||||
| @ -94,7 +131,6 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): | ||||
|         super().__init__(hs) | ||||
| 
 | ||||
|         self.handler = hs.get_account_data_handler() | ||||
|         self.clock = hs.get_clock() | ||||
| 
 | ||||
|     @staticmethod | ||||
|     async def _serialize_payload(  # type: ignore[override] | ||||
| @ -118,6 +154,44 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): | ||||
|         return 200, {"max_stream_id": max_stream_id} | ||||
| 
 | ||||
| 
 | ||||
| class ReplicationRemoveRoomAccountDataRestServlet(ReplicationEndpoint): | ||||
|     """Remove room account data on the appropriate account data worker. | ||||
| 
 | ||||
|     Request format: | ||||
| 
 | ||||
|         POST /_synapse/replication/remove_room_account_data/:user_id/:room_id/:account_data_type | ||||
| 
 | ||||
|         { | ||||
|             "content": { ... }, | ||||
|         } | ||||
| 
 | ||||
|     """ | ||||
| 
 | ||||
|     NAME = "remove_room_account_data" | ||||
|     PATH_ARGS = ("user_id", "room_id", "account_data_type") | ||||
|     CACHE = False | ||||
| 
 | ||||
|     def __init__(self, hs: "HomeServer"): | ||||
|         super().__init__(hs) | ||||
| 
 | ||||
|         self.handler = hs.get_account_data_handler() | ||||
| 
 | ||||
|     @staticmethod | ||||
|     async def _serialize_payload(  # type: ignore[override] | ||||
|         user_id: str, room_id: str, account_data_type: str, content: JsonDict | ||||
|     ) -> JsonDict: | ||||
|         return {} | ||||
| 
 | ||||
|     async def _handle_request(  # type: ignore[override] | ||||
|         self, request: Request, user_id: str, room_id: str, account_data_type: str | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         max_stream_id = await self.handler.remove_account_data_for_room( | ||||
|             user_id, room_id, account_data_type | ||||
|         ) | ||||
| 
 | ||||
|         return 200, {"max_stream_id": max_stream_id} | ||||
| 
 | ||||
| 
 | ||||
| class ReplicationAddTagRestServlet(ReplicationEndpoint): | ||||
|     """Add tag on the appropriate account data worker. | ||||
| 
 | ||||
| @ -139,7 +213,6 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint): | ||||
|         super().__init__(hs) | ||||
| 
 | ||||
|         self.handler = hs.get_account_data_handler() | ||||
|         self.clock = hs.get_clock() | ||||
| 
 | ||||
|     @staticmethod | ||||
|     async def _serialize_payload(  # type: ignore[override] | ||||
| @ -186,7 +259,6 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): | ||||
|         super().__init__(hs) | ||||
| 
 | ||||
|         self.handler = hs.get_account_data_handler() | ||||
|         self.clock = hs.get_clock() | ||||
| 
 | ||||
|     @staticmethod | ||||
|     async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict:  # type: ignore[override] | ||||
| @ -206,7 +278,11 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): | ||||
| 
 | ||||
| 
 | ||||
| def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: | ||||
|     ReplicationUserAccountDataRestServlet(hs).register(http_server) | ||||
|     ReplicationRoomAccountDataRestServlet(hs).register(http_server) | ||||
|     ReplicationAddUserAccountDataRestServlet(hs).register(http_server) | ||||
|     ReplicationAddRoomAccountDataRestServlet(hs).register(http_server) | ||||
|     ReplicationAddTagRestServlet(hs).register(http_server) | ||||
|     ReplicationRemoveTagRestServlet(hs).register(http_server) | ||||
| 
 | ||||
|     if hs.config.experimental.msc3391_enabled: | ||||
|         ReplicationRemoveUserAccountDataRestServlet(hs).register(http_server) | ||||
|         ReplicationRemoveRoomAccountDataRestServlet(hs).register(http_server) | ||||
|  | ||||
| @ -41,6 +41,7 @@ class AccountDataServlet(RestServlet): | ||||
| 
 | ||||
|     def __init__(self, hs: "HomeServer"): | ||||
|         super().__init__() | ||||
|         self._hs = hs | ||||
|         self.auth = hs.get_auth() | ||||
|         self.store = hs.get_datastores().main | ||||
|         self.handler = hs.get_account_data_handler() | ||||
| @ -54,6 +55,16 @@ class AccountDataServlet(RestServlet): | ||||
| 
 | ||||
|         body = parse_json_object_from_request(request) | ||||
| 
 | ||||
|         # If experimental support for MSC3391 is enabled, then providing an empty dict | ||||
|         # as the value for an account data type should be functionally equivalent to | ||||
|         # calling the DELETE method on the same type. | ||||
|         if self._hs.config.experimental.msc3391_enabled: | ||||
|             if body == {}: | ||||
|                 await self.handler.remove_account_data_for_user( | ||||
|                     user_id, account_data_type | ||||
|                 ) | ||||
|                 return 200, {} | ||||
| 
 | ||||
|         await self.handler.add_account_data_for_user(user_id, account_data_type, body) | ||||
| 
 | ||||
|         return 200, {} | ||||
| @ -72,9 +83,48 @@ class AccountDataServlet(RestServlet): | ||||
|         if event is None: | ||||
|             raise NotFoundError("Account data not found") | ||||
| 
 | ||||
|         # If experimental support for MSC3391 is enabled, then this endpoint should | ||||
|         # return a 404 if the content for an account data type is an empty dict. | ||||
|         if self._hs.config.experimental.msc3391_enabled and event == {}: | ||||
|             raise NotFoundError("Account data not found") | ||||
| 
 | ||||
|         return 200, event | ||||
| 
 | ||||
| 
 | ||||
| class UnstableAccountDataServlet(RestServlet): | ||||
|     """ | ||||
|     Contains an unstable endpoint for removing user account data, as specified by | ||||
|     MSC3391. If that MSC is accepted, this code should have unstable prefixes removed | ||||
|     and become incorporated into AccountDataServlet above. | ||||
|     """ | ||||
| 
 | ||||
|     PATTERNS = client_patterns( | ||||
|         "/org.matrix.msc3391/user/(?P<user_id>[^/]*)" | ||||
|         "/account_data/(?P<account_data_type>[^/]*)", | ||||
|         unstable=True, | ||||
|         releases=(), | ||||
|     ) | ||||
| 
 | ||||
|     def __init__(self, hs: "HomeServer"): | ||||
|         super().__init__() | ||||
|         self.auth = hs.get_auth() | ||||
|         self.handler = hs.get_account_data_handler() | ||||
| 
 | ||||
|     async def on_DELETE( | ||||
|         self, | ||||
|         request: SynapseRequest, | ||||
|         user_id: str, | ||||
|         account_data_type: str, | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         requester = await self.auth.get_user_by_req(request) | ||||
|         if user_id != requester.user.to_string(): | ||||
|             raise AuthError(403, "Cannot delete account data for other users.") | ||||
| 
 | ||||
|         await self.handler.remove_account_data_for_user(user_id, account_data_type) | ||||
| 
 | ||||
|         return 200, {} | ||||
| 
 | ||||
| 
 | ||||
| class RoomAccountDataServlet(RestServlet): | ||||
|     """ | ||||
|     PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 | ||||
| @ -89,6 +139,7 @@ class RoomAccountDataServlet(RestServlet): | ||||
| 
 | ||||
|     def __init__(self, hs: "HomeServer"): | ||||
|         super().__init__() | ||||
|         self._hs = hs | ||||
|         self.auth = hs.get_auth() | ||||
|         self.store = hs.get_datastores().main | ||||
|         self.handler = hs.get_account_data_handler() | ||||
| @ -121,6 +172,16 @@ class RoomAccountDataServlet(RestServlet): | ||||
|                 Codes.BAD_JSON, | ||||
|             ) | ||||
| 
 | ||||
|         # If experimental support for MSC3391 is enabled, then providing an empty dict | ||||
|         # as the value for an account data type should be functionally equivalent to | ||||
|         # calling the DELETE method on the same type. | ||||
|         if self._hs.config.experimental.msc3391_enabled: | ||||
|             if body == {}: | ||||
|                 await self.handler.remove_account_data_for_room( | ||||
|                     user_id, room_id, account_data_type | ||||
|                 ) | ||||
|                 return 200, {} | ||||
| 
 | ||||
|         await self.handler.add_account_data_to_room( | ||||
|             user_id, room_id, account_data_type, body | ||||
|         ) | ||||
| @ -152,9 +213,63 @@ class RoomAccountDataServlet(RestServlet): | ||||
|         if event is None: | ||||
|             raise NotFoundError("Room account data not found") | ||||
| 
 | ||||
|         # If experimental support for MSC3391 is enabled, then this endpoint should | ||||
|         # return a 404 if the content for an account data type is an empty dict. | ||||
|         if self._hs.config.experimental.msc3391_enabled and event == {}: | ||||
|             raise NotFoundError("Room account data not found") | ||||
| 
 | ||||
|         return 200, event | ||||
| 
 | ||||
| 
 | ||||
| class UnstableRoomAccountDataServlet(RestServlet): | ||||
|     """ | ||||
|     Contains an unstable endpoint for removing room account data, as specified by | ||||
|     MSC3391. If that MSC is accepted, this code should have unstable prefixes removed | ||||
|     and become incorporated into RoomAccountDataServlet above. | ||||
|     """ | ||||
| 
 | ||||
|     PATTERNS = client_patterns( | ||||
|         "/org.matrix.msc3391/user/(?P<user_id>[^/]*)" | ||||
|         "/rooms/(?P<room_id>[^/]*)" | ||||
|         "/account_data/(?P<account_data_type>[^/]*)", | ||||
|         unstable=True, | ||||
|         releases=(), | ||||
|     ) | ||||
| 
 | ||||
|     def __init__(self, hs: "HomeServer"): | ||||
|         super().__init__() | ||||
|         self.auth = hs.get_auth() | ||||
|         self.handler = hs.get_account_data_handler() | ||||
| 
 | ||||
|     async def on_DELETE( | ||||
|         self, | ||||
|         request: SynapseRequest, | ||||
|         user_id: str, | ||||
|         room_id: str, | ||||
|         account_data_type: str, | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         requester = await self.auth.get_user_by_req(request) | ||||
|         if user_id != requester.user.to_string(): | ||||
|             raise AuthError(403, "Cannot delete account data for other users.") | ||||
| 
 | ||||
|         if not RoomID.is_valid(room_id): | ||||
|             raise SynapseError( | ||||
|                 400, | ||||
|                 f"{room_id} is not a valid room ID", | ||||
|                 Codes.INVALID_PARAM, | ||||
|             ) | ||||
| 
 | ||||
|         await self.handler.remove_account_data_for_room( | ||||
|             user_id, room_id, account_data_type | ||||
|         ) | ||||
| 
 | ||||
|         return 200, {} | ||||
| 
 | ||||
| 
 | ||||
| def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: | ||||
|     AccountDataServlet(hs).register(http_server) | ||||
|     RoomAccountDataServlet(hs).register(http_server) | ||||
| 
 | ||||
|     if hs.config.experimental.msc3391_enabled: | ||||
|         UnstableAccountDataServlet(hs).register(http_server) | ||||
|         UnstableRoomAccountDataServlet(hs).register(http_server) | ||||
|  | ||||
| @ -1762,7 +1762,8 @@ class DatabasePool: | ||||
|             desc: description of the transaction, for logging and metrics | ||||
| 
 | ||||
|         Returns: | ||||
|             A list of dictionaries. | ||||
|             A list of dictionaries, one per result row, each a mapping between the | ||||
|             column names from `retcols` and that column's value for the row. | ||||
|         """ | ||||
|         return await self.runInteraction( | ||||
|             desc, | ||||
| @ -1791,6 +1792,10 @@ class DatabasePool: | ||||
|                 column names and values to select the rows with, or None to not | ||||
|                 apply a WHERE clause. | ||||
|             retcols: the names of the columns to return | ||||
| 
 | ||||
|         Returns: | ||||
|             A list of dictionaries, one per result row, each a mapping between the | ||||
|             column names from `retcols` and that column's value for the row. | ||||
|         """ | ||||
|         if keyvalues: | ||||
|             sql = "SELECT %s FROM %s WHERE %s" % ( | ||||
| @ -1898,6 +1903,19 @@ class DatabasePool: | ||||
|         updatevalues: Dict[str, Any], | ||||
|         desc: str, | ||||
|     ) -> int: | ||||
|         """ | ||||
|         Update rows in the given database table. | ||||
|         If the given keyvalues don't match anything, nothing will be updated. | ||||
| 
 | ||||
|         Args: | ||||
|             table: The database table to update. | ||||
|             keyvalues: A mapping of column name to value to match rows on. | ||||
|             updatevalues: A mapping of column name to value to replace in any matched rows. | ||||
|             desc: description of the transaction, for logging and metrics. | ||||
| 
 | ||||
|         Returns: | ||||
|             The number of rows that were updated. Will be 0 if no matching rows were found. | ||||
|         """ | ||||
|         return await self.runInteraction( | ||||
|             desc, self.simple_update_txn, table, keyvalues, updatevalues | ||||
|         ) | ||||
| @ -1909,6 +1927,19 @@ class DatabasePool: | ||||
|         keyvalues: Dict[str, Any], | ||||
|         updatevalues: Dict[str, Any], | ||||
|     ) -> int: | ||||
|         """ | ||||
|         Update rows in the given database table. | ||||
|         If the given keyvalues don't match anything, nothing will be updated. | ||||
| 
 | ||||
|         Args: | ||||
|             txn: The database transaction object. | ||||
|             table: The database table to update. | ||||
|             keyvalues: A mapping of column name to value to match rows on. | ||||
|             updatevalues: A mapping of column name to value to replace in any matched rows. | ||||
| 
 | ||||
|         Returns: | ||||
|             The number of rows that were updated. Will be 0 if no matching rows were found. | ||||
|         """ | ||||
|         if keyvalues: | ||||
|             where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) | ||||
|         else: | ||||
|  | ||||
| @ -123,7 +123,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) | ||||
|     async def get_account_data_for_user( | ||||
|         self, user_id: str | ||||
|     ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: | ||||
|         """Get all the client account_data for a user. | ||||
|         """ | ||||
|         Get all the client account_data for a user. | ||||
| 
 | ||||
|         If experimental MSC3391 support is enabled, any entries with an empty | ||||
|         content body are excluded; as this means they have been deleted. | ||||
| 
 | ||||
|         Args: | ||||
|             user_id: The user to get the account_data for. | ||||
| @ -135,27 +139,48 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) | ||||
|         def get_account_data_for_user_txn( | ||||
|             txn: LoggingTransaction, | ||||
|         ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: | ||||
|             rows = self.db_pool.simple_select_list_txn( | ||||
|                 txn, | ||||
|                 "account_data", | ||||
|                 {"user_id": user_id}, | ||||
|                 ["account_data_type", "content"], | ||||
|             ) | ||||
|             # The 'content != '{}' condition below prevents us from using | ||||
|             # `simple_select_list_txn` here, as it doesn't support conditions | ||||
|             # other than 'equals'. | ||||
|             sql = """ | ||||
|                 SELECT account_data_type, content FROM account_data | ||||
|                 WHERE user_id = ? | ||||
|             """ | ||||
| 
 | ||||
|             # If experimental MSC3391 support is enabled, then account data entries | ||||
|             # with an empty content are considered "deleted". So skip adding them to | ||||
|             # the results. | ||||
|             if self.hs.config.experimental.msc3391_enabled: | ||||
|                 sql += " AND content != '{}'" | ||||
| 
 | ||||
|             txn.execute(sql, (user_id,)) | ||||
|             rows = self.db_pool.cursor_to_dict(txn) | ||||
| 
 | ||||
|             global_account_data = { | ||||
|                 row["account_data_type"]: db_to_json(row["content"]) for row in rows | ||||
|             } | ||||
| 
 | ||||
|             rows = self.db_pool.simple_select_list_txn( | ||||
|                 txn, | ||||
|                 "room_account_data", | ||||
|                 {"user_id": user_id}, | ||||
|                 ["room_id", "account_data_type", "content"], | ||||
|             ) | ||||
|             # The 'content != '{}' condition below prevents us from using | ||||
|             # `simple_select_list_txn` here, as it doesn't support conditions | ||||
|             # other than 'equals'. | ||||
|             sql = """ | ||||
|                 SELECT room_id, account_data_type, content FROM room_account_data | ||||
|                 WHERE user_id = ? | ||||
|             """ | ||||
| 
 | ||||
|             # If experimental MSC3391 support is enabled, then account data entries | ||||
|             # with an empty content are considered "deleted". So skip adding them to | ||||
|             # the results. | ||||
|             if self.hs.config.experimental.msc3391_enabled: | ||||
|                 sql += " AND content != '{}'" | ||||
| 
 | ||||
|             txn.execute(sql, (user_id,)) | ||||
|             rows = self.db_pool.cursor_to_dict(txn) | ||||
| 
 | ||||
|             by_room: Dict[str, Dict[str, JsonDict]] = {} | ||||
|             for row in rows: | ||||
|                 room_data = by_room.setdefault(row["room_id"], {}) | ||||
| 
 | ||||
|                 room_data[row["account_data_type"]] = db_to_json(row["content"]) | ||||
| 
 | ||||
|             return global_account_data, by_room | ||||
| @ -469,6 +494,72 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) | ||||
| 
 | ||||
|         return self._account_data_id_gen.get_current_token() | ||||
| 
 | ||||
|     async def remove_account_data_for_room( | ||||
|         self, user_id: str, room_id: str, account_data_type: str | ||||
|     ) -> Optional[int]: | ||||
|         """Delete the room account data for the user of a given type. | ||||
| 
 | ||||
|         Args: | ||||
|             user_id: The user to remove account_data for. | ||||
|             room_id: The room ID to scope the request to. | ||||
|             account_data_type: The account data type to delete. | ||||
| 
 | ||||
|         Returns: | ||||
|             The maximum stream position, or None if there was no matching room account | ||||
|             data to delete. | ||||
|         """ | ||||
|         assert self._can_write_to_account_data | ||||
|         assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) | ||||
| 
 | ||||
|         def _remove_account_data_for_room_txn( | ||||
|             txn: LoggingTransaction, next_id: int | ||||
|         ) -> bool: | ||||
|             """ | ||||
|             Args: | ||||
|                 txn: The transaction object. | ||||
|                 next_id: The stream_id to update any existing rows to. | ||||
| 
 | ||||
|             Returns: | ||||
|                 True if an entry in room_account_data had its content set to '{}', | ||||
|                 otherwise False. This informs callers of whether there actually was an | ||||
|                 existing room account data entry to delete, or if the call was a no-op. | ||||
|             """ | ||||
|             # We can't use `simple_update` as it doesn't have the ability to specify | ||||
|             # where clauses other than '=', which we need for `content != '{}'` below. | ||||
|             sql = """ | ||||
|                 UPDATE room_account_data | ||||
|                     SET stream_id = ?, content = '{}' | ||||
|                 WHERE user_id = ? | ||||
|                     AND room_id = ? | ||||
|                     AND account_data_type = ? | ||||
|                     AND content != '{}' | ||||
|             """ | ||||
|             txn.execute( | ||||
|                 sql, | ||||
|                 (next_id, user_id, room_id, account_data_type), | ||||
|             ) | ||||
|             # Return true if any rows were updated. | ||||
|             return txn.rowcount != 0 | ||||
| 
 | ||||
|         async with self._account_data_id_gen.get_next() as next_id: | ||||
|             row_updated = await self.db_pool.runInteraction( | ||||
|                 "remove_account_data_for_room", | ||||
|                 _remove_account_data_for_room_txn, | ||||
|                 next_id, | ||||
|             ) | ||||
| 
 | ||||
|             if not row_updated: | ||||
|                 return None | ||||
| 
 | ||||
|             self._account_data_stream_cache.entity_has_changed(user_id, next_id) | ||||
|             self.get_account_data_for_user.invalidate((user_id,)) | ||||
|             self.get_account_data_for_room.invalidate((user_id, room_id)) | ||||
|             self.get_account_data_for_room_and_type.prefill( | ||||
|                 (user_id, room_id, account_data_type), {} | ||||
|             ) | ||||
| 
 | ||||
|         return self._account_data_id_gen.get_current_token() | ||||
| 
 | ||||
|     async def add_account_data_for_user( | ||||
|         self, user_id: str, account_data_type: str, content: JsonDict | ||||
|     ) -> int: | ||||
| @ -569,6 +660,108 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) | ||||
|             self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) | ||||
|         self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,)) | ||||
| 
 | ||||
|     async def remove_account_data_for_user( | ||||
|         self, | ||||
|         user_id: str, | ||||
|         account_data_type: str, | ||||
|     ) -> Optional[int]: | ||||
|         """ | ||||
|         Delete a single piece of user account data by type. | ||||
| 
 | ||||
|         A "delete" is performed by updating a potentially existing row in the | ||||
|         "account_data" database table for (user_id, account_data_type) and | ||||
|         setting its content to "{}". | ||||
| 
 | ||||
|         Args: | ||||
|             user_id: The user ID to modify the account data of. | ||||
|             account_data_type: The type to remove. | ||||
| 
 | ||||
|         Returns: | ||||
|             The maximum stream position, or None if there was no matching account data | ||||
|             to delete. | ||||
|         """ | ||||
|         assert self._can_write_to_account_data | ||||
|         assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) | ||||
| 
 | ||||
|         def _remove_account_data_for_user_txn( | ||||
|             txn: LoggingTransaction, next_id: int | ||||
|         ) -> bool: | ||||
|             """ | ||||
|             Args: | ||||
|                 txn: The transaction object. | ||||
|                 next_id: The stream_id to update any existing rows to. | ||||
| 
 | ||||
|             Returns: | ||||
|                 True if an entry in account_data had its content set to '{}', otherwise | ||||
|                 False. This informs callers of whether there actually was an existing | ||||
|                 account data entry to delete, or if the call was a no-op. | ||||
|             """ | ||||
|             # We can't use `simple_update` as it doesn't have the ability to specify | ||||
|             # where clauses other than '=', which we need for `content != '{}'` below. | ||||
|             sql = """ | ||||
|                 UPDATE account_data | ||||
|                     SET stream_id = ?, content = '{}' | ||||
|                 WHERE user_id = ? | ||||
|                     AND account_data_type = ? | ||||
|                     AND content != '{}' | ||||
|             """ | ||||
|             txn.execute(sql, (next_id, user_id, account_data_type)) | ||||
|             if txn.rowcount == 0: | ||||
|                 # We didn't update any rows. This means that there was no matching room | ||||
|                 # account data entry to delete in the first place. | ||||
|                 return False | ||||
| 
 | ||||
|             # Ignored users get denormalized into a separate table as an optimisation. | ||||
|             if account_data_type == AccountDataTypes.IGNORED_USER_LIST: | ||||
|                 # If this method was called with the ignored users account data type, we | ||||
|                 # simply delete all ignored users. | ||||
| 
 | ||||
|                 # First pull all the users that this user ignores. | ||||
|                 previously_ignored_users = set( | ||||
|                     self.db_pool.simple_select_onecol_txn( | ||||
|                         txn, | ||||
|                         table="ignored_users", | ||||
|                         keyvalues={"ignorer_user_id": user_id}, | ||||
|                         retcol="ignored_user_id", | ||||
|                     ) | ||||
|                 ) | ||||
| 
 | ||||
|                 # Then delete them from the database. | ||||
|                 self.db_pool.simple_delete_txn( | ||||
|                     txn, | ||||
|                     table="ignored_users", | ||||
|                     keyvalues={"ignorer_user_id": user_id}, | ||||
|                 ) | ||||
| 
 | ||||
|                 # Invalidate the cache for ignored users which were removed. | ||||
|                 for ignored_user_id in previously_ignored_users: | ||||
|                     self._invalidate_cache_and_stream( | ||||
|                         txn, self.ignored_by, (ignored_user_id,) | ||||
|                     ) | ||||
| 
 | ||||
|                 # Invalidate for this user the cache tracking ignored users. | ||||
|                 self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,)) | ||||
| 
 | ||||
|             return True | ||||
| 
 | ||||
|         async with self._account_data_id_gen.get_next() as next_id: | ||||
|             row_updated = await self.db_pool.runInteraction( | ||||
|                 "remove_account_data_for_user", | ||||
|                 _remove_account_data_for_user_txn, | ||||
|                 next_id, | ||||
|             ) | ||||
| 
 | ||||
|             if not row_updated: | ||||
|                 return None | ||||
| 
 | ||||
|             self._account_data_stream_cache.entity_has_changed(user_id, next_id) | ||||
|             self.get_account_data_for_user.invalidate((user_id,)) | ||||
|             self.get_global_account_data_by_type_for_user.prefill( | ||||
|                 (user_id, account_data_type), {} | ||||
|             ) | ||||
| 
 | ||||
|         return self._account_data_id_gen.get_current_token() | ||||
| 
 | ||||
|     async def purge_account_data_for_user(self, user_id: str) -> None: | ||||
|         """ | ||||
|         Removes ALL the account data for a user. | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user