mirror of
https://github.com/matrix-org/synapse.git
synced 2025-02-10 17:25:48 +00:00
Refactor and convert Linearizer
to async (#12357)
Refactor and convert `Linearizer` to async. This makes a `Linearizer` cancellation bug easier to fix. Also refactor to use an async context manager, which eliminates an unlikely footgun where code that doesn't immediately use the context manager could forget to release the lock. Signed-off-by: Sean Quah <seanq@element.io>
This commit is contained in:
parent
ab3fdcf960
commit
800ba87cc8
1
changelog.d/12357.misc
Normal file
1
changelog.d/12357.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Refactor `Linearizer`, convert methods to async and use an async context manager.
|
@ -188,7 +188,7 @@ class FederationServer(FederationBase):
|
|||||||
async def on_backfill_request(
|
async def on_backfill_request(
|
||||||
self, origin: str, room_id: str, versions: List[str], limit: int
|
self, origin: str, room_id: str, versions: List[str], limit: int
|
||||||
) -> Tuple[int, Dict[str, Any]]:
|
) -> Tuple[int, Dict[str, Any]]:
|
||||||
with (await self._server_linearizer.queue((origin, room_id))):
|
async with self._server_linearizer.queue((origin, room_id)):
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
await self.check_server_matches_acl(origin_host, room_id)
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
|
|
||||||
@ -218,7 +218,7 @@ class FederationServer(FederationBase):
|
|||||||
Tuple indicating the response status code and dictionary response
|
Tuple indicating the response status code and dictionary response
|
||||||
body including `event_id`.
|
body including `event_id`.
|
||||||
"""
|
"""
|
||||||
with (await self._server_linearizer.queue((origin, room_id))):
|
async with self._server_linearizer.queue((origin, room_id)):
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
await self.check_server_matches_acl(origin_host, room_id)
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
|
|
||||||
@ -529,7 +529,7 @@ class FederationServer(FederationBase):
|
|||||||
# in the cache so we could return it without waiting for the linearizer
|
# in the cache so we could return it without waiting for the linearizer
|
||||||
# - but that's non-trivial to get right, and anyway somewhat defeats
|
# - but that's non-trivial to get right, and anyway somewhat defeats
|
||||||
# the point of the linearizer.
|
# the point of the linearizer.
|
||||||
with (await self._server_linearizer.queue((origin, room_id))):
|
async with self._server_linearizer.queue((origin, room_id)):
|
||||||
resp: JsonDict = dict(
|
resp: JsonDict = dict(
|
||||||
await self._state_resp_cache.wrap(
|
await self._state_resp_cache.wrap(
|
||||||
(room_id, event_id),
|
(room_id, event_id),
|
||||||
@ -883,7 +883,7 @@ class FederationServer(FederationBase):
|
|||||||
async def on_event_auth(
|
async def on_event_auth(
|
||||||
self, origin: str, room_id: str, event_id: str
|
self, origin: str, room_id: str, event_id: str
|
||||||
) -> Tuple[int, Dict[str, Any]]:
|
) -> Tuple[int, Dict[str, Any]]:
|
||||||
with (await self._server_linearizer.queue((origin, room_id))):
|
async with self._server_linearizer.queue((origin, room_id)):
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
await self.check_server_matches_acl(origin_host, room_id)
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
|
|
||||||
@ -945,7 +945,7 @@ class FederationServer(FederationBase):
|
|||||||
latest_events: List[str],
|
latest_events: List[str],
|
||||||
limit: int,
|
limit: int,
|
||||||
) -> Dict[str, list]:
|
) -> Dict[str, list]:
|
||||||
with (await self._server_linearizer.queue((origin, room_id))):
|
async with self._server_linearizer.queue((origin, room_id)):
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
await self.check_server_matches_acl(origin_host, room_id)
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
|
|
||||||
|
@ -330,10 +330,8 @@ class ApplicationServicesHandler:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Since we read/update the stream position for this AS/stream
|
# Since we read/update the stream position for this AS/stream
|
||||||
with (
|
async with self._ephemeral_events_linearizer.queue(
|
||||||
await self._ephemeral_events_linearizer.queue(
|
(service.id, stream_key)
|
||||||
(service.id, stream_key)
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
if stream_key == "receipt_key":
|
if stream_key == "receipt_key":
|
||||||
events = await self._handle_receipts(service, new_token)
|
events = await self._handle_receipts(service, new_token)
|
||||||
|
@ -833,7 +833,7 @@ class DeviceListUpdater:
|
|||||||
async def _handle_device_updates(self, user_id: str) -> None:
|
async def _handle_device_updates(self, user_id: str) -> None:
|
||||||
"Actually handle pending updates."
|
"Actually handle pending updates."
|
||||||
|
|
||||||
with (await self._remote_edu_linearizer.queue(user_id)):
|
async with self._remote_edu_linearizer.queue(user_id):
|
||||||
pending_updates = self._pending_updates.pop(user_id, [])
|
pending_updates = self._pending_updates.pop(user_id, [])
|
||||||
if not pending_updates:
|
if not pending_updates:
|
||||||
# This can happen since we batch updates
|
# This can happen since we batch updates
|
||||||
|
@ -118,7 +118,7 @@ class E2eKeysHandler:
|
|||||||
from_device_id: the device making the query. This is used to limit
|
from_device_id: the device making the query. This is used to limit
|
||||||
the number of in-flight queries at a time.
|
the number of in-flight queries at a time.
|
||||||
"""
|
"""
|
||||||
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
|
async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
|
||||||
device_keys_query: Dict[str, Iterable[str]] = query_body.get(
|
device_keys_query: Dict[str, Iterable[str]] = query_body.get(
|
||||||
"device_keys", {}
|
"device_keys", {}
|
||||||
)
|
)
|
||||||
@ -1386,7 +1386,7 @@ class SigningKeyEduUpdater:
|
|||||||
device_handler = self.e2e_keys_handler.device_handler
|
device_handler = self.e2e_keys_handler.device_handler
|
||||||
device_list_updater = device_handler.device_list_updater
|
device_list_updater = device_handler.device_list_updater
|
||||||
|
|
||||||
with (await self._remote_edu_linearizer.queue(user_id)):
|
async with self._remote_edu_linearizer.queue(user_id):
|
||||||
pending_updates = self._pending_updates.pop(user_id, [])
|
pending_updates = self._pending_updates.pop(user_id, [])
|
||||||
if not pending_updates:
|
if not pending_updates:
|
||||||
# This can happen since we batch updates
|
# This can happen since we batch updates
|
||||||
|
@ -83,7 +83,7 @@ class E2eRoomKeysHandler:
|
|||||||
|
|
||||||
# we deliberately take the lock to get keys so that changing the version
|
# we deliberately take the lock to get keys so that changing the version
|
||||||
# works atomically
|
# works atomically
|
||||||
with (await self._upload_linearizer.queue(user_id)):
|
async with self._upload_linearizer.queue(user_id):
|
||||||
# make sure the backup version exists
|
# make sure the backup version exists
|
||||||
try:
|
try:
|
||||||
await self.store.get_e2e_room_keys_version_info(user_id, version)
|
await self.store.get_e2e_room_keys_version_info(user_id, version)
|
||||||
@ -126,7 +126,7 @@ class E2eRoomKeysHandler:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# lock for consistency with uploading
|
# lock for consistency with uploading
|
||||||
with (await self._upload_linearizer.queue(user_id)):
|
async with self._upload_linearizer.queue(user_id):
|
||||||
# make sure the backup version exists
|
# make sure the backup version exists
|
||||||
try:
|
try:
|
||||||
version_info = await self.store.get_e2e_room_keys_version_info(
|
version_info = await self.store.get_e2e_room_keys_version_info(
|
||||||
@ -187,7 +187,7 @@ class E2eRoomKeysHandler:
|
|||||||
# TODO: Validate the JSON to make sure it has the right keys.
|
# TODO: Validate the JSON to make sure it has the right keys.
|
||||||
|
|
||||||
# XXX: perhaps we should use a finer grained lock here?
|
# XXX: perhaps we should use a finer grained lock here?
|
||||||
with (await self._upload_linearizer.queue(user_id)):
|
async with self._upload_linearizer.queue(user_id):
|
||||||
|
|
||||||
# Check that the version we're trying to upload is the current version
|
# Check that the version we're trying to upload is the current version
|
||||||
try:
|
try:
|
||||||
@ -332,7 +332,7 @@ class E2eRoomKeysHandler:
|
|||||||
# TODO: Validate the JSON to make sure it has the right keys.
|
# TODO: Validate the JSON to make sure it has the right keys.
|
||||||
|
|
||||||
# lock everyone out until we've switched version
|
# lock everyone out until we've switched version
|
||||||
with (await self._upload_linearizer.queue(user_id)):
|
async with self._upload_linearizer.queue(user_id):
|
||||||
new_version = await self.store.create_e2e_room_keys_version(
|
new_version = await self.store.create_e2e_room_keys_version(
|
||||||
user_id, version_info
|
user_id, version_info
|
||||||
)
|
)
|
||||||
@ -359,7 +359,7 @@ class E2eRoomKeysHandler:
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with (await self._upload_linearizer.queue(user_id)):
|
async with self._upload_linearizer.queue(user_id):
|
||||||
try:
|
try:
|
||||||
res = await self.store.get_e2e_room_keys_version_info(user_id, version)
|
res = await self.store.get_e2e_room_keys_version_info(user_id, version)
|
||||||
except StoreError as e:
|
except StoreError as e:
|
||||||
@ -383,7 +383,7 @@ class E2eRoomKeysHandler:
|
|||||||
NotFoundError: if this backup version doesn't exist
|
NotFoundError: if this backup version doesn't exist
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with (await self._upload_linearizer.queue(user_id)):
|
async with self._upload_linearizer.queue(user_id):
|
||||||
try:
|
try:
|
||||||
await self.store.delete_e2e_room_keys_version(user_id, version)
|
await self.store.delete_e2e_room_keys_version(user_id, version)
|
||||||
except StoreError as e:
|
except StoreError as e:
|
||||||
@ -413,7 +413,7 @@ class E2eRoomKeysHandler:
|
|||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400, "Version in body does not match", Codes.INVALID_PARAM
|
400, "Version in body does not match", Codes.INVALID_PARAM
|
||||||
)
|
)
|
||||||
with (await self._upload_linearizer.queue(user_id)):
|
async with self._upload_linearizer.queue(user_id):
|
||||||
try:
|
try:
|
||||||
old_info = await self.store.get_e2e_room_keys_version_info(
|
old_info = await self.store.get_e2e_room_keys_version_info(
|
||||||
user_id, version
|
user_id, version
|
||||||
|
@ -151,7 +151,7 @@ class FederationHandler:
|
|||||||
return. This is used as part of the heuristic to decide if we
|
return. This is used as part of the heuristic to decide if we
|
||||||
should back paginate.
|
should back paginate.
|
||||||
"""
|
"""
|
||||||
with (await self._room_backfill.queue(room_id)):
|
async with self._room_backfill.queue(room_id):
|
||||||
return await self._maybe_backfill_inner(room_id, current_depth, limit)
|
return await self._maybe_backfill_inner(room_id, current_depth, limit)
|
||||||
|
|
||||||
async def _maybe_backfill_inner(
|
async def _maybe_backfill_inner(
|
||||||
|
@ -224,7 +224,7 @@ class FederationEventHandler:
|
|||||||
len(missing_prevs),
|
len(missing_prevs),
|
||||||
shortstr(missing_prevs),
|
shortstr(missing_prevs),
|
||||||
)
|
)
|
||||||
with (await self._room_pdu_linearizer.queue(pdu.room_id)):
|
async with self._room_pdu_linearizer.queue(pdu.room_id):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Acquired room lock to fetch %d missing prev_events",
|
"Acquired room lock to fetch %d missing prev_events",
|
||||||
len(missing_prevs),
|
len(missing_prevs),
|
||||||
|
@ -851,7 +851,7 @@ class EventCreationHandler:
|
|||||||
# a situation where event persistence can't keep up, causing
|
# a situation where event persistence can't keep up, causing
|
||||||
# extremities to pile up, which in turn leads to state resolution
|
# extremities to pile up, which in turn leads to state resolution
|
||||||
# taking longer.
|
# taking longer.
|
||||||
with (await self.limiter.queue(event_dict["room_id"])):
|
async with self.limiter.queue(event_dict["room_id"]):
|
||||||
if txn_id and requester.access_token_id:
|
if txn_id and requester.access_token_id:
|
||||||
existing_event_id = await self.store.get_event_id_from_transaction_id(
|
existing_event_id = await self.store.get_event_id_from_transaction_id(
|
||||||
event_dict["room_id"],
|
event_dict["room_id"],
|
||||||
|
@ -1030,7 +1030,7 @@ class PresenceHandler(BasePresenceHandler):
|
|||||||
is_syncing: Whether or not the user is now syncing
|
is_syncing: Whether or not the user is now syncing
|
||||||
sync_time_msec: Time in ms when the user was last syncing
|
sync_time_msec: Time in ms when the user was last syncing
|
||||||
"""
|
"""
|
||||||
with (await self.external_sync_linearizer.queue(process_id)):
|
async with self.external_sync_linearizer.queue(process_id):
|
||||||
prev_state = await self.current_state_for_user(user_id)
|
prev_state = await self.current_state_for_user(user_id)
|
||||||
|
|
||||||
process_presence = self.external_process_to_current_syncs.setdefault(
|
process_presence = self.external_process_to_current_syncs.setdefault(
|
||||||
@ -1071,7 +1071,7 @@ class PresenceHandler(BasePresenceHandler):
|
|||||||
|
|
||||||
Used when the process has stopped/disappeared.
|
Used when the process has stopped/disappeared.
|
||||||
"""
|
"""
|
||||||
with (await self.external_sync_linearizer.queue(process_id)):
|
async with self.external_sync_linearizer.queue(process_id):
|
||||||
process_presence = self.external_process_to_current_syncs.pop(
|
process_presence = self.external_process_to_current_syncs.pop(
|
||||||
process_id, set()
|
process_id, set()
|
||||||
)
|
)
|
||||||
|
@ -40,7 +40,7 @@ class ReadMarkerHandler:
|
|||||||
the read marker has changed.
|
the read marker has changed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with await self.read_marker_linearizer.queue((room_id, user_id)):
|
async with self.read_marker_linearizer.queue((room_id, user_id)):
|
||||||
existing_read_marker = await self.store.get_account_data_for_room_and_type(
|
existing_read_marker = await self.store.get_account_data_for_room_and_type(
|
||||||
user_id, room_id, "m.fully_read"
|
user_id, room_id, "m.fully_read"
|
||||||
)
|
)
|
||||||
|
@ -883,7 +883,7 @@ class RoomCreationHandler:
|
|||||||
#
|
#
|
||||||
# we also don't need to check the requester's shadow-ban here, as we
|
# we also don't need to check the requester's shadow-ban here, as we
|
||||||
# have already done so above (and potentially emptied invite_list).
|
# have already done so above (and potentially emptied invite_list).
|
||||||
with (await self.room_member_handler.member_linearizer.queue((room_id,))):
|
async with self.room_member_handler.member_linearizer.queue((room_id,)):
|
||||||
content = {}
|
content = {}
|
||||||
is_direct = config.get("is_direct", None)
|
is_direct = config.get("is_direct", None)
|
||||||
if is_direct:
|
if is_direct:
|
||||||
|
@ -515,8 +515,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||||||
|
|
||||||
# We first linearise by the application service (to try to limit concurrent joins
|
# We first linearise by the application service (to try to limit concurrent joins
|
||||||
# by application services), and then by room ID.
|
# by application services), and then by room ID.
|
||||||
with (await self.member_as_limiter.queue(as_id)):
|
async with self.member_as_limiter.queue(as_id):
|
||||||
with (await self.member_linearizer.queue(key)):
|
async with self.member_linearizer.queue(key):
|
||||||
result = await self.update_membership_locked(
|
result = await self.update_membership_locked(
|
||||||
requester,
|
requester,
|
||||||
target,
|
target,
|
||||||
|
@ -430,7 +430,7 @@ class SsoHandler:
|
|||||||
# grab a lock while we try to find a mapping for this user. This seems...
|
# grab a lock while we try to find a mapping for this user. This seems...
|
||||||
# optimistic, especially for implementations that end up redirecting to
|
# optimistic, especially for implementations that end up redirecting to
|
||||||
# interstitial pages.
|
# interstitial pages.
|
||||||
with await self._mapping_lock.queue(auth_provider_id):
|
async with self._mapping_lock.queue(auth_provider_id):
|
||||||
# first of all, check if we already have a mapping for this user
|
# first of all, check if we already have a mapping for this user
|
||||||
user_id = await self.get_sso_user_by_remote_user_id(
|
user_id = await self.get_sso_user_by_remote_user_id(
|
||||||
auth_provider_id,
|
auth_provider_id,
|
||||||
|
@ -397,7 +397,7 @@ class RulesForRoom:
|
|||||||
self.room_push_rule_cache_metrics.inc_hits()
|
self.room_push_rule_cache_metrics.inc_hits()
|
||||||
return self.data.rules_by_user
|
return self.data.rules_by_user
|
||||||
|
|
||||||
with (await self.linearizer.queue(self.room_id)):
|
async with self.linearizer.queue(self.room_id):
|
||||||
if state_group and self.data.state_group == state_group:
|
if state_group and self.data.state_group == state_group:
|
||||||
logger.debug("Using cached rules for %r", self.room_id)
|
logger.debug("Using cached rules for %r", self.room_id)
|
||||||
self.room_push_rule_cache_metrics.inc_hits()
|
self.room_push_rule_cache_metrics.inc_hits()
|
||||||
|
@ -451,7 +451,7 @@ class FederationSenderHandler:
|
|||||||
# service for robustness? Or could we replace it with an assertion that
|
# service for robustness? Or could we replace it with an assertion that
|
||||||
# we're not being re-entered?
|
# we're not being re-entered?
|
||||||
|
|
||||||
with (await self._fed_position_linearizer.queue(None)):
|
async with self._fed_position_linearizer.queue(None):
|
||||||
# We persist and ack the same position, so we take a copy of it
|
# We persist and ack the same position, so we take a copy of it
|
||||||
# here as otherwise it can get modified from underneath us.
|
# here as otherwise it can get modified from underneath us.
|
||||||
current_position = self.federation_position
|
current_position = self.federation_position
|
||||||
|
@ -258,7 +258,7 @@ class MediaRepository:
|
|||||||
# We linearize here to ensure that we don't try and download remote
|
# We linearize here to ensure that we don't try and download remote
|
||||||
# media multiple times concurrently
|
# media multiple times concurrently
|
||||||
key = (server_name, media_id)
|
key = (server_name, media_id)
|
||||||
with (await self.remote_media_linearizer.queue(key)):
|
async with self.remote_media_linearizer.queue(key):
|
||||||
responder, media_info = await self._get_remote_media_impl(
|
responder, media_info = await self._get_remote_media_impl(
|
||||||
server_name, media_id
|
server_name, media_id
|
||||||
)
|
)
|
||||||
@ -294,7 +294,7 @@ class MediaRepository:
|
|||||||
# We linearize here to ensure that we don't try and download remote
|
# We linearize here to ensure that we don't try and download remote
|
||||||
# media multiple times concurrently
|
# media multiple times concurrently
|
||||||
key = (server_name, media_id)
|
key = (server_name, media_id)
|
||||||
with (await self.remote_media_linearizer.queue(key)):
|
async with self.remote_media_linearizer.queue(key):
|
||||||
responder, media_info = await self._get_remote_media_impl(
|
responder, media_info = await self._get_remote_media_impl(
|
||||||
server_name, media_id
|
server_name, media_id
|
||||||
)
|
)
|
||||||
@ -850,7 +850,7 @@ class MediaRepository:
|
|||||||
|
|
||||||
# TODO: Should we delete from the backup store
|
# TODO: Should we delete from the backup store
|
||||||
|
|
||||||
with (await self.remote_media_linearizer.queue(key)):
|
async with self.remote_media_linearizer.queue(key):
|
||||||
full_path = self.filepaths.remote_media_filepath(origin, file_id)
|
full_path = self.filepaths.remote_media_filepath(origin, file_id)
|
||||||
try:
|
try:
|
||||||
os.remove(full_path)
|
os.remove(full_path)
|
||||||
|
@ -573,7 +573,7 @@ class StateResolutionHandler:
|
|||||||
"""
|
"""
|
||||||
group_names = frozenset(state_groups_ids.keys())
|
group_names = frozenset(state_groups_ids.keys())
|
||||||
|
|
||||||
with (await self.resolve_linearizer.queue(group_names)):
|
async with self.resolve_linearizer.queue(group_names):
|
||||||
cache = self._state_cache.get(group_names, None)
|
cache = self._state_cache.get(group_names, None)
|
||||||
if cache:
|
if cache:
|
||||||
return cache
|
return cache
|
||||||
|
@ -888,7 +888,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||||||
return frozenset(cache.hosts_to_joined_users)
|
return frozenset(cache.hosts_to_joined_users)
|
||||||
|
|
||||||
# Since we'll mutate the cache we need to lock.
|
# Since we'll mutate the cache we need to lock.
|
||||||
with (await self._joined_host_linearizer.queue(room_id)):
|
async with self._joined_host_linearizer.queue(room_id):
|
||||||
if state_entry.state_group == cache.state_group:
|
if state_entry.state_group == cache.state_group:
|
||||||
# Same state group, so nothing to do. We've already checked for
|
# Same state group, so nothing to do. We've already checked for
|
||||||
# this above, but the cache may have changed while waiting on
|
# this above, but the cache may have changed while waiting on
|
||||||
|
@ -18,7 +18,7 @@ import collections
|
|||||||
import inspect
|
import inspect
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from contextlib import asynccontextmanager, contextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
@ -29,7 +29,6 @@ from typing import (
|
|||||||
Generic,
|
Generic,
|
||||||
Hashable,
|
Hashable,
|
||||||
Iterable,
|
Iterable,
|
||||||
Iterator,
|
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
@ -342,7 +341,7 @@ class Linearizer:
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
with await limiter.queue("test_key"):
|
async with limiter.queue("test_key"):
|
||||||
# do some work.
|
# do some work.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -383,95 +382,53 @@ class Linearizer:
|
|||||||
# non-empty.
|
# non-empty.
|
||||||
return bool(entry.deferreds)
|
return bool(entry.deferreds)
|
||||||
|
|
||||||
def queue(self, key: Hashable) -> defer.Deferred:
|
def queue(self, key: Hashable) -> AsyncContextManager[None]:
|
||||||
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
|
@asynccontextmanager
|
||||||
# (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not
|
async def _ctx_manager() -> AsyncIterator[None]:
|
||||||
# propagated inside inlineCallbacks until Twisted 18.7)
|
entry = await self._acquire_lock(key)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self._release_lock(key, entry)
|
||||||
|
|
||||||
|
return _ctx_manager()
|
||||||
|
|
||||||
|
async def _acquire_lock(self, key: Hashable) -> _LinearizerEntry:
|
||||||
|
"""Acquires a linearizer lock, waiting if necessary.
|
||||||
|
|
||||||
|
Returns once we have secured the lock.
|
||||||
|
"""
|
||||||
entry = self.key_to_defer.setdefault(
|
entry = self.key_to_defer.setdefault(
|
||||||
key, _LinearizerEntry(0, collections.OrderedDict())
|
key, _LinearizerEntry(0, collections.OrderedDict())
|
||||||
)
|
)
|
||||||
|
|
||||||
# If the number of things executing is greater than the maximum
|
if entry.count < self.max_count:
|
||||||
# then add a deferred to the list of blocked items
|
# The number of things executing is less than the maximum.
|
||||||
# When one of the things currently executing finishes it will callback
|
|
||||||
# this item so that it can continue executing.
|
|
||||||
if entry.count >= self.max_count:
|
|
||||||
res = self._await_lock(key)
|
|
||||||
else:
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Acquired uncontended linearizer lock %r for key %r", self.name, key
|
"Acquired uncontended linearizer lock %r for key %r", self.name, key
|
||||||
)
|
)
|
||||||
entry.count += 1
|
entry.count += 1
|
||||||
res = defer.succeed(None)
|
return entry
|
||||||
|
|
||||||
# once we successfully get the lock, we need to return a context manager which
|
|
||||||
# will release the lock.
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def _ctx_manager(_: None) -> Iterator[None]:
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
logger.debug("Releasing linearizer lock %r for key %r", self.name, key)
|
|
||||||
|
|
||||||
# We've finished executing so check if there are any things
|
|
||||||
# blocked waiting to execute and start one of them
|
|
||||||
entry.count -= 1
|
|
||||||
|
|
||||||
if entry.deferreds:
|
|
||||||
(next_def, _) = entry.deferreds.popitem(last=False)
|
|
||||||
|
|
||||||
# we need to run the next thing in the sentinel context.
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
next_def.callback(None)
|
|
||||||
elif entry.count == 0:
|
|
||||||
# We were the last thing for this key: remove it from the
|
|
||||||
# map.
|
|
||||||
del self.key_to_defer[key]
|
|
||||||
|
|
||||||
res.addCallback(_ctx_manager)
|
|
||||||
return res
|
|
||||||
|
|
||||||
def _await_lock(self, key: Hashable) -> defer.Deferred:
|
|
||||||
"""Helper for queue: adds a deferred to the queue
|
|
||||||
|
|
||||||
Assumes that we've already checked that we've reached the limit of the number
|
|
||||||
of lock-holders we allow. Creates a new deferred which is added to the list, and
|
|
||||||
adds some management around cancellations.
|
|
||||||
|
|
||||||
Returns the deferred, which will callback once we have secured the lock.
|
|
||||||
|
|
||||||
"""
|
|
||||||
entry = self.key_to_defer[key]
|
|
||||||
|
|
||||||
|
# Otherwise, the number of things executing is at the maximum and we have to
|
||||||
|
# add a deferred to the list of blocked items.
|
||||||
|
# When one of the things currently executing finishes it will callback
|
||||||
|
# this item so that it can continue executing.
|
||||||
logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
|
logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
|
||||||
|
|
||||||
new_defer: "defer.Deferred[None]" = make_deferred_yieldable(defer.Deferred())
|
new_defer: "defer.Deferred[None]" = make_deferred_yieldable(defer.Deferred())
|
||||||
entry.deferreds[new_defer] = 1
|
entry.deferreds[new_defer] = 1
|
||||||
|
|
||||||
def cb(_r: None) -> "defer.Deferred[None]":
|
try:
|
||||||
logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
|
await new_defer
|
||||||
entry.count += 1
|
except Exception as e:
|
||||||
|
|
||||||
# if the code holding the lock completes synchronously, then it
|
|
||||||
# will recursively run the next claimant on the list. That can
|
|
||||||
# relatively rapidly lead to stack exhaustion. This is essentially
|
|
||||||
# the same problem as http://twistedmatrix.com/trac/ticket/9304.
|
|
||||||
#
|
|
||||||
# In order to break the cycle, we add a cheeky sleep(0) here to
|
|
||||||
# ensure that we fall back to the reactor between each iteration.
|
|
||||||
#
|
|
||||||
# (This needs to happen while we hold the lock, and the context manager's exit
|
|
||||||
# code must be synchronous, so this is the only sensible place.)
|
|
||||||
return self._clock.sleep(0)
|
|
||||||
|
|
||||||
def eb(e: Failure) -> Failure:
|
|
||||||
logger.info("defer %r got err %r", new_defer, e)
|
logger.info("defer %r got err %r", new_defer, e)
|
||||||
if isinstance(e, CancelledError):
|
if isinstance(e, CancelledError):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Cancelling wait for linearizer lock %r for key %r", self.name, key
|
"Cancelling wait for linearizer lock %r for key %r",
|
||||||
|
self.name,
|
||||||
|
key,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Unexpected exception waiting for linearizer lock %r for key %r",
|
"Unexpected exception waiting for linearizer lock %r for key %r",
|
||||||
@ -481,10 +438,43 @@ class Linearizer:
|
|||||||
|
|
||||||
# we just have to take ourselves back out of the queue.
|
# we just have to take ourselves back out of the queue.
|
||||||
del entry.deferreds[new_defer]
|
del entry.deferreds[new_defer]
|
||||||
return e
|
raise
|
||||||
|
|
||||||
new_defer.addCallbacks(cb, eb)
|
logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
|
||||||
return new_defer
|
entry.count += 1
|
||||||
|
|
||||||
|
# if the code holding the lock completes synchronously, then it
|
||||||
|
# will recursively run the next claimant on the list. That can
|
||||||
|
# relatively rapidly lead to stack exhaustion. This is essentially
|
||||||
|
# the same problem as http://twistedmatrix.com/trac/ticket/9304.
|
||||||
|
#
|
||||||
|
# In order to break the cycle, we add a cheeky sleep(0) here to
|
||||||
|
# ensure that we fall back to the reactor between each iteration.
|
||||||
|
#
|
||||||
|
# This needs to happen while we hold the lock. We could put it on the
|
||||||
|
# exit path, but that would slow down the uncontended case.
|
||||||
|
await self._clock.sleep(0)
|
||||||
|
|
||||||
|
return entry
|
||||||
|
|
||||||
|
def _release_lock(self, key: Hashable, entry: _LinearizerEntry) -> None:
|
||||||
|
"""Releases a held linearizer lock."""
|
||||||
|
logger.debug("Releasing linearizer lock %r for key %r", self.name, key)
|
||||||
|
|
||||||
|
# We've finished executing so check if there are any things
|
||||||
|
# blocked waiting to execute and start one of them
|
||||||
|
entry.count -= 1
|
||||||
|
|
||||||
|
if entry.deferreds:
|
||||||
|
(next_def, _) = entry.deferreds.popitem(last=False)
|
||||||
|
|
||||||
|
# we need to run the next thing in the sentinel context.
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
next_def.callback(None)
|
||||||
|
elif entry.count == 0:
|
||||||
|
# We were the last thing for this key: remove it from the
|
||||||
|
# map.
|
||||||
|
del self.key_to_defer[key]
|
||||||
|
|
||||||
|
|
||||||
class ReadWriteLock:
|
class ReadWriteLock:
|
||||||
|
@ -46,7 +46,7 @@ class LinearizerTestCase(unittest.TestCase):
|
|||||||
unblock_d: "Deferred[None]" = Deferred()
|
unblock_d: "Deferred[None]" = Deferred()
|
||||||
|
|
||||||
async def task() -> None:
|
async def task() -> None:
|
||||||
with await linearizer.queue(key):
|
async with linearizer.queue(key):
|
||||||
acquired_d.callback(None)
|
acquired_d.callback(None)
|
||||||
await unblock_d
|
await unblock_d
|
||||||
|
|
||||||
@ -125,7 +125,7 @@ class LinearizerTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
async def func(i: int) -> None:
|
async def func(i: int) -> None:
|
||||||
with LoggingContext("func(%s)" % i) as lc:
|
with LoggingContext("func(%s)" % i) as lc:
|
||||||
with (await linearizer.queue(key)):
|
async with linearizer.queue(key):
|
||||||
self.assertEqual(current_context(), lc)
|
self.assertEqual(current_context(), lc)
|
||||||
|
|
||||||
self.assertEqual(current_context(), lc)
|
self.assertEqual(current_context(), lc)
|
||||||
|
Loading…
Reference in New Issue
Block a user