diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index cbddc80ca9..477e16e0fa 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -23,7 +23,6 @@ from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.http.site import SynapseSite from synapse.federation import send_queue -from synapse.federation.units import Edu from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.events import SlavedEventStore @@ -33,7 +32,6 @@ from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.storage.engines import create_engine -from synapse.storage.presence import UserPresenceState from synapse.util.async import Linearizer from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn @@ -51,7 +49,6 @@ from daemonize import Daemonize import sys import logging import gc -import ujson as json logger = logging.getLogger("synapse.app.appservice") @@ -278,70 +275,7 @@ class FederationSenderHandler(object): # The federation stream contains things that we want to send out, e.g. # presence, typing, etc. if stream_name == "federation": - # The federation stream containis a bunch of different types of - # rows that need to be handled differently. We parse the rows, put - # them into the appropriate collection and then send them off. - presence_to_send = {} - keyed_edus = {} - edus = {} - failures = {} - device_destinations = set() - - # Parse the rows in the stream - for row in rows: - typ = row.type - content_js = row.data - content = json.loads(content_js) - - if typ == send_queue.PRESENCE_TYPE: - destination = content["destination"] - state = UserPresenceState.from_dict(content["state"]) - - presence_to_send.setdefault(destination, []).append(state) - elif typ == send_queue.KEYED_EDU_TYPE: - key = content["key"] - edu = Edu(**content["edu"]) - - keyed_edus.setdefault( - edu.destination, {} - )[(edu.destination, tuple(key))] = edu - elif typ == send_queue.EDU_TYPE: - edu = Edu(**content) - - edus.setdefault(edu.destination, []).append(edu) - elif typ == send_queue.FAILURE_TYPE: - destination = content["destination"] - failure = content["failure"] - - failures.setdefault(destination, []).append(failure) - elif typ == send_queue.DEVICE_MESSAGE_TYPE: - device_destinations.add(content["destination"]) - else: - raise Exception("Unrecognised federation type: %r", typ) - - # We've finished collecting, send everything off - for destination, states in presence_to_send.items(): - self.federation_sender.send_presence(destination, states) - - for destination, edu_map in keyed_edus.items(): - for key, edu in edu_map.items(): - self.federation_sender.send_edu( - edu.destination, edu.edu_type, edu.content, key=key, - ) - - for destination, edu_list in edus.items(): - for edu in edu_list: - self.federation_sender.send_edu( - edu.destination, edu.edu_type, edu.content, key=None, - ) - - for destination, failure_list in failures.items(): - for failure in failure_list: - self.federation_sender.send_failure(destination, failure) - - for destination in device_destinations: - self.federation_sender.send_device_messages(destination) - + send_queue.process_rows_for_federation(self.federation_sender, rows) preserve_fn(self.update_token)(token) # We also need to poke the federation sender when new events happen diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index a1ef5dfa77..d39e3161fe 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -62,7 +62,6 @@ import sys import logging import contextlib import gc -import ujson as json logger = logging.getLogger("synapse.app.synchrotron") @@ -120,12 +119,53 @@ class SynchrotronPresence(object): for state in active_presence } + # user_id -> last_sync_ms. Lists the users that have stopped syncing + # but we haven't notified the master of that yet + self.users_going_offline = {} + + self._send_stop_syncing_loop = self.clock.looping_call( + self.send_stop_syncing, 10 * 1000 + ) + self.process_id = random_string(16) logger.info("Presence process_id is %r", self.process_id) def send_user_sync(self, user_id, is_syncing, last_sync_ms): self.hs.get_tcp_replication().send_user_sync(user_id, is_syncing, last_sync_ms) + def mark_as_coming_online(self, user_id): + """A user has started syncing. Send a UserSync to the master, unless they + had recently stopped syncing. + + Args: + user_id (str) + """ + going_offline = self.users_going_offline.pop(user_id, None) + if not going_offline: + # Safe to skip because we haven't yet told the master they were offline + self.send_user_sync(user_id, True, self.clock.time_msec()) + + def mark_as_going_offline(self, user_id): + """A user has stopped syncing. We wait before notifying the master as + its likely they'll come back soon. This allows us to avoid sending + a stopped syncing immediately followed by a started syncing notification + to the master + + Args: + user_id (str) + """ + self.users_going_offline[user_id] = self.clock.time_msec() + + def send_stop_syncing(self): + """Check if there are any users who have stopped syncing a while ago + and haven't come back yet. If there are poke the master about them. + """ + now = self.clock.time_msec() + for user_id, last_sync_ms in self.users_going_offline.items(): + if now - last_sync_ms > 10 * 1000: + self.users_going_offline.pop(user_id, None) + self.send_user_sync(user_id, False, last_sync_ms) + def set_state(self, user, state, ignore_status_msg=False): # TODO Hows this supposed to work? pass @@ -142,8 +182,7 @@ class SynchrotronPresence(object): # If we went from no in flight sync to some, notify replication if self.user_to_num_current_syncs[user_id] == 1: - now = self.clock.time_msec() - self.send_user_sync(user_id, True, now) + self.mark_as_coming_online(user_id) def _end(): # We check that the user_id is in user_to_num_current_syncs because @@ -154,8 +193,7 @@ class SynchrotronPresence(object): # If we went from one in flight sync to non, notify replication if self.user_to_num_current_syncs[user_id] == 0: - now = self.clock.time_msec() - self.send_user_sync(user_id, False, now) + self.mark_as_going_offline(user_id) @contextlib.contextmanager def _user_syncing(): @@ -215,9 +253,8 @@ class SynchrotronTyping(object): self._latest_room_serial = token for row in rows: - typing = json.loads(row.user_ids) self._room_serials[row.room_id] = token - self._room_typing[row.room_id] = typing + self._room_typing[row.room_id] = row.user_ids class SynchrotronApplicationService(object): diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py index e8218d01ad..8223734845 100755 --- a/synapse/app/synctl.py +++ b/synapse/app/synctl.py @@ -125,7 +125,7 @@ def main(): "configfile", nargs="?", default="homeserver.yaml", - help="the homeserver config file, defaults to homserver.yaml", + help="the homeserver config file, defaults to homeserver.yaml", ) parser.add_argument( "-w", "--worker", diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 4bde66fbf8..748548bbe2 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -31,23 +31,21 @@ Events are replicated via a separate events stream. from .units import Edu +from synapse.storage.presence import UserPresenceState from synapse.util.metrics import Measure import synapse.metrics from blist import sorteddict -import ujson +from collections import namedtuple + +import logging + +logger = logging.getLogger(__name__) metrics = synapse.metrics.get_metrics_for(__name__) -PRESENCE_TYPE = "p" -KEYED_EDU_TYPE = "k" -EDU_TYPE = "e" -FAILURE_TYPE = "f" -DEVICE_MESSAGE_TYPE = "d" - - class FederationRemoteSendQueue(object): """A drop in replacement for TransactionQueue""" @@ -240,6 +238,8 @@ class FederationRemoteSendQueue(object): if from_token > self.pos: from_token = -1 + # list of tuple(int, BaseFederationRow), where the first is the position + # of the federation stream. rows = [] # There should be only one reader, so lets delete everything its @@ -258,10 +258,10 @@ class FederationRemoteSendQueue(object): ) for (key, (dest, user_id)) in dest_user_ids: - rows.append((key, PRESENCE_TYPE, ujson.dumps({ - "destination": dest, - "state": self.presence_map[user_id].as_dict(), - }))) + rows.append((key, PresenceRow( + destination=dest, + state=self.presence_map[user_id], + ))) # Fetch changes keyed edus keys = self.keyed_edu_changed.keys() @@ -270,12 +270,10 @@ class FederationRemoteSendQueue(object): keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:j]) for (pos, (destination, edu_key)) in keyed_edus: - rows.append( - (pos, KEYED_EDU_TYPE, ujson.dumps({ - "key": edu_key, - "edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(), - })) - ) + rows.append((pos, KeyedEduRow( + key=edu_key, + edu=self.keyed_edu[(destination, edu_key)], + ))) # Fetch changed edus keys = self.edus.keys() @@ -284,7 +282,7 @@ class FederationRemoteSendQueue(object): edus = set((k, self.edus[k]) for k in keys[i:j]) for (pos, edu) in edus: - rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict()))) + rows.append((pos, EduRow(edu))) # Fetch changed failures keys = self.failures.keys() @@ -293,10 +291,10 @@ class FederationRemoteSendQueue(object): failures = set((k, self.failures[k]) for k in keys[i:j]) for (pos, (destination, failure)) in failures: - rows.append((pos, FAILURE_TYPE, ujson.dumps({ - "destination": destination, - "failure": failure, - }))) + rows.append((pos, FailureRow( + destination=destination, + failure=failure, + ))) # Fetch changed device messages keys = self.device_messages.keys() @@ -305,11 +303,229 @@ class FederationRemoteSendQueue(object): device_messages = set((k, self.device_messages[k]) for k in keys[i:j]) for (pos, destination) in device_messages: - rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({ - "destination": destination, - }))) + rows.append((pos, DeviceRow( + destination=destination, + ))) # Sort rows based on pos rows.sort() - return rows + return [(pos, row.TypeId, row.to_data()) for pos, row in rows] + + +class BaseFederationRow(object): + """Base class for rows to be sent in the federation stream. + + Specifies how to identify, serialize and deserialize the different types. + """ + + TypeId = None # Unique string that ids the type. Must be overriden in sub classes. + + @staticmethod + def from_data(data): + """Parse the data from the federation stream into a row. + + Args: + data: The value of ``data`` from FederationStreamRow.data, type + depends on the type of stream + """ + raise NotImplementedError() + + def to_data(self): + """Serialize this row to be sent over the federation stream. + + Returns: + The value to be sent in FederationStreamRow.data. The type depends + on the type of stream. + """ + raise NotImplementedError() + + def add_to_buffer(self, buff): + """Add this row to the appropriate field in the buffer ready for this + to be sent over federation. + + We use a buffer so that we can batch up events that have come in at + the same time and send them all at once. + + Args: + buff (BufferedToSend) + """ + raise NotImplementedError() + + +class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", ( + "destination", # str + "state", # UserPresenceState +))): + TypeId = "p" + + @staticmethod + def from_data(data): + return PresenceRow( + destination=data["destination"], + state=UserPresenceState.from_dict(data["state"]) + ) + + def to_data(self): + return { + "destination": self.destination, + "state": self.state.as_dict() + } + + def add_to_buffer(self, buff): + buff.presence.setdefault(self.destination, []).append(self.state) + + +class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", ( + "key", # tuple(str) - the edu key passed to send_edu + "edu", # Edu +))): + TypeId = "k" + + @staticmethod + def from_data(data): + return KeyedEduRow( + key=tuple(data["key"]), + edu=Edu(**data["edu"]), + ) + + def to_data(self): + return { + "key": self.key, + "edu": self.edu.get_internal_dict(), + } + + def add_to_buffer(self, buff): + buff.keyed_edus.setdefault( + self.edu.destination, {} + )[self.key] = self.edu + + +class EduRow(BaseFederationRow, namedtuple("EduRow", ( + "edu", # Edu +))): + TypeId = "e" + + @staticmethod + def from_data(data): + return EduRow(Edu(**data)) + + def to_data(self): + return self.edu.get_internal_dict() + + def add_to_buffer(self, buff): + buff.edus.setdefault(self.edu.destination, []).append(self.edu) + + +class FailureRow(BaseFederationRow, namedtuple("FailureRow", ( + "destination", # str + "failure", +))): + TypeId = "f" + + @staticmethod + def from_data(data): + return FailureRow( + destination=data["destination"], + failure=data["failure"], + ) + + def to_data(self): + return { + "destination": self.destination, + "failure": self.failure, + } + + def add_to_buffer(self, buff): + buff.failures.setdefault(self.destination, []).append(self.failure) + + +class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ( + "destination", # str +))): + TypeId = "d" + + @staticmethod + def from_data(data): + return DeviceRow(destination=data["destination"]) + + def to_data(self): + return {"destination": self.destination} + + def add_to_buffer(self, buff): + buff.device_destinations.add(self.destination) + + +TypeToRow = { + Row.TypeId: Row + for Row in ( + PresenceRow, + KeyedEduRow, + EduRow, + FailureRow, + DeviceRow, + ) +} + + +ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", ( + "presence", # dict of destination -> [UserPresenceState] + "keyed_edus", # dict of destination -> { key -> Edu } + "edus", # dict of destination -> [Edu] + "failures", # dict of destination -> [failures] + "device_destinations", # set of destinations +)) + + +def process_rows_for_federation(transaction_queue, rows): + """Parse a list of rows from the federation stream and put them in the + transaction queue ready for sending to the relevant homeservers. + + Args: + transaction_queue (TransactionQueue) + rows (list(synapse.replication.tcp.streams.FederationStreamRow)) + """ + + # The federation stream contains a bunch of different types of + # rows that need to be handled differently. We parse the rows, put + # them into the appropriate collection and then send them off. + + buff = ParsedFederationStreamData( + presence={}, + keyed_edus={}, + edus={}, + failures={}, + device_destinations=set(), + ) + + # Parse the rows in the stream and add to the buffer + for row in rows: + if row.type not in TypeToRow: + logger.error("Unrecognized federation row type %r", row.type) + continue + + RowType = TypeToRow[row.type] + parsed_row = RowType.from_data(row.data) + parsed_row.add_to_buffer(buff) + + for destination, states in buff.presence.iteritems(): + transaction_queue.send_presence(destination, states) + + for destination, edu_map in buff.keyed_edus.iteritems(): + for key, edu in edu_map.items(): + transaction_queue.send_edu( + edu.destination, edu.edu_type, edu.content, key=key, + ) + + for destination, edu_list in buff.edus.iteritems(): + for edu in edu_list: + transaction_queue.send_edu( + edu.destination, edu.edu_type, edu.content, key=None, + ) + + for destination, failure_list in buff.failures.iteritems(): + for failure in failure_list: + transaction_queue.send_failure(destination, failure) + + for destination in buff.device_destinations: + transaction_queue.send_device_messages(destination) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index d64e16670c..3946b01dd5 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -599,14 +599,14 @@ class PresenceHandler(object): for user_id in user_ids } - missing = [user_id for user_id, state in states.items() if not state] + missing = [user_id for user_id, state in states.iteritems() if not state] if missing: # There are things not in our in memory cache. Lets pull them out of # the database. res = yield self.store.get_presence_for_users(missing) states.update(res) - missing = [user_id for user_id, state in states.items() if not state] + missing = [user_id for user_id, state in states.iteritems() if not state] if missing: new = { user_id: UserPresenceState.default(user_id) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index d6809862e0..3b7818af5c 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -24,7 +24,6 @@ from synapse.types import UserID, get_domain_from_id import logging from collections import namedtuple -import ujson as json logger = logging.getLogger(__name__) @@ -288,8 +287,7 @@ class TypingHandler(object): for room_id, serial in self._room_serials.items(): if last_id < serial and serial <= current_id: typing = self._room_typing[room_id] - typing_bytes = json.dumps(list(typing), ensure_ascii=False) - rows.append((serial, room_id, typing_bytes)) + rows.append((serial, room_id, list(typing))) rows.sort() return rows diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 251d3afcf4..90fb6c1336 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -175,7 +175,7 @@ class ReplicationClientHandler(object): def send_invalidate_cache(self, cache_func, keys): """Poke the master to invalidate a cache. """ - cmd = InvalidateCacheCommand(cache_func, keys) + cmd = InvalidateCacheCommand(cache_func.__name__, keys) self.send_command(cmd) def await_sync(self, data): diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index d4d672aafe..5770b7125a 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -85,6 +85,8 @@ logger = logging.getLogger(__name__) PING_TIME = 5000 +PING_TIMEOUT_MULTIPLIER = 5 +PING_TIMEOUT_MS = PING_TIME * PING_TIMEOUT_MULTIPLIER class ConnectionStates(object): @@ -166,7 +168,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): now = self.clock.time_msec() if self.time_we_closed: - if now - self.time_we_closed > PING_TIME * 3: + if now - self.time_we_closed > PING_TIMEOUT_MS: logger.info( "[%s] Failed to close connection gracefully, aborting", self.id() ) @@ -175,7 +177,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if now - self.last_sent_command >= PING_TIME: self.send_command(PingCommand(now)) - if self.received_ping and now - self.last_received_command > PING_TIME * 3: + if self.received_ping and now - self.last_received_command > PING_TIMEOUT_MS: logger.info( "[%s] Connection hasn't received command in %r ms. Closing.", self.id(), now - self.last_received_command @@ -220,6 +222,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): logger.exception("[%s] Failed to handle line: %r", self.id(), line) def close(self): + logger.warn("[%s] Closing connection", self.id()) self.time_we_closed = self.clock.time_msec() self.transport.loseConnection() self.on_connection_closed() @@ -505,7 +508,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) - self.transport.abortConnection() + self.send_error("Wrong remote") def on_RDATA(self, cmd): try: diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py index 4de4ebe84d..369d5f2428 100644 --- a/synapse/replication/tcp/streams.py +++ b/synapse/replication/tcp/streams.py @@ -36,34 +36,82 @@ logger = logging.getLogger(__name__) MAX_EVENTS_BEHIND = 10000 -EventStreamRow = namedtuple("EventStreamRow", - ("event_id", "room_id", "type", "state_key", "redacts")) -BackfillStreamRow = namedtuple("BackfillStreamRow", - ("event_id", "room_id", "type", "state_key", "redacts")) -PresenceStreamRow = namedtuple("PresenceStreamRow", - ("user_id", "state", "last_active_ts", - "last_federation_update_ts", "last_user_sync_ts", - "status_msg", "currently_active")) -TypingStreamRow = namedtuple("TypingStreamRow", - ("room_id", "user_ids")) -ReceiptsStreamRow = namedtuple("ReceiptsStreamRow", - ("room_id", "receipt_type", "user_id", "event_id", - "data")) -PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) -PushersStreamRow = namedtuple("PushersStreamRow", - ("user_id", "app_id", "pushkey", "deleted",)) -CachesStreamRow = namedtuple("CachesStreamRow", - ("cache_func", "keys", "invalidation_ts",)) -PublicRoomsStreamRow = namedtuple("PublicRoomsStreamRow", - ("room_id", "visibility", "appservice_id", - "network_id",)) -DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", ("user_id", "destination",)) -ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) -FederationStreamRow = namedtuple("FederationStreamRow", ("type", "data",)) -TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", - ("user_id", "room_id", "data")) -AccountDataStreamRow = namedtuple("AccountDataStream", - ("user_id", "room_id", "data_type", "data")) +EventStreamRow = namedtuple("EventStreamRow", ( + "event_id", # str + "room_id", # str + "type", # str + "state_key", # str, optional + "redacts", # str, optional +)) +BackfillStreamRow = namedtuple("BackfillStreamRow", ( + "event_id", # str + "room_id", # str + "type", # str + "state_key", # str, optional + "redacts", # str, optional +)) +PresenceStreamRow = namedtuple("PresenceStreamRow", ( + "user_id", # str + "state", # str + "last_active_ts", # int + "last_federation_update_ts", # int + "last_user_sync_ts", # int + "status_msg", # str + "currently_active", # bool +)) +TypingStreamRow = namedtuple("TypingStreamRow", ( + "room_id", # str + "user_ids", # list(str) +)) +ReceiptsStreamRow = namedtuple("ReceiptsStreamRow", ( + "room_id", # str + "receipt_type", # str + "user_id", # str + "event_id", # str + "data", # dict +)) +PushRulesStreamRow = namedtuple("PushRulesStreamRow", ( + "user_id", # str +)) +PushersStreamRow = namedtuple("PushersStreamRow", ( + "user_id", # str + "app_id", # str + "pushkey", # str + "deleted", # bool +)) +CachesStreamRow = namedtuple("CachesStreamRow", ( + "cache_func", # str + "keys", # list(str) + "invalidation_ts", # int +)) +PublicRoomsStreamRow = namedtuple("PublicRoomsStreamRow", ( + "room_id", # str + "visibility", # str + "appservice_id", # str, optional + "network_id", # str, optional +)) +DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", ( + "user_id", # str + "destination", # str +)) +ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ( + "entity", # str +)) +FederationStreamRow = namedtuple("FederationStreamRow", ( + "type", # str, the type of data as defined in the BaseFederationRows + "data", # dict, serialization of a federation.send_queue.BaseFederationRow +)) +TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", ( + "user_id", # str + "room_id", # str + "data", # dict +)) +AccountDataStreamRow = namedtuple("AccountDataStream", ( + "user_id", # str + "room_id", # str + "data_type", # str + "data", # dict +)) class Stream(object): diff --git a/synapse/storage/state.py b/synapse/storage/state.py index fb23f6f462..acd69944c4 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,7 +14,7 @@ # limitations under the License. from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches import intern_string from synapse.storage.engines import PostgresEngine @@ -69,17 +69,33 @@ class StateStore(SQLBaseStore): where_clause="type='m.room.member'", ) - @cachedInlineCallbacks(max_entries=100000, iterable=True) + @cached(max_entries=100000, iterable=True) def get_current_state_ids(self, room_id): - rows = yield self._simple_select_list( - table="current_state_events", - keyvalues={"room_id": room_id}, - retcols=["event_id", "type", "state_key"], - desc="_calculate_state_delta", + """Get the current state event ids for a room based on the + current_state_events table. + + Args: + room_id (str) + + Returns: + deferred: dict of (type, state_key) -> event_id + """ + def _get_current_state_ids_txn(txn): + txn.execute( + """SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? + """, + (room_id,) + ) + + return { + (r[0], r[1]): r[2] for r in txn + } + + return self.runInteraction( + "get_current_state_ids", + _get_current_state_ids_txn, ) - defer.returnValue({ - (r["type"], r["state_key"]): r["event_id"] for r in rows - }) @defer.inlineCallbacks def get_state_groups_ids(self, room_id, event_ids):