Add ability to wait for locks and add locks to purge history / room deletion (#15791)

c.f. #13476
This commit is contained in:
Erik Johnston 2023-07-31 10:58:03 +01:00 committed by GitHub
parent 0c6142c4a1
commit ae55cc1e6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 784 additions and 109 deletions

1
changelog.d/15791.bugfix Normal file
View File

@ -0,0 +1 @@
Fix bug where purging history and paginating simultaneously could lead to database corruption when using workers.

View File

@ -63,6 +63,7 @@ from synapse.federation.federation_base import (
) )
from synapse.federation.persistence import TransactionActions from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction from synapse.federation.units import Edu, Transaction
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
from synapse.http.servlet import assert_params_in_dict from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import ( from synapse.logging.context import (
make_deferred_yieldable, make_deferred_yieldable,
@ -137,6 +138,7 @@ class FederationServer(FederationBase):
self._event_auth_handler = hs.get_event_auth_handler() self._event_auth_handler = hs.get_event_auth_handler()
self._room_member_handler = hs.get_room_member_handler() self._room_member_handler = hs.get_room_member_handler()
self._e2e_keys_handler = hs.get_e2e_keys_handler() self._e2e_keys_handler = hs.get_e2e_keys_handler()
self._worker_lock_handler = hs.get_worker_locks_handler()
self._state_storage_controller = hs.get_storage_controllers().state self._state_storage_controller = hs.get_storage_controllers().state
@ -1236,9 +1238,18 @@ class FederationServer(FederationBase):
logger.info("handling received PDU in room %s: %s", room_id, event) logger.info("handling received PDU in room %s: %s", room_id, event)
try: try:
with nested_logging_context(event.event_id): with nested_logging_context(event.event_id):
await self._federation_event_handler.on_receive_pdu( # We're taking out a lock within a lock, which could
origin, event # lead to deadlocks if we're not careful. However, it is
) # safe on this occasion as we only ever take a write
# lock when deleting a room, which we would never do
# while holding the `_INBOUND_EVENT_HANDLING_LOCK_NAME`
# lock.
async with self._worker_lock_handler.acquire_read_write_lock(
DELETE_ROOM_LOCK_NAME, room_id, write=False
):
await self._federation_event_handler.on_receive_pdu(
origin, event
)
except FederationError as e: except FederationError as e:
# XXX: Ideally we'd inform the remote we failed to process # XXX: Ideally we'd inform the remote we failed to process
# the event, but we can't return an error in the transaction # the event, but we can't return an error in the transaction

View File

@ -53,6 +53,7 @@ from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler from synapse.handlers.directory import DirectoryHandler
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
from synapse.logging import opentracing from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
@ -485,6 +486,7 @@ class EventCreationHandler:
self._events_shard_config = self.config.worker.events_shard_config self._events_shard_config = self.config.worker.events_shard_config
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
self._notifier = hs.get_notifier() self._notifier = hs.get_notifier()
self._worker_lock_handler = hs.get_worker_locks_handler()
self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
@ -1010,6 +1012,37 @@ class EventCreationHandler:
event.internal_metadata.stream_ordering, event.internal_metadata.stream_ordering,
) )
async with self._worker_lock_handler.acquire_read_write_lock(
DELETE_ROOM_LOCK_NAME, room_id, write=False
):
return await self._create_and_send_nonmember_event_locked(
requester=requester,
event_dict=event_dict,
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
state_event_ids=state_event_ids,
ratelimit=ratelimit,
txn_id=txn_id,
ignore_shadow_ban=ignore_shadow_ban,
outlier=outlier,
depth=depth,
)
async def _create_and_send_nonmember_event_locked(
self,
requester: Requester,
event_dict: dict,
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
state_event_ids: Optional[List[str]] = None,
ratelimit: bool = True,
txn_id: Optional[str] = None,
ignore_shadow_ban: bool = False,
outlier: bool = False,
depth: Optional[int] = None,
) -> Tuple[EventBase, int]:
room_id = event_dict["room_id"]
# If we don't have any prev event IDs specified then we need to # If we don't have any prev event IDs specified then we need to
# check that the host is in the room (as otherwise populating the # check that the host is in the room (as otherwise populating the
# prev events will fail), at which point we may as well check the # prev events will fail), at which point we may as well check the
@ -1923,7 +1956,10 @@ class EventCreationHandler:
) )
for room_id in room_ids: for room_id in room_ids:
dummy_event_sent = await self._send_dummy_event_for_room(room_id) async with self._worker_lock_handler.acquire_read_write_lock(
DELETE_ROOM_LOCK_NAME, room_id, write=False
):
dummy_event_sent = await self._send_dummy_event_for_room(room_id)
if not dummy_event_sent: if not dummy_event_sent:
# Did not find a valid user in the room, so remove from future attempts # Did not find a valid user in the room, so remove from future attempts

View File

@ -46,6 +46,11 @@ logger = logging.getLogger(__name__)
BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD = 3 BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD = 3
PURGE_HISTORY_LOCK_NAME = "purge_history_lock"
DELETE_ROOM_LOCK_NAME = "delete_room_lock"
@attr.s(slots=True, auto_attribs=True) @attr.s(slots=True, auto_attribs=True)
class PurgeStatus: class PurgeStatus:
"""Object tracking the status of a purge request """Object tracking the status of a purge request
@ -142,6 +147,7 @@ class PaginationHandler:
self._server_name = hs.hostname self._server_name = hs.hostname
self._room_shutdown_handler = hs.get_room_shutdown_handler() self._room_shutdown_handler = hs.get_room_shutdown_handler()
self._relations_handler = hs.get_relations_handler() self._relations_handler = hs.get_relations_handler()
self._worker_locks = hs.get_worker_locks_handler()
self.pagination_lock = ReadWriteLock() self.pagination_lock = ReadWriteLock()
# IDs of rooms in which there currently an active purge *or delete* operation. # IDs of rooms in which there currently an active purge *or delete* operation.
@ -356,7 +362,9 @@ class PaginationHandler:
""" """
self._purges_in_progress_by_room.add(room_id) self._purges_in_progress_by_room.add(room_id)
try: try:
async with self.pagination_lock.write(room_id): async with self._worker_locks.acquire_read_write_lock(
PURGE_HISTORY_LOCK_NAME, room_id, write=True
):
await self._storage_controllers.purge_events.purge_history( await self._storage_controllers.purge_events.purge_history(
room_id, token, delete_local_events room_id, token, delete_local_events
) )
@ -412,7 +420,10 @@ class PaginationHandler:
room_id: room to be purged room_id: room to be purged
force: set true to skip checking for joined users. force: set true to skip checking for joined users.
""" """
async with self.pagination_lock.write(room_id): async with self._worker_locks.acquire_multi_read_write_lock(
[(PURGE_HISTORY_LOCK_NAME, room_id), (DELETE_ROOM_LOCK_NAME, room_id)],
write=True,
):
# first check that we have no users in this room # first check that we have no users in this room
if not force: if not force:
joined = await self.store.is_host_joined(room_id, self._server_name) joined = await self.store.is_host_joined(room_id, self._server_name)
@ -471,7 +482,9 @@ class PaginationHandler:
room_token = from_token.room_key room_token = from_token.room_key
async with self.pagination_lock.read(room_id): async with self._worker_locks.acquire_read_write_lock(
PURGE_HISTORY_LOCK_NAME, room_id, write=False
):
(membership, member_event_id) = (None, None) (membership, member_event_id) = (None, None)
if not use_admin_priviledge: if not use_admin_priviledge:
( (
@ -747,7 +760,9 @@ class PaginationHandler:
self._purges_in_progress_by_room.add(room_id) self._purges_in_progress_by_room.add(room_id)
try: try:
async with self.pagination_lock.write(room_id): async with self._worker_locks.acquire_read_write_lock(
PURGE_HISTORY_LOCK_NAME, room_id, write=True
):
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN
self._delete_by_id[ self._delete_by_id[
delete_id delete_id

View File

@ -39,6 +39,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
from synapse.logging import opentracing from synapse.logging import opentracing
from synapse.metrics import event_processing_positions from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
@ -94,6 +95,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.account_data_handler = hs.get_account_data_handler() self.account_data_handler = hs.get_account_data_handler()
self.event_auth_handler = hs.get_event_auth_handler() self.event_auth_handler = hs.get_event_auth_handler()
self._worker_lock_handler = hs.get_worker_locks_handler()
self.member_linearizer: Linearizer = Linearizer(name="member") self.member_linearizer: Linearizer = Linearizer(name="member")
self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter") self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter")
@ -638,26 +640,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# by application services), and then by room ID. # by application services), and then by room ID.
async with self.member_as_limiter.queue(as_id): async with self.member_as_limiter.queue(as_id):
async with self.member_linearizer.queue(key): async with self.member_linearizer.queue(key):
with opentracing.start_active_span("update_membership_locked"): async with self._worker_lock_handler.acquire_read_write_lock(
result = await self.update_membership_locked( DELETE_ROOM_LOCK_NAME, room_id, write=False
requester, ):
target, with opentracing.start_active_span("update_membership_locked"):
room_id, result = await self.update_membership_locked(
action, requester,
txn_id=txn_id, target,
remote_room_hosts=remote_room_hosts, room_id,
third_party_signed=third_party_signed, action,
ratelimit=ratelimit, txn_id=txn_id,
content=content, remote_room_hosts=remote_room_hosts,
new_room=new_room, third_party_signed=third_party_signed,
require_consent=require_consent, ratelimit=ratelimit,
outlier=outlier, content=content,
allow_no_prev_events=allow_no_prev_events, new_room=new_room,
prev_event_ids=prev_event_ids, require_consent=require_consent,
state_event_ids=state_event_ids, outlier=outlier,
depth=depth, allow_no_prev_events=allow_no_prev_events,
origin_server_ts=origin_server_ts, prev_event_ids=prev_event_ids,
) state_event_ids=state_event_ids,
depth=depth,
origin_server_ts=origin_server_ts,
)
return result return result

View File

@ -0,0 +1,333 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from types import TracebackType
from typing import (
TYPE_CHECKING,
AsyncContextManager,
Collection,
Dict,
Optional,
Tuple,
Type,
Union,
)
from weakref import WeakSet
import attr
from twisted.internet import defer
from twisted.internet.interfaces import IReactorTime
from synapse.logging.context import PreserveLoggingContext
from synapse.logging.opentracing import start_active_span
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.databases.main.lock import Lock, LockStore
from synapse.util.async_helpers import timeout_deferred
if TYPE_CHECKING:
from synapse.logging.opentracing import opentracing
from synapse.server import HomeServer
DELETE_ROOM_LOCK_NAME = "delete_room_lock"
class WorkerLocksHandler:
"""A class for waiting on taking out locks, rather than using the storage
functions directly (which don't support awaiting).
"""
def __init__(self, hs: "HomeServer") -> None:
self._reactor = hs.get_reactor()
self._store = hs.get_datastores().main
self._clock = hs.get_clock()
self._notifier = hs.get_notifier()
self._instance_name = hs.get_instance_name()
# Map from lock name/key to set of `WaitingLock` that are active for
# that lock.
self._locks: Dict[
Tuple[str, str], WeakSet[Union[WaitingLock, WaitingMultiLock]]
] = {}
self._clock.looping_call(self._cleanup_locks, 30_000)
self._notifier.add_lock_released_callback(self._on_lock_released)
def acquire_lock(self, lock_name: str, lock_key: str) -> "WaitingLock":
"""Acquire a standard lock, returns a context manager that will block
until the lock is acquired.
Note: Care must be taken to avoid deadlocks. In particular, this
function does *not* timeout.
Usage:
async with handler.acquire_lock(name, key):
# Do work while holding the lock...
"""
lock = WaitingLock(
reactor=self._reactor,
store=self._store,
handler=self,
lock_name=lock_name,
lock_key=lock_key,
write=None,
)
self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)
return lock
def acquire_read_write_lock(
self,
lock_name: str,
lock_key: str,
*,
write: bool,
) -> "WaitingLock":
"""Acquire a read/write lock, returns a context manager that will block
until the lock is acquired.
Note: Care must be taken to avoid deadlocks. In particular, this
function does *not* timeout.
Usage:
async with handler.acquire_read_write_lock(name, key, write=True):
# Do work while holding the lock...
"""
lock = WaitingLock(
reactor=self._reactor,
store=self._store,
handler=self,
lock_name=lock_name,
lock_key=lock_key,
write=write,
)
self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)
return lock
def acquire_multi_read_write_lock(
self,
lock_names: Collection[Tuple[str, str]],
*,
write: bool,
) -> "WaitingMultiLock":
"""Acquires multi read/write locks at once, returns a context manager
that will block until all the locks are acquired.
This will try and acquire all locks at once, and will never hold on to a
subset of the locks. (This avoids accidentally creating deadlocks).
Note: Care must be taken to avoid deadlocks. In particular, this
function does *not* timeout.
"""
lock = WaitingMultiLock(
lock_names=lock_names,
write=write,
reactor=self._reactor,
store=self._store,
handler=self,
)
for lock_name, lock_key in lock_names:
self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)
return lock
def notify_lock_released(self, lock_name: str, lock_key: str) -> None:
"""Notify that a lock has been released.
Pokes both the notifier and replication.
"""
self._notifier.notify_lock_released(self._instance_name, lock_name, lock_key)
def _on_lock_released(
self, instance_name: str, lock_name: str, lock_key: str
) -> None:
"""Called when a lock has been released.
Wakes up any locks that might be waiting on this.
"""
locks = self._locks.get((lock_name, lock_key))
if not locks:
return
def _wake_deferred(deferred: defer.Deferred) -> None:
if not deferred.called:
deferred.callback(None)
for lock in locks:
self._clock.call_later(0, _wake_deferred, lock.deferred)
@wrap_as_background_process("_cleanup_locks")
async def _cleanup_locks(self) -> None:
"""Periodically cleans out stale entries in the locks map"""
self._locks = {key: value for key, value in self._locks.items() if value}
@attr.s(auto_attribs=True, eq=False)
class WaitingLock:
reactor: IReactorTime
store: LockStore
handler: WorkerLocksHandler
lock_name: str
lock_key: str
write: Optional[bool]
deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred)
_inner_lock: Optional[Lock] = None
_retry_interval: float = 0.1
_lock_span: "opentracing.Scope" = attr.Factory(
lambda: start_active_span("WaitingLock.lock")
)
async def __aenter__(self) -> None:
self._lock_span.__enter__()
with start_active_span("WaitingLock.waiting_for_lock"):
while self._inner_lock is None:
self.deferred = defer.Deferred()
if self.write is not None:
lock = await self.store.try_acquire_read_write_lock(
self.lock_name, self.lock_key, write=self.write
)
else:
lock = await self.store.try_acquire_lock(
self.lock_name, self.lock_key
)
if lock:
self._inner_lock = lock
break
try:
# Wait until the we get notified the lock might have been
# released (by the deferred being resolved). We also
# periodically wake up in case the lock was released but we
# weren't notified.
with PreserveLoggingContext():
await timeout_deferred(
deferred=self.deferred,
timeout=self._get_next_retry_interval(),
reactor=self.reactor,
)
except Exception:
pass
return await self._inner_lock.__aenter__()
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> Optional[bool]:
assert self._inner_lock
self.handler.notify_lock_released(self.lock_name, self.lock_key)
try:
r = await self._inner_lock.__aexit__(exc_type, exc, tb)
finally:
self._lock_span.__exit__(exc_type, exc, tb)
return r
def _get_next_retry_interval(self) -> float:
next = self._retry_interval
self._retry_interval = max(5, next * 2)
return next * random.uniform(0.9, 1.1)
@attr.s(auto_attribs=True, eq=False)
class WaitingMultiLock:
lock_names: Collection[Tuple[str, str]]
write: bool
reactor: IReactorTime
store: LockStore
handler: WorkerLocksHandler
deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred)
_inner_lock_cm: Optional[AsyncContextManager] = None
_retry_interval: float = 0.1
_lock_span: "opentracing.Scope" = attr.Factory(
lambda: start_active_span("WaitingLock.lock")
)
async def __aenter__(self) -> None:
self._lock_span.__enter__()
with start_active_span("WaitingLock.waiting_for_lock"):
while self._inner_lock_cm is None:
self.deferred = defer.Deferred()
lock_cm = await self.store.try_acquire_multi_read_write_lock(
self.lock_names, write=self.write
)
if lock_cm:
self._inner_lock_cm = lock_cm
break
try:
# Wait until the we get notified the lock might have been
# released (by the deferred being resolved). We also
# periodically wake up in case the lock was released but we
# weren't notified.
with PreserveLoggingContext():
await timeout_deferred(
deferred=self.deferred,
timeout=self._get_next_retry_interval(),
reactor=self.reactor,
)
except Exception:
pass
assert self._inner_lock_cm
await self._inner_lock_cm.__aenter__()
return
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> Optional[bool]:
assert self._inner_lock_cm
for lock_name, lock_key in self.lock_names:
self.handler.notify_lock_released(lock_name, lock_key)
try:
r = await self._inner_lock_cm.__aexit__(exc_type, exc, tb)
finally:
self._lock_span.__exit__(exc_type, exc, tb)
return r
def _get_next_retry_interval(self) -> float:
next = self._retry_interval
self._retry_interval = max(5, next * 2)
return next * random.uniform(0.9, 1.1)

View File

@ -234,6 +234,9 @@ class Notifier:
self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
# List of callbacks to be notified when a lock is released
self._lock_released_callback: List[Callable[[str, str, str], None]] = []
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
self._pusher_pool = hs.get_pusherpool() self._pusher_pool = hs.get_pusherpool()
@ -785,6 +788,19 @@ class Notifier:
# that any in flight requests can be immediately retried. # that any in flight requests can be immediately retried.
self._federation_client.wake_destination(server) self._federation_client.wake_destination(server)
def add_lock_released_callback(
self, callback: Callable[[str, str, str], None]
) -> None:
"""Add a function to be called whenever we are notified about a released lock."""
self._lock_released_callback.append(callback)
def notify_lock_released(
self, instance_name: str, lock_name: str, lock_key: str
) -> None:
"""Notify the callbacks that a lock has been released."""
for cb in self._lock_released_callback:
cb(instance_name, lock_name, lock_key)
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class ReplicationNotifier: class ReplicationNotifier:

View File

@ -422,6 +422,36 @@ class RemoteServerUpCommand(_SimpleCommand):
NAME = "REMOTE_SERVER_UP" NAME = "REMOTE_SERVER_UP"
class LockReleasedCommand(Command):
"""Sent to inform other instances that a given lock has been dropped.
Format::
LOCK_RELEASED ["<instance_name>", "<lock_name>", "<lock_key>"]
"""
NAME = "LOCK_RELEASED"
def __init__(
self,
instance_name: str,
lock_name: str,
lock_key: str,
):
self.instance_name = instance_name
self.lock_name = lock_name
self.lock_key = lock_key
@classmethod
def from_line(cls: Type["LockReleasedCommand"], line: str) -> "LockReleasedCommand":
instance_name, lock_name, lock_key = json_decoder.decode(line)
return cls(instance_name, lock_name, lock_key)
def to_line(self) -> str:
return json_encoder.encode([self.instance_name, self.lock_name, self.lock_key])
_COMMANDS: Tuple[Type[Command], ...] = ( _COMMANDS: Tuple[Type[Command], ...] = (
ServerCommand, ServerCommand,
RdataCommand, RdataCommand,
@ -435,6 +465,7 @@ _COMMANDS: Tuple[Type[Command], ...] = (
UserIpCommand, UserIpCommand,
RemoteServerUpCommand, RemoteServerUpCommand,
ClearUserSyncsCommand, ClearUserSyncsCommand,
LockReleasedCommand,
) )
# Map of command name to command type. # Map of command name to command type.
@ -448,6 +479,7 @@ VALID_SERVER_COMMANDS = (
ErrorCommand.NAME, ErrorCommand.NAME,
PingCommand.NAME, PingCommand.NAME,
RemoteServerUpCommand.NAME, RemoteServerUpCommand.NAME,
LockReleasedCommand.NAME,
) )
# The commands the client is allowed to send # The commands the client is allowed to send
@ -461,6 +493,7 @@ VALID_CLIENT_COMMANDS = (
UserIpCommand.NAME, UserIpCommand.NAME,
ErrorCommand.NAME, ErrorCommand.NAME,
RemoteServerUpCommand.NAME, RemoteServerUpCommand.NAME,
LockReleasedCommand.NAME,
) )

View File

@ -39,6 +39,7 @@ from synapse.replication.tcp.commands import (
ClearUserSyncsCommand, ClearUserSyncsCommand,
Command, Command,
FederationAckCommand, FederationAckCommand,
LockReleasedCommand,
PositionCommand, PositionCommand,
RdataCommand, RdataCommand,
RemoteServerUpCommand, RemoteServerUpCommand,
@ -248,6 +249,9 @@ class ReplicationCommandHandler:
if self._is_master or self._should_insert_client_ips: if self._is_master or self._should_insert_client_ips:
self.subscribe_to_channel("USER_IP") self.subscribe_to_channel("USER_IP")
if hs.config.redis.redis_enabled:
self._notifier.add_lock_released_callback(self.on_lock_released)
def subscribe_to_channel(self, channel_name: str) -> None: def subscribe_to_channel(self, channel_name: str) -> None:
""" """
Indicates that we wish to subscribe to a Redis channel by name. Indicates that we wish to subscribe to a Redis channel by name.
@ -648,6 +652,17 @@ class ReplicationCommandHandler:
self._notifier.notify_remote_server_up(cmd.data) self._notifier.notify_remote_server_up(cmd.data)
def on_LOCK_RELEASED(
self, conn: IReplicationConnection, cmd: LockReleasedCommand
) -> None:
"""Called when we get a new LOCK_RELEASED command."""
if cmd.instance_name == self._instance_name:
return
self._notifier.notify_lock_released(
cmd.instance_name, cmd.lock_name, cmd.lock_key
)
def new_connection(self, connection: IReplicationConnection) -> None: def new_connection(self, connection: IReplicationConnection) -> None:
"""Called when we have a new connection.""" """Called when we have a new connection."""
self._connections.append(connection) self._connections.append(connection)
@ -754,6 +769,13 @@ class ReplicationCommandHandler:
""" """
self.send_command(RdataCommand(stream_name, self._instance_name, token, data)) self.send_command(RdataCommand(stream_name, self._instance_name, token, data))
def on_lock_released(
self, instance_name: str, lock_name: str, lock_key: str
) -> None:
"""Called when we released a lock and should notify other instances."""
if instance_name == self._instance_name:
self.send_command(LockReleasedCommand(instance_name, lock_name, lock_key))
UpdateToken = TypeVar("UpdateToken") UpdateToken = TypeVar("UpdateToken")
UpdateRow = TypeVar("UpdateRow") UpdateRow = TypeVar("UpdateRow")

View File

@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, ShadowBanError, SynapseError from synapse.api.errors import Codes, ShadowBanError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
@ -60,6 +61,7 @@ class RoomUpgradeRestServlet(RestServlet):
self._hs = hs self._hs = hs
self._room_creation_handler = hs.get_room_creation_handler() self._room_creation_handler = hs.get_room_creation_handler()
self._auth = hs.get_auth() self._auth = hs.get_auth()
self._worker_lock_handler = hs.get_worker_locks_handler()
async def on_POST( async def on_POST(
self, request: SynapseRequest, room_id: str self, request: SynapseRequest, room_id: str
@ -78,9 +80,12 @@ class RoomUpgradeRestServlet(RestServlet):
) )
try: try:
new_room_id = await self._room_creation_handler.upgrade_room( async with self._worker_lock_handler.acquire_read_write_lock(
requester, room_id, new_version DELETE_ROOM_LOCK_NAME, room_id, write=False
) ):
new_room_id = await self._room_creation_handler.upgrade_room(
requester, room_id, new_version
)
except ShadowBanError: except ShadowBanError:
# Generate a random room ID. # Generate a random room ID.
new_room_id = stringutils.random_string(18) new_room_id = stringutils.random_string(18)

View File

@ -107,6 +107,7 @@ from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
from synapse.handlers.user_directory import UserDirectoryHandler from synapse.handlers.user_directory import UserDirectoryHandler
from synapse.handlers.worker_lock import WorkerLocksHandler
from synapse.http.client import ( from synapse.http.client import (
InsecureInterceptableContextFactory, InsecureInterceptableContextFactory,
ReplicationClient, ReplicationClient,
@ -912,3 +913,7 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_common_usage_metrics_manager(self) -> CommonUsageMetricsManager: def get_common_usage_metrics_manager(self) -> CommonUsageMetricsManager:
"""Usage metrics shared between phone home stats and the prometheus exporter.""" """Usage metrics shared between phone home stats and the prometheus exporter."""
return CommonUsageMetricsManager(self) return CommonUsageMetricsManager(self)
@cache_in_self
def get_worker_locks_handler(self) -> WorkerLocksHandler:
return WorkerLocksHandler(self)

View File

@ -45,6 +45,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.logging.opentracing import ( from synapse.logging.opentracing import (
SynapseTags, SynapseTags,
@ -338,6 +339,7 @@ class EventsPersistenceStorageController:
) )
self._state_resolution_handler = hs.get_state_resolution_handler() self._state_resolution_handler = hs.get_state_resolution_handler()
self._state_controller = state_controller self._state_controller = state_controller
self.hs = hs
async def _process_event_persist_queue_task( async def _process_event_persist_queue_task(
self, self,
@ -350,15 +352,22 @@ class EventsPersistenceStorageController:
A dictionary of event ID to event ID we didn't persist as we already A dictionary of event ID to event ID we didn't persist as we already
had another event persisted with the same TXN ID. had another event persisted with the same TXN ID.
""" """
if isinstance(task, _PersistEventsTask):
return await self._persist_event_batch(room_id, task) # Ensure that the room can't be deleted while we're persisting events to
elif isinstance(task, _UpdateCurrentStateTask): # it. We might already have taken out the lock, but since this is just a
await self._update_current_state(room_id, task) # "read" lock its inherently reentrant.
return {} async with self.hs.get_worker_locks_handler().acquire_read_write_lock(
else: DELETE_ROOM_LOCK_NAME, room_id, write=False
raise AssertionError( ):
f"Found an unexpected task type in event persistence queue: {task}" if isinstance(task, _PersistEventsTask):
) return await self._persist_event_batch(room_id, task)
elif isinstance(task, _UpdateCurrentStateTask):
await self._update_current_state(room_id, task)
return {}
else:
raise AssertionError(
f"Found an unexpected task type in event persistence queue: {task}"
)
@trace @trace
async def persist_events( async def persist_events(

View File

@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from contextlib import AsyncExitStack
from types import TracebackType from types import TracebackType
from typing import TYPE_CHECKING, Optional, Set, Tuple, Type from typing import TYPE_CHECKING, Collection, Optional, Set, Tuple, Type
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
from twisted.internet.interfaces import IReactorCore from twisted.internet.interfaces import IReactorCore
@ -208,77 +209,86 @@ class LockStore(SQLBaseStore):
used (otherwise the lock will leak). used (otherwise the lock will leak).
""" """
now = self._clock.time_msec()
token = random_string(6)
def _try_acquire_read_write_lock_txn(txn: LoggingTransaction) -> None:
# We attempt to acquire the lock by inserting into
# `worker_read_write_locks` and seeing if that fails any
# constraints. If it doesn't then we have acquired the lock,
# otherwise we haven't.
#
# Before that though we clear the table of any stale locks.
delete_sql = """
DELETE FROM worker_read_write_locks
WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
"""
insert_sql = """
INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
VALUES (?, ?, ?, ?, ?, ?)
"""
if isinstance(self.database_engine, PostgresEngine):
# For Postgres we can send these queries at the same time.
txn.execute(
delete_sql + ";" + insert_sql,
(
# DELETE args
now - _LOCK_TIMEOUT_MS,
lock_name,
lock_key,
# UPSERT args
lock_name,
lock_key,
write,
self._instance_name,
token,
now,
),
)
else:
# For SQLite these need to be two queries.
txn.execute(
delete_sql,
(
now - _LOCK_TIMEOUT_MS,
lock_name,
lock_key,
),
)
txn.execute(
insert_sql,
(
lock_name,
lock_key,
write,
self._instance_name,
token,
now,
),
)
return
try: try:
await self.db_pool.runInteraction( lock = await self.db_pool.runInteraction(
"try_acquire_read_write_lock", "try_acquire_read_write_lock",
_try_acquire_read_write_lock_txn, self._try_acquire_read_write_lock_txn,
lock_name,
lock_key,
write,
) )
except self.database_engine.module.IntegrityError: except self.database_engine.module.IntegrityError:
return None return None
return lock
def _try_acquire_read_write_lock_txn(
self,
txn: LoggingTransaction,
lock_name: str,
lock_key: str,
write: bool,
) -> "Lock":
# We attempt to acquire the lock by inserting into
# `worker_read_write_locks` and seeing if that fails any
# constraints. If it doesn't then we have acquired the lock,
# otherwise we haven't.
#
# Before that though we clear the table of any stale locks.
now = self._clock.time_msec()
token = random_string(6)
delete_sql = """
DELETE FROM worker_read_write_locks
WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
"""
insert_sql = """
INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
VALUES (?, ?, ?, ?, ?, ?)
"""
if isinstance(self.database_engine, PostgresEngine):
# For Postgres we can send these queries at the same time.
txn.execute(
delete_sql + ";" + insert_sql,
(
# DELETE args
now - _LOCK_TIMEOUT_MS,
lock_name,
lock_key,
# UPSERT args
lock_name,
lock_key,
write,
self._instance_name,
token,
now,
),
)
else:
# For SQLite these need to be two queries.
txn.execute(
delete_sql,
(
now - _LOCK_TIMEOUT_MS,
lock_name,
lock_key,
),
)
txn.execute(
insert_sql,
(
lock_name,
lock_key,
write,
self._instance_name,
token,
now,
),
)
lock = Lock( lock = Lock(
self._reactor, self._reactor,
self._clock, self._clock,
@ -289,10 +299,58 @@ class LockStore(SQLBaseStore):
token=token, token=token,
) )
self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock def set_lock() -> None:
self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock
txn.call_after(set_lock)
return lock return lock
async def try_acquire_multi_read_write_lock(
self,
lock_names: Collection[Tuple[str, str]],
write: bool,
) -> Optional[AsyncExitStack]:
"""Try to acquire multiple locks for the given names/keys. Will return
an async context manager if the locks are successfully acquired, which
*must* be used (otherwise the lock will leak).
If only a subset of the locks can be acquired then it will immediately
drop them and return `None`.
"""
try:
locks = await self.db_pool.runInteraction(
"try_acquire_multi_read_write_lock",
self._try_acquire_multi_read_write_lock_txn,
lock_names,
write,
)
except self.database_engine.module.IntegrityError:
return None
stack = AsyncExitStack()
for lock in locks:
await stack.enter_async_context(lock)
return stack
def _try_acquire_multi_read_write_lock_txn(
self,
txn: LoggingTransaction,
lock_names: Collection[Tuple[str, str]],
write: bool,
) -> Collection["Lock"]:
locks = []
for lock_name, lock_key in lock_names:
lock = self._try_acquire_read_write_lock_txn(
txn, lock_name, lock_key, write
)
locks.append(lock)
return locks
class Lock: class Lock:
"""An async context manager that manages an acquired lock, ensuring it is """An async context manager that manages an acquired lock, ensuring it is

View File

@ -0,0 +1,74 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.replication._base import BaseMultiWorkerStreamTestCase
class WorkerLockTestCase(unittest.HomeserverTestCase):
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.worker_lock_handler = self.hs.get_worker_locks_handler()
def test_wait_for_lock_locally(self) -> None:
"""Test waiting for a lock on a single worker"""
lock1 = self.worker_lock_handler.acquire_lock("name", "key")
self.get_success(lock1.__aenter__())
lock2 = self.worker_lock_handler.acquire_lock("name", "key")
d2 = defer.ensureDeferred(lock2.__aenter__())
self.assertNoResult(d2)
self.get_success(lock1.__aexit__(None, None, None))
self.get_success(d2)
self.get_success(lock2.__aexit__(None, None, None))
class WorkerLockWorkersTestCase(BaseMultiWorkerStreamTestCase):
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.main_worker_lock_handler = self.hs.get_worker_locks_handler()
def test_wait_for_lock_worker(self) -> None:
"""Test waiting for a lock on another worker"""
worker = self.make_worker_hs(
"synapse.app.generic_worker",
extra_config={
"redis": {"enabled": True},
},
)
worker_lock_handler = worker.get_worker_locks_handler()
lock1 = self.main_worker_lock_handler.acquire_lock("name", "key")
self.get_success(lock1.__aenter__())
lock2 = worker_lock_handler.acquire_lock("name", "key")
d2 = defer.ensureDeferred(lock2.__aenter__())
self.assertNoResult(d2)
self.get_success(lock1.__aexit__(None, None, None))
self.get_success(d2)
self.get_success(lock2.__aexit__(None, None, None))

View File

@ -711,7 +711,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body) self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None assert channel.resource_usage is not None
self.assertEqual(30, channel.resource_usage.db_txn_count) self.assertEqual(32, channel.resource_usage.db_txn_count)
def test_post_room_initial_state(self) -> None: def test_post_room_initial_state(self) -> None:
# POST with initial_state config key, expect new room id # POST with initial_state config key, expect new room id
@ -724,7 +724,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body) self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None assert channel.resource_usage is not None
self.assertEqual(32, channel.resource_usage.db_txn_count) self.assertEqual(34, channel.resource_usage.db_txn_count)
def test_post_room_visibility_key(self) -> None: def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id # POST with visibility config key, expect new room id

View File

@ -448,3 +448,55 @@ class ReadWriteLockTestCase(unittest.HomeserverTestCase):
self.get_success(self.store._on_shutdown()) self.get_success(self.store._on_shutdown())
self.assertEqual(self.store._live_read_write_lock_tokens, {}) self.assertEqual(self.store._live_read_write_lock_tokens, {})
def test_acquire_multiple_locks(self) -> None:
"""Tests that acquiring multiple locks at once works."""
# Take out multiple locks and ensure that we can't get those locks out
# again.
lock = self.get_success(
self.store.try_acquire_multi_read_write_lock(
[("name1", "key1"), ("name2", "key2")], write=True
)
)
self.assertIsNotNone(lock)
assert lock is not None
self.get_success(lock.__aenter__())
lock2 = self.get_success(
self.store.try_acquire_read_write_lock("name1", "key1", write=True)
)
self.assertIsNone(lock2)
lock3 = self.get_success(
self.store.try_acquire_read_write_lock("name2", "key2", write=False)
)
self.assertIsNone(lock3)
# Overlapping locks attempts will fail, and won't lock any locks.
lock4 = self.get_success(
self.store.try_acquire_multi_read_write_lock(
[("name1", "key1"), ("name3", "key3")], write=True
)
)
self.assertIsNone(lock4)
lock5 = self.get_success(
self.store.try_acquire_read_write_lock("name3", "key3", write=True)
)
self.assertIsNotNone(lock5)
assert lock5 is not None
self.get_success(lock5.__aenter__())
self.get_success(lock5.__aexit__(None, None, None))
# Once we release the lock we can take out the locks again.
self.get_success(lock.__aexit__(None, None, None))
lock6 = self.get_success(
self.store.try_acquire_read_write_lock("name1", "key1", write=True)
)
self.assertIsNotNone(lock6)
assert lock6 is not None
self.get_success(lock6.__aenter__())
self.get_success(lock6.__aexit__(None, None, None))