mypy plugin to check `@cached` return types (#14911)

Co-authored-by: David Robertson <davidr@element.io>
Co-authored-by: Patrick Cloke <patrickc@matrix.org>
Co-authored-by: Erik Johnston <erik@matrix.org>

Assert that the return type of callables wrapped in @cached
and @cachedList are cachable (aka immutable).
This commit is contained in:
David Robertson 2023-10-02 15:22:36 +01:00 committed by GitHub
parent 5725712d47
commit 1026776380
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 323 additions and 58 deletions

1
changelog.d/14911.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -16,13 +16,24 @@
can crop up, e.g the cache descriptors. can crop up, e.g the cache descriptors.
""" """
from typing import Callable, Optional, Type from typing import Callable, Optional, Tuple, Type, Union
import mypy.types
from mypy.erasetype import remove_instance_last_known_values from mypy.erasetype import remove_instance_last_known_values
from mypy.nodes import ARG_NAMED_OPT from mypy.errorcodes import ErrorCode
from mypy.plugin import MethodSigContext, Plugin from mypy.nodes import ARG_NAMED_OPT, TempNode, Var
from mypy.plugin import FunctionSigContext, MethodSigContext, Plugin
from mypy.typeops import bind_self from mypy.typeops import bind_self
from mypy.types import CallableType, Instance, NoneType, UnionType from mypy.types import (
AnyType,
CallableType,
Instance,
NoneType,
TupleType,
TypeAliasType,
UninhabitedType,
UnionType,
)
class SynapsePlugin(Plugin): class SynapsePlugin(Plugin):
@ -36,9 +47,37 @@ class SynapsePlugin(Plugin):
) )
): ):
return cached_function_method_signature return cached_function_method_signature
if fullname in (
"synapse.util.caches.descriptors._CachedFunctionDescriptor.__call__",
"synapse.util.caches.descriptors._CachedListFunctionDescriptor.__call__",
):
return check_is_cacheable_wrapper
return None return None
def _get_true_return_type(signature: CallableType) -> mypy.types.Type:
"""
Get the "final" return type of a callable which might return an Awaitable/Deferred.
"""
if isinstance(signature.ret_type, Instance):
# If a coroutine, unwrap the coroutine's return type.
if signature.ret_type.type.fullname == "typing.Coroutine":
return signature.ret_type.args[2]
# If an awaitable, unwrap the awaitable's final value.
elif signature.ret_type.type.fullname == "typing.Awaitable":
return signature.ret_type.args[0]
# If a Deferred, unwrap the Deferred's final value.
elif signature.ret_type.type.fullname == "twisted.internet.defer.Deferred":
return signature.ret_type.args[0]
# Otherwise, return the raw value of the function.
return signature.ret_type
def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
"""Fixes the `CachedFunction.__call__` signature to be correct. """Fixes the `CachedFunction.__call__` signature to be correct.
@ -47,16 +86,17 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
1. the `self` argument needs to be marked as "bound"; 1. the `self` argument needs to be marked as "bound";
2. any `cache_context` argument should be removed; 2. any `cache_context` argument should be removed;
3. an optional keyword argument `on_invalidated` should be added. 3. an optional keyword argument `on_invalidated` should be added.
4. Wrap the return type to always be a Deferred.
""" """
# First we mark this as a bound function signature. # 1. Mark this as a bound function signature.
signature = bind_self(ctx.default_signature) signature: CallableType = bind_self(ctx.default_signature)
# Secondly, we remove any "cache_context" args. # 2. Remove any "cache_context" args.
# #
# Note: We should be only doing this if `cache_context=True` is set, but if # Note: We should be only doing this if `cache_context=True` is set, but if
# it isn't then the code will raise an exception when its called anyway, so # it isn't then the code will raise an exception when its called anyway, so
# its not the end of the world. # it's not the end of the world.
context_arg_index = None context_arg_index = None
for idx, name in enumerate(signature.arg_names): for idx, name in enumerate(signature.arg_names):
if name == "cache_context": if name == "cache_context":
@ -72,7 +112,7 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
arg_names.pop(context_arg_index) arg_names.pop(context_arg_index)
arg_kinds.pop(context_arg_index) arg_kinds.pop(context_arg_index)
# Third, we add an optional "on_invalidate" argument. # 3. Add an optional "on_invalidate" argument.
# #
# This is a either # This is a either
# - a callable which accepts no input and returns nothing, or # - a callable which accepts no input and returns nothing, or
@ -94,35 +134,16 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
arg_names.append("on_invalidate") arg_names.append("on_invalidate")
arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg. arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg.
# Finally we ensure the return type is a Deferred. # 4. Ensure the return type is a Deferred.
if ( ret_arg = _get_true_return_type(signature)
isinstance(signature.ret_type, Instance)
and signature.ret_type.type.fullname == "twisted.internet.defer.Deferred"
):
# If it is already a Deferred, nothing to do.
ret_type = signature.ret_type
else:
ret_arg = None
if isinstance(signature.ret_type, Instance):
# If a coroutine, wrap the coroutine's return type in a Deferred.
if signature.ret_type.type.fullname == "typing.Coroutine":
ret_arg = signature.ret_type.args[2]
# If an awaitable, wrap the awaitable's final value in a Deferred. # This should be able to use ctx.api.named_generic_type, but that doesn't seem
elif signature.ret_type.type.fullname == "typing.Awaitable": # to find the correct symbol for anything more than 1 module deep.
ret_arg = signature.ret_type.args[0] #
# modules is not part of CheckerPluginInterface. The following is a combination
# Otherwise, wrap the return value in a Deferred. # of TypeChecker.named_generic_type and TypeChecker.lookup_typeinfo.
if ret_arg is None: sym = ctx.api.modules["twisted.internet.defer"].names.get("Deferred") # type: ignore[attr-defined]
ret_arg = signature.ret_type ret_type = Instance(sym.node, [remove_instance_last_known_values(ret_arg)])
# This should be able to use ctx.api.named_generic_type, but that doesn't seem
# to find the correct symbol for anything more than 1 module deep.
#
# modules is not part of CheckerPluginInterface. The following is a combination
# of TypeChecker.named_generic_type and TypeChecker.lookup_typeinfo.
sym = ctx.api.modules["twisted.internet.defer"].names.get("Deferred") # type: ignore[attr-defined]
ret_type = Instance(sym.node, [remove_instance_last_known_values(ret_arg)])
signature = signature.copy_modified( signature = signature.copy_modified(
arg_types=arg_types, arg_types=arg_types,
@ -134,6 +155,198 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
return signature return signature
def check_is_cacheable_wrapper(ctx: MethodSigContext) -> CallableType:
"""Asserts that the signature of a method returns a value which can be cached.
Makes no changes to the provided method signature.
"""
# The true signature, this isn't being modified so this is what will be returned.
signature: CallableType = ctx.default_signature
if not isinstance(ctx.args[0][0], TempNode):
ctx.api.note("Cached function is not a TempNode?!", ctx.context) # type: ignore[attr-defined]
return signature
orig_sig = ctx.args[0][0].type
if not isinstance(orig_sig, CallableType):
ctx.api.fail("Cached 'function' is not a callable", ctx.context)
return signature
check_is_cacheable(orig_sig, ctx)
return signature
def check_is_cacheable(
signature: CallableType,
ctx: Union[MethodSigContext, FunctionSigContext],
) -> None:
"""
Check if a callable returns a type which can be cached.
Args:
signature: The callable to check.
ctx: The signature context, used for error reporting.
"""
# Unwrap the true return type from the cached function.
return_type = _get_true_return_type(signature)
verbose = ctx.api.options.verbosity >= 1
# TODO Technically a cachedList only needs immutable values, but forcing them
# to return Mapping instead of Dict is fine.
ok, note = is_cacheable(return_type, signature, verbose)
if ok:
message = f"function {signature.name} is @cached, returning {return_type}"
else:
message = f"function {signature.name} is @cached, but has mutable return value {return_type}"
if note:
message += f" ({note})"
message = message.replace("builtins.", "").replace("typing.", "")
if ok and note:
ctx.api.note(message, ctx.context) # type: ignore[attr-defined]
elif not ok:
ctx.api.fail(message, ctx.context, code=AT_CACHED_MUTABLE_RETURN)
# Immutable simple values.
IMMUTABLE_VALUE_TYPES = {
"builtins.bool",
"builtins.int",
"builtins.float",
"builtins.str",
"builtins.bytes",
}
# Types defined in Synapse which are known to be immutable.
IMMUTABLE_CUSTOM_TYPES = {
"synapse.synapse_rust.acl.ServerAclEvaluator",
"synapse.synapse_rust.push.FilteredPushRules",
# This is technically not immutable, but close enough.
"signedjson.types.VerifyKey",
}
# Immutable containers only if the values are also immutable.
IMMUTABLE_CONTAINER_TYPES_REQUIRING_IMMUTABLE_ELEMENTS = {
"builtins.frozenset",
"builtins.tuple",
"typing.AbstractSet",
"typing.Sequence",
"immutabledict.immutabledict",
}
MUTABLE_CONTAINER_TYPES = {
"builtins.set",
"builtins.list",
"builtins.dict",
}
AT_CACHED_MUTABLE_RETURN = ErrorCode(
"synapse-@cached-mutable",
"@cached() should have an immutable return type",
"General",
)
def is_cacheable(
rt: mypy.types.Type, signature: CallableType, verbose: bool
) -> Tuple[bool, Optional[str]]:
"""
Check if a particular type is cachable.
A type is cachable if it is immutable; for complex types this recurses to
check each type parameter.
Returns: a 2-tuple (cacheable, message).
- cachable: False means the type is definitely not cacheable;
true means anything else.
- Optional message.
"""
# This should probably be done via a TypeVisitor. Apologies to the reader!
if isinstance(rt, AnyType):
return True, ("may be mutable" if verbose else None)
elif isinstance(rt, Instance):
if (
rt.type.fullname in IMMUTABLE_VALUE_TYPES
or rt.type.fullname in IMMUTABLE_CUSTOM_TYPES
):
# "Simple" types are generally immutable.
return True, None
elif rt.type.fullname == "typing.Mapping":
# Generally mapping keys are immutable, but they only *have* to be
# hashable, which doesn't imply immutability. E.g. Mapping[K, V]
# is cachable iff K and V are cachable.
return is_cacheable(rt.args[0], signature, verbose) and is_cacheable(
rt.args[1], signature, verbose
)
elif rt.type.fullname in IMMUTABLE_CONTAINER_TYPES_REQUIRING_IMMUTABLE_ELEMENTS:
# E.g. Collection[T] is cachable iff T is cachable.
return is_cacheable(rt.args[0], signature, verbose)
elif rt.type.fullname in MUTABLE_CONTAINER_TYPES:
# Mutable containers are mutable regardless of their underlying type.
return False, None
elif "attrs" in rt.type.metadata:
# attrs classes are only cachable iff it is frozen (immutable itself)
# and all attributes are cachable.
frozen = rt.type.metadata["attrs"]["frozen"]
if frozen:
for attribute in rt.type.metadata["attrs"]["attributes"]:
attribute_name = attribute["name"]
symbol_node = rt.type.names[attribute_name].node
assert isinstance(symbol_node, Var)
assert symbol_node.type is not None
ok, note = is_cacheable(symbol_node.type, signature, verbose)
if not ok:
return False, f"non-frozen attrs property: {attribute_name}"
# All attributes were frozen.
return True, None
else:
return False, "non-frozen attrs class"
else:
# Ensure we fail for unknown types, these generally means that the
# above code is not complete.
return (
False,
f"Don't know how to handle {rt.type.fullname} return type instance",
)
elif isinstance(rt, NoneType):
# None is cachable.
return True, None
elif isinstance(rt, (TupleType, UnionType)):
# Tuples and unions are cachable iff all their items are cachable.
for item in rt.items:
ok, note = is_cacheable(item, signature, verbose)
if not ok:
return False, note
# This discards notes but that's probably fine
return True, None
elif isinstance(rt, TypeAliasType):
# For a type alias, check if the underlying real type is cachable.
return is_cacheable(mypy.types.get_proper_type(rt), signature, verbose)
elif isinstance(rt, UninhabitedType) and rt.is_noreturn:
# There is no return value, just consider it cachable. This is only used
# in tests.
return True, None
else:
# Ensure we fail for unknown types, these generally means that the
# above code is not complete.
return False, f"Don't know how to handle {type(rt).__qualname__} return type"
def plugin(version: str) -> Type[SynapsePlugin]: def plugin(version: str) -> Type[SynapsePlugin]:
# This is the entry point of the plugin, and lets us deal with the fact # This is the entry point of the plugin, and lets us deal with the fact
# that the mypy plugin interface is *not* stable by looking at the version # that the mypy plugin interface is *not* stable by looking at the version

View File

@ -33,7 +33,7 @@ from synapse.api.errors import (
RequestSendFailed, RequestSendFailed,
SynapseError, SynapseError,
) )
from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.types import JsonDict, JsonMapping, ThirdPartyInstanceID
from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
@ -256,7 +256,7 @@ class RoomListHandler:
cache_context: _CacheContext, cache_context: _CacheContext,
with_alias: bool = True, with_alias: bool = True,
allow_private: bool = False, allow_private: bool = False,
) -> Optional[JsonDict]: ) -> Optional[JsonMapping]:
"""Returns the entry for a room """Returns the entry for a room
Args: Args:

View File

@ -182,6 +182,7 @@ class UserPushAction(EmailPushAction):
profile_tag: str profile_tag: str
# TODO This is used as a cached value and is mutable.
@attr.s(slots=True, auto_attribs=True) @attr.s(slots=True, auto_attribs=True)
class NotifCounts: class NotifCounts:
""" """
@ -193,7 +194,7 @@ class NotifCounts:
highlight_count: int = 0 highlight_count: int = 0
@attr.s(slots=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class RoomNotifCounts: class RoomNotifCounts:
""" """
The per-user, per-room count of notifications. Used by sync and push. The per-user, per-room count of notifications. Used by sync and push.
@ -201,7 +202,7 @@ class RoomNotifCounts:
main_timeline: NotifCounts main_timeline: NotifCounts
# Map of thread ID to the notification counts. # Map of thread ID to the notification counts.
threads: Dict[str, NotifCounts] threads: Mapping[str, NotifCounts]
@staticmethod @staticmethod
def empty() -> "RoomNotifCounts": def empty() -> "RoomNotifCounts":
@ -483,7 +484,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
return room_to_count return room_to_count
@cached(tree=True, max_entries=5000, iterable=True) @cached(tree=True, max_entries=5000, iterable=True) # type: ignore[synapse-@cached-mutable]
async def get_unread_event_push_actions_by_room_for_user( async def get_unread_event_push_actions_by_room_for_user(
self, self,
room_id: str, room_id: str,

View File

@ -458,7 +458,7 @@ class RelationsWorkerStore(SQLBaseStore):
) )
return result is not None return result is not None
@cached() @cached() # type: ignore[synapse-@cached-mutable]
async def get_references_for_event(self, event_id: str) -> List[JsonDict]: async def get_references_for_event(self, event_id: str) -> List[JsonDict]:
raise NotImplementedError() raise NotImplementedError()
@ -512,11 +512,12 @@ class RelationsWorkerStore(SQLBaseStore):
"_get_references_for_events_txn", _get_references_for_events_txn "_get_references_for_events_txn", _get_references_for_events_txn
) )
@cached() @cached() # type: ignore[synapse-@cached-mutable]
def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
raise NotImplementedError() raise NotImplementedError()
@cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") # TODO: This returns a mutable object, which is generally bad.
@cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") # type: ignore[synapse-@cached-mutable]
async def get_applicable_edits( async def get_applicable_edits(
self, event_ids: Collection[str] self, event_ids: Collection[str]
) -> Mapping[str, Optional[EventBase]]: ) -> Mapping[str, Optional[EventBase]]:
@ -598,11 +599,12 @@ class RelationsWorkerStore(SQLBaseStore):
for original_event_id in event_ids for original_event_id in event_ids
} }
@cached() @cached() # type: ignore[synapse-@cached-mutable]
def get_thread_summary(self, event_id: str) -> Optional[Tuple[int, EventBase]]: def get_thread_summary(self, event_id: str) -> Optional[Tuple[int, EventBase]]:
raise NotImplementedError() raise NotImplementedError()
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids") # TODO: This returns a mutable object, which is generally bad.
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids") # type: ignore[synapse-@cached-mutable]
async def get_thread_summaries( async def get_thread_summaries(
self, event_ids: Collection[str] self, event_ids: Collection[str]
) -> Mapping[str, Optional[Tuple[int, EventBase]]]: ) -> Mapping[str, Optional[Tuple[int, EventBase]]]:

View File

@ -275,7 +275,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
_get_users_in_room_with_profiles, _get_users_in_room_with_profiles,
) )
@cached(max_entries=100000) @cached(max_entries=100000) # type: ignore[synapse-@cached-mutable]
async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]: async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]:
"""Get the details of a room roughly suitable for use by the room """Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members. summary extension to /sync. Useful when lazy loading room members.
@ -1071,7 +1071,8 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
) )
return {row["event_id"]: row["membership"] for row in rows} return {row["event_id"]: row["membership"] for row in rows}
@cached(max_entries=10000) # TODO This returns a mutable object, which is generally confusing when using a cache.
@cached(max_entries=10000) # type: ignore[synapse-@cached-mutable]
def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache": def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache":
return _JoinedHostsCache() return _JoinedHostsCache()

View File

@ -45,6 +45,7 @@ class ProfileInfo:
display_name: Optional[str] display_name: Optional[str]
# TODO This is used as a cached value and is mutable.
@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True) @attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
class MemberSummary: class MemberSummary:
# A truncated list of (user_id, event_id) tuples for users of a given # A truncated list of (user_id, event_id) tuples for users of a given

View File

@ -36,6 +36,8 @@ from typing import (
) )
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
import attr
from twisted.internet import defer from twisted.internet import defer
from twisted.python.failure import Failure from twisted.python.failure import Failure
@ -466,6 +468,35 @@ class _CacheContext:
) )
@attr.s(auto_attribs=True, slots=True, frozen=True)
class _CachedFunctionDescriptor:
"""Helper for `@cached`, we name it so that we can hook into it with mypy
plugin."""
max_entries: int
num_args: Optional[int]
uncached_args: Optional[Collection[str]]
tree: bool
cache_context: bool
iterable: bool
prune_unread_entries: bool
name: Optional[str]
def __call__(self, orig: F) -> CachedFunction[F]:
d = DeferredCacheDescriptor(
orig,
max_entries=self.max_entries,
num_args=self.num_args,
uncached_args=self.uncached_args,
tree=self.tree,
cache_context=self.cache_context,
iterable=self.iterable,
prune_unread_entries=self.prune_unread_entries,
name=self.name,
)
return cast(CachedFunction[F], d)
def cached( def cached(
*, *,
max_entries: int = 1000, max_entries: int = 1000,
@ -476,9 +507,8 @@ def cached(
iterable: bool = False, iterable: bool = False,
prune_unread_entries: bool = True, prune_unread_entries: bool = True,
name: Optional[str] = None, name: Optional[str] = None,
) -> Callable[[F], CachedFunction[F]]: ) -> _CachedFunctionDescriptor:
func = lambda orig: DeferredCacheDescriptor( return _CachedFunctionDescriptor(
orig,
max_entries=max_entries, max_entries=max_entries,
num_args=num_args, num_args=num_args,
uncached_args=uncached_args, uncached_args=uncached_args,
@ -489,7 +519,26 @@ def cached(
name=name, name=name,
) )
return cast(Callable[[F], CachedFunction[F]], func)
@attr.s(auto_attribs=True, slots=True, frozen=True)
class _CachedListFunctionDescriptor:
"""Helper for `@cachedList`, we name it so that we can hook into it with mypy
plugin."""
cached_method_name: str
list_name: str
num_args: Optional[int] = None
name: Optional[str] = None
def __call__(self, orig: F) -> CachedFunction[F]:
d = DeferredCacheListDescriptor(
orig,
cached_method_name=self.cached_method_name,
list_name=self.list_name,
num_args=self.num_args,
name=self.name,
)
return cast(CachedFunction[F], d)
def cachedList( def cachedList(
@ -498,7 +547,7 @@ def cachedList(
list_name: str, list_name: str,
num_args: Optional[int] = None, num_args: Optional[int] = None,
name: Optional[str] = None, name: Optional[str] = None,
) -> Callable[[F], CachedFunction[F]]: ) -> _CachedListFunctionDescriptor:
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`. """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
Used to do batch lookups for an already created cache. One of the arguments Used to do batch lookups for an already created cache. One of the arguments
@ -527,16 +576,13 @@ def cachedList(
def batch_do_something(self, first_arg, second_args): def batch_do_something(self, first_arg, second_args):
... ...
""" """
func = lambda orig: DeferredCacheListDescriptor( return _CachedListFunctionDescriptor(
orig,
cached_method_name=cached_method_name, cached_method_name=cached_method_name,
list_name=list_name, list_name=list_name,
num_args=num_args, num_args=num_args,
name=name, name=name,
) )
return cast(Callable[[F], CachedFunction[F]], func)
def _get_cache_key_builder( def _get_cache_key_builder(
param_names: Sequence[str], param_names: Sequence[str],