Speed up `@cachedList` (#13591)
This speeds things up by ~2x. The vast majority of the time is now spent in `LruCache` moving things around the linked lists. We do this via two things: 1. Don't create a deferred per-key during bulk set operations in `DeferredCache`. Instead, only create them if a subsequent caller asks for the key. 2. Add a bulk lookup API to `DeferredCache` rather than use a loop.
This commit is contained in:
parent
05c9c7363b
commit
f7ddfe17a3
|
@ -0,0 +1 @@
|
|||
Improve performance of `@cachedList`.
|
|
@ -14,15 +14,19 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import enum
|
||||
import threading
|
||||
from typing import (
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterable,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Set,
|
||||
Sized,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
|
@ -31,7 +35,6 @@ from typing import (
|
|||
from prometheus_client import Gauge
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.python import failure
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from synapse.util.async_helpers import ObservableDeferred
|
||||
|
@ -94,7 +97,7 @@ class DeferredCache(Generic[KT, VT]):
|
|||
|
||||
# _pending_deferred_cache maps from the key value to a `CacheEntry` object.
|
||||
self._pending_deferred_cache: Union[
|
||||
TreeCache, "MutableMapping[KT, CacheEntry]"
|
||||
TreeCache, "MutableMapping[KT, CacheEntry[KT, VT]]"
|
||||
] = cache_type()
|
||||
|
||||
def metrics_cb() -> None:
|
||||
|
@ -159,15 +162,16 @@ class DeferredCache(Generic[KT, VT]):
|
|||
Raises:
|
||||
KeyError if the key is not found in the cache
|
||||
"""
|
||||
callbacks = [callback] if callback else []
|
||||
val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
|
||||
if val is not _Sentinel.sentinel:
|
||||
val.callbacks.update(callbacks)
|
||||
val.add_invalidation_callback(key, callback)
|
||||
if update_metrics:
|
||||
m = self.cache.metrics
|
||||
assert m # we always have a name, so should always have metrics
|
||||
m.inc_hits()
|
||||
return val.deferred.observe()
|
||||
return val.deferred(key)
|
||||
|
||||
callbacks = (callback,) if callback else ()
|
||||
|
||||
val2 = self.cache.get(
|
||||
key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
|
||||
|
@ -177,6 +181,73 @@ class DeferredCache(Generic[KT, VT]):
|
|||
else:
|
||||
return defer.succeed(val2)
|
||||
|
||||
def get_bulk(
|
||||
self,
|
||||
keys: Collection[KT],
|
||||
callback: Optional[Callable[[], None]] = None,
|
||||
) -> Tuple[Dict[KT, VT], Optional["defer.Deferred[Dict[KT, VT]]"], Collection[KT]]:
|
||||
"""Bulk lookup of items in the cache.
|
||||
|
||||
Returns:
|
||||
A 3-tuple of:
|
||||
1. a dict of key/value of items already cached;
|
||||
2. a deferred that resolves to a dict of key/value of items
|
||||
we're already fetching; and
|
||||
3. a collection of keys that don't appear in the previous two.
|
||||
"""
|
||||
|
||||
# The cached results
|
||||
cached = {}
|
||||
|
||||
# List of pending deferreds
|
||||
pending = []
|
||||
|
||||
# Dict that gets filled out when the pending deferreds complete
|
||||
pending_results = {}
|
||||
|
||||
# List of keys that aren't in either cache
|
||||
missing = []
|
||||
|
||||
callbacks = (callback,) if callback else ()
|
||||
|
||||
for key in keys:
|
||||
# Check if its in the main cache.
|
||||
immediate_value = self.cache.get(
|
||||
key,
|
||||
_Sentinel.sentinel,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
if immediate_value is not _Sentinel.sentinel:
|
||||
cached[key] = immediate_value
|
||||
continue
|
||||
|
||||
# Check if its in the pending cache
|
||||
pending_value = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
|
||||
if pending_value is not _Sentinel.sentinel:
|
||||
pending_value.add_invalidation_callback(key, callback)
|
||||
|
||||
def completed_cb(value: VT, key: KT) -> VT:
|
||||
pending_results[key] = value
|
||||
return value
|
||||
|
||||
# Add a callback to fill out `pending_results` when that completes
|
||||
d = pending_value.deferred(key).addCallback(completed_cb, key)
|
||||
pending.append(d)
|
||||
continue
|
||||
|
||||
# Not in either cache
|
||||
missing.append(key)
|
||||
|
||||
# If we've got pending deferreds, squash them into a single one that
|
||||
# returns `pending_results`.
|
||||
pending_deferred = None
|
||||
if pending:
|
||||
pending_deferred = defer.gatherResults(
|
||||
pending, consumeErrors=True
|
||||
).addCallback(lambda _: pending_results)
|
||||
|
||||
return (cached, pending_deferred, missing)
|
||||
|
||||
def get_immediate(
|
||||
self, key: KT, default: T, update_metrics: bool = True
|
||||
) -> Union[VT, T]:
|
||||
|
@ -218,84 +289,89 @@ class DeferredCache(Generic[KT, VT]):
|
|||
value: a deferred which will complete with a result to add to the cache
|
||||
callback: An optional callback to be called when the entry is invalidated
|
||||
"""
|
||||
if not isinstance(value, defer.Deferred):
|
||||
raise TypeError("not a Deferred")
|
||||
|
||||
callbacks = [callback] if callback else []
|
||||
self.check_thread()
|
||||
|
||||
existing_entry = self._pending_deferred_cache.pop(key, None)
|
||||
if existing_entry:
|
||||
existing_entry.invalidate()
|
||||
self._pending_deferred_cache.pop(key, None)
|
||||
|
||||
# XXX: why don't we invalidate the entry in `self.cache` yet?
|
||||
|
||||
# we can save a whole load of effort if the deferred is ready.
|
||||
if value.called:
|
||||
result = value.result
|
||||
if not isinstance(result, failure.Failure):
|
||||
self.cache.set(key, cast(VT, result), callbacks)
|
||||
return value
|
||||
|
||||
# otherwise, we'll add an entry to the _pending_deferred_cache for now,
|
||||
# and add callbacks to add it to the cache properly later.
|
||||
|
||||
observable = ObservableDeferred(value, consumeErrors=True)
|
||||
observer = observable.observe()
|
||||
entry = CacheEntry(deferred=observable, callbacks=callbacks)
|
||||
|
||||
entry = CacheEntrySingle[KT, VT](value)
|
||||
entry.add_invalidation_callback(key, callback)
|
||||
self._pending_deferred_cache[key] = entry
|
||||
|
||||
def compare_and_pop() -> bool:
|
||||
"""Check if our entry is still the one in _pending_deferred_cache, and
|
||||
if so, pop it.
|
||||
|
||||
Returns true if the entries matched.
|
||||
"""
|
||||
existing_entry = self._pending_deferred_cache.pop(key, None)
|
||||
if existing_entry is entry:
|
||||
return True
|
||||
|
||||
# oops, the _pending_deferred_cache has been updated since
|
||||
# we started our query, so we are out of date.
|
||||
#
|
||||
# Better put back whatever we took out. (We do it this way
|
||||
# round, rather than peeking into the _pending_deferred_cache
|
||||
# and then removing on a match, to make the common case faster)
|
||||
if existing_entry is not None:
|
||||
self._pending_deferred_cache[key] = existing_entry
|
||||
|
||||
return False
|
||||
|
||||
def cb(result: VT) -> None:
|
||||
if compare_and_pop():
|
||||
self.cache.set(key, result, entry.callbacks)
|
||||
else:
|
||||
# we're not going to put this entry into the cache, so need
|
||||
# to make sure that the invalidation callbacks are called.
|
||||
# That was probably done when _pending_deferred_cache was
|
||||
# updated, but it's possible that `set` was called without
|
||||
# `invalidate` being previously called, in which case it may
|
||||
# not have been. Either way, let's double-check now.
|
||||
entry.invalidate()
|
||||
|
||||
def eb(_fail: Failure) -> None:
|
||||
compare_and_pop()
|
||||
entry.invalidate()
|
||||
|
||||
# once the deferred completes, we can move the entry from the
|
||||
# _pending_deferred_cache to the real cache.
|
||||
#
|
||||
observer.addCallbacks(cb, eb)
|
||||
deferred = entry.deferred(key).addCallbacks(
|
||||
self._completed_callback,
|
||||
self._error_callback,
|
||||
callbackArgs=(entry, key),
|
||||
errbackArgs=(entry, key),
|
||||
)
|
||||
|
||||
# we return a new Deferred which will be called before any subsequent observers.
|
||||
return observable.observe()
|
||||
return deferred
|
||||
|
||||
def start_bulk_input(
|
||||
self,
|
||||
keys: Collection[KT],
|
||||
callback: Optional[Callable[[], None]] = None,
|
||||
) -> "CacheMultipleEntries[KT, VT]":
|
||||
"""Bulk set API for use when fetching multiple keys at once from the DB.
|
||||
|
||||
Called *before* starting the fetch from the DB, and the caller *must*
|
||||
call either `complete_bulk(..)` or `error_bulk(..)` on the return value.
|
||||
"""
|
||||
|
||||
entry = CacheMultipleEntries[KT, VT]()
|
||||
entry.add_global_invalidation_callback(callback)
|
||||
|
||||
for key in keys:
|
||||
self._pending_deferred_cache[key] = entry
|
||||
|
||||
return entry
|
||||
|
||||
def _completed_callback(
|
||||
self, value: VT, entry: "CacheEntry[KT, VT]", key: KT
|
||||
) -> VT:
|
||||
"""Called when a deferred is completed."""
|
||||
# We check if the current entry matches the entry associated with the
|
||||
# deferred. If they don't match then it got invalidated.
|
||||
current_entry = self._pending_deferred_cache.pop(key, None)
|
||||
if current_entry is not entry:
|
||||
if current_entry:
|
||||
self._pending_deferred_cache[key] = current_entry
|
||||
return value
|
||||
|
||||
self.cache.set(key, value, entry.get_invalidation_callbacks(key))
|
||||
|
||||
return value
|
||||
|
||||
def _error_callback(
|
||||
self,
|
||||
failure: Failure,
|
||||
entry: "CacheEntry[KT, VT]",
|
||||
key: KT,
|
||||
) -> Failure:
|
||||
"""Called when a deferred errors."""
|
||||
|
||||
# We check if the current entry matches the entry associated with the
|
||||
# deferred. If they don't match then it got invalidated.
|
||||
current_entry = self._pending_deferred_cache.pop(key, None)
|
||||
if current_entry is not entry:
|
||||
if current_entry:
|
||||
self._pending_deferred_cache[key] = current_entry
|
||||
return failure
|
||||
|
||||
for cb in entry.get_invalidation_callbacks(key):
|
||||
cb()
|
||||
|
||||
return failure
|
||||
|
||||
def prefill(
|
||||
self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
|
||||
) -> None:
|
||||
callbacks = [callback] if callback else []
|
||||
callbacks = (callback,) if callback else ()
|
||||
self.cache.set(key, value, callbacks=callbacks)
|
||||
self._pending_deferred_cache.pop(key, None)
|
||||
|
||||
def invalidate(self, key: KT) -> None:
|
||||
"""Delete a key, or tree of entries
|
||||
|
@ -311,41 +387,129 @@ class DeferredCache(Generic[KT, VT]):
|
|||
self.cache.del_multi(key)
|
||||
|
||||
# if we have a pending lookup for this key, remove it from the
|
||||
# _pending_deferred_cache, which will (a) stop it being returned
|
||||
# for future queries and (b) stop it being persisted as a proper entry
|
||||
# _pending_deferred_cache, which will (a) stop it being returned for
|
||||
# future queries and (b) stop it being persisted as a proper entry
|
||||
# in self.cache.
|
||||
entry = self._pending_deferred_cache.pop(key, None)
|
||||
|
||||
# run the invalidation callbacks now, rather than waiting for the
|
||||
# deferred to resolve.
|
||||
if entry:
|
||||
# _pending_deferred_cache.pop should either return a CacheEntry, or, in the
|
||||
# case of a TreeCache, a dict of keys to cache entries. Either way calling
|
||||
# iterate_tree_cache_entry on it will do the right thing.
|
||||
for entry in iterate_tree_cache_entry(entry):
|
||||
entry.invalidate()
|
||||
for cb in entry.get_invalidation_callbacks(key):
|
||||
cb()
|
||||
|
||||
def invalidate_all(self) -> None:
|
||||
self.check_thread()
|
||||
self.cache.clear()
|
||||
for entry in self._pending_deferred_cache.values():
|
||||
entry.invalidate()
|
||||
for key, entry in self._pending_deferred_cache.items():
|
||||
for cb in entry.get_invalidation_callbacks(key):
|
||||
cb()
|
||||
|
||||
self._pending_deferred_cache.clear()
|
||||
|
||||
|
||||
class CacheEntry:
|
||||
__slots__ = ["deferred", "callbacks", "invalidated"]
|
||||
class CacheEntry(Generic[KT, VT], metaclass=abc.ABCMeta):
|
||||
"""Abstract class for entries in `DeferredCache[KT, VT]`"""
|
||||
|
||||
def __init__(
|
||||
self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
|
||||
):
|
||||
self.deferred = deferred
|
||||
self.callbacks = set(callbacks)
|
||||
self.invalidated = False
|
||||
@abc.abstractmethod
|
||||
def deferred(self, key: KT) -> "defer.Deferred[VT]":
|
||||
"""Get a deferred that a caller can wait on to get the value at the
|
||||
given key"""
|
||||
...
|
||||
|
||||
def invalidate(self) -> None:
|
||||
if not self.invalidated:
|
||||
self.invalidated = True
|
||||
for callback in self.callbacks:
|
||||
callback()
|
||||
self.callbacks.clear()
|
||||
@abc.abstractmethod
|
||||
def add_invalidation_callback(
|
||||
self, key: KT, callback: Optional[Callable[[], None]]
|
||||
) -> None:
|
||||
"""Add an invalidation callback"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
|
||||
"""Get all invalidation callbacks"""
|
||||
...
|
||||
|
||||
|
||||
class CacheEntrySingle(CacheEntry[KT, VT]):
|
||||
"""An implementation of `CacheEntry` wrapping a deferred that results in a
|
||||
single cache entry.
|
||||
"""
|
||||
|
||||
__slots__ = ["_deferred", "_callbacks"]
|
||||
|
||||
def __init__(self, deferred: "defer.Deferred[VT]") -> None:
|
||||
self._deferred = ObservableDeferred(deferred, consumeErrors=True)
|
||||
self._callbacks: Set[Callable[[], None]] = set()
|
||||
|
||||
def deferred(self, key: KT) -> "defer.Deferred[VT]":
|
||||
return self._deferred.observe()
|
||||
|
||||
def add_invalidation_callback(
|
||||
self, key: KT, callback: Optional[Callable[[], None]]
|
||||
) -> None:
|
||||
if callback is None:
|
||||
return
|
||||
|
||||
self._callbacks.add(callback)
|
||||
|
||||
def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
|
||||
return self._callbacks
|
||||
|
||||
|
||||
class CacheMultipleEntries(CacheEntry[KT, VT]):
|
||||
"""Cache entry that is used for bulk lookups and insertions."""
|
||||
|
||||
__slots__ = ["_deferred", "_callbacks", "_global_callbacks"]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._deferred: Optional[ObservableDeferred[Dict[KT, VT]]] = None
|
||||
self._callbacks: Dict[KT, Set[Callable[[], None]]] = {}
|
||||
self._global_callbacks: Set[Callable[[], None]] = set()
|
||||
|
||||
def deferred(self, key: KT) -> "defer.Deferred[VT]":
|
||||
if not self._deferred:
|
||||
self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
|
||||
return self._deferred.observe().addCallback(lambda res: res.get(key))
|
||||
|
||||
def add_invalidation_callback(
|
||||
self, key: KT, callback: Optional[Callable[[], None]]
|
||||
) -> None:
|
||||
if callback is None:
|
||||
return
|
||||
|
||||
self._callbacks.setdefault(key, set()).add(callback)
|
||||
|
||||
def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
|
||||
return self._callbacks.get(key, set()) | self._global_callbacks
|
||||
|
||||
def add_global_invalidation_callback(
|
||||
self, callback: Optional[Callable[[], None]]
|
||||
) -> None:
|
||||
"""Add a callback for when any keys get invalidated."""
|
||||
if callback is None:
|
||||
return
|
||||
|
||||
self._global_callbacks.add(callback)
|
||||
|
||||
def complete_bulk(
|
||||
self,
|
||||
cache: DeferredCache[KT, VT],
|
||||
result: Dict[KT, VT],
|
||||
) -> None:
|
||||
"""Called when there is a result"""
|
||||
for key, value in result.items():
|
||||
cache._completed_callback(value, self, key)
|
||||
|
||||
if self._deferred:
|
||||
self._deferred.callback(result)
|
||||
|
||||
def error_bulk(
|
||||
self, cache: DeferredCache[KT, VT], keys: Collection[KT], failure: Failure
|
||||
) -> None:
|
||||
"""Called when bulk lookup failed."""
|
||||
for key in keys:
|
||||
cache._error_callback(failure, self, key)
|
||||
|
||||
if self._deferred:
|
||||
self._deferred.errback(failure)
|
||||
|
|
|
@ -25,6 +25,7 @@ from typing import (
|
|||
Generic,
|
||||
Hashable,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
|
@ -440,16 +441,6 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|||
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
|
||||
list_args = arg_dict[self.list_name]
|
||||
|
||||
results = {}
|
||||
|
||||
def update_results_dict(res: Any, arg: Hashable) -> None:
|
||||
results[arg] = res
|
||||
|
||||
# list of deferreds to wait for
|
||||
cached_defers = []
|
||||
|
||||
missing = set()
|
||||
|
||||
# If the cache takes a single arg then that is used as the key,
|
||||
# otherwise a tuple is used.
|
||||
if num_args == 1:
|
||||
|
@ -457,6 +448,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|||
def arg_to_cache_key(arg: Hashable) -> Hashable:
|
||||
return arg
|
||||
|
||||
def cache_key_to_arg(key: tuple) -> Hashable:
|
||||
return key
|
||||
|
||||
else:
|
||||
keylist = list(keyargs)
|
||||
|
||||
|
@ -464,58 +458,53 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|||
keylist[self.list_pos] = arg
|
||||
return tuple(keylist)
|
||||
|
||||
for arg in list_args:
|
||||
try:
|
||||
res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
|
||||
if not res.called:
|
||||
res.addCallback(update_results_dict, arg)
|
||||
cached_defers.append(res)
|
||||
else:
|
||||
results[arg] = res.result
|
||||
except KeyError:
|
||||
missing.add(arg)
|
||||
def cache_key_to_arg(key: tuple) -> Hashable:
|
||||
return key[self.list_pos]
|
||||
|
||||
if missing:
|
||||
# we need a deferred for each entry in the list,
|
||||
# which we put in the cache. Each deferred resolves with the
|
||||
# relevant result for that key.
|
||||
deferreds_map = {}
|
||||
for arg in missing:
|
||||
deferred: "defer.Deferred[Any]" = defer.Deferred()
|
||||
deferreds_map[arg] = deferred
|
||||
key = arg_to_cache_key(arg)
|
||||
cached_defers.append(
|
||||
cache.set(key, deferred, callback=invalidate_callback)
|
||||
cache_keys = [arg_to_cache_key(arg) for arg in list_args]
|
||||
immediate_results, pending_deferred, missing = cache.get_bulk(
|
||||
cache_keys, callback=invalidate_callback
|
||||
)
|
||||
|
||||
results = {cache_key_to_arg(key): v for key, v in immediate_results.items()}
|
||||
|
||||
cached_defers: List["defer.Deferred[Any]"] = []
|
||||
if pending_deferred:
|
||||
|
||||
def update_results(r: Dict) -> None:
|
||||
for k, v in r.items():
|
||||
results[cache_key_to_arg(k)] = v
|
||||
|
||||
pending_deferred.addCallback(update_results)
|
||||
cached_defers.append(pending_deferred)
|
||||
|
||||
if missing:
|
||||
cache_entry = cache.start_bulk_input(missing, invalidate_callback)
|
||||
|
||||
def complete_all(res: Dict[Hashable, Any]) -> None:
|
||||
# the wrapped function has completed. It returns a dict.
|
||||
# We can now update our own result map, and then resolve the
|
||||
# observable deferreds in the cache.
|
||||
for e, d1 in deferreds_map.items():
|
||||
val = res.get(e, None)
|
||||
# make sure we update the results map before running the
|
||||
# deferreds, because as soon as we run the last deferred, the
|
||||
# gatherResults() below will complete and return the result
|
||||
# dict to our caller.
|
||||
results[e] = val
|
||||
d1.callback(val)
|
||||
missing_results = {}
|
||||
for key in missing:
|
||||
arg = cache_key_to_arg(key)
|
||||
val = res.get(arg, None)
|
||||
|
||||
results[arg] = val
|
||||
missing_results[key] = val
|
||||
|
||||
cache_entry.complete_bulk(cache, missing_results)
|
||||
|
||||
def errback_all(f: Failure) -> None:
|
||||
# the wrapped function has failed. Propagate the failure into
|
||||
# the cache, which will invalidate the entry, and cause the
|
||||
# relevant cached_deferreds to fail, which will propagate the
|
||||
# failure to our caller.
|
||||
for d1 in deferreds_map.values():
|
||||
d1.errback(f)
|
||||
cache_entry.error_bulk(cache, missing, f)
|
||||
|
||||
args_to_call = dict(arg_dict)
|
||||
args_to_call[self.list_name] = missing
|
||||
args_to_call[self.list_name] = {
|
||||
cache_key_to_arg(key) for key in missing
|
||||
}
|
||||
|
||||
# dispatch the call, and attach the two handlers
|
||||
defer.maybeDeferred(
|
||||
missing_d = defer.maybeDeferred(
|
||||
preserve_fn(self.orig), **args_to_call
|
||||
).addCallbacks(complete_all, errback_all)
|
||||
cached_defers.append(missing_d)
|
||||
|
||||
if cached_defers:
|
||||
d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
|
||||
|
|
|
@ -135,6 +135,9 @@ class TreeCache:
|
|||
def values(self):
|
||||
return iterate_tree_cache_entry(self.root)
|
||||
|
||||
def items(self):
|
||||
return iterate_tree_cache_items((), self.root)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.size
|
||||
|
||||
|
|
Loading…
Reference in New Issue