Convert state delta processing from a dict to attrs. (#16469)

For improved type checking & memory usage.
This commit is contained in:
Patrick Cloke 2023-10-16 07:35:22 -04:00 committed by GitHub
parent 4fe73f8f2f
commit e3e0ae4ab1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 111 additions and 109 deletions

1
changelog.d/16469.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -110,6 +110,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.replication.tcp.commands import ClearUserSyncsCommand from synapse.replication.tcp.commands import ClearUserSyncsCommand
from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.streams import EventSource from synapse.streams import EventSource
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
@ -1499,9 +1500,9 @@ class PresenceHandler(BasePresenceHandler):
# We may get multiple deltas for different rooms, but we want to # We may get multiple deltas for different rooms, but we want to
# handle them on a room by room basis, so we batch them up by # handle them on a room by room basis, so we batch them up by
# room. # room.
deltas_by_room: Dict[str, List[JsonDict]] = {} deltas_by_room: Dict[str, List[StateDelta]] = {}
for delta in deltas: for delta in deltas:
deltas_by_room.setdefault(delta["room_id"], []).append(delta) deltas_by_room.setdefault(delta.room_id, []).append(delta)
for room_id, deltas_for_room in deltas_by_room.items(): for room_id, deltas_for_room in deltas_by_room.items():
await self._handle_state_delta(room_id, deltas_for_room) await self._handle_state_delta(room_id, deltas_for_room)
@ -1513,7 +1514,7 @@ class PresenceHandler(BasePresenceHandler):
max_pos max_pos
) )
async def _handle_state_delta(self, room_id: str, deltas: List[JsonDict]) -> None: async def _handle_state_delta(self, room_id: str, deltas: List[StateDelta]) -> None:
"""Process current state deltas for the room to find new joins that need """Process current state deltas for the room to find new joins that need
to be handled. to be handled.
""" """
@ -1524,31 +1525,30 @@ class PresenceHandler(BasePresenceHandler):
newly_joined_users = set() newly_joined_users = set()
for delta in deltas: for delta in deltas:
assert room_id == delta["room_id"] assert room_id == delta.room_id
typ = delta["type"] logger.debug(
state_key = delta["state_key"] "Handling: %r %r, %s", delta.event_type, delta.state_key, delta.event_id
event_id = delta["event_id"] )
prev_event_id = delta["prev_event_id"]
logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
# Drop any event that isn't a membership join # Drop any event that isn't a membership join
if typ != EventTypes.Member: if delta.event_type != EventTypes.Member:
continue continue
if event_id is None: if delta.event_id is None:
# state has been deleted, so this is not a join. We only care about # state has been deleted, so this is not a join. We only care about
# joins. # joins.
continue continue
event = await self.store.get_event(event_id, allow_none=True) event = await self.store.get_event(delta.event_id, allow_none=True)
if not event or event.content.get("membership") != Membership.JOIN: if not event or event.content.get("membership") != Membership.JOIN:
# We only care about joins # We only care about joins
continue continue
if prev_event_id: if delta.prev_event_id:
prev_event = await self.store.get_event(prev_event_id, allow_none=True) prev_event = await self.store.get_event(
delta.prev_event_id, allow_none=True
)
if ( if (
prev_event prev_event
and prev_event.content.get("membership") == Membership.JOIN and prev_event.content.get("membership") == Membership.JOIN
@ -1556,7 +1556,7 @@ class PresenceHandler(BasePresenceHandler):
# Ignore changes to join events. # Ignore changes to join events.
continue continue
newly_joined_users.add(state_key) newly_joined_users.add(delta.state_key)
if not newly_joined_users: if not newly_joined_users:
# If nobody has joined then there's nothing to do. # If nobody has joined then there's nothing to do.

View File

@ -16,7 +16,7 @@ import abc
import logging import logging
import random import random
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple
from synapse import types from synapse import types
from synapse.api.constants import ( from synapse.api.constants import (
@ -44,6 +44,7 @@ from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging import opentracing from synapse.logging import opentracing
from synapse.metrics import event_processing_positions from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
Requester, Requester,
@ -2146,24 +2147,18 @@ class RoomForgetterHandler(StateDeltasHandler):
await self._store.update_room_forgetter_stream_pos(max_pos) await self._store.update_room_forgetter_stream_pos(max_pos)
async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None: async def _handle_deltas(self, deltas: List[StateDelta]) -> None:
"""Called with the state deltas to process""" """Called with the state deltas to process"""
for delta in deltas: for delta in deltas:
typ = delta["type"] if delta.event_type != EventTypes.Member:
state_key = delta["state_key"]
room_id = delta["room_id"]
event_id = delta["event_id"]
prev_event_id = delta["prev_event_id"]
if typ != EventTypes.Member:
continue continue
if not self._hs.is_mine_id(state_key): if not self._hs.is_mine_id(delta.state_key):
continue continue
change = await self._get_key_change( change = await self._get_key_change(
prev_event_id, delta.prev_event_id,
event_id, delta.event_id,
key_name="membership", key_name="membership",
public_value=Membership.JOIN, public_value=Membership.JOIN,
) )
@ -2172,7 +2167,7 @@ class RoomForgetterHandler(StateDeltasHandler):
if is_leave: if is_leave:
try: try:
await self._room_member_handler.forget( await self._room_member_handler.forget(
UserID.from_string(state_key), room_id UserID.from_string(delta.state_key), delta.room_id
) )
except SynapseError as e: except SynapseError as e:
if e.code == 400: if e.code == 400:

View File

@ -27,6 +27,7 @@ from typing import (
from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.metrics import event_processing_positions from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.types import JsonDict from synapse.types import JsonDict
if TYPE_CHECKING: if TYPE_CHECKING:
@ -142,7 +143,7 @@ class StatsHandler:
self.pos = max_pos self.pos = max_pos
async def _handle_deltas( async def _handle_deltas(
self, deltas: Iterable[JsonDict] self, deltas: Iterable[StateDelta]
) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]: ) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]:
"""Called with the state deltas to process """Called with the state deltas to process
@ -157,51 +158,50 @@ class StatsHandler:
room_to_state_updates: Dict[str, Dict[str, Any]] = {} room_to_state_updates: Dict[str, Dict[str, Any]] = {}
for delta in deltas: for delta in deltas:
typ = delta["type"] logger.debug(
state_key = delta["state_key"] "Handling: %r, %r %r, %s",
room_id = delta["room_id"] delta.room_id,
event_id = delta["event_id"] delta.event_type,
stream_id = delta["stream_id"] delta.state_key,
prev_event_id = delta["prev_event_id"] delta.event_id,
)
logger.debug("Handling: %r, %r %r, %s", room_id, typ, state_key, event_id) token = await self.store.get_earliest_token_for_stats("room", delta.room_id)
token = await self.store.get_earliest_token_for_stats("room", room_id)
# If the earliest token to begin from is larger than our current # If the earliest token to begin from is larger than our current
# stream ID, skip processing this delta. # stream ID, skip processing this delta.
if token is not None and token >= stream_id: if token is not None and token >= delta.stream_id:
logger.debug( logger.debug(
"Ignoring: %s as earlier than this room's initial ingestion event", "Ignoring: %s as earlier than this room's initial ingestion event",
event_id, delta.event_id,
) )
continue continue
if event_id is None and prev_event_id is None: if delta.event_id is None and delta.prev_event_id is None:
logger.error( logger.error(
"event ID is None and so is the previous event ID. stream_id: %s", "event ID is None and so is the previous event ID. stream_id: %s",
stream_id, delta.stream_id,
) )
continue continue
event_content: JsonDict = {} event_content: JsonDict = {}
if event_id is not None: if delta.event_id is not None:
event = await self.store.get_event(event_id, allow_none=True) event = await self.store.get_event(delta.event_id, allow_none=True)
if event: if event:
event_content = event.content or {} event_content = event.content or {}
# All the values in this dict are deltas (RELATIVE changes) # All the values in this dict are deltas (RELATIVE changes)
room_stats_delta = room_to_stats_deltas.setdefault(room_id, Counter()) room_stats_delta = room_to_stats_deltas.setdefault(delta.room_id, Counter())
room_state = room_to_state_updates.setdefault(room_id, {}) room_state = room_to_state_updates.setdefault(delta.room_id, {})
if prev_event_id is None: if delta.prev_event_id is None:
# this state event doesn't overwrite another, # this state event doesn't overwrite another,
# so it is a new effective/current state event # so it is a new effective/current state event
room_stats_delta["current_state_events"] += 1 room_stats_delta["current_state_events"] += 1
if typ == EventTypes.Member: if delta.event_type == EventTypes.Member:
# we could use StateDeltasHandler._get_key_change here but it's # we could use StateDeltasHandler._get_key_change here but it's
# a bit inefficient given we're not testing for a specific # a bit inefficient given we're not testing for a specific
# result; might as well just grab the prev_membership and # result; might as well just grab the prev_membership and
@ -210,9 +210,9 @@ class StatsHandler:
# in the absence of a previous event because we do not want to # in the absence of a previous event because we do not want to
# reduce the leave count when a new-to-the-room user joins. # reduce the leave count when a new-to-the-room user joins.
prev_membership = None prev_membership = None
if prev_event_id is not None: if delta.prev_event_id is not None:
prev_event = await self.store.get_event( prev_event = await self.store.get_event(
prev_event_id, allow_none=True delta.prev_event_id, allow_none=True
) )
if prev_event: if prev_event:
prev_event_content = prev_event.content prev_event_content = prev_event.content
@ -256,7 +256,7 @@ class StatsHandler:
else: else:
raise ValueError("%r is not a valid membership" % (membership,)) raise ValueError("%r is not a valid membership" % (membership,))
user_id = state_key user_id = delta.state_key
if self.is_mine_id(user_id): if self.is_mine_id(user_id):
# this accounts for transitions like leave → ban and so on. # this accounts for transitions like leave → ban and so on.
has_changed_joinedness = (prev_membership == Membership.JOIN) != ( has_changed_joinedness = (prev_membership == Membership.JOIN) != (
@ -272,30 +272,30 @@ class StatsHandler:
room_stats_delta["local_users_in_room"] += membership_delta room_stats_delta["local_users_in_room"] += membership_delta
elif typ == EventTypes.Create: elif delta.event_type == EventTypes.Create:
room_state["is_federatable"] = ( room_state["is_federatable"] = (
event_content.get(EventContentFields.FEDERATE, True) is True event_content.get(EventContentFields.FEDERATE, True) is True
) )
room_type = event_content.get(EventContentFields.ROOM_TYPE) room_type = event_content.get(EventContentFields.ROOM_TYPE)
if isinstance(room_type, str): if isinstance(room_type, str):
room_state["room_type"] = room_type room_state["room_type"] = room_type
elif typ == EventTypes.JoinRules: elif delta.event_type == EventTypes.JoinRules:
room_state["join_rules"] = event_content.get("join_rule") room_state["join_rules"] = event_content.get("join_rule")
elif typ == EventTypes.RoomHistoryVisibility: elif delta.event_type == EventTypes.RoomHistoryVisibility:
room_state["history_visibility"] = event_content.get( room_state["history_visibility"] = event_content.get(
"history_visibility" "history_visibility"
) )
elif typ == EventTypes.RoomEncryption: elif delta.event_type == EventTypes.RoomEncryption:
room_state["encryption"] = event_content.get("algorithm") room_state["encryption"] = event_content.get("algorithm")
elif typ == EventTypes.Name: elif delta.event_type == EventTypes.Name:
room_state["name"] = event_content.get("name") room_state["name"] = event_content.get("name")
elif typ == EventTypes.Topic: elif delta.event_type == EventTypes.Topic:
room_state["topic"] = event_content.get("topic") room_state["topic"] = event_content.get("topic")
elif typ == EventTypes.RoomAvatar: elif delta.event_type == EventTypes.RoomAvatar:
room_state["avatar"] = event_content.get("url") room_state["avatar"] = event_content.get("url")
elif typ == EventTypes.CanonicalAlias: elif delta.event_type == EventTypes.CanonicalAlias:
room_state["canonical_alias"] = event_content.get("alias") room_state["canonical_alias"] = event_content.get("alias")
elif typ == EventTypes.GuestAccess: elif delta.event_type == EventTypes.GuestAccess:
room_state["guest_access"] = event_content.get( room_state["guest_access"] = event_content.get(
EventContentFields.GUEST_ACCESS EventContentFields.GUEST_ACCESS
) )

View File

@ -14,7 +14,7 @@
import logging import logging
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple from typing import TYPE_CHECKING, List, Optional, Set, Tuple
from twisted.internet.interfaces import IDelayedCall from twisted.internet.interfaces import IDelayedCall
@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Memb
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.storage.databases.main.user_directory import SearchResult from synapse.storage.databases.main.user_directory import SearchResult
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
from synapse.types import UserID from synapse.types import UserID
@ -247,32 +248,31 @@ class UserDirectoryHandler(StateDeltasHandler):
await self.store.update_user_directory_stream_pos(max_pos) await self.store.update_user_directory_stream_pos(max_pos)
async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None: async def _handle_deltas(self, deltas: List[StateDelta]) -> None:
"""Called with the state deltas to process""" """Called with the state deltas to process"""
for delta in deltas: for delta in deltas:
typ = delta["type"] logger.debug(
state_key = delta["state_key"] "Handling: %r %r, %s", delta.event_type, delta.state_key, delta.event_id
room_id = delta["room_id"] )
event_id: Optional[str] = delta["event_id"]
prev_event_id: Optional[str] = delta["prev_event_id"]
logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
# For join rule and visibility changes we need to check if the room # For join rule and visibility changes we need to check if the room
# may have become public or not and add/remove the users in said room # may have become public or not and add/remove the users in said room
if typ in (EventTypes.RoomHistoryVisibility, EventTypes.JoinRules): if delta.event_type in (
EventTypes.RoomHistoryVisibility,
EventTypes.JoinRules,
):
await self._handle_room_publicity_change( await self._handle_room_publicity_change(
room_id, prev_event_id, event_id, typ delta.room_id, delta.prev_event_id, delta.event_id, delta.event_type
) )
elif typ == EventTypes.Member: elif delta.event_type == EventTypes.Member:
await self._handle_room_membership_event( await self._handle_room_membership_event(
room_id, delta.room_id,
prev_event_id, delta.prev_event_id,
event_id, delta.event_id,
state_key, delta.state_key,
) )
else: else:
logger.debug("Ignoring irrelevant type: %r", typ) logger.debug("Ignoring irrelevant type: %r", delta.event_type)
async def _handle_room_publicity_change( async def _handle_room_publicity_change(
self, self,

View File

@ -16,7 +16,6 @@ from itertools import chain
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
AbstractSet, AbstractSet,
Any,
Callable, Callable,
Collection, Collection,
Dict, Dict,
@ -32,6 +31,7 @@ from typing import (
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase from synapse.events import EventBase
from synapse.logging.opentracing import tag_args, trace from synapse.logging.opentracing import tag_args, trace
from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
from synapse.storage.util.partial_state_events_tracker import ( from synapse.storage.util.partial_state_events_tracker import (
PartialCurrentStateTracker, PartialCurrentStateTracker,
@ -531,19 +531,9 @@ class StateStorageController:
@tag_args @tag_args
async def get_current_state_deltas( async def get_current_state_deltas(
self, prev_stream_id: int, max_stream_id: int self, prev_stream_id: int, max_stream_id: int
) -> Tuple[int, List[Dict[str, Any]]]: ) -> Tuple[int, List[StateDelta]]:
"""Fetch a list of room state changes since the given stream id """Fetch a list of room state changes since the given stream id
Each entry in the result contains the following fields:
- stream_id (int)
- room_id (str)
- type (str): event type
- state_key (str):
- event_id (str|None): new event_id for this state key. None if the
state has been deleted.
- prev_event_id (str|None): previous event_id for this state key. None
if it's new state.
Args: Args:
prev_stream_id: point to get changes since (exclusive) prev_stream_id: point to get changes since (exclusive)
max_stream_id: the point that we know has been correctly persisted max_stream_id: the point that we know has been correctly persisted

View File

@ -13,7 +13,9 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Dict, List, Tuple from typing import List, Optional, Tuple
import attr
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction from synapse.storage.database import LoggingTransaction
@ -22,6 +24,20 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class StateDelta:
stream_id: int
room_id: str
event_type: str
state_key: str
event_id: Optional[str]
"""new event_id for this state key. None if the state has been deleted."""
prev_event_id: Optional[str]
"""previous event_id for this state key. None if it's new state."""
class StateDeltasStore(SQLBaseStore): class StateDeltasStore(SQLBaseStore):
# This class must be mixed in with a child class which provides the following # This class must be mixed in with a child class which provides the following
# attribute. TODO: can we get static analysis to enforce this? # attribute. TODO: can we get static analysis to enforce this?
@ -29,19 +45,9 @@ class StateDeltasStore(SQLBaseStore):
async def get_partial_current_state_deltas( async def get_partial_current_state_deltas(
self, prev_stream_id: int, max_stream_id: int self, prev_stream_id: int, max_stream_id: int
) -> Tuple[int, List[Dict[str, Any]]]: ) -> Tuple[int, List[StateDelta]]:
"""Fetch a list of room state changes since the given stream id """Fetch a list of room state changes since the given stream id
Each entry in the result contains the following fields:
- stream_id (int)
- room_id (str)
- type (str): event type
- state_key (str):
- event_id (str|None): new event_id for this state key. None if the
state has been deleted.
- prev_event_id (str|None): previous event_id for this state key. None
if it's new state.
This may be the partial state if we're lazy joining the room. This may be the partial state if we're lazy joining the room.
Args: Args:
@ -72,7 +78,7 @@ class StateDeltasStore(SQLBaseStore):
def get_current_state_deltas_txn( def get_current_state_deltas_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Tuple[int, List[Dict[str, Any]]]: ) -> Tuple[int, List[StateDelta]]:
# First we calculate the max stream id that will give us less than # First we calculate the max stream id that will give us less than
# N results. # N results.
# We arbitrarily limit to 100 stream_id entries to ensure we don't # We arbitrarily limit to 100 stream_id entries to ensure we don't
@ -112,7 +118,17 @@ class StateDeltasStore(SQLBaseStore):
ORDER BY stream_id ASC ORDER BY stream_id ASC
""" """
txn.execute(sql, (prev_stream_id, clipped_stream_id)) txn.execute(sql, (prev_stream_id, clipped_stream_id))
return clipped_stream_id, self.db_pool.cursor_to_dict(txn) return clipped_stream_id, [
StateDelta(
stream_id=row[0],
room_id=row[1],
event_type=row[2],
state_key=row[3],
event_id=row[4],
prev_event_id=row[5],
)
for row in txn.fetchall()
]
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn "get_current_state_deltas", get_current_state_deltas_txn

View File

@ -174,7 +174,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
return_value=1 return_value=1
) )
self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) # type: ignore[method-assign] self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, [])) # type: ignore[method-assign]
self.datastore.get_to_device_stream_token = Mock( # type: ignore[method-assign] self.datastore.get_to_device_stream_token = Mock( # type: ignore[method-assign]
return_value=0 return_value=0