mirror of
				https://github.com/matrix-org/synapse.git
				synced 2025-10-31 04:08:21 +00:00 
			
		
		
		
	Delete device messages asynchronously and in staged batches (#16240)
This commit is contained in:
		
							parent
							
								
									1e571cd664
								
							
						
					
					
						commit
						4f1840a88a
					
				
							
								
								
									
										1
									
								
								changelog.d/16240.misc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								changelog.d/16240.misc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1 @@ | |||||||
|  | Delete device messages asynchronously and in staged batches using the task scheduler. | ||||||
| @ -43,9 +43,12 @@ from synapse.metrics.background_process_metrics import ( | |||||||
| ) | ) | ||||||
| from synapse.types import ( | from synapse.types import ( | ||||||
|     JsonDict, |     JsonDict, | ||||||
|  |     JsonMapping, | ||||||
|  |     ScheduledTask, | ||||||
|     StrCollection, |     StrCollection, | ||||||
|     StreamKeyType, |     StreamKeyType, | ||||||
|     StreamToken, |     StreamToken, | ||||||
|  |     TaskStatus, | ||||||
|     UserID, |     UserID, | ||||||
|     get_domain_from_id, |     get_domain_from_id, | ||||||
|     get_verify_key_from_cross_signing_key, |     get_verify_key_from_cross_signing_key, | ||||||
| @ -62,6 +65,7 @@ if TYPE_CHECKING: | |||||||
| 
 | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
|  | DELETE_DEVICE_MSGS_TASK_NAME = "delete_device_messages" | ||||||
| MAX_DEVICE_DISPLAY_NAME_LEN = 100 | MAX_DEVICE_DISPLAY_NAME_LEN = 100 | ||||||
| DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000 | DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000 | ||||||
| 
 | 
 | ||||||
| @ -78,6 +82,7 @@ class DeviceWorkerHandler: | |||||||
|         self._appservice_handler = hs.get_application_service_handler() |         self._appservice_handler = hs.get_application_service_handler() | ||||||
|         self._state_storage = hs.get_storage_controllers().state |         self._state_storage = hs.get_storage_controllers().state | ||||||
|         self._auth_handler = hs.get_auth_handler() |         self._auth_handler = hs.get_auth_handler() | ||||||
|  |         self._event_sources = hs.get_event_sources() | ||||||
|         self.server_name = hs.hostname |         self.server_name = hs.hostname | ||||||
|         self._msc3852_enabled = hs.config.experimental.msc3852_enabled |         self._msc3852_enabled = hs.config.experimental.msc3852_enabled | ||||||
|         self._query_appservices_for_keys = ( |         self._query_appservices_for_keys = ( | ||||||
| @ -386,6 +391,7 @@ class DeviceHandler(DeviceWorkerHandler): | |||||||
|         self._account_data_handler = hs.get_account_data_handler() |         self._account_data_handler = hs.get_account_data_handler() | ||||||
|         self._storage_controllers = hs.get_storage_controllers() |         self._storage_controllers = hs.get_storage_controllers() | ||||||
|         self.db_pool = hs.get_datastores().main.db_pool |         self.db_pool = hs.get_datastores().main.db_pool | ||||||
|  |         self._task_scheduler = hs.get_task_scheduler() | ||||||
| 
 | 
 | ||||||
|         self.device_list_updater = DeviceListUpdater(hs, self) |         self.device_list_updater = DeviceListUpdater(hs, self) | ||||||
| 
 | 
 | ||||||
| @ -419,6 +425,10 @@ class DeviceHandler(DeviceWorkerHandler): | |||||||
|                 self._delete_stale_devices, |                 self._delete_stale_devices, | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|  |         self._task_scheduler.register_action( | ||||||
|  |             self._delete_device_messages, DELETE_DEVICE_MSGS_TASK_NAME | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|     def _check_device_name_length(self, name: Optional[str]) -> None: |     def _check_device_name_length(self, name: Optional[str]) -> None: | ||||||
|         """ |         """ | ||||||
|         Checks whether a device name is longer than the maximum allowed length. |         Checks whether a device name is longer than the maximum allowed length. | ||||||
| @ -530,6 +540,7 @@ class DeviceHandler(DeviceWorkerHandler): | |||||||
|             user_id: The user to delete devices from. |             user_id: The user to delete devices from. | ||||||
|             device_ids: The list of device IDs to delete |             device_ids: The list of device IDs to delete | ||||||
|         """ |         """ | ||||||
|  |         to_device_stream_id = self._event_sources.get_current_token().to_device_key | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             await self.store.delete_devices(user_id, device_ids) |             await self.store.delete_devices(user_id, device_ids) | ||||||
| @ -559,12 +570,49 @@ class DeviceHandler(DeviceWorkerHandler): | |||||||
|                     f"org.matrix.msc3890.local_notification_settings.{device_id}", |                     f"org.matrix.msc3890.local_notification_settings.{device_id}", | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|  |             # Delete device messages asynchronously and in batches using the task scheduler | ||||||
|  |             await self._task_scheduler.schedule_task( | ||||||
|  |                 DELETE_DEVICE_MSGS_TASK_NAME, | ||||||
|  |                 resource_id=device_id, | ||||||
|  |                 params={ | ||||||
|  |                     "user_id": user_id, | ||||||
|  |                     "device_id": device_id, | ||||||
|  |                     "up_to_stream_id": to_device_stream_id, | ||||||
|  |                 }, | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|         # Pushers are deleted after `delete_access_tokens_for_user` is called so that |         # Pushers are deleted after `delete_access_tokens_for_user` is called so that | ||||||
|         # modules using `on_logged_out` hook can use them if needed. |         # modules using `on_logged_out` hook can use them if needed. | ||||||
|         await self.hs.get_pusherpool().remove_pushers_by_devices(user_id, device_ids) |         await self.hs.get_pusherpool().remove_pushers_by_devices(user_id, device_ids) | ||||||
| 
 | 
 | ||||||
|         await self.notify_device_update(user_id, device_ids) |         await self.notify_device_update(user_id, device_ids) | ||||||
| 
 | 
 | ||||||
|  |     DEVICE_MSGS_DELETE_BATCH_LIMIT = 100 | ||||||
|  | 
 | ||||||
|  |     async def _delete_device_messages( | ||||||
|  |         self, | ||||||
|  |         task: ScheduledTask, | ||||||
|  |     ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: | ||||||
|  |         """Scheduler task to delete device messages in batch of `DEVICE_MSGS_DELETE_BATCH_LIMIT`.""" | ||||||
|  |         assert task.params is not None | ||||||
|  |         user_id = task.params["user_id"] | ||||||
|  |         device_id = task.params["device_id"] | ||||||
|  |         up_to_stream_id = task.params["up_to_stream_id"] | ||||||
|  | 
 | ||||||
|  |         res = await self.store.delete_messages_for_device( | ||||||
|  |             user_id=user_id, | ||||||
|  |             device_id=device_id, | ||||||
|  |             up_to_stream_id=up_to_stream_id, | ||||||
|  |             limit=DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT: | ||||||
|  |             return TaskStatus.COMPLETE, None, None | ||||||
|  |         else: | ||||||
|  |             # There is probably still device messages to be deleted, let's keep the task active and it will be run | ||||||
|  |             # again in a subsequent scheduler loop run (probably the next one, if not too many tasks are running). | ||||||
|  |             return TaskStatus.ACTIVE, None, None | ||||||
|  | 
 | ||||||
|     async def update_device(self, user_id: str, device_id: str, content: dict) -> None: |     async def update_device(self, user_id: str, device_id: str, content: dict) -> None: | ||||||
|         """Update the given device |         """Update the given device | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -183,6 +183,7 @@ class BasePresenceHandler(abc.ABC): | |||||||
|     writer""" |     writer""" | ||||||
| 
 | 
 | ||||||
|     def __init__(self, hs: "HomeServer"): |     def __init__(self, hs: "HomeServer"): | ||||||
|  |         self.hs = hs | ||||||
|         self.clock = hs.get_clock() |         self.clock = hs.get_clock() | ||||||
|         self.store = hs.get_datastores().main |         self.store = hs.get_datastores().main | ||||||
|         self._storage_controllers = hs.get_storage_controllers() |         self._storage_controllers = hs.get_storage_controllers() | ||||||
| @ -473,8 +474,6 @@ class _NullContextManager(ContextManager[None]): | |||||||
| class WorkerPresenceHandler(BasePresenceHandler): | class WorkerPresenceHandler(BasePresenceHandler): | ||||||
|     def __init__(self, hs: "HomeServer"): |     def __init__(self, hs: "HomeServer"): | ||||||
|         super().__init__(hs) |         super().__init__(hs) | ||||||
|         self.hs = hs |  | ||||||
| 
 |  | ||||||
|         self._presence_writer_instance = hs.config.worker.writers.presence[0] |         self._presence_writer_instance = hs.config.worker.writers.presence[0] | ||||||
| 
 | 
 | ||||||
|         # Route presence EDUs to the right worker |         # Route presence EDUs to the right worker | ||||||
| @ -738,7 +737,6 @@ class WorkerPresenceHandler(BasePresenceHandler): | |||||||
| class PresenceHandler(BasePresenceHandler): | class PresenceHandler(BasePresenceHandler): | ||||||
|     def __init__(self, hs: "HomeServer"): |     def __init__(self, hs: "HomeServer"): | ||||||
|         super().__init__(hs) |         super().__init__(hs) | ||||||
|         self.hs = hs |  | ||||||
|         self.wheel_timer: WheelTimer[str] = WheelTimer() |         self.wheel_timer: WheelTimer[str] = WheelTimer() | ||||||
|         self.notifier = hs.get_notifier() |         self.notifier = hs.get_notifier() | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -40,6 +40,7 @@ from synapse.api.filtering import FilterCollection | |||||||
| from synapse.api.presence import UserPresenceState | from synapse.api.presence import UserPresenceState | ||||||
| from synapse.api.room_versions import KNOWN_ROOM_VERSIONS | from synapse.api.room_versions import KNOWN_ROOM_VERSIONS | ||||||
| from synapse.events import EventBase | from synapse.events import EventBase | ||||||
|  | from synapse.handlers.device import DELETE_DEVICE_MSGS_TASK_NAME | ||||||
| from synapse.handlers.relations import BundledAggregations | from synapse.handlers.relations import BundledAggregations | ||||||
| from synapse.logging import issue9533_logger | from synapse.logging import issue9533_logger | ||||||
| from synapse.logging.context import current_context | from synapse.logging.context import current_context | ||||||
| @ -268,6 +269,7 @@ class SyncHandler: | |||||||
|         self._storage_controllers = hs.get_storage_controllers() |         self._storage_controllers = hs.get_storage_controllers() | ||||||
|         self._state_storage_controller = self._storage_controllers.state |         self._state_storage_controller = self._storage_controllers.state | ||||||
|         self._device_handler = hs.get_device_handler() |         self._device_handler = hs.get_device_handler() | ||||||
|  |         self._task_scheduler = hs.get_task_scheduler() | ||||||
| 
 | 
 | ||||||
|         self.should_calculate_push_rules = hs.config.push.enable_push |         self.should_calculate_push_rules = hs.config.push.enable_push | ||||||
| 
 | 
 | ||||||
| @ -360,11 +362,19 @@ class SyncHandler: | |||||||
|         # (since we now know that the device has received them) |         # (since we now know that the device has received them) | ||||||
|         if since_token is not None: |         if since_token is not None: | ||||||
|             since_stream_id = since_token.to_device_key |             since_stream_id = since_token.to_device_key | ||||||
|             deleted = await self.store.delete_messages_for_device( |             # Delete device messages asynchronously and in batches using the task scheduler | ||||||
|                 sync_config.user.to_string(), sync_config.device_id, since_stream_id |             await self._task_scheduler.schedule_task( | ||||||
|  |                 DELETE_DEVICE_MSGS_TASK_NAME, | ||||||
|  |                 resource_id=sync_config.device_id, | ||||||
|  |                 params={ | ||||||
|  |                     "user_id": sync_config.user.to_string(), | ||||||
|  |                     "device_id": sync_config.device_id, | ||||||
|  |                     "up_to_stream_id": since_stream_id, | ||||||
|  |                 }, | ||||||
|             ) |             ) | ||||||
|             logger.debug( |             logger.debug( | ||||||
|                 "Deleted %d to-device messages up to %d", deleted, since_stream_id |                 "Deletion of to-device messages up to %d scheduled", | ||||||
|  |                 since_stream_id, | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         if timeout == 0 or since_token is None or full_state: |         if timeout == 0 or since_token is None or full_state: | ||||||
|  | |||||||
| @ -445,13 +445,18 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||||||
| 
 | 
 | ||||||
|     @trace |     @trace | ||||||
|     async def delete_messages_for_device( |     async def delete_messages_for_device( | ||||||
|         self, user_id: str, device_id: Optional[str], up_to_stream_id: int |         self, | ||||||
|  |         user_id: str, | ||||||
|  |         device_id: Optional[str], | ||||||
|  |         up_to_stream_id: int, | ||||||
|  |         limit: int, | ||||||
|     ) -> int: |     ) -> int: | ||||||
|         """ |         """ | ||||||
|         Args: |         Args: | ||||||
|             user_id: The recipient user_id. |             user_id: The recipient user_id. | ||||||
|             device_id: The recipient device_id. |             device_id: The recipient device_id. | ||||||
|             up_to_stream_id: Where to delete messages up to. |             up_to_stream_id: Where to delete messages up to. | ||||||
|  |             limit: maximum number of messages to delete | ||||||
| 
 | 
 | ||||||
|         Returns: |         Returns: | ||||||
|             The number of messages deleted. |             The number of messages deleted. | ||||||
| @ -472,12 +477,16 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||||||
|                 log_kv({"message": "No changes in cache since last check"}) |                 log_kv({"message": "No changes in cache since last check"}) | ||||||
|                 return 0 |                 return 0 | ||||||
| 
 | 
 | ||||||
|  |         ROW_ID_NAME = self.database_engine.row_id_name | ||||||
|  | 
 | ||||||
|         def delete_messages_for_device_txn(txn: LoggingTransaction) -> int: |         def delete_messages_for_device_txn(txn: LoggingTransaction) -> int: | ||||||
|             sql = ( |             sql = f""" | ||||||
|                 "DELETE FROM device_inbox" |                 DELETE FROM device_inbox WHERE {ROW_ID_NAME} IN ( | ||||||
|                 " WHERE user_id = ? AND device_id = ?" |                   SELECT {ROW_ID_NAME} FROM device_inbox | ||||||
|                 " AND stream_id <= ?" |                   WHERE user_id = ? AND device_id = ? AND stream_id <= ? | ||||||
|             ) |                   LIMIT {limit} | ||||||
|  |                 ) | ||||||
|  |                 """ | ||||||
|             txn.execute(sql, (user_id, device_id, up_to_stream_id)) |             txn.execute(sql, (user_id, device_id, up_to_stream_id)) | ||||||
|             return txn.rowcount |             return txn.rowcount | ||||||
| 
 | 
 | ||||||
| @ -487,6 +496,11 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||||||
| 
 | 
 | ||||||
|         log_kv({"message": f"deleted {count} messages for device", "count": count}) |         log_kv({"message": f"deleted {count} messages for device", "count": count}) | ||||||
| 
 | 
 | ||||||
|  |         # In this case we don't know if we hit the limit or the delete is complete | ||||||
|  |         # so let's not update the cache. | ||||||
|  |         if count == limit: | ||||||
|  |             return count | ||||||
|  | 
 | ||||||
|         # Update the cache, ensuring that we only ever increase the value |         # Update the cache, ensuring that we only ever increase the value | ||||||
|         updated_last_deleted_stream_id = self._last_device_delete_cache.get( |         updated_last_deleted_stream_id = self._last_device_delete_cache.get( | ||||||
|             (user_id, device_id), 0 |             (user_id, device_id), 0 | ||||||
|  | |||||||
| @ -1766,14 +1766,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||||||
|                 keyvalues={"user_id": user_id, "hidden": False}, |                 keyvalues={"user_id": user_id, "hidden": False}, | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|             self.db_pool.simple_delete_many_txn( |  | ||||||
|                 txn, |  | ||||||
|                 table="device_inbox", |  | ||||||
|                 column="device_id", |  | ||||||
|                 values=device_ids, |  | ||||||
|                 keyvalues={"user_id": user_id}, |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|             self.db_pool.simple_delete_many_txn( |             self.db_pool.simple_delete_many_txn( | ||||||
|                 txn, |                 txn, | ||||||
|                 table="device_auth_providers", |                 table="device_auth_providers", | ||||||
|  | |||||||
| @ -939,11 +939,7 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): | |||||||
|         receipts.""" |         receipts.""" | ||||||
| 
 | 
 | ||||||
|         def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None: |         def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None: | ||||||
|             if isinstance(self.database_engine, PostgresEngine): |             ROW_ID_NAME = self.database_engine.row_id_name | ||||||
|                 ROW_ID_NAME = "ctid" |  | ||||||
|             else: |  | ||||||
|                 ROW_ID_NAME = "rowid" |  | ||||||
| 
 |  | ||||||
|             # Identify any duplicate receipts arising from |             # Identify any duplicate receipts arising from | ||||||
|             # https://github.com/matrix-org/synapse/issues/14406. |             # https://github.com/matrix-org/synapse/issues/14406. | ||||||
|             # The following query takes less than a minute on matrix.org. |             # The following query takes less than a minute on matrix.org. | ||||||
|  | |||||||
| @ -100,6 +100,12 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM | |||||||
|         """Gets a string giving the server version. For example: '3.22.0'""" |         """Gets a string giving the server version. For example: '3.22.0'""" | ||||||
|         ... |         ... | ||||||
| 
 | 
 | ||||||
|  |     @property | ||||||
|  |     @abc.abstractmethod | ||||||
|  |     def row_id_name(self) -> str: | ||||||
|  |         """Gets the literal name representing a row id for this engine.""" | ||||||
|  |         ... | ||||||
|  | 
 | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
|     def in_transaction(self, conn: ConnectionType) -> bool: |     def in_transaction(self, conn: ConnectionType) -> bool: | ||||||
|         """Whether the connection is currently in a transaction.""" |         """Whether the connection is currently in a transaction.""" | ||||||
|  | |||||||
| @ -211,6 +211,10 @@ class PostgresEngine( | |||||||
|         else: |         else: | ||||||
|             return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100) |             return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100) | ||||||
| 
 | 
 | ||||||
|  |     @property | ||||||
|  |     def row_id_name(self) -> str: | ||||||
|  |         return "ctid" | ||||||
|  | 
 | ||||||
|     def in_transaction(self, conn: psycopg2.extensions.connection) -> bool: |     def in_transaction(self, conn: psycopg2.extensions.connection) -> bool: | ||||||
|         return conn.status != psycopg2.extensions.STATUS_READY |         return conn.status != psycopg2.extensions.STATUS_READY | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -123,6 +123,10 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]): | |||||||
|         """Gets a string giving the server version. For example: '3.22.0'.""" |         """Gets a string giving the server version. For example: '3.22.0'.""" | ||||||
|         return "%i.%i.%i" % sqlite3.sqlite_version_info |         return "%i.%i.%i" % sqlite3.sqlite_version_info | ||||||
| 
 | 
 | ||||||
|  |     @property | ||||||
|  |     def row_id_name(self) -> str: | ||||||
|  |         return "rowid" | ||||||
|  | 
 | ||||||
|     def in_transaction(self, conn: sqlite3.Connection) -> bool: |     def in_transaction(self, conn: sqlite3.Connection) -> bool: | ||||||
|         return conn.in_transaction |         return conn.in_transaction | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -14,7 +14,7 @@ | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| from synapse.storage.database import LoggingTransaction | from synapse.storage.database import LoggingTransaction | ||||||
| from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine | from synapse.storage.engines import BaseDatabaseEngine | ||||||
| from synapse.storage.prepare_database import get_statements | from synapse.storage.prepare_database import get_statements | ||||||
| 
 | 
 | ||||||
| FIX_INDEXES = """ | FIX_INDEXES = """ | ||||||
| @ -37,7 +37,7 @@ CREATE INDEX group_rooms_r_idx ON group_rooms(room_id); | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: | def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: | ||||||
|     rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid" |     rowid = database_engine.row_id_name | ||||||
| 
 | 
 | ||||||
|     # remove duplicates from group_users & group_invites tables |     # remove duplicates from group_users & group_invites tables | ||||||
|     cur.execute( |     cur.execute( | ||||||
|  | |||||||
| @ -77,6 +77,7 @@ class TaskScheduler: | |||||||
|     LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000  # 24hrs |     LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000  # 24hrs | ||||||
| 
 | 
 | ||||||
|     def __init__(self, hs: "HomeServer"): |     def __init__(self, hs: "HomeServer"): | ||||||
|  |         self._hs = hs | ||||||
|         self._store = hs.get_datastores().main |         self._store = hs.get_datastores().main | ||||||
|         self._clock = hs.get_clock() |         self._clock = hs.get_clock() | ||||||
|         self._running_tasks: Set[str] = set() |         self._running_tasks: Set[str] = set() | ||||||
| @ -97,8 +98,6 @@ class TaskScheduler: | |||||||
|                 "handle_scheduled_tasks", |                 "handle_scheduled_tasks", | ||||||
|                 self._handle_scheduled_tasks, |                 self._handle_scheduled_tasks, | ||||||
|             ) |             ) | ||||||
|         else: |  | ||||||
|             self.replication_client = hs.get_replication_command_handler() |  | ||||||
| 
 | 
 | ||||||
|     def register_action( |     def register_action( | ||||||
|         self, |         self, | ||||||
| @ -133,7 +132,7 @@ class TaskScheduler: | |||||||
|         params: Optional[JsonMapping] = None, |         params: Optional[JsonMapping] = None, | ||||||
|     ) -> str: |     ) -> str: | ||||||
|         """Schedule a new potentially resumable task. A function matching the specified |         """Schedule a new potentially resumable task. A function matching the specified | ||||||
|         `action` should have been previously registered with `register_action`. |         `action` should have be registered with `register_action` before the task is run. | ||||||
| 
 | 
 | ||||||
|         Args: |         Args: | ||||||
|             action: the name of a previously registered action |             action: the name of a previously registered action | ||||||
| @ -149,11 +148,6 @@ class TaskScheduler: | |||||||
|         Returns: |         Returns: | ||||||
|             The id of the scheduled task |             The id of the scheduled task | ||||||
|         """ |         """ | ||||||
|         if action not in self._actions: |  | ||||||
|             raise Exception( |  | ||||||
|                 f"No function associated with action {action} of the scheduled task" |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|         status = TaskStatus.SCHEDULED |         status = TaskStatus.SCHEDULED | ||||||
|         if timestamp is None or timestamp < self._clock.time_msec(): |         if timestamp is None or timestamp < self._clock.time_msec(): | ||||||
|             timestamp = self._clock.time_msec() |             timestamp = self._clock.time_msec() | ||||||
| @ -175,7 +169,7 @@ class TaskScheduler: | |||||||
|             if self._run_background_tasks: |             if self._run_background_tasks: | ||||||
|                 await self._launch_task(task) |                 await self._launch_task(task) | ||||||
|             else: |             else: | ||||||
|                 self.replication_client.send_new_active_task(task.id) |                 self._hs.get_replication_command_handler().send_new_active_task(task.id) | ||||||
| 
 | 
 | ||||||
|         return task.id |         return task.id | ||||||
| 
 | 
 | ||||||
| @ -315,7 +309,10 @@ class TaskScheduler: | |||||||
|         """ |         """ | ||||||
|         assert self._run_background_tasks |         assert self._run_background_tasks | ||||||
| 
 | 
 | ||||||
|         assert task.action in self._actions |         if task.action not in self._actions: | ||||||
|  |             raise Exception( | ||||||
|  |                 f"No function associated with action {task.action} of the scheduled task {task.id}" | ||||||
|  |             ) | ||||||
|         function = self._actions[task.action] |         function = self._actions[task.action] | ||||||
| 
 | 
 | ||||||
|         async def wrapper() -> None: |         async def wrapper() -> None: | ||||||
|  | |||||||
| @ -30,6 +30,7 @@ from synapse.server import HomeServer | |||||||
| from synapse.storage.databases.main.appservice import _make_exclusive_regex | from synapse.storage.databases.main.appservice import _make_exclusive_regex | ||||||
| from synapse.types import JsonDict, create_requester | from synapse.types import JsonDict, create_requester | ||||||
| from synapse.util import Clock | from synapse.util import Clock | ||||||
|  | from synapse.util.task_scheduler import TaskScheduler | ||||||
| 
 | 
 | ||||||
| from tests import unittest | from tests import unittest | ||||||
| from tests.unittest import override_config | from tests.unittest import override_config | ||||||
| @ -49,6 +50,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): | |||||||
|         assert isinstance(handler, DeviceHandler) |         assert isinstance(handler, DeviceHandler) | ||||||
|         self.handler = handler |         self.handler = handler | ||||||
|         self.store = hs.get_datastores().main |         self.store = hs.get_datastores().main | ||||||
|  |         self.device_message_handler = hs.get_device_message_handler() | ||||||
|         return hs |         return hs | ||||||
| 
 | 
 | ||||||
|     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: |     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | ||||||
| @ -211,6 +213,51 @@ class DeviceTestCase(unittest.HomeserverTestCase): | |||||||
|         ) |         ) | ||||||
|         self.assertIsNone(res) |         self.assertIsNone(res) | ||||||
| 
 | 
 | ||||||
|  |     def test_delete_device_and_big_device_inbox(self) -> None: | ||||||
|  |         """Check that deleting a big device inbox is staged and batched asynchronously.""" | ||||||
|  |         DEVICE_ID = "abc" | ||||||
|  |         sender = "@sender:" + self.hs.hostname | ||||||
|  |         receiver = "@receiver:" + self.hs.hostname | ||||||
|  |         self._record_user(sender, DEVICE_ID, DEVICE_ID) | ||||||
|  |         self._record_user(receiver, DEVICE_ID, DEVICE_ID) | ||||||
|  | 
 | ||||||
|  |         # queue a bunch of messages in the inbox | ||||||
|  |         requester = create_requester(sender, device_id=DEVICE_ID) | ||||||
|  |         for i in range(0, DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT + 10): | ||||||
|  |             self.get_success( | ||||||
|  |                 self.device_message_handler.send_device_message( | ||||||
|  |                     requester, "message_type", {receiver: {"*": {"val": i}}} | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         # delete the device | ||||||
|  |         self.get_success(self.handler.delete_devices(receiver, [DEVICE_ID])) | ||||||
|  | 
 | ||||||
|  |         # messages should be deleted up to DEVICE_MSGS_DELETE_BATCH_LIMIT straight away | ||||||
|  |         res = self.get_success( | ||||||
|  |             self.store.db_pool.simple_select_list( | ||||||
|  |                 table="device_inbox", | ||||||
|  |                 keyvalues={"user_id": receiver}, | ||||||
|  |                 retcols=("user_id", "device_id", "stream_id"), | ||||||
|  |                 desc="get_device_id_from_device_inbox", | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(10, len(res)) | ||||||
|  | 
 | ||||||
|  |         # wait for the task scheduler to do a second delete pass | ||||||
|  |         self.reactor.advance(TaskScheduler.SCHEDULE_INTERVAL_MS / 1000) | ||||||
|  | 
 | ||||||
|  |         # remaining messages should now be deleted | ||||||
|  |         res = self.get_success( | ||||||
|  |             self.store.db_pool.simple_select_list( | ||||||
|  |                 table="device_inbox", | ||||||
|  |                 keyvalues={"user_id": receiver}, | ||||||
|  |                 retcols=("user_id", "device_id", "stream_id"), | ||||||
|  |                 desc="get_device_id_from_device_inbox", | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(0, len(res)) | ||||||
|  | 
 | ||||||
|     def test_update_device(self) -> None: |     def test_update_device(self) -> None: | ||||||
|         self._record_users() |         self._record_users() | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user