mirror of
https://github.com/matrix-org/synapse.git
synced 2025-02-23 15:45:47 +00:00
Better return type for get_all_entities_changed
(#14604)
Help callers from using the return value incorrectly by ensuring that callers explicitly check if there was a cache hit or not.
This commit is contained in:
parent
6a8310f3df
commit
cee9445884
1
changelog.d/14604.bugfix
Normal file
1
changelog.d/14604.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Fix a long-standing bug where a device list update might not be sent to clients in certain circumstances.
|
@ -615,8 +615,8 @@ class ApplicationServicesHandler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Fetch the users who have modified their device list since then.
|
# Fetch the users who have modified their device list since then.
|
||||||
users_with_changed_device_lists = (
|
users_with_changed_device_lists = await self.store.get_all_devices_changed(
|
||||||
await self.store.get_users_whose_devices_changed(from_key, to_key=new_key)
|
from_key, to_key=new_key
|
||||||
)
|
)
|
||||||
|
|
||||||
# Filter out any users the application service is not interested in
|
# Filter out any users the application service is not interested in
|
||||||
|
@ -1692,10 +1692,12 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
|
|||||||
|
|
||||||
if from_key is not None:
|
if from_key is not None:
|
||||||
# First get all users that have had a presence update
|
# First get all users that have had a presence update
|
||||||
updated_users = stream_change_cache.get_all_entities_changed(from_key)
|
result = stream_change_cache.get_all_entities_changed(from_key)
|
||||||
|
|
||||||
# Cross-reference users we're interested in with those that have had updates.
|
# Cross-reference users we're interested in with those that have had updates.
|
||||||
if updated_users is not None:
|
if result.hit:
|
||||||
|
updated_users = result.entities
|
||||||
|
|
||||||
# If we have the full list of changes for presence we can
|
# If we have the full list of changes for presence we can
|
||||||
# simply check which ones share a room with the user.
|
# simply check which ones share a room with the user.
|
||||||
get_updates_counter.labels("stream").inc()
|
get_updates_counter.labels("stream").inc()
|
||||||
@ -1767,9 +1769,9 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
|
|||||||
updated_users = None
|
updated_users = None
|
||||||
if from_key:
|
if from_key:
|
||||||
# Only return updates since the last sync
|
# Only return updates since the last sync
|
||||||
updated_users = self.store.presence_stream_cache.get_all_entities_changed(
|
result = self.store.presence_stream_cache.get_all_entities_changed(from_key)
|
||||||
from_key
|
if result.hit:
|
||||||
)
|
updated_users = result.entities
|
||||||
|
|
||||||
if updated_users is not None:
|
if updated_users is not None:
|
||||||
# Get the actual presence update for each change
|
# Get the actual presence update for each change
|
||||||
|
@ -1528,10 +1528,12 @@ class SyncHandler:
|
|||||||
#
|
#
|
||||||
# If we don't have that info cached then we get all the users that
|
# If we don't have that info cached then we get all the users that
|
||||||
# share a room with our user and check if those users have changed.
|
# share a room with our user and check if those users have changed.
|
||||||
changed_users = self.store.get_cached_device_list_changes(
|
cache_result = self.store.get_cached_device_list_changes(
|
||||||
since_token.device_list_key
|
since_token.device_list_key
|
||||||
)
|
)
|
||||||
if changed_users is not None:
|
if cache_result.hit:
|
||||||
|
changed_users = cache_result.entities
|
||||||
|
|
||||||
result = await self.store.get_rooms_for_users(changed_users)
|
result = await self.store.get_rooms_for_users(changed_users)
|
||||||
|
|
||||||
for changed_user_id, entries in result.items():
|
for changed_user_id, entries in result.items():
|
||||||
|
@ -420,11 +420,11 @@ class TypingWriterHandler(FollowerTypingHandler):
|
|||||||
if last_id == current_id:
|
if last_id == current_id:
|
||||||
return [], current_id, False
|
return [], current_id, False
|
||||||
|
|
||||||
changed_rooms: Optional[
|
result = self._typing_stream_change_cache.get_all_entities_changed(last_id)
|
||||||
Iterable[str]
|
|
||||||
] = self._typing_stream_change_cache.get_all_entities_changed(last_id)
|
|
||||||
|
|
||||||
if changed_rooms is None:
|
if result.hit:
|
||||||
|
changed_rooms: Iterable[str] = result.entities
|
||||||
|
else:
|
||||||
changed_rooms = self._room_serials
|
changed_rooms = self._room_serials
|
||||||
|
|
||||||
rows = []
|
rows = []
|
||||||
|
@ -58,7 +58,10 @@ from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
|
|||||||
from synapse.util import json_decoder, json_encoder
|
from synapse.util import json_decoder, json_encoder
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import (
|
||||||
|
AllEntitiesChangedResult,
|
||||||
|
StreamChangeCache,
|
||||||
|
)
|
||||||
from synapse.util.cancellation import cancellable
|
from synapse.util.cancellation import cancellable
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
from synapse.util.stringutils import shortstr
|
from synapse.util.stringutils import shortstr
|
||||||
@ -799,18 +802,66 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||||||
def get_cached_device_list_changes(
|
def get_cached_device_list_changes(
|
||||||
self,
|
self,
|
||||||
from_key: int,
|
from_key: int,
|
||||||
) -> Optional[List[str]]:
|
) -> AllEntitiesChangedResult:
|
||||||
"""Get set of users whose devices have changed since `from_key`, or None
|
"""Get set of users whose devices have changed since `from_key`, or None
|
||||||
if that information is not in our cache.
|
if that information is not in our cache.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self._device_list_stream_cache.get_all_entities_changed(from_key)
|
return self._device_list_stream_cache.get_all_entities_changed(from_key)
|
||||||
|
|
||||||
|
@cancellable
|
||||||
|
async def get_all_devices_changed(
|
||||||
|
self,
|
||||||
|
from_key: int,
|
||||||
|
to_key: int,
|
||||||
|
) -> Set[str]:
|
||||||
|
"""Get all users whose devices have changed in the given range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
from_key: The minimum device lists stream token to query device list
|
||||||
|
changes for, exclusive.
|
||||||
|
to_key: The maximum device lists stream token to query device list
|
||||||
|
changes for, inclusive.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The set of user_ids whose devices have changed since `from_key`
|
||||||
|
(exclusive) until `to_key` (inclusive).
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = self._device_list_stream_cache.get_all_entities_changed(from_key)
|
||||||
|
|
||||||
|
if result.hit:
|
||||||
|
# We know which users might have changed devices.
|
||||||
|
if not result.entities:
|
||||||
|
# If no users then we can return early.
|
||||||
|
return set()
|
||||||
|
|
||||||
|
# Otherwise we need to filter down the list
|
||||||
|
return await self.get_users_whose_devices_changed(
|
||||||
|
from_key, result.entities, to_key
|
||||||
|
)
|
||||||
|
|
||||||
|
# If the cache didn't tell us anything, we just need to query the full
|
||||||
|
# range.
|
||||||
|
sql = """
|
||||||
|
SELECT DISTINCT user_id FROM device_lists_stream
|
||||||
|
WHERE ? < stream_id AND stream_id <= ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
rows = await self.db_pool.execute(
|
||||||
|
"get_all_devices_changed",
|
||||||
|
None,
|
||||||
|
sql,
|
||||||
|
from_key,
|
||||||
|
to_key,
|
||||||
|
)
|
||||||
|
return {u for u, in rows}
|
||||||
|
|
||||||
@cancellable
|
@cancellable
|
||||||
async def get_users_whose_devices_changed(
|
async def get_users_whose_devices_changed(
|
||||||
self,
|
self,
|
||||||
from_key: int,
|
from_key: int,
|
||||||
user_ids: Optional[Collection[str]] = None,
|
user_ids: Collection[str],
|
||||||
to_key: Optional[int] = None,
|
to_key: Optional[int] = None,
|
||||||
) -> Set[str]:
|
) -> Set[str]:
|
||||||
"""Get set of users whose devices have changed since `from_key` that
|
"""Get set of users whose devices have changed since `from_key` that
|
||||||
@ -830,43 +881,23 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||||||
"""
|
"""
|
||||||
# Get set of users who *may* have changed. Users not in the returned
|
# Get set of users who *may* have changed. Users not in the returned
|
||||||
# list have definitely not changed.
|
# list have definitely not changed.
|
||||||
user_ids_to_check: Optional[Collection[str]]
|
|
||||||
if user_ids is None:
|
|
||||||
# Get set of all users that have had device list changes since 'from_key'
|
|
||||||
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
|
|
||||||
from_key
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# The same as above, but filter results to only those users in 'user_ids'
|
|
||||||
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
|
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
|
||||||
user_ids, from_key
|
user_ids, from_key
|
||||||
)
|
)
|
||||||
|
|
||||||
# If an empty set was returned, there's nothing to do.
|
# If an empty set was returned, there's nothing to do.
|
||||||
if user_ids_to_check is not None and not user_ids_to_check:
|
if not user_ids_to_check:
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
|
if to_key is None:
|
||||||
|
to_key = self._device_list_id_gen.get_current_token()
|
||||||
|
|
||||||
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
|
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
|
||||||
stream_id_where_clause = "stream_id > ?"
|
sql = """
|
||||||
sql_args = [from_key]
|
|
||||||
|
|
||||||
if to_key:
|
|
||||||
stream_id_where_clause += " AND stream_id <= ?"
|
|
||||||
sql_args.append(to_key)
|
|
||||||
|
|
||||||
sql = f"""
|
|
||||||
SELECT DISTINCT user_id FROM device_lists_stream
|
SELECT DISTINCT user_id FROM device_lists_stream
|
||||||
WHERE {stream_id_where_clause}
|
WHERE ? < stream_id AND stream_id <= ? AND %s
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# If the stream change cache gave us no information, fetch *all*
|
|
||||||
# users between the stream IDs.
|
|
||||||
if user_ids_to_check is None:
|
|
||||||
txn.execute(sql, sql_args)
|
|
||||||
return {user_id for user_id, in txn}
|
|
||||||
|
|
||||||
# Otherwise, fetch changes for the given users.
|
|
||||||
else:
|
|
||||||
changes: Set[str] = set()
|
changes: Set[str] = set()
|
||||||
|
|
||||||
# Query device changes with a batch of users at a time
|
# Query device changes with a batch of users at a time
|
||||||
@ -874,7 +905,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||||||
clause, args = make_in_list_sql_clause(
|
clause, args = make_in_list_sql_clause(
|
||||||
txn.database_engine, "user_id", chunk
|
txn.database_engine, "user_id", chunk
|
||||||
)
|
)
|
||||||
txn.execute(sql + " AND " + clause, sql_args + args)
|
txn.execute(sql % (clause,), [from_key, to_key] + args)
|
||||||
changes.update(user_id for user_id, in txn)
|
changes.update(user_id for user_id, in txn)
|
||||||
|
|
||||||
return changes
|
return changes
|
||||||
|
@ -16,6 +16,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
from typing import Collection, Dict, FrozenSet, List, Mapping, Optional, Set, Union
|
from typing import Collection, Dict, FrozenSet, List, Mapping, Optional, Set, Union
|
||||||
|
|
||||||
|
import attr
|
||||||
from sortedcontainers import SortedDict
|
from sortedcontainers import SortedDict
|
||||||
|
|
||||||
from synapse.util import caches
|
from synapse.util import caches
|
||||||
@ -26,6 +27,29 @@ logger = logging.getLogger(__name__)
|
|||||||
EntityType = str
|
EntityType = str
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
||||||
|
class AllEntitiesChangedResult:
|
||||||
|
"""Return type of `get_all_entities_changed`.
|
||||||
|
|
||||||
|
Callers must check that there was a cache hit, via `result.hit`, before
|
||||||
|
using the entities in `result.entities`.
|
||||||
|
|
||||||
|
This specifically does *not* implement helpers such as `__bool__` to ensure
|
||||||
|
that callers do the correct checks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_entities: Optional[List[EntityType]]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hit(self) -> bool:
|
||||||
|
return self._entities is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def entities(self) -> List[EntityType]:
|
||||||
|
assert self._entities is not None
|
||||||
|
return self._entities
|
||||||
|
|
||||||
|
|
||||||
class StreamChangeCache:
|
class StreamChangeCache:
|
||||||
"""
|
"""
|
||||||
Keeps track of the stream positions of the latest change in a set of entities.
|
Keeps track of the stream positions of the latest change in a set of entities.
|
||||||
@ -153,19 +177,19 @@ class StreamChangeCache:
|
|||||||
This will be all entities if the given stream position is at or earlier
|
This will be all entities if the given stream position is at or earlier
|
||||||
than the earliest known stream position.
|
than the earliest known stream position.
|
||||||
"""
|
"""
|
||||||
changed_entities = self.get_all_entities_changed(stream_pos)
|
cache_result = self.get_all_entities_changed(stream_pos)
|
||||||
if changed_entities is not None:
|
if cache_result.hit:
|
||||||
# We now do an intersection, trying to do so in the most efficient
|
# We now do an intersection, trying to do so in the most efficient
|
||||||
# way possible (some of these sets are *large*). First check in the
|
# way possible (some of these sets are *large*). First check in the
|
||||||
# given iterable is already a set that we can reuse, otherwise we
|
# given iterable is already a set that we can reuse, otherwise we
|
||||||
# create a set of the *smallest* of the two iterables and call
|
# create a set of the *smallest* of the two iterables and call
|
||||||
# `intersection(..)` on it (this can be twice as fast as the reverse).
|
# `intersection(..)` on it (this can be twice as fast as the reverse).
|
||||||
if isinstance(entities, (set, frozenset)):
|
if isinstance(entities, (set, frozenset)):
|
||||||
result = entities.intersection(changed_entities)
|
result = entities.intersection(cache_result.entities)
|
||||||
elif len(changed_entities) < len(entities):
|
elif len(cache_result.entities) < len(entities):
|
||||||
result = set(changed_entities).intersection(entities)
|
result = set(cache_result.entities).intersection(entities)
|
||||||
else:
|
else:
|
||||||
result = set(entities).intersection(changed_entities)
|
result = set(entities).intersection(cache_result.entities)
|
||||||
self.metrics.inc_hits()
|
self.metrics.inc_hits()
|
||||||
else:
|
else:
|
||||||
result = set(entities)
|
result = set(entities)
|
||||||
@ -202,12 +226,12 @@ class StreamChangeCache:
|
|||||||
self.metrics.inc_hits()
|
self.metrics.inc_hits()
|
||||||
return stream_pos < self._cache.peekitem()[0]
|
return stream_pos < self._cache.peekitem()[0]
|
||||||
|
|
||||||
def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]:
|
def get_all_entities_changed(self, stream_pos: int) -> AllEntitiesChangedResult:
|
||||||
"""
|
"""
|
||||||
Returns all entities that have had changes after the given position.
|
Returns all entities that have had changes after the given position.
|
||||||
|
|
||||||
If the stream change cache does not go far enough back, i.e. the position
|
If the stream change cache does not go far enough back, i.e. the
|
||||||
is too old, it will return None.
|
position is too old, it will return None.
|
||||||
|
|
||||||
Returns the entities in the order that they were changed.
|
Returns the entities in the order that they were changed.
|
||||||
|
|
||||||
@ -215,23 +239,21 @@ class StreamChangeCache:
|
|||||||
stream_pos: The stream position to check for changes after.
|
stream_pos: The stream position to check for changes after.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
Entities which have changed after the given stream position.
|
A class indicating if we have the requested data cached, and if so
|
||||||
|
includes the entities in the order they were changed.
|
||||||
None if the given stream position is at or earlier than the earliest
|
|
||||||
known stream position.
|
|
||||||
"""
|
"""
|
||||||
assert isinstance(stream_pos, int)
|
assert isinstance(stream_pos, int)
|
||||||
|
|
||||||
# _cache is not valid at or before the earliest known stream position, so
|
# _cache is not valid at or before the earliest known stream position, so
|
||||||
# return None to mark that it is unknown if an entity has changed.
|
# return None to mark that it is unknown if an entity has changed.
|
||||||
if stream_pos <= self._earliest_known_stream_pos:
|
if stream_pos <= self._earliest_known_stream_pos:
|
||||||
return None
|
return AllEntitiesChangedResult(None)
|
||||||
|
|
||||||
changed_entities: List[EntityType] = []
|
changed_entities: List[EntityType] = []
|
||||||
|
|
||||||
for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
|
for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
|
||||||
changed_entities.extend(self._cache[k])
|
changed_entities.extend(self._cache[k])
|
||||||
return changed_entities
|
return AllEntitiesChangedResult(changed_entities)
|
||||||
|
|
||||||
def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None:
|
def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -73,8 +73,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
|
|||||||
# The oldest item has been popped off
|
# The oldest item has been popped off
|
||||||
self.assertTrue("user@foo.com" not in cache._entity_to_key)
|
self.assertTrue("user@foo.com" not in cache._entity_to_key)
|
||||||
|
|
||||||
self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"])
|
self.assertEqual(
|
||||||
self.assertIsNone(cache.get_all_entities_changed(2))
|
cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"]
|
||||||
|
)
|
||||||
|
self.assertFalse(cache.get_all_entities_changed(2).hit)
|
||||||
|
|
||||||
# If we update an existing entity, it keeps the two existing entities
|
# If we update an existing entity, it keeps the two existing entities
|
||||||
cache.entity_has_changed("bar@baz.net", 5)
|
cache.entity_has_changed("bar@baz.net", 5)
|
||||||
@ -82,10 +84,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
|
|||||||
{"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key)
|
{"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key)
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
cache.get_all_entities_changed(3),
|
cache.get_all_entities_changed(3).entities,
|
||||||
["user@elsewhere.org", "bar@baz.net"],
|
["user@elsewhere.org", "bar@baz.net"],
|
||||||
)
|
)
|
||||||
self.assertIsNone(cache.get_all_entities_changed(2))
|
self.assertFalse(cache.get_all_entities_changed(2).hit)
|
||||||
|
|
||||||
def test_get_all_entities_changed(self) -> None:
|
def test_get_all_entities_changed(self) -> None:
|
||||||
"""
|
"""
|
||||||
@ -105,10 +107,12 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
|
|||||||
# Results are ordered so either of these are valid.
|
# Results are ordered so either of these are valid.
|
||||||
ok1 = ["bar@baz.net", "anotheruser@foo.com", "user@elsewhere.org"]
|
ok1 = ["bar@baz.net", "anotheruser@foo.com", "user@elsewhere.org"]
|
||||||
ok2 = ["anotheruser@foo.com", "bar@baz.net", "user@elsewhere.org"]
|
ok2 = ["anotheruser@foo.com", "bar@baz.net", "user@elsewhere.org"]
|
||||||
self.assertTrue(r == ok1 or r == ok2)
|
self.assertTrue(r.entities == ok1 or r.entities == ok2)
|
||||||
|
|
||||||
self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"])
|
self.assertEqual(
|
||||||
self.assertEqual(cache.get_all_entities_changed(1), None)
|
cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"]
|
||||||
|
)
|
||||||
|
self.assertFalse(cache.get_all_entities_changed(1).hit)
|
||||||
|
|
||||||
# ... later, things gest more updates
|
# ... later, things gest more updates
|
||||||
cache.entity_has_changed("user@foo.com", 5)
|
cache.entity_has_changed("user@foo.com", 5)
|
||||||
@ -128,7 +132,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
|
|||||||
"anotheruser@foo.com",
|
"anotheruser@foo.com",
|
||||||
]
|
]
|
||||||
r = cache.get_all_entities_changed(3)
|
r = cache.get_all_entities_changed(3)
|
||||||
self.assertTrue(r == ok1 or r == ok2)
|
self.assertTrue(r.entities == ok1 or r.entities == ok2)
|
||||||
|
|
||||||
def test_has_any_entity_changed(self) -> None:
|
def test_has_any_entity_changed(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user