Convert more cached return values to immutable types (#16356)
This commit is contained in:
parent
d7c89c5908
commit
7ec0a141b4
|
@ -0,0 +1 @@
|
||||||
|
Improve type hints.
|
|
@ -37,7 +37,7 @@ from synapse.api.constants import EduTypes, EventContentFields
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.api.presence import UserPresenceState
|
from synapse.api.presence import UserPresenceState
|
||||||
from synapse.events import EventBase, relation_from_event
|
from synapse.events import EventBase, relation_from_event
|
||||||
from synapse.types import JsonDict, RoomID, UserID
|
from synapse.types import JsonDict, JsonMapping, RoomID, UserID
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -191,7 +191,7 @@ FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict)
|
||||||
|
|
||||||
|
|
||||||
class FilterCollection:
|
class FilterCollection:
|
||||||
def __init__(self, hs: "HomeServer", filter_json: JsonDict):
|
def __init__(self, hs: "HomeServer", filter_json: JsonMapping):
|
||||||
self._filter_json = filter_json
|
self._filter_json = filter_json
|
||||||
|
|
||||||
room_filter_json = self._filter_json.get("room", {})
|
room_filter_json = self._filter_json.get("room", {})
|
||||||
|
@ -219,7 +219,7 @@ class FilterCollection:
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
|
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
|
||||||
|
|
||||||
def get_filter_json(self) -> JsonDict:
|
def get_filter_json(self) -> JsonMapping:
|
||||||
return self._filter_json
|
return self._filter_json
|
||||||
|
|
||||||
def timeline_limit(self) -> int:
|
def timeline_limit(self) -> int:
|
||||||
|
@ -313,7 +313,7 @@ class FilterCollection:
|
||||||
|
|
||||||
|
|
||||||
class Filter:
|
class Filter:
|
||||||
def __init__(self, hs: "HomeServer", filter_json: JsonDict):
|
def __init__(self, hs: "HomeServer", filter_json: JsonMapping):
|
||||||
self._hs = hs
|
self._hs = hs
|
||||||
self._store = hs.get_datastores().main
|
self._store = hs.get_datastores().main
|
||||||
self.filter_json = filter_json
|
self.filter_json = filter_json
|
||||||
|
|
|
@ -64,7 +64,7 @@ from synapse.federation.transport.client import SendJoinResponse
|
||||||
from synapse.http.client import is_unknown_endpoint
|
from synapse.http.client import is_unknown_endpoint
|
||||||
from synapse.http.types import QueryParams
|
from synapse.http.types import QueryParams
|
||||||
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace
|
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace
|
||||||
from synapse.types import JsonDict, UserID, get_domain_from_id
|
from synapse.types import JsonDict, StrCollection, UserID, get_domain_from_id
|
||||||
from synapse.util.async_helpers import concurrently_execute
|
from synapse.util.async_helpers import concurrently_execute
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
@ -1704,7 +1704,7 @@ class FederationClient(FederationBase):
|
||||||
async def timestamp_to_event(
|
async def timestamp_to_event(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
destinations: List[str],
|
destinations: StrCollection,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
timestamp: int,
|
timestamp: int,
|
||||||
direction: Direction,
|
direction: Direction,
|
||||||
|
|
|
@ -1538,7 +1538,7 @@ class FederationEventHandler:
|
||||||
logger.exception("Failed to resync device for %s", sender)
|
logger.exception("Failed to resync device for %s", sender)
|
||||||
|
|
||||||
async def backfill_event_id(
|
async def backfill_event_id(
|
||||||
self, destinations: List[str], room_id: str, event_id: str
|
self, destinations: StrCollection, room_id: str, event_id: str
|
||||||
) -> PulledPduInfo:
|
) -> PulledPduInfo:
|
||||||
"""Backfill a single event and persist it as a non-outlier which means
|
"""Backfill a single event and persist it as a non-outlier which means
|
||||||
we also pull in all of the state and auth events necessary for it.
|
we also pull in all of the state and auth events necessary for it.
|
||||||
|
|
|
@ -13,7 +13,17 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Collection,
|
||||||
|
Dict,
|
||||||
|
FrozenSet,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -245,7 +255,7 @@ class RelationsHandler:
|
||||||
|
|
||||||
async def get_references_for_events(
|
async def get_references_for_events(
|
||||||
self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
|
self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
|
||||||
) -> Dict[str, List[_RelatedEvent]]:
|
) -> Mapping[str, Sequence[_RelatedEvent]]:
|
||||||
"""Get a list of references to the given events.
|
"""Get a list of references to the given events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -19,7 +19,7 @@ from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseErro
|
||||||
from synapse.http.server import HttpServer
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.types import JsonDict, UserID
|
from synapse.types import JsonDict, JsonMapping, UserID
|
||||||
|
|
||||||
from ._base import client_patterns, set_timeline_upper_limit
|
from ._base import client_patterns, set_timeline_upper_limit
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ class GetFilterRestServlet(RestServlet):
|
||||||
|
|
||||||
async def on_GET(
|
async def on_GET(
|
||||||
self, request: SynapseRequest, user_id: str, filter_id: str
|
self, request: SynapseRequest, user_id: str, filter_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonMapping]:
|
||||||
target_user = UserID.from_string(user_id)
|
target_user = UserID.from_string(user_id)
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
|
|
@ -582,7 +582,7 @@ class StateStorageController:
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
@tag_args
|
@tag_args
|
||||||
async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
|
async def get_current_hosts_in_room_ordered(self, room_id: str) -> Tuple[str, ...]:
|
||||||
"""Get current hosts in room based on current state.
|
"""Get current hosts in room based on current state.
|
||||||
|
|
||||||
Blocks until we have full state for the given room. This only happens for rooms
|
Blocks until we have full state for the given room. This only happens for rooms
|
||||||
|
|
|
@ -25,7 +25,7 @@ from synapse.storage.database import (
|
||||||
LoggingTransaction,
|
LoggingTransaction,
|
||||||
)
|
)
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.types import JsonDict, UserID
|
from synapse.types import JsonDict, JsonMapping, UserID
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -145,7 +145,7 @@ class FilteringWorkerStore(SQLBaseStore):
|
||||||
@cached(num_args=2)
|
@cached(num_args=2)
|
||||||
async def get_user_filter(
|
async def get_user_filter(
|
||||||
self, user_id: UserID, filter_id: Union[int, str]
|
self, user_id: UserID, filter_id: Union[int, str]
|
||||||
) -> JsonDict:
|
) -> JsonMapping:
|
||||||
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
|
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
|
||||||
# with a coherent error message rather than 500 M_UNKNOWN.
|
# with a coherent error message rather than 500 M_UNKNOWN.
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -465,7 +465,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
@cachedList(cached_method_name="get_references_for_event", list_name="event_ids")
|
@cachedList(cached_method_name="get_references_for_event", list_name="event_ids")
|
||||||
async def get_references_for_events(
|
async def get_references_for_events(
|
||||||
self, event_ids: Collection[str]
|
self, event_ids: Collection[str]
|
||||||
) -> Mapping[str, Optional[List[_RelatedEvent]]]:
|
) -> Mapping[str, Optional[Sequence[_RelatedEvent]]]:
|
||||||
"""Get a list of references to the given events.
|
"""Get a list of references to the given events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -931,7 +931,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
room_id: str,
|
room_id: str,
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
from_token: Optional[ThreadsNextBatch] = None,
|
from_token: Optional[ThreadsNextBatch] = None,
|
||||||
) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
|
) -> Tuple[Sequence[str], Optional[ThreadsNextBatch]]:
|
||||||
"""Get a list of thread IDs, ordered by topological ordering of their
|
"""Get a list of thread IDs, ordered by topological ordering of their
|
||||||
latest reply.
|
latest reply.
|
||||||
|
|
||||||
|
|
|
@ -984,7 +984,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached(iterable=True, max_entries=10000)
|
@cached(iterable=True, max_entries=10000)
|
||||||
async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
|
async def get_current_hosts_in_room_ordered(self, room_id: str) -> Tuple[str, ...]:
|
||||||
"""
|
"""
|
||||||
Get current hosts in room based on current state.
|
Get current hosts in room based on current state.
|
||||||
|
|
||||||
|
@ -1013,12 +1013,14 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||||
# `get_users_in_room` rather than funky SQL.
|
# `get_users_in_room` rather than funky SQL.
|
||||||
|
|
||||||
domains = await self.get_current_hosts_in_room(room_id)
|
domains = await self.get_current_hosts_in_room(room_id)
|
||||||
return list(domains)
|
return tuple(domains)
|
||||||
|
|
||||||
# For PostgreSQL we can use a regex to pull out the domains from the
|
# For PostgreSQL we can use a regex to pull out the domains from the
|
||||||
# joined users in `current_state_events` via regex.
|
# joined users in `current_state_events` via regex.
|
||||||
|
|
||||||
def get_current_hosts_in_room_ordered_txn(txn: LoggingTransaction) -> List[str]:
|
def get_current_hosts_in_room_ordered_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[str, ...]:
|
||||||
# Returns a list of servers currently joined in the room sorted by
|
# Returns a list of servers currently joined in the room sorted by
|
||||||
# longest in the room first (aka. with the lowest depth). The
|
# longest in the room first (aka. with the lowest depth). The
|
||||||
# heuristic of sorting by servers who have been in the room the
|
# heuristic of sorting by servers who have been in the room the
|
||||||
|
@ -1043,7 +1045,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (room_id,))
|
txn.execute(sql, (room_id,))
|
||||||
# `server_domain` will be `NULL` for malformed MXIDs with no colons.
|
# `server_domain` will be `NULL` for malformed MXIDs with no colons.
|
||||||
return [d for d, in txn if d is not None]
|
return tuple(d for d, in txn if d is not None)
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn
|
"get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn
|
||||||
|
|
|
@ -15,10 +15,10 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
|
||||||
Generator,
|
Generator,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
|
Mapping,
|
||||||
NoReturn,
|
NoReturn,
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
|
@ -96,7 +96,7 @@ class DescriptorTestCase(unittest.TestCase):
|
||||||
self.mock = mock.Mock()
|
self.mock = mock.Mock()
|
||||||
|
|
||||||
@descriptors.cached(num_args=1)
|
@descriptors.cached(num_args=1)
|
||||||
def fn(self, arg1: int, arg2: int) -> mock.Mock:
|
def fn(self, arg1: int, arg2: int) -> str:
|
||||||
return self.mock(arg1, arg2)
|
return self.mock(arg1, arg2)
|
||||||
|
|
||||||
obj = Cls()
|
obj = Cls()
|
||||||
|
@ -228,8 +228,9 @@ class DescriptorTestCase(unittest.TestCase):
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
def fn(self, arg1: int) -> Optional[Deferred]:
|
def fn(self, arg1: int) -> Deferred:
|
||||||
self.call_count += 1
|
self.call_count += 1
|
||||||
|
assert self.result is not None
|
||||||
return self.result
|
return self.result
|
||||||
|
|
||||||
obj = Cls()
|
obj = Cls()
|
||||||
|
@ -401,21 +402,21 @@ class DescriptorTestCase(unittest.TestCase):
|
||||||
self.mock = mock.Mock()
|
self.mock = mock.Mock()
|
||||||
|
|
||||||
@descriptors.cached(iterable=True)
|
@descriptors.cached(iterable=True)
|
||||||
def fn(self, arg1: int, arg2: int) -> List[str]:
|
def fn(self, arg1: int, arg2: int) -> Tuple[str, ...]:
|
||||||
return self.mock(arg1, arg2)
|
return self.mock(arg1, arg2)
|
||||||
|
|
||||||
obj = Cls()
|
obj = Cls()
|
||||||
|
|
||||||
obj.mock.return_value = ["spam", "eggs"]
|
obj.mock.return_value = ("spam", "eggs")
|
||||||
r = obj.fn(1, 2)
|
r = obj.fn(1, 2)
|
||||||
self.assertEqual(r.result, ["spam", "eggs"])
|
self.assertEqual(r.result, ("spam", "eggs"))
|
||||||
obj.mock.assert_called_once_with(1, 2)
|
obj.mock.assert_called_once_with(1, 2)
|
||||||
obj.mock.reset_mock()
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
# a call with different params should call the mock again
|
# a call with different params should call the mock again
|
||||||
obj.mock.return_value = ["chips"]
|
obj.mock.return_value = ("chips",)
|
||||||
r = obj.fn(1, 3)
|
r = obj.fn(1, 3)
|
||||||
self.assertEqual(r.result, ["chips"])
|
self.assertEqual(r.result, ("chips",))
|
||||||
obj.mock.assert_called_once_with(1, 3)
|
obj.mock.assert_called_once_with(1, 3)
|
||||||
obj.mock.reset_mock()
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
@ -423,9 +424,9 @@ class DescriptorTestCase(unittest.TestCase):
|
||||||
self.assertEqual(len(obj.fn.cache.cache), 3)
|
self.assertEqual(len(obj.fn.cache.cache), 3)
|
||||||
|
|
||||||
r = obj.fn(1, 2)
|
r = obj.fn(1, 2)
|
||||||
self.assertEqual(r.result, ["spam", "eggs"])
|
self.assertEqual(r.result, ("spam", "eggs"))
|
||||||
r = obj.fn(1, 3)
|
r = obj.fn(1, 3)
|
||||||
self.assertEqual(r.result, ["chips"])
|
self.assertEqual(r.result, ("chips",))
|
||||||
obj.mock.assert_not_called()
|
obj.mock.assert_not_called()
|
||||||
|
|
||||||
def test_cache_iterable_with_sync_exception(self) -> None:
|
def test_cache_iterable_with_sync_exception(self) -> None:
|
||||||
|
@ -784,7 +785,9 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
|
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
|
||||||
async def list_fn(self, args1: Iterable[int], arg2: int) -> Dict[int, str]:
|
async def list_fn(
|
||||||
|
self, args1: Iterable[int], arg2: int
|
||||||
|
) -> Mapping[int, str]:
|
||||||
context = current_context()
|
context = current_context()
|
||||||
assert isinstance(context, LoggingContext)
|
assert isinstance(context, LoggingContext)
|
||||||
assert context.name == "c1"
|
assert context.name == "c1"
|
||||||
|
@ -847,11 +850,11 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
|
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
|
||||||
def list_fn(self, args1: List[int]) -> "Deferred[dict]":
|
def list_fn(self, args1: List[int]) -> "Deferred[Mapping[int, str]]":
|
||||||
return self.mock(args1)
|
return self.mock(args1)
|
||||||
|
|
||||||
obj = Cls()
|
obj = Cls()
|
||||||
deferred_result: "Deferred[dict]" = Deferred()
|
deferred_result: "Deferred[Mapping[int, str]]" = Deferred()
|
||||||
obj.mock.return_value = deferred_result
|
obj.mock.return_value = deferred_result
|
||||||
|
|
||||||
# start off several concurrent lookups of the same key
|
# start off several concurrent lookups of the same key
|
||||||
|
@ -890,7 +893,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
|
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
|
||||||
async def list_fn(self, args1: List[int], arg2: int) -> Dict[int, str]:
|
async def list_fn(self, args1: List[int], arg2: int) -> Mapping[int, str]:
|
||||||
# we want this to behave like an asynchronous function
|
# we want this to behave like an asynchronous function
|
||||||
await run_on_reactor()
|
await run_on_reactor()
|
||||||
return self.mock(args1, arg2)
|
return self.mock(args1, arg2)
|
||||||
|
@ -929,7 +932,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@cachedList(cached_method_name="fn", list_name="args")
|
@cachedList(cached_method_name="fn", list_name="args")
|
||||||
async def list_fn(self, args: List[int]) -> Dict[int, str]:
|
async def list_fn(self, args: List[int]) -> Mapping[int, str]:
|
||||||
await complete_lookup
|
await complete_lookup
|
||||||
return {arg: str(arg) for arg in args}
|
return {arg: str(arg) for arg in args}
|
||||||
|
|
||||||
|
@ -964,7 +967,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@cachedList(cached_method_name="fn", list_name="args")
|
@cachedList(cached_method_name="fn", list_name="args")
|
||||||
async def list_fn(self, args: List[int]) -> Dict[int, str]:
|
async def list_fn(self, args: List[int]) -> Mapping[int, str]:
|
||||||
await make_deferred_yieldable(complete_lookup)
|
await make_deferred_yieldable(complete_lookup)
|
||||||
self.inner_context_was_finished = current_context().finished
|
self.inner_context_was_finished = current_context().finished
|
||||||
return {arg: str(arg) for arg in args}
|
return {arg: str(arg) for arg in args}
|
||||||
|
|
Loading…
Reference in New Issue