mirror of
				https://github.com/matrix-org/synapse.git
				synced 2025-11-03 21:57:26 +00:00 
			
		
		
		
	Better return type for get_all_entities_changed (#14604)
				
					
				
			Help callers from using the return value incorrectly by ensuring that callers explicitly check if there was a cache hit or not.
This commit is contained in:
		
							parent
							
								
									6a8310f3df
								
							
						
					
					
						commit
						cee9445884
					
				
							
								
								
									
										1
									
								
								changelog.d/14604.bugfix
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								changelog.d/14604.bugfix
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1 @@
 | 
			
		||||
Fix a long-standing bug where a device list update might not be sent to clients in certain circumstances.
 | 
			
		||||
@ -615,8 +615,8 @@ class ApplicationServicesHandler:
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Fetch the users who have modified their device list since then.
 | 
			
		||||
        users_with_changed_device_lists = (
 | 
			
		||||
            await self.store.get_users_whose_devices_changed(from_key, to_key=new_key)
 | 
			
		||||
        users_with_changed_device_lists = await self.store.get_all_devices_changed(
 | 
			
		||||
            from_key, to_key=new_key
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Filter out any users the application service is not interested in
 | 
			
		||||
 | 
			
		||||
@ -1692,10 +1692,12 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
 | 
			
		||||
 | 
			
		||||
            if from_key is not None:
 | 
			
		||||
                # First get all users that have had a presence update
 | 
			
		||||
                updated_users = stream_change_cache.get_all_entities_changed(from_key)
 | 
			
		||||
                result = stream_change_cache.get_all_entities_changed(from_key)
 | 
			
		||||
 | 
			
		||||
                # Cross-reference users we're interested in with those that have had updates.
 | 
			
		||||
                if updated_users is not None:
 | 
			
		||||
                if result.hit:
 | 
			
		||||
                    updated_users = result.entities
 | 
			
		||||
 | 
			
		||||
                    # If we have the full list of changes for presence we can
 | 
			
		||||
                    # simply check which ones share a room with the user.
 | 
			
		||||
                    get_updates_counter.labels("stream").inc()
 | 
			
		||||
@ -1767,9 +1769,9 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
 | 
			
		||||
        updated_users = None
 | 
			
		||||
        if from_key:
 | 
			
		||||
            # Only return updates since the last sync
 | 
			
		||||
            updated_users = self.store.presence_stream_cache.get_all_entities_changed(
 | 
			
		||||
                from_key
 | 
			
		||||
            )
 | 
			
		||||
            result = self.store.presence_stream_cache.get_all_entities_changed(from_key)
 | 
			
		||||
            if result.hit:
 | 
			
		||||
                updated_users = result.entities
 | 
			
		||||
 | 
			
		||||
        if updated_users is not None:
 | 
			
		||||
            # Get the actual presence update for each change
 | 
			
		||||
 | 
			
		||||
@ -1528,10 +1528,12 @@ class SyncHandler:
 | 
			
		||||
            #
 | 
			
		||||
            # If we don't have that info cached then we get all the users that
 | 
			
		||||
            # share a room with our user and check if those users have changed.
 | 
			
		||||
            changed_users = self.store.get_cached_device_list_changes(
 | 
			
		||||
            cache_result = self.store.get_cached_device_list_changes(
 | 
			
		||||
                since_token.device_list_key
 | 
			
		||||
            )
 | 
			
		||||
            if changed_users is not None:
 | 
			
		||||
            if cache_result.hit:
 | 
			
		||||
                changed_users = cache_result.entities
 | 
			
		||||
 | 
			
		||||
                result = await self.store.get_rooms_for_users(changed_users)
 | 
			
		||||
 | 
			
		||||
                for changed_user_id, entries in result.items():
 | 
			
		||||
 | 
			
		||||
@ -420,11 +420,11 @@ class TypingWriterHandler(FollowerTypingHandler):
 | 
			
		||||
        if last_id == current_id:
 | 
			
		||||
            return [], current_id, False
 | 
			
		||||
 | 
			
		||||
        changed_rooms: Optional[
 | 
			
		||||
            Iterable[str]
 | 
			
		||||
        ] = self._typing_stream_change_cache.get_all_entities_changed(last_id)
 | 
			
		||||
        result = self._typing_stream_change_cache.get_all_entities_changed(last_id)
 | 
			
		||||
 | 
			
		||||
        if changed_rooms is None:
 | 
			
		||||
        if result.hit:
 | 
			
		||||
            changed_rooms: Iterable[str] = result.entities
 | 
			
		||||
        else:
 | 
			
		||||
            changed_rooms = self._room_serials
 | 
			
		||||
 | 
			
		||||
        rows = []
 | 
			
		||||
 | 
			
		||||
@ -58,7 +58,10 @@ from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
 | 
			
		||||
from synapse.util import json_decoder, json_encoder
 | 
			
		||||
from synapse.util.caches.descriptors import cached, cachedList
 | 
			
		||||
from synapse.util.caches.lrucache import LruCache
 | 
			
		||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
 | 
			
		||||
from synapse.util.caches.stream_change_cache import (
 | 
			
		||||
    AllEntitiesChangedResult,
 | 
			
		||||
    StreamChangeCache,
 | 
			
		||||
)
 | 
			
		||||
from synapse.util.cancellation import cancellable
 | 
			
		||||
from synapse.util.iterutils import batch_iter
 | 
			
		||||
from synapse.util.stringutils import shortstr
 | 
			
		||||
@ -799,18 +802,66 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
 | 
			
		||||
    def get_cached_device_list_changes(
 | 
			
		||||
        self,
 | 
			
		||||
        from_key: int,
 | 
			
		||||
    ) -> Optional[List[str]]:
 | 
			
		||||
    ) -> AllEntitiesChangedResult:
 | 
			
		||||
        """Get set of users whose devices have changed since `from_key`, or None
 | 
			
		||||
        if that information is not in our cache.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        return self._device_list_stream_cache.get_all_entities_changed(from_key)
 | 
			
		||||
 | 
			
		||||
    @cancellable
 | 
			
		||||
    async def get_all_devices_changed(
 | 
			
		||||
        self,
 | 
			
		||||
        from_key: int,
 | 
			
		||||
        to_key: int,
 | 
			
		||||
    ) -> Set[str]:
 | 
			
		||||
        """Get all users whose devices have changed in the given range.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            from_key: The minimum device lists stream token to query device list
 | 
			
		||||
                changes for, exclusive.
 | 
			
		||||
            to_key: The maximum device lists stream token to query device list
 | 
			
		||||
                changes for, inclusive.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            The set of user_ids whose devices have changed since `from_key`
 | 
			
		||||
            (exclusive) until `to_key` (inclusive).
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        result = self._device_list_stream_cache.get_all_entities_changed(from_key)
 | 
			
		||||
 | 
			
		||||
        if result.hit:
 | 
			
		||||
            # We know which users might have changed devices.
 | 
			
		||||
            if not result.entities:
 | 
			
		||||
                # If no users then we can return early.
 | 
			
		||||
                return set()
 | 
			
		||||
 | 
			
		||||
            # Otherwise we need to filter down the list
 | 
			
		||||
            return await self.get_users_whose_devices_changed(
 | 
			
		||||
                from_key, result.entities, to_key
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # If the cache didn't tell us anything, we just need to query the full
 | 
			
		||||
        # range.
 | 
			
		||||
        sql = """
 | 
			
		||||
            SELECT DISTINCT user_id FROM device_lists_stream
 | 
			
		||||
            WHERE ? < stream_id AND stream_id <= ?
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        rows = await self.db_pool.execute(
 | 
			
		||||
            "get_all_devices_changed",
 | 
			
		||||
            None,
 | 
			
		||||
            sql,
 | 
			
		||||
            from_key,
 | 
			
		||||
            to_key,
 | 
			
		||||
        )
 | 
			
		||||
        return {u for u, in rows}
 | 
			
		||||
 | 
			
		||||
    @cancellable
 | 
			
		||||
    async def get_users_whose_devices_changed(
 | 
			
		||||
        self,
 | 
			
		||||
        from_key: int,
 | 
			
		||||
        user_ids: Optional[Collection[str]] = None,
 | 
			
		||||
        user_ids: Collection[str],
 | 
			
		||||
        to_key: Optional[int] = None,
 | 
			
		||||
    ) -> Set[str]:
 | 
			
		||||
        """Get set of users whose devices have changed since `from_key` that
 | 
			
		||||
@ -830,52 +881,32 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
 | 
			
		||||
        """
 | 
			
		||||
        # Get set of users who *may* have changed. Users not in the returned
 | 
			
		||||
        # list have definitely not changed.
 | 
			
		||||
        user_ids_to_check: Optional[Collection[str]]
 | 
			
		||||
        if user_ids is None:
 | 
			
		||||
            # Get set of all users that have had device list changes since 'from_key'
 | 
			
		||||
            user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
 | 
			
		||||
                from_key
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            # The same as above, but filter results to only those users in 'user_ids'
 | 
			
		||||
            user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
 | 
			
		||||
                user_ids, from_key
 | 
			
		||||
            )
 | 
			
		||||
        user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
 | 
			
		||||
            user_ids, from_key
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # If an empty set was returned, there's nothing to do.
 | 
			
		||||
        if user_ids_to_check is not None and not user_ids_to_check:
 | 
			
		||||
        if not user_ids_to_check:
 | 
			
		||||
            return set()
 | 
			
		||||
 | 
			
		||||
        if to_key is None:
 | 
			
		||||
            to_key = self._device_list_id_gen.get_current_token()
 | 
			
		||||
 | 
			
		||||
        def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
 | 
			
		||||
            stream_id_where_clause = "stream_id > ?"
 | 
			
		||||
            sql_args = [from_key]
 | 
			
		||||
 | 
			
		||||
            if to_key:
 | 
			
		||||
                stream_id_where_clause += " AND stream_id <= ?"
 | 
			
		||||
                sql_args.append(to_key)
 | 
			
		||||
 | 
			
		||||
            sql = f"""
 | 
			
		||||
            sql = """
 | 
			
		||||
                SELECT DISTINCT user_id FROM device_lists_stream
 | 
			
		||||
                WHERE {stream_id_where_clause}
 | 
			
		||||
                WHERE  ? < stream_id AND stream_id <= ? AND %s
 | 
			
		||||
            """
 | 
			
		||||
 | 
			
		||||
            # If the stream change cache gave us no information, fetch *all*
 | 
			
		||||
            # users between the stream IDs.
 | 
			
		||||
            if user_ids_to_check is None:
 | 
			
		||||
                txn.execute(sql, sql_args)
 | 
			
		||||
                return {user_id for user_id, in txn}
 | 
			
		||||
            changes: Set[str] = set()
 | 
			
		||||
 | 
			
		||||
            # Otherwise, fetch changes for the given users.
 | 
			
		||||
            else:
 | 
			
		||||
                changes: Set[str] = set()
 | 
			
		||||
 | 
			
		||||
                # Query device changes with a batch of users at a time
 | 
			
		||||
                for chunk in batch_iter(user_ids_to_check, 100):
 | 
			
		||||
                    clause, args = make_in_list_sql_clause(
 | 
			
		||||
                        txn.database_engine, "user_id", chunk
 | 
			
		||||
                    )
 | 
			
		||||
                    txn.execute(sql + " AND " + clause, sql_args + args)
 | 
			
		||||
                    changes.update(user_id for user_id, in txn)
 | 
			
		||||
            # Query device changes with a batch of users at a time
 | 
			
		||||
            for chunk in batch_iter(user_ids_to_check, 100):
 | 
			
		||||
                clause, args = make_in_list_sql_clause(
 | 
			
		||||
                    txn.database_engine, "user_id", chunk
 | 
			
		||||
                )
 | 
			
		||||
                txn.execute(sql % (clause,), [from_key, to_key] + args)
 | 
			
		||||
                changes.update(user_id for user_id, in txn)
 | 
			
		||||
 | 
			
		||||
            return changes
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@ import logging
 | 
			
		||||
import math
 | 
			
		||||
from typing import Collection, Dict, FrozenSet, List, Mapping, Optional, Set, Union
 | 
			
		||||
 | 
			
		||||
import attr
 | 
			
		||||
from sortedcontainers import SortedDict
 | 
			
		||||
 | 
			
		||||
from synapse.util import caches
 | 
			
		||||
@ -26,6 +27,29 @@ logger = logging.getLogger(__name__)
 | 
			
		||||
EntityType = str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
 | 
			
		||||
class AllEntitiesChangedResult:
 | 
			
		||||
    """Return type of `get_all_entities_changed`.
 | 
			
		||||
 | 
			
		||||
    Callers must check that there was a cache hit, via `result.hit`, before
 | 
			
		||||
    using the entities in `result.entities`.
 | 
			
		||||
 | 
			
		||||
    This specifically does *not* implement helpers such as `__bool__` to ensure
 | 
			
		||||
    that callers do the correct checks.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    _entities: Optional[List[EntityType]]
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def hit(self) -> bool:
 | 
			
		||||
        return self._entities is not None
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def entities(self) -> List[EntityType]:
 | 
			
		||||
        assert self._entities is not None
 | 
			
		||||
        return self._entities
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StreamChangeCache:
 | 
			
		||||
    """
 | 
			
		||||
    Keeps track of the stream positions of the latest change in a set of entities.
 | 
			
		||||
@ -153,19 +177,19 @@ class StreamChangeCache:
 | 
			
		||||
            This will be all entities if the given stream position is at or earlier
 | 
			
		||||
            than the earliest known stream position.
 | 
			
		||||
        """
 | 
			
		||||
        changed_entities = self.get_all_entities_changed(stream_pos)
 | 
			
		||||
        if changed_entities is not None:
 | 
			
		||||
        cache_result = self.get_all_entities_changed(stream_pos)
 | 
			
		||||
        if cache_result.hit:
 | 
			
		||||
            # We now do an intersection, trying to do so in the most efficient
 | 
			
		||||
            # way possible (some of these sets are *large*). First check in the
 | 
			
		||||
            # given iterable is already a set that we can reuse, otherwise we
 | 
			
		||||
            # create a set of the *smallest* of the two iterables and call
 | 
			
		||||
            # `intersection(..)` on it (this can be twice as fast as the reverse).
 | 
			
		||||
            if isinstance(entities, (set, frozenset)):
 | 
			
		||||
                result = entities.intersection(changed_entities)
 | 
			
		||||
            elif len(changed_entities) < len(entities):
 | 
			
		||||
                result = set(changed_entities).intersection(entities)
 | 
			
		||||
                result = entities.intersection(cache_result.entities)
 | 
			
		||||
            elif len(cache_result.entities) < len(entities):
 | 
			
		||||
                result = set(cache_result.entities).intersection(entities)
 | 
			
		||||
            else:
 | 
			
		||||
                result = set(entities).intersection(changed_entities)
 | 
			
		||||
                result = set(entities).intersection(cache_result.entities)
 | 
			
		||||
            self.metrics.inc_hits()
 | 
			
		||||
        else:
 | 
			
		||||
            result = set(entities)
 | 
			
		||||
@ -202,12 +226,12 @@ class StreamChangeCache:
 | 
			
		||||
        self.metrics.inc_hits()
 | 
			
		||||
        return stream_pos < self._cache.peekitem()[0]
 | 
			
		||||
 | 
			
		||||
    def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]:
 | 
			
		||||
    def get_all_entities_changed(self, stream_pos: int) -> AllEntitiesChangedResult:
 | 
			
		||||
        """
 | 
			
		||||
        Returns all entities that have had changes after the given position.
 | 
			
		||||
 | 
			
		||||
        If the stream change cache does not go far enough back, i.e. the position
 | 
			
		||||
        is too old, it will return None.
 | 
			
		||||
        If the stream change cache does not go far enough back, i.e. the
 | 
			
		||||
        position is too old, it will return None.
 | 
			
		||||
 | 
			
		||||
        Returns the entities in the order that they were changed.
 | 
			
		||||
 | 
			
		||||
@ -215,23 +239,21 @@ class StreamChangeCache:
 | 
			
		||||
            stream_pos: The stream position to check for changes after.
 | 
			
		||||
 | 
			
		||||
        Return:
 | 
			
		||||
            Entities which have changed after the given stream position.
 | 
			
		||||
 | 
			
		||||
            None if the given stream position is at or earlier than the earliest
 | 
			
		||||
            known stream position.
 | 
			
		||||
            A class indicating if we have the requested data cached, and if so
 | 
			
		||||
            includes the entities in the order they were changed.
 | 
			
		||||
        """
 | 
			
		||||
        assert isinstance(stream_pos, int)
 | 
			
		||||
 | 
			
		||||
        # _cache is not valid at or before the earliest known stream position, so
 | 
			
		||||
        # return None to mark that it is unknown if an entity has changed.
 | 
			
		||||
        if stream_pos <= self._earliest_known_stream_pos:
 | 
			
		||||
            return None
 | 
			
		||||
            return AllEntitiesChangedResult(None)
 | 
			
		||||
 | 
			
		||||
        changed_entities: List[EntityType] = []
 | 
			
		||||
 | 
			
		||||
        for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
 | 
			
		||||
            changed_entities.extend(self._cache[k])
 | 
			
		||||
        return changed_entities
 | 
			
		||||
        return AllEntitiesChangedResult(changed_entities)
 | 
			
		||||
 | 
			
		||||
    def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None:
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
@ -73,8 +73,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
 | 
			
		||||
        # The oldest item has been popped off
 | 
			
		||||
        self.assertTrue("user@foo.com" not in cache._entity_to_key)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"])
 | 
			
		||||
        self.assertIsNone(cache.get_all_entities_changed(2))
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"]
 | 
			
		||||
        )
 | 
			
		||||
        self.assertFalse(cache.get_all_entities_changed(2).hit)
 | 
			
		||||
 | 
			
		||||
        # If we update an existing entity, it keeps the two existing entities
 | 
			
		||||
        cache.entity_has_changed("bar@baz.net", 5)
 | 
			
		||||
@ -82,10 +84,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
 | 
			
		||||
            {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key)
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            cache.get_all_entities_changed(3),
 | 
			
		||||
            cache.get_all_entities_changed(3).entities,
 | 
			
		||||
            ["user@elsewhere.org", "bar@baz.net"],
 | 
			
		||||
        )
 | 
			
		||||
        self.assertIsNone(cache.get_all_entities_changed(2))
 | 
			
		||||
        self.assertFalse(cache.get_all_entities_changed(2).hit)
 | 
			
		||||
 | 
			
		||||
    def test_get_all_entities_changed(self) -> None:
 | 
			
		||||
        """
 | 
			
		||||
@ -105,10 +107,12 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
 | 
			
		||||
        # Results are ordered so either of these are valid.
 | 
			
		||||
        ok1 = ["bar@baz.net", "anotheruser@foo.com", "user@elsewhere.org"]
 | 
			
		||||
        ok2 = ["anotheruser@foo.com", "bar@baz.net", "user@elsewhere.org"]
 | 
			
		||||
        self.assertTrue(r == ok1 or r == ok2)
 | 
			
		||||
        self.assertTrue(r.entities == ok1 or r.entities == ok2)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"])
 | 
			
		||||
        self.assertEqual(cache.get_all_entities_changed(1), None)
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"]
 | 
			
		||||
        )
 | 
			
		||||
        self.assertFalse(cache.get_all_entities_changed(1).hit)
 | 
			
		||||
 | 
			
		||||
        # ... later, things gest more updates
 | 
			
		||||
        cache.entity_has_changed("user@foo.com", 5)
 | 
			
		||||
@ -128,7 +132,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
 | 
			
		||||
            "anotheruser@foo.com",
 | 
			
		||||
        ]
 | 
			
		||||
        r = cache.get_all_entities_changed(3)
 | 
			
		||||
        self.assertTrue(r == ok1 or r == ok2)
 | 
			
		||||
        self.assertTrue(r.entities == ok1 or r.entities == ok2)
 | 
			
		||||
 | 
			
		||||
    def test_has_any_entity_changed(self) -> None:
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user