Avoid sending massive replication updates when purging a room. (#16510)

This commit is contained in:
Patrick Cloke 2023-10-18 12:26:01 -04:00 committed by GitHub
parent bcff01b406
commit 49c9745b45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 114 additions and 29 deletions

1
changelog.d/16510.misc Normal file
View File

@ -0,0 +1 @@
Improve replication performance when purging rooms.

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import heapq import heapq
from collections import defaultdict
from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast
import attr import attr
@ -51,8 +52,19 @@ data part are:
* The state_key of the state which has changed * The state_key of the state which has changed
* The event id of the new state * The event id of the new state
A "state-all" row is sent whenever the "current state" in a room changes, but there are
too many state updates for a particular room in the same update. This replaces any
"state" rows on a per-room basis. The fields in the data part are:
* The room id for the state changes
""" """
# Any room with more than _MAX_STATE_UPDATES_PER_ROOM will send a EventsStreamAllStateRow
# instead of individual EventsStreamEventRow. This is predominantly useful when
# purging large rooms.
_MAX_STATE_UPDATES_PER_ROOM = 150
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class EventsStreamRow: class EventsStreamRow:
@ -111,9 +123,17 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow):
event_id: Optional[str] event_id: Optional[str]
@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventsStreamAllStateRow(BaseEventsStreamRow):
TypeId = "state-all"
room_id: str
_EventRows: Tuple[Type[BaseEventsStreamRow], ...] = ( _EventRows: Tuple[Type[BaseEventsStreamRow], ...] = (
EventsStreamEventRow, EventsStreamEventRow,
EventsStreamCurrentStateRow, EventsStreamCurrentStateRow,
EventsStreamAllStateRow,
) )
TypeToRow = {Row.TypeId: Row for Row in _EventRows} TypeToRow = {Row.TypeId: Row for Row in _EventRows}
@ -213,9 +233,28 @@ class EventsStream(Stream):
if stream_id <= upper_limit if stream_id <= upper_limit
) )
# Separate out rooms that have many state updates, listeners should clear
# all state for those rooms.
state_updates_by_room = defaultdict(list)
for stream_id, room_id, _type, _state_key, _event_id in state_rows:
state_updates_by_room[room_id].append(stream_id)
state_all_rows = [
(stream_ids[-1], room_id)
for room_id, stream_ids in state_updates_by_room.items()
if len(stream_ids) >= _MAX_STATE_UPDATES_PER_ROOM
]
state_all_updates: Iterable[Tuple[int, Tuple]] = (
(max_stream_id, (EventsStreamAllStateRow.TypeId, (room_id,)))
for (max_stream_id, room_id) in state_all_rows
)
# Any remaining state updates are sent individually.
state_all_rooms = {room_id for _, room_id in state_all_rows}
state_updates: Iterable[Tuple[int, Tuple]] = ( state_updates: Iterable[Tuple[int, Tuple]] = (
(stream_id, (EventsStreamCurrentStateRow.TypeId, rest)) (stream_id, (EventsStreamCurrentStateRow.TypeId, rest))
for (stream_id, *rest) in state_rows for (stream_id, *rest) in state_rows
if rest[0] not in state_all_rooms
) )
ex_outliers_updates: Iterable[Tuple[int, Tuple]] = ( ex_outliers_updates: Iterable[Tuple[int, Tuple]] = (
@ -224,7 +263,11 @@ class EventsStream(Stream):
) )
# we need to return a sorted list, so merge them together. # we need to return a sorted list, so merge them together.
updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates)) updates = list(
heapq.merge(
event_updates, state_all_updates, state_updates, ex_outliers_updates
)
)
return updates, upper_limit, limited return updates, upper_limit, limited
@classmethod @classmethod

View File

@ -23,6 +23,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
from synapse.replication.tcp.streams import BackfillStream, CachesStream from synapse.replication.tcp.streams import BackfillStream, CachesStream
from synapse.replication.tcp.streams.events import ( from synapse.replication.tcp.streams.events import (
EventsStream, EventsStream,
EventsStreamAllStateRow,
EventsStreamCurrentStateRow, EventsStreamCurrentStateRow,
EventsStreamEventRow, EventsStreamEventRow,
EventsStreamRow, EventsStreamRow,
@ -264,6 +265,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
(data.state_key,) (data.state_key,)
) )
self.get_rooms_for_user.invalidate((data.state_key,)) # type: ignore[attr-defined] self.get_rooms_for_user.invalidate((data.state_key,)) # type: ignore[attr-defined]
elif row.type == EventsStreamAllStateRow.TypeId:
assert isinstance(data, EventsStreamAllStateRow)
# Similar to the above, but the entire caches are invalidated. This is
# unfortunate for the membership caches, but should recover quickly.
self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) # type: ignore[attr-defined]
self.get_rooms_for_user_with_stream_ordering.invalidate_all() # type: ignore[attr-defined]
self.get_rooms_for_user.invalidate_all() # type: ignore[attr-defined]
else: else:
raise Exception("Unknown events stream row type %s" % (row.type,)) raise Exception("Unknown events stream row type %s" % (row.type,))

View File

@ -14,6 +14,8 @@
from typing import Any, List, Optional from typing import Any, List, Optional
from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
@ -21,6 +23,8 @@ from synapse.events import EventBase
from synapse.replication.tcp.commands import RdataCommand from synapse.replication.tcp.commands import RdataCommand
from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
from synapse.replication.tcp.streams.events import ( from synapse.replication.tcp.streams.events import (
_MAX_STATE_UPDATES_PER_ROOM,
EventsStreamAllStateRow,
EventsStreamCurrentStateRow, EventsStreamCurrentStateRow,
EventsStreamEventRow, EventsStreamEventRow,
EventsStreamRow, EventsStreamRow,
@ -106,11 +110,21 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows) self.assertEqual([], received_rows)
def test_update_function_huge_state_change(self) -> None: @parameterized.expand(
[(_STREAM_UPDATE_TARGET_ROW_COUNT, False), (_MAX_STATE_UPDATES_PER_ROOM, True)]
)
def test_update_function_huge_state_change(
self, num_state_changes: int, collapse_state_changes: bool
) -> None:
"""Test replication with many state events """Test replication with many state events
Ensures that all events are correctly replicated when there are lots of Ensures that all events are correctly replicated when there are lots of
state change rows to be replicated. state change rows to be replicated.
Args:
num_state_changes: The number of state changes to create.
collapse_state_changes: Whether the state changes are expected to be
collapsed or not.
""" """
# we want to generate lots of state changes at a single stream ID. # we want to generate lots of state changes at a single stream ID.
@ -145,7 +159,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
events = [ events = [
self._inject_state_event(sender=OTHER_USER) self._inject_state_event(sender=OTHER_USER)
for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT) for _ in range(num_state_changes)
] ]
self.replicate() self.replicate()
@ -202,8 +216,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
row for row in self.test_handler.received_rdata_rows if row[0] == "events" row for row in self.test_handler.received_rdata_rows if row[0] == "events"
] ]
# first check the first two rows, which should be state1 # first check the first two rows, which should be the state1 event.
stream_name, token, row = received_rows.pop(0) stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name) self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow) self.assertIsInstance(row, EventsStreamRow)
@ -217,7 +230,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertIsInstance(row.data, EventsStreamCurrentStateRow) self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state1.event_id) self.assertEqual(row.data.event_id, state1.event_id)
# now the last two rows, which should be state2 # now the last two rows, which should be the state2 event.
stream_name, token, row = received_rows.pop(-2) stream_name, token, row = received_rows.pop(-2)
self.assertEqual("events", stream_name) self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow) self.assertIsInstance(row, EventsStreamRow)
@ -231,6 +244,26 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertIsInstance(row.data, EventsStreamCurrentStateRow) self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state2.event_id) self.assertEqual(row.data.event_id, state2.event_id)
# Based on the number of
if collapse_state_changes:
# that should leave us with the rows for the PL event, the state changes
# get collapsed into a single row.
self.assertEqual(len(received_rows), 2)
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, pl_event.event_id)
stream_name, token, row = received_rows.pop(0)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state-all")
self.assertIsInstance(row.data, EventsStreamAllStateRow)
self.assertEqual(row.data.room_id, state2.room_id)
else:
# that should leave us with the rows for the PL event # that should leave us with the rows for the PL event
self.assertEqual(len(received_rows), len(events) + 2) self.assertEqual(len(received_rows), len(events) + 2)