Improve type hints for attrs classes (#16276)

This commit is contained in:
David Robertson 2023-09-08 19:29:38 +01:00 committed by GitHub
parent a0ed55ef12
commit edd83f23b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 37 additions and 39 deletions

1
changelog.d/16276.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -30,7 +30,7 @@ class OEmbedEndpointConfig:
# The API endpoint to fetch. # The API endpoint to fetch.
api_endpoint: str api_endpoint: str
# The patterns to match. # The patterns to match.
url_patterns: List[Pattern] url_patterns: List[Pattern[str]]
# The supported formats. # The supported formats.
formats: Optional[List[str]] formats: Optional[List[str]]

View File

@ -154,12 +154,13 @@ class _UpdateCurrentStateTask:
_EventPersistQueueTask = Union[_PersistEventsTask, _UpdateCurrentStateTask] _EventPersistQueueTask = Union[_PersistEventsTask, _UpdateCurrentStateTask]
_PersistResult = TypeVar("_PersistResult")
@attr.s(auto_attribs=True, slots=True) @attr.s(auto_attribs=True, slots=True)
class _EventPersistQueueItem: class _EventPersistQueueItem(Generic[_PersistResult]):
task: _EventPersistQueueTask task: _EventPersistQueueTask
deferred: ObservableDeferred deferred: ObservableDeferred[_PersistResult]
parent_opentracing_span_contexts: List = attr.ib(factory=list) parent_opentracing_span_contexts: List = attr.ib(factory=list)
"""A list of opentracing spans waiting for this batch""" """A list of opentracing spans waiting for this batch"""
@ -168,9 +169,6 @@ class _EventPersistQueueItem:
"""The opentracing span under which the persistence actually happened""" """The opentracing span under which the persistence actually happened"""
_PersistResult = TypeVar("_PersistResult")
class _EventPeristenceQueue(Generic[_PersistResult]): class _EventPeristenceQueue(Generic[_PersistResult]):
"""Queues up tasks so that they can be processed with only one concurrent """Queues up tasks so that they can be processed with only one concurrent
transaction per room. transaction per room.

View File

@ -19,6 +19,7 @@ import collections
import inspect import inspect
import itertools import itertools
import logging import logging
import typing
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import ( from typing import (
Any, Any,
@ -29,6 +30,7 @@ from typing import (
Collection, Collection,
Coroutine, Coroutine,
Dict, Dict,
Generator,
Generic, Generic,
Hashable, Hashable,
Iterable, Iterable,
@ -398,7 +400,7 @@ class _LinearizerEntry:
# The number of things executing. # The number of things executing.
count: int count: int
# Deferreds for the things blocked from executing. # Deferreds for the things blocked from executing.
deferreds: collections.OrderedDict deferreds: typing.OrderedDict["defer.Deferred[None]", Literal[1]]
class Linearizer: class Linearizer:
@ -717,30 +719,25 @@ def timeout_deferred(
return new_d return new_d
# This class can't be generic because it uses slots with attrs.
# See: https://github.com/python-attrs/attrs/issues/313
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class DoneAwaitable: # should be: Generic[R] class DoneAwaitable(Awaitable[R]):
"""Simple awaitable that returns the provided value.""" """Simple awaitable that returns the provided value."""
value: Any # should be: R value: R
def __await__(self) -> Any: def __await__(self) -> Generator[Any, None, R]:
return self yield None
return self.value
def __iter__(self) -> "DoneAwaitable":
return self
def __next__(self) -> None:
raise StopIteration(self.value)
def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]: def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
"""Convert a value to an awaitable if not already an awaitable.""" """Convert a value to an awaitable if not already an awaitable."""
if inspect.isawaitable(value): if inspect.isawaitable(value):
assert isinstance(value, Awaitable)
return value return value
# For some reason mypy doesn't deduce that value is not Awaitable here, even though
# inspect.isawaitable returns a TypeGuard.
assert not isinstance(value, Awaitable)
return DoneAwaitable(value) return DoneAwaitable(value)

View File

@ -14,7 +14,7 @@
import enum import enum
import logging import logging
import threading import threading
from typing import Any, Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union from typing import Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union
import attr import attr
from typing_extensions import Literal from typing_extensions import Literal
@ -33,10 +33,8 @@ DKT = TypeVar("DKT")
DV = TypeVar("DV") DV = TypeVar("DV")
# This class can't be generic because it uses slots with attrs.
# See: https://github.com/python-attrs/attrs/issues/313
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class DictionaryEntry: # should be: Generic[DKT, DV]. class DictionaryEntry(Generic[DKT, DV]):
"""Returned when getting an entry from the cache """Returned when getting an entry from the cache
If `full` is true then `known_absent` will be the empty set. If `full` is true then `known_absent` will be the empty set.
@ -50,8 +48,8 @@ class DictionaryEntry: # should be: Generic[DKT, DV].
""" """
full: bool full: bool
known_absent: Set[Any] # should be: Set[DKT] known_absent: Set[DKT]
value: Dict[Any, Any] # should be: Dict[DKT, DV] value: Dict[DKT, DV]
def __len__(self) -> int: def __len__(self) -> int:
return len(self.value) return len(self.value)

View File

@ -14,7 +14,7 @@
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Generic, Optional, TypeVar, Union, overload from typing import Any, Generic, Iterable, Optional, TypeVar, Union, overload
import attr import attr
from typing_extensions import Literal from typing_extensions import Literal
@ -73,7 +73,7 @@ class ExpiringCache(Generic[KT, VT]):
self._expiry_ms = expiry_ms self._expiry_ms = expiry_ms
self._reset_expiry_on_get = reset_expiry_on_get self._reset_expiry_on_get = reset_expiry_on_get
self._cache: OrderedDict[KT, _CacheEntry] = OrderedDict() self._cache: OrderedDict[KT, _CacheEntry[VT]] = OrderedDict()
self.iterable = iterable self.iterable = iterable
@ -100,7 +100,10 @@ class ExpiringCache(Generic[KT, VT]):
while self._max_size and len(self) > self._max_size: while self._max_size and len(self) > self._max_size:
_key, value = self._cache.popitem(last=False) _key, value = self._cache.popitem(last=False)
if self.iterable: if self.iterable:
self.metrics.inc_evictions(EvictionReason.size, len(value.value)) # type-ignore, here and below: if self.iterable is true, then the value
# type VT should be Sized (i.e. have a __len__ method). We don't enforce
# this via the type system at present.
self.metrics.inc_evictions(EvictionReason.size, len(value.value)) # type: ignore[arg-type]
else: else:
self.metrics.inc_evictions(EvictionReason.size) self.metrics.inc_evictions(EvictionReason.size)
@ -134,7 +137,7 @@ class ExpiringCache(Generic[KT, VT]):
return default return default
if self.iterable: if self.iterable:
self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value)) self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value)) # type: ignore[arg-type]
else: else:
self.metrics.inc_evictions(EvictionReason.invalidation) self.metrics.inc_evictions(EvictionReason.invalidation)
@ -182,7 +185,7 @@ class ExpiringCache(Generic[KT, VT]):
for k in keys_to_delete: for k in keys_to_delete:
value = self._cache.pop(k) value = self._cache.pop(k)
if self.iterable: if self.iterable:
self.metrics.inc_evictions(EvictionReason.time, len(value.value)) self.metrics.inc_evictions(EvictionReason.time, len(value.value)) # type: ignore[arg-type]
else: else:
self.metrics.inc_evictions(EvictionReason.time) self.metrics.inc_evictions(EvictionReason.time)
@ -195,7 +198,8 @@ class ExpiringCache(Generic[KT, VT]):
def __len__(self) -> int: def __len__(self) -> int:
if self.iterable: if self.iterable:
return sum(len(entry.value) for entry in self._cache.values()) g: Iterable[int] = (len(entry.value) for entry in self._cache.values()) # type: ignore[arg-type]
return sum(g)
else: else:
return len(self._cache) return len(self._cache)
@ -218,6 +222,6 @@ class ExpiringCache(Generic[KT, VT]):
@attr.s(slots=True, auto_attribs=True) @attr.s(slots=True, auto_attribs=True)
class _CacheEntry: class _CacheEntry(Generic[VT]):
time: int time: int
value: Any value: VT

View File

@ -35,10 +35,10 @@ class TTLCache(Generic[KT, VT]):
def __init__(self, cache_name: str, timer: Callable[[], float] = time.time): def __init__(self, cache_name: str, timer: Callable[[], float] = time.time):
# map from key to _CacheEntry # map from key to _CacheEntry
self._data: Dict[KT, _CacheEntry] = {} self._data: Dict[KT, _CacheEntry[KT, VT]] = {}
# the _CacheEntries, sorted by expiry time # the _CacheEntries, sorted by expiry time
self._expiry_list: SortedList[_CacheEntry] = SortedList() self._expiry_list: SortedList[_CacheEntry[KT, VT]] = SortedList()
self._timer = timer self._timer = timer
@ -160,11 +160,11 @@ class TTLCache(Generic[KT, VT]):
@attr.s(frozen=True, slots=True, auto_attribs=True) @attr.s(frozen=True, slots=True, auto_attribs=True)
class _CacheEntry: # Should be Generic[KT, VT]. See python-attrs/attrs#313 class _CacheEntry(Generic[KT, VT]):
"""TTLCache entry""" """TTLCache entry"""
# expiry_time is the first attribute, so that entries are sorted by expiry. # expiry_time is the first attribute, so that entries are sorted by expiry.
expiry_time: float expiry_time: float
ttl: float ttl: float
key: Any # should be KT key: KT
value: Any # should be VT value: VT