Add ability to wait for locks and add locks to purge history / room deletion (#15791)
c.f. #13476
This commit is contained in:
parent
0c6142c4a1
commit
ae55cc1e6b
|
@ -0,0 +1 @@
|
||||||
|
Fix bug where purging history and paginating simultaneously could lead to database corruption when using workers.
|
|
@ -63,6 +63,7 @@ from synapse.federation.federation_base import (
|
||||||
)
|
)
|
||||||
from synapse.federation.persistence import TransactionActions
|
from synapse.federation.persistence import TransactionActions
|
||||||
from synapse.federation.units import Edu, Transaction
|
from synapse.federation.units import Edu, Transaction
|
||||||
|
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
|
||||||
from synapse.http.servlet import assert_params_in_dict
|
from synapse.http.servlet import assert_params_in_dict
|
||||||
from synapse.logging.context import (
|
from synapse.logging.context import (
|
||||||
make_deferred_yieldable,
|
make_deferred_yieldable,
|
||||||
|
@ -137,6 +138,7 @@ class FederationServer(FederationBase):
|
||||||
self._event_auth_handler = hs.get_event_auth_handler()
|
self._event_auth_handler = hs.get_event_auth_handler()
|
||||||
self._room_member_handler = hs.get_room_member_handler()
|
self._room_member_handler = hs.get_room_member_handler()
|
||||||
self._e2e_keys_handler = hs.get_e2e_keys_handler()
|
self._e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||||
|
self._worker_lock_handler = hs.get_worker_locks_handler()
|
||||||
|
|
||||||
self._state_storage_controller = hs.get_storage_controllers().state
|
self._state_storage_controller = hs.get_storage_controllers().state
|
||||||
|
|
||||||
|
@ -1236,9 +1238,18 @@ class FederationServer(FederationBase):
|
||||||
logger.info("handling received PDU in room %s: %s", room_id, event)
|
logger.info("handling received PDU in room %s: %s", room_id, event)
|
||||||
try:
|
try:
|
||||||
with nested_logging_context(event.event_id):
|
with nested_logging_context(event.event_id):
|
||||||
await self._federation_event_handler.on_receive_pdu(
|
# We're taking out a lock within a lock, which could
|
||||||
origin, event
|
# lead to deadlocks if we're not careful. However, it is
|
||||||
)
|
# safe on this occasion as we only ever take a write
|
||||||
|
# lock when deleting a room, which we would never do
|
||||||
|
# while holding the `_INBOUND_EVENT_HANDLING_LOCK_NAME`
|
||||||
|
# lock.
|
||||||
|
async with self._worker_lock_handler.acquire_read_write_lock(
|
||||||
|
DELETE_ROOM_LOCK_NAME, room_id, write=False
|
||||||
|
):
|
||||||
|
await self._federation_event_handler.on_receive_pdu(
|
||||||
|
origin, event
|
||||||
|
)
|
||||||
except FederationError as e:
|
except FederationError as e:
|
||||||
# XXX: Ideally we'd inform the remote we failed to process
|
# XXX: Ideally we'd inform the remote we failed to process
|
||||||
# the event, but we can't return an error in the transaction
|
# the event, but we can't return an error in the transaction
|
||||||
|
|
|
@ -53,6 +53,7 @@ from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
|
||||||
from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field
|
from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field
|
||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
from synapse.handlers.directory import DirectoryHandler
|
from synapse.handlers.directory import DirectoryHandler
|
||||||
|
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
|
||||||
from synapse.logging import opentracing
|
from synapse.logging import opentracing
|
||||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
@ -485,6 +486,7 @@ class EventCreationHandler:
|
||||||
self._events_shard_config = self.config.worker.events_shard_config
|
self._events_shard_config = self.config.worker.events_shard_config
|
||||||
self._instance_name = hs.get_instance_name()
|
self._instance_name = hs.get_instance_name()
|
||||||
self._notifier = hs.get_notifier()
|
self._notifier = hs.get_notifier()
|
||||||
|
self._worker_lock_handler = hs.get_worker_locks_handler()
|
||||||
|
|
||||||
self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
|
self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
|
||||||
|
|
||||||
|
@ -1010,6 +1012,37 @@ class EventCreationHandler:
|
||||||
event.internal_metadata.stream_ordering,
|
event.internal_metadata.stream_ordering,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async with self._worker_lock_handler.acquire_read_write_lock(
|
||||||
|
DELETE_ROOM_LOCK_NAME, room_id, write=False
|
||||||
|
):
|
||||||
|
return await self._create_and_send_nonmember_event_locked(
|
||||||
|
requester=requester,
|
||||||
|
event_dict=event_dict,
|
||||||
|
allow_no_prev_events=allow_no_prev_events,
|
||||||
|
prev_event_ids=prev_event_ids,
|
||||||
|
state_event_ids=state_event_ids,
|
||||||
|
ratelimit=ratelimit,
|
||||||
|
txn_id=txn_id,
|
||||||
|
ignore_shadow_ban=ignore_shadow_ban,
|
||||||
|
outlier=outlier,
|
||||||
|
depth=depth,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _create_and_send_nonmember_event_locked(
|
||||||
|
self,
|
||||||
|
requester: Requester,
|
||||||
|
event_dict: dict,
|
||||||
|
allow_no_prev_events: bool = False,
|
||||||
|
prev_event_ids: Optional[List[str]] = None,
|
||||||
|
state_event_ids: Optional[List[str]] = None,
|
||||||
|
ratelimit: bool = True,
|
||||||
|
txn_id: Optional[str] = None,
|
||||||
|
ignore_shadow_ban: bool = False,
|
||||||
|
outlier: bool = False,
|
||||||
|
depth: Optional[int] = None,
|
||||||
|
) -> Tuple[EventBase, int]:
|
||||||
|
room_id = event_dict["room_id"]
|
||||||
|
|
||||||
# If we don't have any prev event IDs specified then we need to
|
# If we don't have any prev event IDs specified then we need to
|
||||||
# check that the host is in the room (as otherwise populating the
|
# check that the host is in the room (as otherwise populating the
|
||||||
# prev events will fail), at which point we may as well check the
|
# prev events will fail), at which point we may as well check the
|
||||||
|
@ -1923,7 +1956,10 @@ class EventCreationHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
for room_id in room_ids:
|
for room_id in room_ids:
|
||||||
dummy_event_sent = await self._send_dummy_event_for_room(room_id)
|
async with self._worker_lock_handler.acquire_read_write_lock(
|
||||||
|
DELETE_ROOM_LOCK_NAME, room_id, write=False
|
||||||
|
):
|
||||||
|
dummy_event_sent = await self._send_dummy_event_for_room(room_id)
|
||||||
|
|
||||||
if not dummy_event_sent:
|
if not dummy_event_sent:
|
||||||
# Did not find a valid user in the room, so remove from future attempts
|
# Did not find a valid user in the room, so remove from future attempts
|
||||||
|
|
|
@ -46,6 +46,11 @@ logger = logging.getLogger(__name__)
|
||||||
BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD = 3
|
BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD = 3
|
||||||
|
|
||||||
|
|
||||||
|
PURGE_HISTORY_LOCK_NAME = "purge_history_lock"
|
||||||
|
|
||||||
|
DELETE_ROOM_LOCK_NAME = "delete_room_lock"
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, auto_attribs=True)
|
@attr.s(slots=True, auto_attribs=True)
|
||||||
class PurgeStatus:
|
class PurgeStatus:
|
||||||
"""Object tracking the status of a purge request
|
"""Object tracking the status of a purge request
|
||||||
|
@ -142,6 +147,7 @@ class PaginationHandler:
|
||||||
self._server_name = hs.hostname
|
self._server_name = hs.hostname
|
||||||
self._room_shutdown_handler = hs.get_room_shutdown_handler()
|
self._room_shutdown_handler = hs.get_room_shutdown_handler()
|
||||||
self._relations_handler = hs.get_relations_handler()
|
self._relations_handler = hs.get_relations_handler()
|
||||||
|
self._worker_locks = hs.get_worker_locks_handler()
|
||||||
|
|
||||||
self.pagination_lock = ReadWriteLock()
|
self.pagination_lock = ReadWriteLock()
|
||||||
# IDs of rooms in which there currently an active purge *or delete* operation.
|
# IDs of rooms in which there currently an active purge *or delete* operation.
|
||||||
|
@ -356,7 +362,9 @@ class PaginationHandler:
|
||||||
"""
|
"""
|
||||||
self._purges_in_progress_by_room.add(room_id)
|
self._purges_in_progress_by_room.add(room_id)
|
||||||
try:
|
try:
|
||||||
async with self.pagination_lock.write(room_id):
|
async with self._worker_locks.acquire_read_write_lock(
|
||||||
|
PURGE_HISTORY_LOCK_NAME, room_id, write=True
|
||||||
|
):
|
||||||
await self._storage_controllers.purge_events.purge_history(
|
await self._storage_controllers.purge_events.purge_history(
|
||||||
room_id, token, delete_local_events
|
room_id, token, delete_local_events
|
||||||
)
|
)
|
||||||
|
@ -412,7 +420,10 @@ class PaginationHandler:
|
||||||
room_id: room to be purged
|
room_id: room to be purged
|
||||||
force: set true to skip checking for joined users.
|
force: set true to skip checking for joined users.
|
||||||
"""
|
"""
|
||||||
async with self.pagination_lock.write(room_id):
|
async with self._worker_locks.acquire_multi_read_write_lock(
|
||||||
|
[(PURGE_HISTORY_LOCK_NAME, room_id), (DELETE_ROOM_LOCK_NAME, room_id)],
|
||||||
|
write=True,
|
||||||
|
):
|
||||||
# first check that we have no users in this room
|
# first check that we have no users in this room
|
||||||
if not force:
|
if not force:
|
||||||
joined = await self.store.is_host_joined(room_id, self._server_name)
|
joined = await self.store.is_host_joined(room_id, self._server_name)
|
||||||
|
@ -471,7 +482,9 @@ class PaginationHandler:
|
||||||
|
|
||||||
room_token = from_token.room_key
|
room_token = from_token.room_key
|
||||||
|
|
||||||
async with self.pagination_lock.read(room_id):
|
async with self._worker_locks.acquire_read_write_lock(
|
||||||
|
PURGE_HISTORY_LOCK_NAME, room_id, write=False
|
||||||
|
):
|
||||||
(membership, member_event_id) = (None, None)
|
(membership, member_event_id) = (None, None)
|
||||||
if not use_admin_priviledge:
|
if not use_admin_priviledge:
|
||||||
(
|
(
|
||||||
|
@ -747,7 +760,9 @@ class PaginationHandler:
|
||||||
|
|
||||||
self._purges_in_progress_by_room.add(room_id)
|
self._purges_in_progress_by_room.add(room_id)
|
||||||
try:
|
try:
|
||||||
async with self.pagination_lock.write(room_id):
|
async with self._worker_locks.acquire_read_write_lock(
|
||||||
|
PURGE_HISTORY_LOCK_NAME, room_id, write=True
|
||||||
|
):
|
||||||
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN
|
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN
|
||||||
self._delete_by_id[
|
self._delete_by_id[
|
||||||
delete_id
|
delete_id
|
||||||
|
|
|
@ -39,6 +39,7 @@ from synapse.events import EventBase
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
|
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
|
||||||
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
|
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
|
||||||
|
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
|
||||||
from synapse.logging import opentracing
|
from synapse.logging import opentracing
|
||||||
from synapse.metrics import event_processing_positions
|
from synapse.metrics import event_processing_positions
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
@ -94,6 +95,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
self.event_creation_handler = hs.get_event_creation_handler()
|
self.event_creation_handler = hs.get_event_creation_handler()
|
||||||
self.account_data_handler = hs.get_account_data_handler()
|
self.account_data_handler = hs.get_account_data_handler()
|
||||||
self.event_auth_handler = hs.get_event_auth_handler()
|
self.event_auth_handler = hs.get_event_auth_handler()
|
||||||
|
self._worker_lock_handler = hs.get_worker_locks_handler()
|
||||||
|
|
||||||
self.member_linearizer: Linearizer = Linearizer(name="member")
|
self.member_linearizer: Linearizer = Linearizer(name="member")
|
||||||
self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter")
|
self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter")
|
||||||
|
@ -638,26 +640,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
# by application services), and then by room ID.
|
# by application services), and then by room ID.
|
||||||
async with self.member_as_limiter.queue(as_id):
|
async with self.member_as_limiter.queue(as_id):
|
||||||
async with self.member_linearizer.queue(key):
|
async with self.member_linearizer.queue(key):
|
||||||
with opentracing.start_active_span("update_membership_locked"):
|
async with self._worker_lock_handler.acquire_read_write_lock(
|
||||||
result = await self.update_membership_locked(
|
DELETE_ROOM_LOCK_NAME, room_id, write=False
|
||||||
requester,
|
):
|
||||||
target,
|
with opentracing.start_active_span("update_membership_locked"):
|
||||||
room_id,
|
result = await self.update_membership_locked(
|
||||||
action,
|
requester,
|
||||||
txn_id=txn_id,
|
target,
|
||||||
remote_room_hosts=remote_room_hosts,
|
room_id,
|
||||||
third_party_signed=third_party_signed,
|
action,
|
||||||
ratelimit=ratelimit,
|
txn_id=txn_id,
|
||||||
content=content,
|
remote_room_hosts=remote_room_hosts,
|
||||||
new_room=new_room,
|
third_party_signed=third_party_signed,
|
||||||
require_consent=require_consent,
|
ratelimit=ratelimit,
|
||||||
outlier=outlier,
|
content=content,
|
||||||
allow_no_prev_events=allow_no_prev_events,
|
new_room=new_room,
|
||||||
prev_event_ids=prev_event_ids,
|
require_consent=require_consent,
|
||||||
state_event_ids=state_event_ids,
|
outlier=outlier,
|
||||||
depth=depth,
|
allow_no_prev_events=allow_no_prev_events,
|
||||||
origin_server_ts=origin_server_ts,
|
prev_event_ids=prev_event_ids,
|
||||||
)
|
state_event_ids=state_event_ids,
|
||||||
|
depth=depth,
|
||||||
|
origin_server_ts=origin_server_ts,
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,333 @@
|
||||||
|
# Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import random
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
AsyncContextManager,
|
||||||
|
Collection,
|
||||||
|
Dict,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
from weakref import WeakSet
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
from twisted.internet.interfaces import IReactorTime
|
||||||
|
|
||||||
|
from synapse.logging.context import PreserveLoggingContext
|
||||||
|
from synapse.logging.opentracing import start_active_span
|
||||||
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||||
|
from synapse.storage.databases.main.lock import Lock, LockStore
|
||||||
|
from synapse.util.async_helpers import timeout_deferred
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.logging.opentracing import opentracing
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
|
||||||
|
DELETE_ROOM_LOCK_NAME = "delete_room_lock"
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerLocksHandler:
|
||||||
|
"""A class for waiting on taking out locks, rather than using the storage
|
||||||
|
functions directly (which don't support awaiting).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer") -> None:
|
||||||
|
self._reactor = hs.get_reactor()
|
||||||
|
self._store = hs.get_datastores().main
|
||||||
|
self._clock = hs.get_clock()
|
||||||
|
self._notifier = hs.get_notifier()
|
||||||
|
self._instance_name = hs.get_instance_name()
|
||||||
|
|
||||||
|
# Map from lock name/key to set of `WaitingLock` that are active for
|
||||||
|
# that lock.
|
||||||
|
self._locks: Dict[
|
||||||
|
Tuple[str, str], WeakSet[Union[WaitingLock, WaitingMultiLock]]
|
||||||
|
] = {}
|
||||||
|
|
||||||
|
self._clock.looping_call(self._cleanup_locks, 30_000)
|
||||||
|
|
||||||
|
self._notifier.add_lock_released_callback(self._on_lock_released)
|
||||||
|
|
||||||
|
def acquire_lock(self, lock_name: str, lock_key: str) -> "WaitingLock":
|
||||||
|
"""Acquire a standard lock, returns a context manager that will block
|
||||||
|
until the lock is acquired.
|
||||||
|
|
||||||
|
Note: Care must be taken to avoid deadlocks. In particular, this
|
||||||
|
function does *not* timeout.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
async with handler.acquire_lock(name, key):
|
||||||
|
# Do work while holding the lock...
|
||||||
|
"""
|
||||||
|
|
||||||
|
lock = WaitingLock(
|
||||||
|
reactor=self._reactor,
|
||||||
|
store=self._store,
|
||||||
|
handler=self,
|
||||||
|
lock_name=lock_name,
|
||||||
|
lock_key=lock_key,
|
||||||
|
write=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)
|
||||||
|
|
||||||
|
return lock
|
||||||
|
|
||||||
|
def acquire_read_write_lock(
|
||||||
|
self,
|
||||||
|
lock_name: str,
|
||||||
|
lock_key: str,
|
||||||
|
*,
|
||||||
|
write: bool,
|
||||||
|
) -> "WaitingLock":
|
||||||
|
"""Acquire a read/write lock, returns a context manager that will block
|
||||||
|
until the lock is acquired.
|
||||||
|
|
||||||
|
Note: Care must be taken to avoid deadlocks. In particular, this
|
||||||
|
function does *not* timeout.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
async with handler.acquire_read_write_lock(name, key, write=True):
|
||||||
|
# Do work while holding the lock...
|
||||||
|
"""
|
||||||
|
|
||||||
|
lock = WaitingLock(
|
||||||
|
reactor=self._reactor,
|
||||||
|
store=self._store,
|
||||||
|
handler=self,
|
||||||
|
lock_name=lock_name,
|
||||||
|
lock_key=lock_key,
|
||||||
|
write=write,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)
|
||||||
|
|
||||||
|
return lock
|
||||||
|
|
||||||
|
def acquire_multi_read_write_lock(
|
||||||
|
self,
|
||||||
|
lock_names: Collection[Tuple[str, str]],
|
||||||
|
*,
|
||||||
|
write: bool,
|
||||||
|
) -> "WaitingMultiLock":
|
||||||
|
"""Acquires multi read/write locks at once, returns a context manager
|
||||||
|
that will block until all the locks are acquired.
|
||||||
|
|
||||||
|
This will try and acquire all locks at once, and will never hold on to a
|
||||||
|
subset of the locks. (This avoids accidentally creating deadlocks).
|
||||||
|
|
||||||
|
Note: Care must be taken to avoid deadlocks. In particular, this
|
||||||
|
function does *not* timeout.
|
||||||
|
"""
|
||||||
|
|
||||||
|
lock = WaitingMultiLock(
|
||||||
|
lock_names=lock_names,
|
||||||
|
write=write,
|
||||||
|
reactor=self._reactor,
|
||||||
|
store=self._store,
|
||||||
|
handler=self,
|
||||||
|
)
|
||||||
|
|
||||||
|
for lock_name, lock_key in lock_names:
|
||||||
|
self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)
|
||||||
|
|
||||||
|
return lock
|
||||||
|
|
||||||
|
def notify_lock_released(self, lock_name: str, lock_key: str) -> None:
|
||||||
|
"""Notify that a lock has been released.
|
||||||
|
|
||||||
|
Pokes both the notifier and replication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._notifier.notify_lock_released(self._instance_name, lock_name, lock_key)
|
||||||
|
|
||||||
|
def _on_lock_released(
|
||||||
|
self, instance_name: str, lock_name: str, lock_key: str
|
||||||
|
) -> None:
|
||||||
|
"""Called when a lock has been released.
|
||||||
|
|
||||||
|
Wakes up any locks that might be waiting on this.
|
||||||
|
"""
|
||||||
|
locks = self._locks.get((lock_name, lock_key))
|
||||||
|
if not locks:
|
||||||
|
return
|
||||||
|
|
||||||
|
def _wake_deferred(deferred: defer.Deferred) -> None:
|
||||||
|
if not deferred.called:
|
||||||
|
deferred.callback(None)
|
||||||
|
|
||||||
|
for lock in locks:
|
||||||
|
self._clock.call_later(0, _wake_deferred, lock.deferred)
|
||||||
|
|
||||||
|
@wrap_as_background_process("_cleanup_locks")
|
||||||
|
async def _cleanup_locks(self) -> None:
|
||||||
|
"""Periodically cleans out stale entries in the locks map"""
|
||||||
|
self._locks = {key: value for key, value in self._locks.items() if value}
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(auto_attribs=True, eq=False)
|
||||||
|
class WaitingLock:
|
||||||
|
reactor: IReactorTime
|
||||||
|
store: LockStore
|
||||||
|
handler: WorkerLocksHandler
|
||||||
|
lock_name: str
|
||||||
|
lock_key: str
|
||||||
|
write: Optional[bool]
|
||||||
|
deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred)
|
||||||
|
_inner_lock: Optional[Lock] = None
|
||||||
|
_retry_interval: float = 0.1
|
||||||
|
_lock_span: "opentracing.Scope" = attr.Factory(
|
||||||
|
lambda: start_active_span("WaitingLock.lock")
|
||||||
|
)
|
||||||
|
|
||||||
|
async def __aenter__(self) -> None:
|
||||||
|
self._lock_span.__enter__()
|
||||||
|
|
||||||
|
with start_active_span("WaitingLock.waiting_for_lock"):
|
||||||
|
while self._inner_lock is None:
|
||||||
|
self.deferred = defer.Deferred()
|
||||||
|
|
||||||
|
if self.write is not None:
|
||||||
|
lock = await self.store.try_acquire_read_write_lock(
|
||||||
|
self.lock_name, self.lock_key, write=self.write
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lock = await self.store.try_acquire_lock(
|
||||||
|
self.lock_name, self.lock_key
|
||||||
|
)
|
||||||
|
|
||||||
|
if lock:
|
||||||
|
self._inner_lock = lock
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Wait until the we get notified the lock might have been
|
||||||
|
# released (by the deferred being resolved). We also
|
||||||
|
# periodically wake up in case the lock was released but we
|
||||||
|
# weren't notified.
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
await timeout_deferred(
|
||||||
|
deferred=self.deferred,
|
||||||
|
timeout=self._get_next_retry_interval(),
|
||||||
|
reactor=self.reactor,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return await self._inner_lock.__aenter__()
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[Type[BaseException]],
|
||||||
|
exc: Optional[BaseException],
|
||||||
|
tb: Optional[TracebackType],
|
||||||
|
) -> Optional[bool]:
|
||||||
|
assert self._inner_lock
|
||||||
|
|
||||||
|
self.handler.notify_lock_released(self.lock_name, self.lock_key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
r = await self._inner_lock.__aexit__(exc_type, exc, tb)
|
||||||
|
finally:
|
||||||
|
self._lock_span.__exit__(exc_type, exc, tb)
|
||||||
|
|
||||||
|
return r
|
||||||
|
|
||||||
|
def _get_next_retry_interval(self) -> float:
|
||||||
|
next = self._retry_interval
|
||||||
|
self._retry_interval = max(5, next * 2)
|
||||||
|
return next * random.uniform(0.9, 1.1)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(auto_attribs=True, eq=False)
|
||||||
|
class WaitingMultiLock:
|
||||||
|
lock_names: Collection[Tuple[str, str]]
|
||||||
|
|
||||||
|
write: bool
|
||||||
|
|
||||||
|
reactor: IReactorTime
|
||||||
|
store: LockStore
|
||||||
|
handler: WorkerLocksHandler
|
||||||
|
|
||||||
|
deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred)
|
||||||
|
|
||||||
|
_inner_lock_cm: Optional[AsyncContextManager] = None
|
||||||
|
_retry_interval: float = 0.1
|
||||||
|
_lock_span: "opentracing.Scope" = attr.Factory(
|
||||||
|
lambda: start_active_span("WaitingLock.lock")
|
||||||
|
)
|
||||||
|
|
||||||
|
async def __aenter__(self) -> None:
|
||||||
|
self._lock_span.__enter__()
|
||||||
|
|
||||||
|
with start_active_span("WaitingLock.waiting_for_lock"):
|
||||||
|
while self._inner_lock_cm is None:
|
||||||
|
self.deferred = defer.Deferred()
|
||||||
|
|
||||||
|
lock_cm = await self.store.try_acquire_multi_read_write_lock(
|
||||||
|
self.lock_names, write=self.write
|
||||||
|
)
|
||||||
|
|
||||||
|
if lock_cm:
|
||||||
|
self._inner_lock_cm = lock_cm
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Wait until the we get notified the lock might have been
|
||||||
|
# released (by the deferred being resolved). We also
|
||||||
|
# periodically wake up in case the lock was released but we
|
||||||
|
# weren't notified.
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
await timeout_deferred(
|
||||||
|
deferred=self.deferred,
|
||||||
|
timeout=self._get_next_retry_interval(),
|
||||||
|
reactor=self.reactor,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert self._inner_lock_cm
|
||||||
|
await self._inner_lock_cm.__aenter__()
|
||||||
|
return
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[Type[BaseException]],
|
||||||
|
exc: Optional[BaseException],
|
||||||
|
tb: Optional[TracebackType],
|
||||||
|
) -> Optional[bool]:
|
||||||
|
assert self._inner_lock_cm
|
||||||
|
|
||||||
|
for lock_name, lock_key in self.lock_names:
|
||||||
|
self.handler.notify_lock_released(lock_name, lock_key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
r = await self._inner_lock_cm.__aexit__(exc_type, exc, tb)
|
||||||
|
finally:
|
||||||
|
self._lock_span.__exit__(exc_type, exc, tb)
|
||||||
|
|
||||||
|
return r
|
||||||
|
|
||||||
|
def _get_next_retry_interval(self) -> float:
|
||||||
|
next = self._retry_interval
|
||||||
|
self._retry_interval = max(5, next * 2)
|
||||||
|
return next * random.uniform(0.9, 1.1)
|
|
@ -234,6 +234,9 @@ class Notifier:
|
||||||
|
|
||||||
self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
|
self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
|
||||||
|
|
||||||
|
# List of callbacks to be notified when a lock is released
|
||||||
|
self._lock_released_callback: List[Callable[[str, str, str], None]] = []
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.appservice_handler = hs.get_application_service_handler()
|
self.appservice_handler = hs.get_application_service_handler()
|
||||||
self._pusher_pool = hs.get_pusherpool()
|
self._pusher_pool = hs.get_pusherpool()
|
||||||
|
@ -785,6 +788,19 @@ class Notifier:
|
||||||
# that any in flight requests can be immediately retried.
|
# that any in flight requests can be immediately retried.
|
||||||
self._federation_client.wake_destination(server)
|
self._federation_client.wake_destination(server)
|
||||||
|
|
||||||
|
def add_lock_released_callback(
|
||||||
|
self, callback: Callable[[str, str, str], None]
|
||||||
|
) -> None:
|
||||||
|
"""Add a function to be called whenever we are notified about a released lock."""
|
||||||
|
self._lock_released_callback.append(callback)
|
||||||
|
|
||||||
|
def notify_lock_released(
|
||||||
|
self, instance_name: str, lock_name: str, lock_key: str
|
||||||
|
) -> None:
|
||||||
|
"""Notify the callbacks that a lock has been released."""
|
||||||
|
for cb in self._lock_released_callback:
|
||||||
|
cb(instance_name, lock_name, lock_key)
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True)
|
@attr.s(auto_attribs=True)
|
||||||
class ReplicationNotifier:
|
class ReplicationNotifier:
|
||||||
|
|
|
@ -422,6 +422,36 @@ class RemoteServerUpCommand(_SimpleCommand):
|
||||||
NAME = "REMOTE_SERVER_UP"
|
NAME = "REMOTE_SERVER_UP"
|
||||||
|
|
||||||
|
|
||||||
|
class LockReleasedCommand(Command):
|
||||||
|
"""Sent to inform other instances that a given lock has been dropped.
|
||||||
|
|
||||||
|
Format::
|
||||||
|
|
||||||
|
LOCK_RELEASED ["<instance_name>", "<lock_name>", "<lock_key>"]
|
||||||
|
"""
|
||||||
|
|
||||||
|
NAME = "LOCK_RELEASED"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
instance_name: str,
|
||||||
|
lock_name: str,
|
||||||
|
lock_key: str,
|
||||||
|
):
|
||||||
|
self.instance_name = instance_name
|
||||||
|
self.lock_name = lock_name
|
||||||
|
self.lock_key = lock_key
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_line(cls: Type["LockReleasedCommand"], line: str) -> "LockReleasedCommand":
|
||||||
|
instance_name, lock_name, lock_key = json_decoder.decode(line)
|
||||||
|
|
||||||
|
return cls(instance_name, lock_name, lock_key)
|
||||||
|
|
||||||
|
def to_line(self) -> str:
|
||||||
|
return json_encoder.encode([self.instance_name, self.lock_name, self.lock_key])
|
||||||
|
|
||||||
|
|
||||||
_COMMANDS: Tuple[Type[Command], ...] = (
|
_COMMANDS: Tuple[Type[Command], ...] = (
|
||||||
ServerCommand,
|
ServerCommand,
|
||||||
RdataCommand,
|
RdataCommand,
|
||||||
|
@ -435,6 +465,7 @@ _COMMANDS: Tuple[Type[Command], ...] = (
|
||||||
UserIpCommand,
|
UserIpCommand,
|
||||||
RemoteServerUpCommand,
|
RemoteServerUpCommand,
|
||||||
ClearUserSyncsCommand,
|
ClearUserSyncsCommand,
|
||||||
|
LockReleasedCommand,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Map of command name to command type.
|
# Map of command name to command type.
|
||||||
|
@ -448,6 +479,7 @@ VALID_SERVER_COMMANDS = (
|
||||||
ErrorCommand.NAME,
|
ErrorCommand.NAME,
|
||||||
PingCommand.NAME,
|
PingCommand.NAME,
|
||||||
RemoteServerUpCommand.NAME,
|
RemoteServerUpCommand.NAME,
|
||||||
|
LockReleasedCommand.NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
# The commands the client is allowed to send
|
# The commands the client is allowed to send
|
||||||
|
@ -461,6 +493,7 @@ VALID_CLIENT_COMMANDS = (
|
||||||
UserIpCommand.NAME,
|
UserIpCommand.NAME,
|
||||||
ErrorCommand.NAME,
|
ErrorCommand.NAME,
|
||||||
RemoteServerUpCommand.NAME,
|
RemoteServerUpCommand.NAME,
|
||||||
|
LockReleasedCommand.NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -39,6 +39,7 @@ from synapse.replication.tcp.commands import (
|
||||||
ClearUserSyncsCommand,
|
ClearUserSyncsCommand,
|
||||||
Command,
|
Command,
|
||||||
FederationAckCommand,
|
FederationAckCommand,
|
||||||
|
LockReleasedCommand,
|
||||||
PositionCommand,
|
PositionCommand,
|
||||||
RdataCommand,
|
RdataCommand,
|
||||||
RemoteServerUpCommand,
|
RemoteServerUpCommand,
|
||||||
|
@ -248,6 +249,9 @@ class ReplicationCommandHandler:
|
||||||
if self._is_master or self._should_insert_client_ips:
|
if self._is_master or self._should_insert_client_ips:
|
||||||
self.subscribe_to_channel("USER_IP")
|
self.subscribe_to_channel("USER_IP")
|
||||||
|
|
||||||
|
if hs.config.redis.redis_enabled:
|
||||||
|
self._notifier.add_lock_released_callback(self.on_lock_released)
|
||||||
|
|
||||||
def subscribe_to_channel(self, channel_name: str) -> None:
|
def subscribe_to_channel(self, channel_name: str) -> None:
|
||||||
"""
|
"""
|
||||||
Indicates that we wish to subscribe to a Redis channel by name.
|
Indicates that we wish to subscribe to a Redis channel by name.
|
||||||
|
@ -648,6 +652,17 @@ class ReplicationCommandHandler:
|
||||||
|
|
||||||
self._notifier.notify_remote_server_up(cmd.data)
|
self._notifier.notify_remote_server_up(cmd.data)
|
||||||
|
|
||||||
|
def on_LOCK_RELEASED(
|
||||||
|
self, conn: IReplicationConnection, cmd: LockReleasedCommand
|
||||||
|
) -> None:
|
||||||
|
"""Called when we get a new LOCK_RELEASED command."""
|
||||||
|
if cmd.instance_name == self._instance_name:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._notifier.notify_lock_released(
|
||||||
|
cmd.instance_name, cmd.lock_name, cmd.lock_key
|
||||||
|
)
|
||||||
|
|
||||||
def new_connection(self, connection: IReplicationConnection) -> None:
|
def new_connection(self, connection: IReplicationConnection) -> None:
|
||||||
"""Called when we have a new connection."""
|
"""Called when we have a new connection."""
|
||||||
self._connections.append(connection)
|
self._connections.append(connection)
|
||||||
|
@ -754,6 +769,13 @@ class ReplicationCommandHandler:
|
||||||
"""
|
"""
|
||||||
self.send_command(RdataCommand(stream_name, self._instance_name, token, data))
|
self.send_command(RdataCommand(stream_name, self._instance_name, token, data))
|
||||||
|
|
||||||
|
def on_lock_released(
|
||||||
|
self, instance_name: str, lock_name: str, lock_key: str
|
||||||
|
) -> None:
|
||||||
|
"""Called when we released a lock and should notify other instances."""
|
||||||
|
if instance_name == self._instance_name:
|
||||||
|
self.send_command(LockReleasedCommand(instance_name, lock_name, lock_key))
|
||||||
|
|
||||||
|
|
||||||
UpdateToken = TypeVar("UpdateToken")
|
UpdateToken = TypeVar("UpdateToken")
|
||||||
UpdateRow = TypeVar("UpdateRow")
|
UpdateRow = TypeVar("UpdateRow")
|
||||||
|
|
|
@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
from synapse.api.errors import Codes, ShadowBanError, SynapseError
|
from synapse.api.errors import Codes, ShadowBanError, SynapseError
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||||
|
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
|
||||||
from synapse.http.server import HttpServer
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
RestServlet,
|
RestServlet,
|
||||||
|
@ -60,6 +61,7 @@ class RoomUpgradeRestServlet(RestServlet):
|
||||||
self._hs = hs
|
self._hs = hs
|
||||||
self._room_creation_handler = hs.get_room_creation_handler()
|
self._room_creation_handler = hs.get_room_creation_handler()
|
||||||
self._auth = hs.get_auth()
|
self._auth = hs.get_auth()
|
||||||
|
self._worker_lock_handler = hs.get_worker_locks_handler()
|
||||||
|
|
||||||
async def on_POST(
|
async def on_POST(
|
||||||
self, request: SynapseRequest, room_id: str
|
self, request: SynapseRequest, room_id: str
|
||||||
|
@ -78,9 +80,12 @@ class RoomUpgradeRestServlet(RestServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
new_room_id = await self._room_creation_handler.upgrade_room(
|
async with self._worker_lock_handler.acquire_read_write_lock(
|
||||||
requester, room_id, new_version
|
DELETE_ROOM_LOCK_NAME, room_id, write=False
|
||||||
)
|
):
|
||||||
|
new_room_id = await self._room_creation_handler.upgrade_room(
|
||||||
|
requester, room_id, new_version
|
||||||
|
)
|
||||||
except ShadowBanError:
|
except ShadowBanError:
|
||||||
# Generate a random room ID.
|
# Generate a random room ID.
|
||||||
new_room_id = stringutils.random_string(18)
|
new_room_id = stringutils.random_string(18)
|
||||||
|
|
|
@ -107,6 +107,7 @@ from synapse.handlers.stats import StatsHandler
|
||||||
from synapse.handlers.sync import SyncHandler
|
from synapse.handlers.sync import SyncHandler
|
||||||
from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
|
from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
|
||||||
from synapse.handlers.user_directory import UserDirectoryHandler
|
from synapse.handlers.user_directory import UserDirectoryHandler
|
||||||
|
from synapse.handlers.worker_lock import WorkerLocksHandler
|
||||||
from synapse.http.client import (
|
from synapse.http.client import (
|
||||||
InsecureInterceptableContextFactory,
|
InsecureInterceptableContextFactory,
|
||||||
ReplicationClient,
|
ReplicationClient,
|
||||||
|
@ -912,3 +913,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
def get_common_usage_metrics_manager(self) -> CommonUsageMetricsManager:
|
def get_common_usage_metrics_manager(self) -> CommonUsageMetricsManager:
|
||||||
"""Usage metrics shared between phone home stats and the prometheus exporter."""
|
"""Usage metrics shared between phone home stats and the prometheus exporter."""
|
||||||
return CommonUsageMetricsManager(self)
|
return CommonUsageMetricsManager(self)
|
||||||
|
|
||||||
|
@cache_in_self
|
||||||
|
def get_worker_locks_handler(self) -> WorkerLocksHandler:
|
||||||
|
return WorkerLocksHandler(self)
|
||||||
|
|
|
@ -45,6 +45,7 @@ from twisted.internet import defer
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
|
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
|
||||||
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
||||||
from synapse.logging.opentracing import (
|
from synapse.logging.opentracing import (
|
||||||
SynapseTags,
|
SynapseTags,
|
||||||
|
@ -338,6 +339,7 @@ class EventsPersistenceStorageController:
|
||||||
)
|
)
|
||||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||||
self._state_controller = state_controller
|
self._state_controller = state_controller
|
||||||
|
self.hs = hs
|
||||||
|
|
||||||
async def _process_event_persist_queue_task(
|
async def _process_event_persist_queue_task(
|
||||||
self,
|
self,
|
||||||
|
@ -350,15 +352,22 @@ class EventsPersistenceStorageController:
|
||||||
A dictionary of event ID to event ID we didn't persist as we already
|
A dictionary of event ID to event ID we didn't persist as we already
|
||||||
had another event persisted with the same TXN ID.
|
had another event persisted with the same TXN ID.
|
||||||
"""
|
"""
|
||||||
if isinstance(task, _PersistEventsTask):
|
|
||||||
return await self._persist_event_batch(room_id, task)
|
# Ensure that the room can't be deleted while we're persisting events to
|
||||||
elif isinstance(task, _UpdateCurrentStateTask):
|
# it. We might already have taken out the lock, but since this is just a
|
||||||
await self._update_current_state(room_id, task)
|
# "read" lock its inherently reentrant.
|
||||||
return {}
|
async with self.hs.get_worker_locks_handler().acquire_read_write_lock(
|
||||||
else:
|
DELETE_ROOM_LOCK_NAME, room_id, write=False
|
||||||
raise AssertionError(
|
):
|
||||||
f"Found an unexpected task type in event persistence queue: {task}"
|
if isinstance(task, _PersistEventsTask):
|
||||||
)
|
return await self._persist_event_batch(room_id, task)
|
||||||
|
elif isinstance(task, _UpdateCurrentStateTask):
|
||||||
|
await self._update_current_state(room_id, task)
|
||||||
|
return {}
|
||||||
|
else:
|
||||||
|
raise AssertionError(
|
||||||
|
f"Found an unexpected task type in event persistence queue: {task}"
|
||||||
|
)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
async def persist_events(
|
async def persist_events(
|
||||||
|
|
|
@ -12,8 +12,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import TYPE_CHECKING, Optional, Set, Tuple, Type
|
from typing import TYPE_CHECKING, Collection, Optional, Set, Tuple, Type
|
||||||
from weakref import WeakValueDictionary
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
from twisted.internet.interfaces import IReactorCore
|
from twisted.internet.interfaces import IReactorCore
|
||||||
|
@ -208,77 +209,86 @@ class LockStore(SQLBaseStore):
|
||||||
used (otherwise the lock will leak).
|
used (otherwise the lock will leak).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
now = self._clock.time_msec()
|
|
||||||
token = random_string(6)
|
|
||||||
|
|
||||||
def _try_acquire_read_write_lock_txn(txn: LoggingTransaction) -> None:
|
|
||||||
# We attempt to acquire the lock by inserting into
|
|
||||||
# `worker_read_write_locks` and seeing if that fails any
|
|
||||||
# constraints. If it doesn't then we have acquired the lock,
|
|
||||||
# otherwise we haven't.
|
|
||||||
#
|
|
||||||
# Before that though we clear the table of any stale locks.
|
|
||||||
|
|
||||||
delete_sql = """
|
|
||||||
DELETE FROM worker_read_write_locks
|
|
||||||
WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
|
|
||||||
"""
|
|
||||||
|
|
||||||
insert_sql = """
|
|
||||||
INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?)
|
|
||||||
"""
|
|
||||||
|
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
|
||||||
# For Postgres we can send these queries at the same time.
|
|
||||||
txn.execute(
|
|
||||||
delete_sql + ";" + insert_sql,
|
|
||||||
(
|
|
||||||
# DELETE args
|
|
||||||
now - _LOCK_TIMEOUT_MS,
|
|
||||||
lock_name,
|
|
||||||
lock_key,
|
|
||||||
# UPSERT args
|
|
||||||
lock_name,
|
|
||||||
lock_key,
|
|
||||||
write,
|
|
||||||
self._instance_name,
|
|
||||||
token,
|
|
||||||
now,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# For SQLite these need to be two queries.
|
|
||||||
txn.execute(
|
|
||||||
delete_sql,
|
|
||||||
(
|
|
||||||
now - _LOCK_TIMEOUT_MS,
|
|
||||||
lock_name,
|
|
||||||
lock_key,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
txn.execute(
|
|
||||||
insert_sql,
|
|
||||||
(
|
|
||||||
lock_name,
|
|
||||||
lock_key,
|
|
||||||
write,
|
|
||||||
self._instance_name,
|
|
||||||
token,
|
|
||||||
now,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.db_pool.runInteraction(
|
lock = await self.db_pool.runInteraction(
|
||||||
"try_acquire_read_write_lock",
|
"try_acquire_read_write_lock",
|
||||||
_try_acquire_read_write_lock_txn,
|
self._try_acquire_read_write_lock_txn,
|
||||||
|
lock_name,
|
||||||
|
lock_key,
|
||||||
|
write,
|
||||||
)
|
)
|
||||||
except self.database_engine.module.IntegrityError:
|
except self.database_engine.module.IntegrityError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
return lock
|
||||||
|
|
||||||
|
def _try_acquire_read_write_lock_txn(
|
||||||
|
self,
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
lock_name: str,
|
||||||
|
lock_key: str,
|
||||||
|
write: bool,
|
||||||
|
) -> "Lock":
|
||||||
|
# We attempt to acquire the lock by inserting into
|
||||||
|
# `worker_read_write_locks` and seeing if that fails any
|
||||||
|
# constraints. If it doesn't then we have acquired the lock,
|
||||||
|
# otherwise we haven't.
|
||||||
|
#
|
||||||
|
# Before that though we clear the table of any stale locks.
|
||||||
|
|
||||||
|
now = self._clock.time_msec()
|
||||||
|
token = random_string(6)
|
||||||
|
|
||||||
|
delete_sql = """
|
||||||
|
DELETE FROM worker_read_write_locks
|
||||||
|
WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
|
||||||
|
"""
|
||||||
|
|
||||||
|
insert_sql = """
|
||||||
|
INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
|
# For Postgres we can send these queries at the same time.
|
||||||
|
txn.execute(
|
||||||
|
delete_sql + ";" + insert_sql,
|
||||||
|
(
|
||||||
|
# DELETE args
|
||||||
|
now - _LOCK_TIMEOUT_MS,
|
||||||
|
lock_name,
|
||||||
|
lock_key,
|
||||||
|
# UPSERT args
|
||||||
|
lock_name,
|
||||||
|
lock_key,
|
||||||
|
write,
|
||||||
|
self._instance_name,
|
||||||
|
token,
|
||||||
|
now,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For SQLite these need to be two queries.
|
||||||
|
txn.execute(
|
||||||
|
delete_sql,
|
||||||
|
(
|
||||||
|
now - _LOCK_TIMEOUT_MS,
|
||||||
|
lock_name,
|
||||||
|
lock_key,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
txn.execute(
|
||||||
|
insert_sql,
|
||||||
|
(
|
||||||
|
lock_name,
|
||||||
|
lock_key,
|
||||||
|
write,
|
||||||
|
self._instance_name,
|
||||||
|
token,
|
||||||
|
now,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
lock = Lock(
|
lock = Lock(
|
||||||
self._reactor,
|
self._reactor,
|
||||||
self._clock,
|
self._clock,
|
||||||
|
@ -289,10 +299,58 @@ class LockStore(SQLBaseStore):
|
||||||
token=token,
|
token=token,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock
|
def set_lock() -> None:
|
||||||
|
self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock
|
||||||
|
|
||||||
|
txn.call_after(set_lock)
|
||||||
|
|
||||||
return lock
|
return lock
|
||||||
|
|
||||||
|
async def try_acquire_multi_read_write_lock(
|
||||||
|
self,
|
||||||
|
lock_names: Collection[Tuple[str, str]],
|
||||||
|
write: bool,
|
||||||
|
) -> Optional[AsyncExitStack]:
|
||||||
|
"""Try to acquire multiple locks for the given names/keys. Will return
|
||||||
|
an async context manager if the locks are successfully acquired, which
|
||||||
|
*must* be used (otherwise the lock will leak).
|
||||||
|
|
||||||
|
If only a subset of the locks can be acquired then it will immediately
|
||||||
|
drop them and return `None`.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
locks = await self.db_pool.runInteraction(
|
||||||
|
"try_acquire_multi_read_write_lock",
|
||||||
|
self._try_acquire_multi_read_write_lock_txn,
|
||||||
|
lock_names,
|
||||||
|
write,
|
||||||
|
)
|
||||||
|
except self.database_engine.module.IntegrityError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
stack = AsyncExitStack()
|
||||||
|
|
||||||
|
for lock in locks:
|
||||||
|
await stack.enter_async_context(lock)
|
||||||
|
|
||||||
|
return stack
|
||||||
|
|
||||||
|
def _try_acquire_multi_read_write_lock_txn(
|
||||||
|
self,
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
lock_names: Collection[Tuple[str, str]],
|
||||||
|
write: bool,
|
||||||
|
) -> Collection["Lock"]:
|
||||||
|
locks = []
|
||||||
|
|
||||||
|
for lock_name, lock_key in lock_names:
|
||||||
|
lock = self._try_acquire_read_write_lock_txn(
|
||||||
|
txn, lock_name, lock_key, write
|
||||||
|
)
|
||||||
|
locks.append(lock)
|
||||||
|
|
||||||
|
return locks
|
||||||
|
|
||||||
|
|
||||||
class Lock:
|
class Lock:
|
||||||
"""An async context manager that manages an acquired lock, ensuring it is
|
"""An async context manager that manages an acquired lock, ensuring it is
|
||||||
|
|
|
@ -0,0 +1,74 @@
|
||||||
|
# Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
from tests import unittest
|
||||||
|
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerLockTestCase(unittest.HomeserverTestCase):
|
||||||
|
def prepare(
|
||||||
|
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||||
|
) -> None:
|
||||||
|
self.worker_lock_handler = self.hs.get_worker_locks_handler()
|
||||||
|
|
||||||
|
def test_wait_for_lock_locally(self) -> None:
|
||||||
|
"""Test waiting for a lock on a single worker"""
|
||||||
|
|
||||||
|
lock1 = self.worker_lock_handler.acquire_lock("name", "key")
|
||||||
|
self.get_success(lock1.__aenter__())
|
||||||
|
|
||||||
|
lock2 = self.worker_lock_handler.acquire_lock("name", "key")
|
||||||
|
d2 = defer.ensureDeferred(lock2.__aenter__())
|
||||||
|
self.assertNoResult(d2)
|
||||||
|
|
||||||
|
self.get_success(lock1.__aexit__(None, None, None))
|
||||||
|
|
||||||
|
self.get_success(d2)
|
||||||
|
self.get_success(lock2.__aexit__(None, None, None))
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerLockWorkersTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
|
def prepare(
|
||||||
|
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||||
|
) -> None:
|
||||||
|
self.main_worker_lock_handler = self.hs.get_worker_locks_handler()
|
||||||
|
|
||||||
|
def test_wait_for_lock_worker(self) -> None:
|
||||||
|
"""Test waiting for a lock on another worker"""
|
||||||
|
|
||||||
|
worker = self.make_worker_hs(
|
||||||
|
"synapse.app.generic_worker",
|
||||||
|
extra_config={
|
||||||
|
"redis": {"enabled": True},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
worker_lock_handler = worker.get_worker_locks_handler()
|
||||||
|
|
||||||
|
lock1 = self.main_worker_lock_handler.acquire_lock("name", "key")
|
||||||
|
self.get_success(lock1.__aenter__())
|
||||||
|
|
||||||
|
lock2 = worker_lock_handler.acquire_lock("name", "key")
|
||||||
|
d2 = defer.ensureDeferred(lock2.__aenter__())
|
||||||
|
self.assertNoResult(d2)
|
||||||
|
|
||||||
|
self.get_success(lock1.__aexit__(None, None, None))
|
||||||
|
|
||||||
|
self.get_success(d2)
|
||||||
|
self.get_success(lock2.__aexit__(None, None, None))
|
|
@ -711,7 +711,7 @@ class RoomsCreateTestCase(RoomBase):
|
||||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||||
self.assertTrue("room_id" in channel.json_body)
|
self.assertTrue("room_id" in channel.json_body)
|
||||||
assert channel.resource_usage is not None
|
assert channel.resource_usage is not None
|
||||||
self.assertEqual(30, channel.resource_usage.db_txn_count)
|
self.assertEqual(32, channel.resource_usage.db_txn_count)
|
||||||
|
|
||||||
def test_post_room_initial_state(self) -> None:
|
def test_post_room_initial_state(self) -> None:
|
||||||
# POST with initial_state config key, expect new room id
|
# POST with initial_state config key, expect new room id
|
||||||
|
@ -724,7 +724,7 @@ class RoomsCreateTestCase(RoomBase):
|
||||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||||
self.assertTrue("room_id" in channel.json_body)
|
self.assertTrue("room_id" in channel.json_body)
|
||||||
assert channel.resource_usage is not None
|
assert channel.resource_usage is not None
|
||||||
self.assertEqual(32, channel.resource_usage.db_txn_count)
|
self.assertEqual(34, channel.resource_usage.db_txn_count)
|
||||||
|
|
||||||
def test_post_room_visibility_key(self) -> None:
|
def test_post_room_visibility_key(self) -> None:
|
||||||
# POST with visibility config key, expect new room id
|
# POST with visibility config key, expect new room id
|
||||||
|
|
|
@ -448,3 +448,55 @@ class ReadWriteLockTestCase(unittest.HomeserverTestCase):
|
||||||
self.get_success(self.store._on_shutdown())
|
self.get_success(self.store._on_shutdown())
|
||||||
|
|
||||||
self.assertEqual(self.store._live_read_write_lock_tokens, {})
|
self.assertEqual(self.store._live_read_write_lock_tokens, {})
|
||||||
|
|
||||||
|
def test_acquire_multiple_locks(self) -> None:
|
||||||
|
"""Tests that acquiring multiple locks at once works."""
|
||||||
|
|
||||||
|
# Take out multiple locks and ensure that we can't get those locks out
|
||||||
|
# again.
|
||||||
|
lock = self.get_success(
|
||||||
|
self.store.try_acquire_multi_read_write_lock(
|
||||||
|
[("name1", "key1"), ("name2", "key2")], write=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(lock)
|
||||||
|
|
||||||
|
assert lock is not None
|
||||||
|
self.get_success(lock.__aenter__())
|
||||||
|
|
||||||
|
lock2 = self.get_success(
|
||||||
|
self.store.try_acquire_read_write_lock("name1", "key1", write=True)
|
||||||
|
)
|
||||||
|
self.assertIsNone(lock2)
|
||||||
|
|
||||||
|
lock3 = self.get_success(
|
||||||
|
self.store.try_acquire_read_write_lock("name2", "key2", write=False)
|
||||||
|
)
|
||||||
|
self.assertIsNone(lock3)
|
||||||
|
|
||||||
|
# Overlapping locks attempts will fail, and won't lock any locks.
|
||||||
|
lock4 = self.get_success(
|
||||||
|
self.store.try_acquire_multi_read_write_lock(
|
||||||
|
[("name1", "key1"), ("name3", "key3")], write=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertIsNone(lock4)
|
||||||
|
|
||||||
|
lock5 = self.get_success(
|
||||||
|
self.store.try_acquire_read_write_lock("name3", "key3", write=True)
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(lock5)
|
||||||
|
assert lock5 is not None
|
||||||
|
self.get_success(lock5.__aenter__())
|
||||||
|
self.get_success(lock5.__aexit__(None, None, None))
|
||||||
|
|
||||||
|
# Once we release the lock we can take out the locks again.
|
||||||
|
self.get_success(lock.__aexit__(None, None, None))
|
||||||
|
|
||||||
|
lock6 = self.get_success(
|
||||||
|
self.store.try_acquire_read_write_lock("name1", "key1", write=True)
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(lock6)
|
||||||
|
assert lock6 is not None
|
||||||
|
self.get_success(lock6.__aenter__())
|
||||||
|
self.get_success(lock6.__aexit__(None, None, None))
|
||||||
|
|
Loading…
Reference in New Issue