Make `backfill` and `get_missing_events` use the same codepath (#10645)

Given that backfill and get_missing_events are basically the same thing, it's somewhat crazy that we have entirely separate code paths for them. This makes backfill use the existing get_missing_events code, and then clears up all the unused code.
This commit is contained in:
Richard van der Hoff 2021-08-26 18:34:57 +01:00 committed by GitHub
parent 40f619eaa5
commit 96715d7633
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 233 deletions

1
changelog.d/10645.misc Normal file
View File

@ -0,0 +1 @@
Make `backfill` and `get_missing_events` use the same codepath.

View File

@ -65,6 +65,7 @@ from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers._base import BaseHandler
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
@ -116,10 +117,6 @@ class _NewEventInfo:
Attributes:
event: the received event
state: the state at that event, according to /state_ids from a remote
homeserver. Only populated for backfilled events which are going to be a
new backwards extremity.
claimed_auth_event_map: a map of (type, state_key) => event for the event's
claimed auth_events.
@ -134,7 +131,6 @@ class _NewEventInfo:
"""
event: EventBase
state: Optional[Sequence[EventBase]]
claimed_auth_event_map: StateMap[EventBase]
@ -443,113 +439,7 @@ class FederationHandler(BaseHandler):
return
logger.info("Got %d prev_events", len(missing_events))
await self._process_pulled_events(origin, missing_events)
async def _get_state_for_room(
self,
destination: str,
room_id: str,
event_id: str,
) -> List[EventBase]:
"""Requests all of the room state at a given event from a remote
homeserver.
Will also fetch any missing events reported in the `auth_chain_ids`
section of `/state_ids`.
Args:
destination: The remote homeserver to query for the state.
room_id: The id of the room we're interested in.
event_id: The id of the event we want the state at.
Returns:
A list of events in the state, not including the event itself.
"""
(
state_event_ids,
auth_event_ids,
) = await self.federation_client.get_room_state_ids(
destination, room_id, event_id=event_id
)
# Fetch the state events from the DB, and check we have the auth events.
event_map = await self.store.get_events(state_event_ids, allow_rejected=True)
auth_events_in_store = await self.store.have_seen_events(
room_id, auth_event_ids
)
# Check for missing events. We handle state and auth event seperately,
# as we want to pull the state from the DB, but we don't for the auth
# events. (Note: we likely won't use the majority of the auth chain, and
# it can be *huge* for large rooms, so it's worth ensuring that we don't
# unnecessarily pull it from the DB).
missing_state_events = set(state_event_ids) - set(event_map)
missing_auth_events = set(auth_event_ids) - set(auth_events_in_store)
if missing_state_events or missing_auth_events:
await self._get_events_and_persist(
destination=destination,
room_id=room_id,
events=missing_state_events | missing_auth_events,
)
if missing_state_events:
new_events = await self.store.get_events(
missing_state_events, allow_rejected=True
)
event_map.update(new_events)
missing_state_events.difference_update(new_events)
if missing_state_events:
logger.warning(
"Failed to fetch missing state events for %s %s",
event_id,
missing_state_events,
)
if missing_auth_events:
auth_events_in_store = await self.store.have_seen_events(
room_id, missing_auth_events
)
missing_auth_events.difference_update(auth_events_in_store)
if missing_auth_events:
logger.warning(
"Failed to fetch missing auth events for %s %s",
event_id,
missing_auth_events,
)
remote_state = list(event_map.values())
# check for events which were in the wrong room.
#
# this can happen if a remote server claims that the state or
# auth_events at an event in room A are actually events in room B
bad_events = [
(event.event_id, event.room_id)
for event in remote_state
if event.room_id != room_id
]
for bad_event_id, bad_room_id in bad_events:
# This is a bogus situation, but since we may only discover it a long time
# after it happened, we try our best to carry on, by just omitting the
# bad events from the returned auth/state set.
logger.warning(
"Remote server %s claims event %s in room %s is an auth/state "
"event in room %s",
destination,
bad_event_id,
bad_room_id,
room_id,
)
if bad_events:
remote_state = [e for e in remote_state if e.room_id == room_id]
return remote_state
await self._process_pulled_events(origin, missing_events, backfilled=False)
async def _get_state_after_missing_prev_event(
self,
@ -567,10 +457,6 @@ class FederationHandler(BaseHandler):
Returns:
A list of events in the state, including the event itself
"""
# TODO: This function is basically the same as _get_state_for_room. Can
# we make backfill() use it, rather than having two code paths? I think the
# only difference is that backfill() persists the prev events separately.
(
state_event_ids,
auth_event_ids,
@ -681,6 +567,7 @@ class FederationHandler(BaseHandler):
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
backfilled: bool = False,
) -> None:
"""Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.
@ -693,6 +580,9 @@ class FederationHandler(BaseHandler):
state: Normally None, but if we are handling a gap in the graph
(ie, we are missing one or more prev_events), the resolved state at the
event
backfilled: True if this is part of a historical batch of events (inhibits
notification to clients, and validation of device keys.)
"""
logger.debug("Processing event: %s", event)
@ -700,10 +590,15 @@ class FederationHandler(BaseHandler):
context = await self.state_handler.compute_event_context(
event, old_state=state
)
await self._auth_and_persist_event(origin, event, context, state=state)
await self._auth_and_persist_event(
origin, event, context, state=state, backfilled=backfilled
)
except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
if backfilled:
return
# For encrypted messages we check that we know about the sending device,
# if we don't then we mark the device cache for that user as stale.
if event.type == EventTypes.Encrypted:
@ -868,7 +763,7 @@ class FederationHandler(BaseHandler):
@log_function
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: List[str]
) -> List[EventBase]:
) -> None:
"""Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side
@ -878,6 +773,9 @@ class FederationHandler(BaseHandler):
sanity-checking on them. If any of the backfilled events are invalid,
this method throws a SynapseError.
We might also raise an InvalidResponseError if the response from the remote
server is just bogus.
TODO: make this more useful to distinguish failures of the remote
server from invalid events (there is probably no point in trying to
re-fetch invalid events from every other HS in the room.)
@ -890,111 +788,18 @@ class FederationHandler(BaseHandler):
)
if not events:
return []
return
# ideally we'd sanity check the events here for excess prev_events etc,
# but it's hard to reject events at this point without completely
# breaking backfill in the same way that it is currently broken by
# events whose signature we cannot verify (#3121).
#
# So for now we accept the events anyway. #3124 tracks this.
#
# for ev in events:
# self._sanity_check_event(ev)
# Don't bother processing events we already have.
seen_events = await self.store.have_events_in_timeline(
{e.event_id for e in events}
)
events = [e for e in events if e.event_id not in seen_events]
if not events:
return []
event_map = {e.event_id: e for e in events}
event_ids = {e.event_id for e in events}
# build a list of events whose prev_events weren't in the batch.
# (XXX: this will include events whose prev_events we already have; that doesn't
# sound right?)
edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids]
logger.info("backfill: Got %d events with %d edges", len(events), len(edges))
# For each edge get the current state.
state_events = {}
events_to_state = {}
for e_id in edges:
state = await self._get_state_for_room(
destination=dest,
room_id=room_id,
event_id=e_id,
)
state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state
required_auth = {
a_id
for event in events + list(state_events.values())
for a_id in event.auth_event_ids()
}
auth_events = await self.store.get_events(required_auth, allow_rejected=True)
auth_events.update(
{e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
)
ev_infos = []
# Step 1: persist the events in the chunk we fetched state for (i.e.
# the backwards extremities), with custom auth events and state
for e_id in events_to_state:
# For paranoia we ensure that these events are marked as
# non-outliers
ev = event_map[e_id]
assert not ev.internal_metadata.is_outlier()
ev_infos.append(
_NewEventInfo(
event=ev,
state=events_to_state[e_id],
claimed_auth_event_map={
(
auth_events[a_id].type,
auth_events[a_id].state_key,
): auth_events[a_id]
for a_id in ev.auth_event_ids()
if a_id in auth_events
},
# if there are any events in the wrong room, the remote server is buggy and
# should not be trusted.
for ev in events:
if ev.room_id != room_id:
raise InvalidResponseError(
f"Remote server {dest} returned event {ev.event_id} which is in "
f"room {ev.room_id}, when we were backfilling in {room_id}"
)
)
if ev_infos:
await self._auth_and_persist_events(
dest, room_id, ev_infos, backfilled=True
)
# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
for event in events:
if event in events_to_state:
continue
# For paranoia we ensure that these events are marked as
# non-outliers
assert not event.internal_metadata.is_outlier()
context = await self.state_handler.compute_event_context(event)
# We store these one at a time since each event depends on the
# previous to work out the state.
# TODO: We can probably do something more clever here.
await self._auth_and_persist_event(dest, event, context, backfilled=True)
return events
await self._process_pulled_events(dest, events, backfilled=True)
async def maybe_backfill(
self, room_id: str, current_depth: int, limit: int
@ -1197,7 +1002,7 @@ class FederationHandler(BaseHandler):
# appropriate stuff.
# TODO: We can probably do something more intelligent here.
return True
except SynapseError as e:
except (SynapseError, InvalidResponseError) as e:
logger.info("Failed to backfill from %s because %s", dom, e)
continue
except HttpResponseException as e:
@ -1351,7 +1156,7 @@ class FederationHandler(BaseHandler):
else:
logger.info("Missing auth event %s", auth_event_id)
event_infos.append(_NewEventInfo(event, None, auth))
event_infos.append(_NewEventInfo(event, auth))
if event_infos:
await self._auth_and_persist_events(
@ -1361,7 +1166,7 @@ class FederationHandler(BaseHandler):
)
async def _process_pulled_events(
self, origin: str, events: Iterable[EventBase]
self, origin: str, events: Iterable[EventBase], backfilled: bool
) -> None:
"""Process a batch of events we have pulled from a remote server
@ -1373,6 +1178,8 @@ class FederationHandler(BaseHandler):
Params:
origin: The server we received these events from
events: The received events.
backfilled: True if this is part of a historical batch of events (inhibits
notification to clients, and validation of device keys.)
"""
# We want to sort these by depth so we process them and
@ -1381,9 +1188,11 @@ class FederationHandler(BaseHandler):
for ev in sorted_events:
with nested_logging_context(ev.event_id):
await self._process_pulled_event(origin, ev)
await self._process_pulled_event(origin, ev, backfilled=backfilled)
async def _process_pulled_event(self, origin: str, event: EventBase) -> None:
async def _process_pulled_event(
self, origin: str, event: EventBase, backfilled: bool
) -> None:
"""Process a single event that we have pulled from a remote server
Pulls in any events required to auth the event, persists the received event,
@ -1400,6 +1209,8 @@ class FederationHandler(BaseHandler):
Params:
origin: The server we received this event from
events: The received event
backfilled: True if this is part of a historical batch of events (inhibits
notification to clients, and validation of device keys.)
"""
logger.info("Processing pulled event %s", event)
@ -1428,7 +1239,9 @@ class FederationHandler(BaseHandler):
try:
state = await self._resolve_state_at_missing_prevs(origin, event)
await self._process_received_pdu(origin, event, state=state)
await self._process_received_pdu(
origin, event, state=state, backfilled=backfilled
)
except FederationError as e:
if e.code == 403:
logger.warning("Pulled event %s failed history check.", event_id)
@ -2451,7 +2264,6 @@ class FederationHandler(BaseHandler):
origin: str,
room_id: str,
event_infos: Collection[_NewEventInfo],
backfilled: bool = False,
) -> None:
"""Creates the appropriate contexts and persists events. The events
should not depend on one another, e.g. this should be used to persist
@ -2467,16 +2279,12 @@ class FederationHandler(BaseHandler):
async def prep(ev_info: _NewEventInfo):
event = ev_info.event
with nested_logging_context(suffix=event.event_id):
res = await self.state_handler.compute_event_context(
event, old_state=ev_info.state
)
res = await self.state_handler.compute_event_context(event)
res = await self._check_event_auth(
origin,
event,
res,
state=ev_info.state,
claimed_auth_event_map=ev_info.claimed_auth_event_map,
backfilled=backfilled,
)
return res
@ -2493,7 +2301,6 @@ class FederationHandler(BaseHandler):
(ev_info.event, context)
for ev_info, context in zip(event_infos, contexts)
],
backfilled=backfilled,
)
async def _persist_auth_tree(

View File

@ -295,6 +295,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
self._invalidate_cache_and_stream(
txn, self.have_seen_event, (room_id, event_id)
)
self._invalidate_get_event_cache(event_id)
logger.info("[purge] done")