Some refactors around receipts stream (#16426)

This commit is contained in:
Erik Johnston 2023-10-04 18:28:40 +03:00 committed by GitHub
parent a01ee24734
commit 80ec81dcc5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 111 additions and 80 deletions

1
changelog.d/16426.misc Normal file
View File

@ -0,0 +1 @@
Refactor some code to simplify and better type receipts stream adjacent code.

View File

@ -216,7 +216,7 @@ class ApplicationServicesHandler:
def notify_interested_services_ephemeral( def notify_interested_services_ephemeral(
self, self,
stream_key: str, stream_key: StreamKeyType,
new_token: Union[int, RoomStreamToken], new_token: Union[int, RoomStreamToken],
users: Collection[Union[str, UserID]], users: Collection[Union[str, UserID]],
) -> None: ) -> None:
@ -326,7 +326,7 @@ class ApplicationServicesHandler:
async def _notify_interested_services_ephemeral( async def _notify_interested_services_ephemeral(
self, self,
services: List[ApplicationService], services: List[ApplicationService],
stream_key: str, stream_key: StreamKeyType,
new_token: int, new_token: int,
users: Collection[Union[str, UserID]], users: Collection[Union[str, UserID]],
) -> None: ) -> None:

View File

@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError, UnrecognizedRequestError
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.push_rule import RuleNotFoundException from synapse.storage.push_rule import RuleNotFoundException
from synapse.synapse_rust.push import get_base_rule_ids from synapse.synapse_rust.push import get_base_rule_ids
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, StreamKeyType, UserID
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -114,7 +114,9 @@ class PushRulesHandler:
user_id: the user ID the change is for. user_id: the user ID the change is for.
""" """
stream_id = self._main_store.get_max_push_rules_stream_id() stream_id = self._main_store.get_max_push_rules_stream_id()
self._notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) self._notifier.on_new_event(
StreamKeyType.PUSH_RULES, stream_id, users=[user_id]
)
async def push_rules_for_user( async def push_rules_for_user(
self, user: UserID self, user: UserID

View File

@ -130,11 +130,10 @@ class ReceiptsHandler:
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool: async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
"""Takes a list of receipts, stores them and informs the notifier.""" """Takes a list of receipts, stores them and informs the notifier."""
min_batch_id: Optional[int] = None
max_batch_id: Optional[int] = None
receipts_persisted: List[ReadReceipt] = []
for receipt in receipts: for receipt in receipts:
res = await self.store.insert_receipt( stream_id = await self.store.insert_receipt(
receipt.room_id, receipt.room_id,
receipt.receipt_type, receipt.receipt_type,
receipt.user_id, receipt.user_id,
@ -143,30 +142,26 @@ class ReceiptsHandler:
receipt.data, receipt.data,
) )
if not res: if stream_id is None:
# res will be None if this receipt is 'old' # stream_id will be None if this receipt is 'old'
continue continue
stream_id, max_persisted_id = res receipts_persisted.append(receipt)
if min_batch_id is None or stream_id < min_batch_id: if not receipts_persisted:
min_batch_id = stream_id
if max_batch_id is None or max_persisted_id > max_batch_id:
max_batch_id = max_persisted_id
# Either both of these should be None or neither.
if min_batch_id is None or max_batch_id is None:
# no new receipts # no new receipts
return False return False
affected_room_ids = list({r.room_id for r in receipts}) max_batch_id = self.store.get_max_receipt_stream_id()
affected_room_ids = list({r.room_id for r in receipts_persisted})
self.notifier.on_new_event( self.notifier.on_new_event(
StreamKeyType.RECEIPT, max_batch_id, rooms=affected_room_ids StreamKeyType.RECEIPT, max_batch_id, rooms=affected_room_ids
) )
# Note that the min here shouldn't be relied upon to be accurate. # Note that the min here shouldn't be relied upon to be accurate.
await self.hs.get_pusherpool().on_new_receipts( await self.hs.get_pusherpool().on_new_receipts(
min_batch_id, max_batch_id, affected_room_ids {r.user_id for r in receipts_persisted}
) )
return True return True

View File

@ -126,7 +126,7 @@ class _NotifierUserStream:
def notify( def notify(
self, self,
stream_key: str, stream_key: StreamKeyType,
stream_id: Union[int, RoomStreamToken], stream_id: Union[int, RoomStreamToken],
time_now_ms: int, time_now_ms: int,
) -> None: ) -> None:
@ -454,7 +454,7 @@ class Notifier:
def on_new_event( def on_new_event(
self, self,
stream_key: str, stream_key: StreamKeyType,
new_token: Union[int, RoomStreamToken], new_token: Union[int, RoomStreamToken],
users: Optional[Collection[Union[str, UserID]]] = None, users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[StrCollection] = None, rooms: Optional[StrCollection] = None,
@ -655,30 +655,29 @@ class Notifier:
events: List[Union[JsonDict, EventBase]] = [] events: List[Union[JsonDict, EventBase]] = []
end_token = from_token end_token = from_token
for name, source in self.event_sources.sources.get_sources(): for keyname, source in self.event_sources.sources.get_sources():
keyname = "%s_key" % name before_id = before_token.get_field(keyname)
before_id = getattr(before_token, keyname) after_id = after_token.get_field(keyname)
after_id = getattr(after_token, keyname)
if before_id == after_id: if before_id == after_id:
continue continue
new_events, new_key = await source.get_new_events( new_events, new_key = await source.get_new_events(
user=user, user=user,
from_key=getattr(from_token, keyname), from_key=from_token.get_field(keyname),
limit=limit, limit=limit,
is_guest=is_peeking, is_guest=is_peeking,
room_ids=room_ids, room_ids=room_ids,
explicit_room_id=explicit_room_id, explicit_room_id=explicit_room_id,
) )
if name == "room": if keyname == StreamKeyType.ROOM:
new_events = await filter_events_for_client( new_events = await filter_events_for_client(
self._storage_controllers, self._storage_controllers,
user.to_string(), user.to_string(),
new_events, new_events,
is_peeking=is_peeking, is_peeking=is_peeking,
) )
elif name == "presence": elif keyname == StreamKeyType.PRESENCE:
now = self.clock.time_msec() now = self.clock.time_msec()
new_events[:] = [ new_events[:] = [
{ {

View File

@ -182,7 +182,7 @@ class Pusher(metaclass=abc.ABCMeta):
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None: def on_new_receipts(self) -> None:
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod

View File

@ -99,7 +99,7 @@ class EmailPusher(Pusher):
pass pass
self.timed_call = None self.timed_call = None
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None: def on_new_receipts(self) -> None:
# We could wake up and cancel the timer but there tend to be quite a # We could wake up and cancel the timer but there tend to be quite a
# lot of read receipts so it's probably less work to just let the # lot of read receipts so it's probably less work to just let the
# timer fire # timer fire

View File

@ -160,7 +160,7 @@ class HttpPusher(Pusher):
if should_check_for_notifs: if should_check_for_notifs:
self._start_processing() self._start_processing()
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None: def on_new_receipts(self) -> None:
# Note that the min here shouldn't be relied upon to be accurate. # Note that the min here shouldn't be relied upon to be accurate.
# We could check the receipts are actually m.read receipts here, # We could check the receipts are actually m.read receipts here,

View File

@ -292,20 +292,12 @@ class PusherPool:
except Exception: except Exception:
logger.exception("Exception in pusher on_new_notifications") logger.exception("Exception in pusher on_new_notifications")
async def on_new_receipts( async def on_new_receipts(self, users_affected: StrCollection) -> None:
self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str]
) -> None:
if not self.pushers: if not self.pushers:
# nothing to do here. # nothing to do here.
return return
try: try:
# Need to subtract 1 from the minimum because the lower bound here
# is not inclusive
users_affected = await self.store.get_users_sent_receipts_between(
min_stream_id - 1, max_stream_id
)
for u in users_affected: for u in users_affected:
# Don't push if the user account has expired # Don't push if the user account has expired
expired = await self._account_validity_handler.is_user_expired(u) expired = await self._account_validity_handler.is_user_expired(u)
@ -314,7 +306,7 @@ class PusherPool:
if u in self.pushers: if u in self.pushers:
for p in self.pushers[u].values(): for p in self.pushers[u].values():
p.on_new_receipts(min_stream_id, max_stream_id) p.on_new_receipts()
except Exception: except Exception:
logger.exception("Exception in pusher on_new_receipts") logger.exception("Exception in pusher on_new_receipts")

View File

@ -129,9 +129,7 @@ class ReplicationDataHandler:
self.notifier.on_new_event( self.notifier.on_new_event(
StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows] StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows]
) )
await self._pusher_pool.on_new_receipts( await self._pusher_pool.on_new_receipts({row.user_id for row in rows})
token, token, {row.room_id for row in rows}
)
elif stream_name == ToDeviceStream.NAME: elif stream_name == ToDeviceStream.NAME:
entities = [row.entity for row in rows if row.entity.startswith("@")] entities = [row.entity for row in rows if row.entity.startswith("@")]
if entities: if entities:

View File

@ -208,7 +208,7 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
"message": "Set room key", "message": "Set room key",
"room_id": room_id, "room_id": room_id,
"session_id": session_id, "session_id": session_id,
StreamKeyType.ROOM: room_key, StreamKeyType.ROOM.value: room_key,
} }
) )

View File

@ -742,7 +742,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
event_ids: List[str], event_ids: List[str],
thread_id: Optional[str], thread_id: Optional[str],
data: dict, data: dict,
) -> Optional[Tuple[int, int]]: ) -> Optional[int]:
"""Insert a receipt, either from local client or remote server. """Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph Automatically does conversion between linearized and graph
@ -804,9 +804,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
data, data,
) )
max_persisted_id = self._receipts_id_gen.get_current_token() return stream_id
return stream_id, max_persisted_id
async def _insert_graph_receipt( async def _insert_graph_receipt(
self, self,

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Iterator, Tuple from typing import TYPE_CHECKING, Sequence, Tuple
import attr import attr
@ -23,7 +23,7 @@ from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource from synapse.handlers.typing import TypingNotificationEventSource
from synapse.logging.opentracing import trace from synapse.logging.opentracing import trace
from synapse.streams import EventSource from synapse.streams import EventSource
from synapse.types import StreamToken from synapse.types import StreamKeyType, StreamToken
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -37,9 +37,14 @@ class _EventSourcesInner:
receipt: ReceiptEventSource receipt: ReceiptEventSource
account_data: AccountDataEventSource account_data: AccountDataEventSource
def get_sources(self) -> Iterator[Tuple[str, EventSource]]: def get_sources(self) -> Sequence[Tuple[StreamKeyType, EventSource]]:
for attribute in attr.fields(_EventSourcesInner): return [
yield attribute.name, getattr(self, attribute.name) (StreamKeyType.ROOM, self.room),
(StreamKeyType.PRESENCE, self.presence),
(StreamKeyType.TYPING, self.typing),
(StreamKeyType.RECEIPT, self.receipt),
(StreamKeyType.ACCOUNT_DATA, self.account_data),
]
class EventSources: class EventSources:

View File

@ -22,8 +22,8 @@ from typing import (
Any, Any,
ClassVar, ClassVar,
Dict, Dict,
Final,
List, List,
Literal,
Mapping, Mapping,
Match, Match,
MutableMapping, MutableMapping,
@ -34,6 +34,7 @@ from typing import (
Type, Type,
TypeVar, TypeVar,
Union, Union,
overload,
) )
import attr import attr
@ -649,20 +650,20 @@ class RoomStreamToken:
return "s%d" % (self.stream,) return "s%d" % (self.stream,)
class StreamKeyType: class StreamKeyType(Enum):
"""Known stream types. """Known stream types.
A stream is a list of entities ordered by an incrementing "stream token". A stream is a list of entities ordered by an incrementing "stream token".
""" """
ROOM: Final = "room_key" ROOM = "room_key"
PRESENCE: Final = "presence_key" PRESENCE = "presence_key"
TYPING: Final = "typing_key" TYPING = "typing_key"
RECEIPT: Final = "receipt_key" RECEIPT = "receipt_key"
ACCOUNT_DATA: Final = "account_data_key" ACCOUNT_DATA = "account_data_key"
PUSH_RULES: Final = "push_rules_key" PUSH_RULES = "push_rules_key"
TO_DEVICE: Final = "to_device_key" TO_DEVICE = "to_device_key"
DEVICE_LIST: Final = "device_list_key" DEVICE_LIST = "device_list_key"
UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key" UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key"
@ -784,7 +785,7 @@ class StreamToken:
def room_stream_id(self) -> int: def room_stream_id(self) -> int:
return self.room_key.stream return self.room_key.stream
def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken": def copy_and_advance(self, key: StreamKeyType, new_value: Any) -> "StreamToken":
"""Advance the given key in the token to a new value if and only if the """Advance the given key in the token to a new value if and only if the
new value is after the old value. new value is after the old value.
@ -797,16 +798,44 @@ class StreamToken:
return new_token return new_token
new_token = self.copy_and_replace(key, new_value) new_token = self.copy_and_replace(key, new_value)
new_id = int(getattr(new_token, key)) new_id = new_token.get_field(key)
old_id = int(getattr(self, key)) old_id = self.get_field(key)
if old_id < new_id: if old_id < new_id:
return new_token return new_token
else: else:
return self return self
def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken": def copy_and_replace(self, key: StreamKeyType, new_value: Any) -> "StreamToken":
return attr.evolve(self, **{key: new_value}) return attr.evolve(self, **{key.value: new_value})
@overload
def get_field(self, key: Literal[StreamKeyType.ROOM]) -> RoomStreamToken:
...
@overload
def get_field(
self,
key: Literal[
StreamKeyType.ACCOUNT_DATA,
StreamKeyType.DEVICE_LIST,
StreamKeyType.PRESENCE,
StreamKeyType.PUSH_RULES,
StreamKeyType.RECEIPT,
StreamKeyType.TO_DEVICE,
StreamKeyType.TYPING,
StreamKeyType.UN_PARTIAL_STATED_ROOMS,
],
) -> int:
...
@overload
def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
...
def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
"""Returns the stream ID for the given key."""
return getattr(self, key.value)
StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0, 0) StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0, 0)

View File

@ -31,7 +31,7 @@ from synapse.appservice import (
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.rest.client import login, receipts, register, room, sendtodevice from synapse.rest.client import login, receipts, register, room, sendtodevice
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict, RoomStreamToken from synapse.types import JsonDict, RoomStreamToken, StreamKeyType
from synapse.util import Clock from synapse.util import Clock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -304,7 +304,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
) )
self.handler.notify_interested_services_ephemeral( self.handler.notify_interested_services_ephemeral(
"receipt_key", 580, ["@fakerecipient:example.com"] StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"]
) )
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
interested_service, ephemeral=[event] interested_service, ephemeral=[event]
@ -332,7 +332,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
) )
self.handler.notify_interested_services_ephemeral( self.handler.notify_interested_services_ephemeral(
"receipt_key", 580, ["@fakerecipient:example.com"] StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"]
) )
# This method will be called, but with an empty list of events # This method will be called, but with an empty list of events
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
@ -634,7 +634,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
self.get_success( self.get_success(
self.hs.get_application_service_handler()._notify_interested_services_ephemeral( self.hs.get_application_service_handler()._notify_interested_services_ephemeral(
services=[interested_appservice], services=[interested_appservice],
stream_key="receipt_key", stream_key=StreamKeyType.RECEIPT,
new_token=stream_token, new_token=stream_token,
users=[self.exclusive_as_user], users=[self.exclusive_as_user],
) )

View File

@ -28,7 +28,7 @@ from synapse.federation.transport.server import TransportLayerServer
from synapse.handlers.typing import TypingWriterHandler from synapse.handlers.typing import TypingWriterHandler
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.types import JsonDict, Requester, StreamKeyType, UserID, create_requester
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -203,7 +203,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
) )
) )
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.on_new_event.assert_has_calls(
[call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
)
self.assertEqual(self.event_source.get_current_key(), 1) self.assertEqual(self.event_source.get_current_key(), 1)
events = self.get_success( events = self.get_success(
@ -273,7 +275,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.on_new_event.assert_has_calls(
[call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
)
self.assertEqual(self.event_source.get_current_key(), 1) self.assertEqual(self.event_source.get_current_key(), 1)
events = self.get_success( events = self.get_success(
@ -349,7 +353,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
) )
) )
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.on_new_event.assert_has_calls(
[call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
)
self.mock_federation_client.put_json.assert_called_once_with( self.mock_federation_client.put_json.assert_called_once_with(
"farm", "farm",
@ -399,7 +405,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
) )
) )
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.on_new_event.assert_has_calls(
[call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
)
self.on_new_event.reset_mock() self.on_new_event.reset_mock()
self.assertEqual(self.event_source.get_current_key(), 1) self.assertEqual(self.event_source.get_current_key(), 1)
@ -425,7 +433,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.reactor.pump([16]) self.reactor.pump([16])
self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])]) self.on_new_event.assert_has_calls(
[call(StreamKeyType.TYPING, 2, rooms=[ROOM_ID])]
)
self.assertEqual(self.event_source.get_current_key(), 2) self.assertEqual(self.event_source.get_current_key(), 2)
events = self.get_success( events = self.get_success(
@ -459,7 +469,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
) )
) )
self.on_new_event.assert_has_calls([call("typing_key", 3, rooms=[ROOM_ID])]) self.on_new_event.assert_has_calls(
[call(StreamKeyType.TYPING, 3, rooms=[ROOM_ID])]
)
self.on_new_event.reset_mock() self.on_new_event.reset_mock()
self.assertEqual(self.event_source.get_current_key(), 3) self.assertEqual(self.event_source.get_current_key(), 3)