Make `backfill` and `get_missing_events` use the same codepath (#10645)

Given that backfill and get_missing_events are basically the same thing, it's somewhat crazy that we have entirely separate code paths for them. This makes backfill use the existing get_missing_events code, and then clears up all the unused code.
This commit is contained in:
Richard van der Hoff 2021-08-26 18:34:57 +01:00 committed by GitHub
parent 40f619eaa5
commit 96715d7633
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 233 deletions

1
changelog.d/10645.misc Normal file
View File

@ -0,0 +1 @@
Make `backfill` and `get_missing_events` use the same codepath.

View File

@ -65,6 +65,7 @@ from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers._base import BaseHandler from synapse.handlers._base import BaseHandler
from synapse.http.servlet import assert_params_in_dict from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import ( from synapse.logging.context import (
@ -116,10 +117,6 @@ class _NewEventInfo:
Attributes: Attributes:
event: the received event event: the received event
state: the state at that event, according to /state_ids from a remote
homeserver. Only populated for backfilled events which are going to be a
new backwards extremity.
claimed_auth_event_map: a map of (type, state_key) => event for the event's claimed_auth_event_map: a map of (type, state_key) => event for the event's
claimed auth_events. claimed auth_events.
@ -134,7 +131,6 @@ class _NewEventInfo:
""" """
event: EventBase event: EventBase
state: Optional[Sequence[EventBase]]
claimed_auth_event_map: StateMap[EventBase] claimed_auth_event_map: StateMap[EventBase]
@ -443,113 +439,7 @@ class FederationHandler(BaseHandler):
return return
logger.info("Got %d prev_events", len(missing_events)) logger.info("Got %d prev_events", len(missing_events))
await self._process_pulled_events(origin, missing_events) await self._process_pulled_events(origin, missing_events, backfilled=False)
async def _get_state_for_room(
self,
destination: str,
room_id: str,
event_id: str,
) -> List[EventBase]:
"""Requests all of the room state at a given event from a remote
homeserver.
Will also fetch any missing events reported in the `auth_chain_ids`
section of `/state_ids`.
Args:
destination: The remote homeserver to query for the state.
room_id: The id of the room we're interested in.
event_id: The id of the event we want the state at.
Returns:
A list of events in the state, not including the event itself.
"""
(
state_event_ids,
auth_event_ids,
) = await self.federation_client.get_room_state_ids(
destination, room_id, event_id=event_id
)
# Fetch the state events from the DB, and check we have the auth events.
event_map = await self.store.get_events(state_event_ids, allow_rejected=True)
auth_events_in_store = await self.store.have_seen_events(
room_id, auth_event_ids
)
# Check for missing events. We handle state and auth event seperately,
# as we want to pull the state from the DB, but we don't for the auth
# events. (Note: we likely won't use the majority of the auth chain, and
# it can be *huge* for large rooms, so it's worth ensuring that we don't
# unnecessarily pull it from the DB).
missing_state_events = set(state_event_ids) - set(event_map)
missing_auth_events = set(auth_event_ids) - set(auth_events_in_store)
if missing_state_events or missing_auth_events:
await self._get_events_and_persist(
destination=destination,
room_id=room_id,
events=missing_state_events | missing_auth_events,
)
if missing_state_events:
new_events = await self.store.get_events(
missing_state_events, allow_rejected=True
)
event_map.update(new_events)
missing_state_events.difference_update(new_events)
if missing_state_events:
logger.warning(
"Failed to fetch missing state events for %s %s",
event_id,
missing_state_events,
)
if missing_auth_events:
auth_events_in_store = await self.store.have_seen_events(
room_id, missing_auth_events
)
missing_auth_events.difference_update(auth_events_in_store)
if missing_auth_events:
logger.warning(
"Failed to fetch missing auth events for %s %s",
event_id,
missing_auth_events,
)
remote_state = list(event_map.values())
# check for events which were in the wrong room.
#
# this can happen if a remote server claims that the state or
# auth_events at an event in room A are actually events in room B
bad_events = [
(event.event_id, event.room_id)
for event in remote_state
if event.room_id != room_id
]
for bad_event_id, bad_room_id in bad_events:
# This is a bogus situation, but since we may only discover it a long time
# after it happened, we try our best to carry on, by just omitting the
# bad events from the returned auth/state set.
logger.warning(
"Remote server %s claims event %s in room %s is an auth/state "
"event in room %s",
destination,
bad_event_id,
bad_room_id,
room_id,
)
if bad_events:
remote_state = [e for e in remote_state if e.room_id == room_id]
return remote_state
async def _get_state_after_missing_prev_event( async def _get_state_after_missing_prev_event(
self, self,
@ -567,10 +457,6 @@ class FederationHandler(BaseHandler):
Returns: Returns:
A list of events in the state, including the event itself A list of events in the state, including the event itself
""" """
# TODO: This function is basically the same as _get_state_for_room. Can
# we make backfill() use it, rather than having two code paths? I think the
# only difference is that backfill() persists the prev events separately.
( (
state_event_ids, state_event_ids,
auth_event_ids, auth_event_ids,
@ -681,6 +567,7 @@ class FederationHandler(BaseHandler):
origin: str, origin: str,
event: EventBase, event: EventBase,
state: Optional[Iterable[EventBase]], state: Optional[Iterable[EventBase]],
backfilled: bool = False,
) -> None: ) -> None:
"""Called when we have a new pdu. We need to do auth checks and put it """Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler. through the StateHandler.
@ -693,6 +580,9 @@ class FederationHandler(BaseHandler):
state: Normally None, but if we are handling a gap in the graph state: Normally None, but if we are handling a gap in the graph
(ie, we are missing one or more prev_events), the resolved state at the (ie, we are missing one or more prev_events), the resolved state at the
event event
backfilled: True if this is part of a historical batch of events (inhibits
notification to clients, and validation of device keys.)
""" """
logger.debug("Processing event: %s", event) logger.debug("Processing event: %s", event)
@ -700,10 +590,15 @@ class FederationHandler(BaseHandler):
context = await self.state_handler.compute_event_context( context = await self.state_handler.compute_event_context(
event, old_state=state event, old_state=state
) )
await self._auth_and_persist_event(origin, event, context, state=state) await self._auth_and_persist_event(
origin, event, context, state=state, backfilled=backfilled
)
except AuthError as e: except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
if backfilled:
return
# For encrypted messages we check that we know about the sending device, # For encrypted messages we check that we know about the sending device,
# if we don't then we mark the device cache for that user as stale. # if we don't then we mark the device cache for that user as stale.
if event.type == EventTypes.Encrypted: if event.type == EventTypes.Encrypted:
@ -868,7 +763,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
async def backfill( async def backfill(
self, dest: str, room_id: str, limit: int, extremities: List[str] self, dest: str, room_id: str, limit: int, extremities: List[str]
) -> List[EventBase]: ) -> None:
"""Trigger a backfill request to `dest` for the given `room_id` """Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side This will attempt to get more events from the remote. If the other side
@ -878,6 +773,9 @@ class FederationHandler(BaseHandler):
sanity-checking on them. If any of the backfilled events are invalid, sanity-checking on them. If any of the backfilled events are invalid,
this method throws a SynapseError. this method throws a SynapseError.
We might also raise an InvalidResponseError if the response from the remote
server is just bogus.
TODO: make this more useful to distinguish failures of the remote TODO: make this more useful to distinguish failures of the remote
server from invalid events (there is probably no point in trying to server from invalid events (there is probably no point in trying to
re-fetch invalid events from every other HS in the room.) re-fetch invalid events from every other HS in the room.)
@ -890,111 +788,18 @@ class FederationHandler(BaseHandler):
) )
if not events: if not events:
return [] return
# ideally we'd sanity check the events here for excess prev_events etc, # if there are any events in the wrong room, the remote server is buggy and
# but it's hard to reject events at this point without completely # should not be trusted.
# breaking backfill in the same way that it is currently broken by for ev in events:
# events whose signature we cannot verify (#3121). if ev.room_id != room_id:
# raise InvalidResponseError(
# So for now we accept the events anyway. #3124 tracks this. f"Remote server {dest} returned event {ev.event_id} which is in "
# f"room {ev.room_id}, when we were backfilling in {room_id}"
# for ev in events:
# self._sanity_check_event(ev)
# Don't bother processing events we already have.
seen_events = await self.store.have_events_in_timeline(
{e.event_id for e in events}
)
events = [e for e in events if e.event_id not in seen_events]
if not events:
return []
event_map = {e.event_id: e for e in events}
event_ids = {e.event_id for e in events}
# build a list of events whose prev_events weren't in the batch.
# (XXX: this will include events whose prev_events we already have; that doesn't
# sound right?)
edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids]
logger.info("backfill: Got %d events with %d edges", len(events), len(edges))
# For each edge get the current state.
state_events = {}
events_to_state = {}
for e_id in edges:
state = await self._get_state_for_room(
destination=dest,
room_id=room_id,
event_id=e_id,
)
state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state
required_auth = {
a_id
for event in events + list(state_events.values())
for a_id in event.auth_event_ids()
}
auth_events = await self.store.get_events(required_auth, allow_rejected=True)
auth_events.update(
{e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
)
ev_infos = []
# Step 1: persist the events in the chunk we fetched state for (i.e.
# the backwards extremities), with custom auth events and state
for e_id in events_to_state:
# For paranoia we ensure that these events are marked as
# non-outliers
ev = event_map[e_id]
assert not ev.internal_metadata.is_outlier()
ev_infos.append(
_NewEventInfo(
event=ev,
state=events_to_state[e_id],
claimed_auth_event_map={
(
auth_events[a_id].type,
auth_events[a_id].state_key,
): auth_events[a_id]
for a_id in ev.auth_event_ids()
if a_id in auth_events
},
) )
)
if ev_infos: await self._process_pulled_events(dest, events, backfilled=True)
await self._auth_and_persist_events(
dest, room_id, ev_infos, backfilled=True
)
# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
for event in events:
if event in events_to_state:
continue
# For paranoia we ensure that these events are marked as
# non-outliers
assert not event.internal_metadata.is_outlier()
context = await self.state_handler.compute_event_context(event)
# We store these one at a time since each event depends on the
# previous to work out the state.
# TODO: We can probably do something more clever here.
await self._auth_and_persist_event(dest, event, context, backfilled=True)
return events
async def maybe_backfill( async def maybe_backfill(
self, room_id: str, current_depth: int, limit: int self, room_id: str, current_depth: int, limit: int
@ -1197,7 +1002,7 @@ class FederationHandler(BaseHandler):
# appropriate stuff. # appropriate stuff.
# TODO: We can probably do something more intelligent here. # TODO: We can probably do something more intelligent here.
return True return True
except SynapseError as e: except (SynapseError, InvalidResponseError) as e:
logger.info("Failed to backfill from %s because %s", dom, e) logger.info("Failed to backfill from %s because %s", dom, e)
continue continue
except HttpResponseException as e: except HttpResponseException as e:
@ -1351,7 +1156,7 @@ class FederationHandler(BaseHandler):
else: else:
logger.info("Missing auth event %s", auth_event_id) logger.info("Missing auth event %s", auth_event_id)
event_infos.append(_NewEventInfo(event, None, auth)) event_infos.append(_NewEventInfo(event, auth))
if event_infos: if event_infos:
await self._auth_and_persist_events( await self._auth_and_persist_events(
@ -1361,7 +1166,7 @@ class FederationHandler(BaseHandler):
) )
async def _process_pulled_events( async def _process_pulled_events(
self, origin: str, events: Iterable[EventBase] self, origin: str, events: Iterable[EventBase], backfilled: bool
) -> None: ) -> None:
"""Process a batch of events we have pulled from a remote server """Process a batch of events we have pulled from a remote server
@ -1373,6 +1178,8 @@ class FederationHandler(BaseHandler):
Params: Params:
origin: The server we received these events from origin: The server we received these events from
events: The received events. events: The received events.
backfilled: True if this is part of a historical batch of events (inhibits
notification to clients, and validation of device keys.)
""" """
# We want to sort these by depth so we process them and # We want to sort these by depth so we process them and
@ -1381,9 +1188,11 @@ class FederationHandler(BaseHandler):
for ev in sorted_events: for ev in sorted_events:
with nested_logging_context(ev.event_id): with nested_logging_context(ev.event_id):
await self._process_pulled_event(origin, ev) await self._process_pulled_event(origin, ev, backfilled=backfilled)
async def _process_pulled_event(self, origin: str, event: EventBase) -> None: async def _process_pulled_event(
self, origin: str, event: EventBase, backfilled: bool
) -> None:
"""Process a single event that we have pulled from a remote server """Process a single event that we have pulled from a remote server
Pulls in any events required to auth the event, persists the received event, Pulls in any events required to auth the event, persists the received event,
@ -1400,6 +1209,8 @@ class FederationHandler(BaseHandler):
Params: Params:
origin: The server we received this event from origin: The server we received this event from
events: The received event events: The received event
backfilled: True if this is part of a historical batch of events (inhibits
notification to clients, and validation of device keys.)
""" """
logger.info("Processing pulled event %s", event) logger.info("Processing pulled event %s", event)
@ -1428,7 +1239,9 @@ class FederationHandler(BaseHandler):
try: try:
state = await self._resolve_state_at_missing_prevs(origin, event) state = await self._resolve_state_at_missing_prevs(origin, event)
await self._process_received_pdu(origin, event, state=state) await self._process_received_pdu(
origin, event, state=state, backfilled=backfilled
)
except FederationError as e: except FederationError as e:
if e.code == 403: if e.code == 403:
logger.warning("Pulled event %s failed history check.", event_id) logger.warning("Pulled event %s failed history check.", event_id)
@ -2451,7 +2264,6 @@ class FederationHandler(BaseHandler):
origin: str, origin: str,
room_id: str, room_id: str,
event_infos: Collection[_NewEventInfo], event_infos: Collection[_NewEventInfo],
backfilled: bool = False,
) -> None: ) -> None:
"""Creates the appropriate contexts and persists events. The events """Creates the appropriate contexts and persists events. The events
should not depend on one another, e.g. this should be used to persist should not depend on one another, e.g. this should be used to persist
@ -2467,16 +2279,12 @@ class FederationHandler(BaseHandler):
async def prep(ev_info: _NewEventInfo): async def prep(ev_info: _NewEventInfo):
event = ev_info.event event = ev_info.event
with nested_logging_context(suffix=event.event_id): with nested_logging_context(suffix=event.event_id):
res = await self.state_handler.compute_event_context( res = await self.state_handler.compute_event_context(event)
event, old_state=ev_info.state
)
res = await self._check_event_auth( res = await self._check_event_auth(
origin, origin,
event, event,
res, res,
state=ev_info.state,
claimed_auth_event_map=ev_info.claimed_auth_event_map, claimed_auth_event_map=ev_info.claimed_auth_event_map,
backfilled=backfilled,
) )
return res return res
@ -2493,7 +2301,6 @@ class FederationHandler(BaseHandler):
(ev_info.event, context) (ev_info.event, context)
for ev_info, context in zip(event_infos, contexts) for ev_info, context in zip(event_infos, contexts)
], ],
backfilled=backfilled,
) )
async def _persist_auth_tree( async def _persist_auth_tree(

View File

@ -295,6 +295,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.have_seen_event, (room_id, event_id) txn, self.have_seen_event, (room_id, event_id)
) )
self._invalidate_get_event_cache(event_id)
logger.info("[purge] done") logger.info("[purge] done")