Refactor `filter_events_for_server` (#15240)
* Tweak docstring and type hint * Flip logic and provide better name * Separate decision from action * Track a set of strings, not EventBases * Require explicit boolean options from callers * Add explicit option for partial state rooms * Changelog * Rename param
This commit is contained in:
parent
e157c63f68
commit
4bb26c95a9
|
@ -0,0 +1 @@
|
||||||
|
Refactor `filter_events_for_server`.
|
|
@ -547,6 +547,8 @@ class PerDestinationQueue:
|
||||||
self._server_name,
|
self._server_name,
|
||||||
new_pdus,
|
new_pdus,
|
||||||
redact=False,
|
redact=False,
|
||||||
|
filter_out_erased_senders=True,
|
||||||
|
filter_out_remote_partial_state_events=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If we've filtered out all the extremities, fall back to
|
# If we've filtered out all the extremities, fall back to
|
||||||
|
|
|
@ -392,7 +392,7 @@ class FederationHandler:
|
||||||
get_prev_content=False,
|
get_prev_content=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We set `check_history_visibility_only` as we might otherwise get false
|
# We unset `filter_out_erased_senders` as we might otherwise get false
|
||||||
# positives from users having been erased.
|
# positives from users having been erased.
|
||||||
filtered_extremities = await filter_events_for_server(
|
filtered_extremities = await filter_events_for_server(
|
||||||
self._storage_controllers,
|
self._storage_controllers,
|
||||||
|
@ -400,7 +400,8 @@ class FederationHandler:
|
||||||
self.server_name,
|
self.server_name,
|
||||||
events_to_check,
|
events_to_check,
|
||||||
redact=False,
|
redact=False,
|
||||||
check_history_visibility_only=True,
|
filter_out_erased_senders=False,
|
||||||
|
filter_out_remote_partial_state_events=False,
|
||||||
)
|
)
|
||||||
if filtered_extremities:
|
if filtered_extremities:
|
||||||
extremities_to_request.append(bp.event_id)
|
extremities_to_request.append(bp.event_id)
|
||||||
|
@ -1331,7 +1332,13 @@ class FederationHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
events = await filter_events_for_server(
|
events = await filter_events_for_server(
|
||||||
self._storage_controllers, origin, self.server_name, events
|
self._storage_controllers,
|
||||||
|
origin,
|
||||||
|
self.server_name,
|
||||||
|
events,
|
||||||
|
redact=True,
|
||||||
|
filter_out_erased_senders=True,
|
||||||
|
filter_out_remote_partial_state_events=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return events
|
return events
|
||||||
|
@ -1362,7 +1369,13 @@ class FederationHandler:
|
||||||
await self._event_auth_handler.assert_host_in_room(event.room_id, origin)
|
await self._event_auth_handler.assert_host_in_room(event.room_id, origin)
|
||||||
|
|
||||||
events = await filter_events_for_server(
|
events = await filter_events_for_server(
|
||||||
self._storage_controllers, origin, self.server_name, [event]
|
self._storage_controllers,
|
||||||
|
origin,
|
||||||
|
self.server_name,
|
||||||
|
[event],
|
||||||
|
redact=True,
|
||||||
|
filter_out_erased_senders=True,
|
||||||
|
filter_out_remote_partial_state_events=True,
|
||||||
)
|
)
|
||||||
event = events[0]
|
event = events[0]
|
||||||
return event
|
return event
|
||||||
|
@ -1390,7 +1403,13 @@ class FederationHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
missing_events = await filter_events_for_server(
|
missing_events = await filter_events_for_server(
|
||||||
self._storage_controllers, origin, self.server_name, missing_events
|
self._storage_controllers,
|
||||||
|
origin,
|
||||||
|
self.server_name,
|
||||||
|
missing_events,
|
||||||
|
redact=True,
|
||||||
|
filter_out_erased_senders=True,
|
||||||
|
filter_out_remote_partial_state_events=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return missing_events
|
return missing_events
|
||||||
|
|
|
@ -14,7 +14,17 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import Collection, Dict, FrozenSet, List, Optional, Tuple
|
from typing import (
|
||||||
|
Collection,
|
||||||
|
Dict,
|
||||||
|
FrozenSet,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from typing_extensions import Final
|
from typing_extensions import Final
|
||||||
|
@ -565,29 +575,43 @@ async def filter_events_for_server(
|
||||||
storage: StorageControllers,
|
storage: StorageControllers,
|
||||||
target_server_name: str,
|
target_server_name: str,
|
||||||
local_server_name: str,
|
local_server_name: str,
|
||||||
events: List[EventBase],
|
events: Sequence[EventBase],
|
||||||
redact: bool = True,
|
*,
|
||||||
check_history_visibility_only: bool = False,
|
redact: bool,
|
||||||
|
filter_out_erased_senders: bool,
|
||||||
|
filter_out_remote_partial_state_events: bool,
|
||||||
) -> List[EventBase]:
|
) -> List[EventBase]:
|
||||||
"""Filter a list of events based on whether given server is allowed to
|
"""Filter a list of events based on whether the target server is allowed to
|
||||||
see them.
|
see them.
|
||||||
|
|
||||||
|
For a fully stated room, the target server is allowed to see an event E if:
|
||||||
|
- the state at E has world readable or shared history vis, OR
|
||||||
|
- the state at E says that the target server is in the room.
|
||||||
|
|
||||||
|
For a partially stated room, the target server is allowed to see E if:
|
||||||
|
- E was created by this homeserver, AND:
|
||||||
|
- the partial state at E has world readable or shared history vis, OR
|
||||||
|
- the partial state at E says that the target server is in the room.
|
||||||
|
|
||||||
|
TODO: state before or state after?
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
storage
|
storage
|
||||||
server_name
|
target_server_name
|
||||||
|
local_server_name
|
||||||
events
|
events
|
||||||
redact: Whether to return a redacted version of the event, or
|
redact: Controls what to do with events which have been filtered out.
|
||||||
to filter them out entirely.
|
If True, include their redacted forms; if False, omit them entirely.
|
||||||
check_history_visibility_only: Whether to only check the
|
filter_out_erased_senders: If true, also filter out events whose sender has been
|
||||||
history visibility, rather than things like if the sender has been
|
|
||||||
erased. This is used e.g. during pagination to decide whether to
|
erased. This is used e.g. during pagination to decide whether to
|
||||||
backfill or not.
|
backfill or not.
|
||||||
|
filter_out_remote_partial_state_events: If True, also filter out events in
|
||||||
|
partial state rooms created by other homeservers.
|
||||||
Returns
|
Returns
|
||||||
The filtered events.
|
The filtered events.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def is_sender_erased(event: EventBase, erased_senders: Dict[str, bool]) -> bool:
|
def is_sender_erased(event: EventBase, erased_senders: Mapping[str, bool]) -> bool:
|
||||||
if erased_senders and erased_senders[event.sender]:
|
if erased_senders and erased_senders[event.sender]:
|
||||||
logger.info("Sender of %s has been erased, redacting", event.event_id)
|
logger.info("Sender of %s has been erased, redacting", event.event_id)
|
||||||
return True
|
return True
|
||||||
|
@ -616,7 +640,7 @@ async def filter_events_for_server(
|
||||||
# server has no users in the room: redact
|
# server has no users in the room: redact
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not check_history_visibility_only:
|
if filter_out_erased_senders:
|
||||||
erased_senders = await storage.main.are_users_erased(e.sender for e in events)
|
erased_senders = await storage.main.are_users_erased(e.sender for e in events)
|
||||||
else:
|
else:
|
||||||
# We don't want to check whether users are erased, which is equivalent
|
# We don't want to check whether users are erased, which is equivalent
|
||||||
|
@ -631,15 +655,15 @@ async def filter_events_for_server(
|
||||||
# otherwise a room could be fully joined after we retrieve those, which would then bypass
|
# otherwise a room could be fully joined after we retrieve those, which would then bypass
|
||||||
# this check but would base the filtering on an outdated view of the membership events.
|
# this check but would base the filtering on an outdated view of the membership events.
|
||||||
|
|
||||||
partial_state_invisible_events = set()
|
partial_state_invisible_event_ids: Set[str] = set()
|
||||||
if not check_history_visibility_only:
|
if filter_out_remote_partial_state_events:
|
||||||
for e in events:
|
for e in events:
|
||||||
sender_domain = get_domain_from_id(e.sender)
|
sender_domain = get_domain_from_id(e.sender)
|
||||||
if (
|
if (
|
||||||
sender_domain != local_server_name
|
sender_domain != local_server_name
|
||||||
and await storage.main.is_partial_state_room(e.room_id)
|
and await storage.main.is_partial_state_room(e.room_id)
|
||||||
):
|
):
|
||||||
partial_state_invisible_events.add(e)
|
partial_state_invisible_event_ids.add(e.event_id)
|
||||||
|
|
||||||
# Let's check to see if all the events have a history visibility
|
# Let's check to see if all the events have a history visibility
|
||||||
# of "shared" or "world_readable". If that's the case then we don't
|
# of "shared" or "world_readable". If that's the case then we don't
|
||||||
|
@ -658,17 +682,20 @@ async def filter_events_for_server(
|
||||||
target_server_name,
|
target_server_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
to_return = []
|
def include_event_in_output(e: EventBase) -> bool:
|
||||||
for e in events:
|
|
||||||
erased = is_sender_erased(e, erased_senders)
|
erased = is_sender_erased(e, erased_senders)
|
||||||
visible = check_event_is_visible(
|
visible = check_event_is_visible(
|
||||||
event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {})
|
event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {})
|
||||||
)
|
)
|
||||||
|
|
||||||
if e in partial_state_invisible_events:
|
if e.event_id in partial_state_invisible_event_ids:
|
||||||
visible = False
|
visible = False
|
||||||
|
|
||||||
if visible and not erased:
|
return visible and not erased
|
||||||
|
|
||||||
|
to_return = []
|
||||||
|
for e in events:
|
||||||
|
if include_event_in_output(e):
|
||||||
to_return.append(e)
|
to_return.append(e)
|
||||||
elif redact:
|
elif redact:
|
||||||
to_return.append(prune_event(e))
|
to_return.append(prune_event(e))
|
||||||
|
|
|
@ -63,7 +63,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
filtered = self.get_success(
|
filtered = self.get_success(
|
||||||
filter_events_for_server(
|
filter_events_for_server(
|
||||||
self._storage_controllers, "test_server", "hs", events_to_filter
|
self._storage_controllers,
|
||||||
|
"test_server",
|
||||||
|
"hs",
|
||||||
|
events_to_filter,
|
||||||
|
redact=True,
|
||||||
|
filter_out_erased_senders=True,
|
||||||
|
filter_out_remote_partial_state_events=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -85,7 +91,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.get_success(
|
self.get_success(
|
||||||
filter_events_for_server(
|
filter_events_for_server(
|
||||||
self._storage_controllers, "remote_hs", "hs", [outlier]
|
self._storage_controllers,
|
||||||
|
"remote_hs",
|
||||||
|
"hs",
|
||||||
|
[outlier],
|
||||||
|
redact=True,
|
||||||
|
filter_out_erased_senders=True,
|
||||||
|
filter_out_remote_partial_state_events=True,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
[outlier],
|
[outlier],
|
||||||
|
@ -96,7 +108,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
filtered = self.get_success(
|
filtered = self.get_success(
|
||||||
filter_events_for_server(
|
filter_events_for_server(
|
||||||
self._storage_controllers, "remote_hs", "local_hs", [outlier, evt]
|
self._storage_controllers,
|
||||||
|
"remote_hs",
|
||||||
|
"local_hs",
|
||||||
|
[outlier, evt],
|
||||||
|
redact=True,
|
||||||
|
filter_out_erased_senders=True,
|
||||||
|
filter_out_remote_partial_state_events=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
|
self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
|
||||||
|
@ -108,7 +126,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||||
# be redacted)
|
# be redacted)
|
||||||
filtered = self.get_success(
|
filtered = self.get_success(
|
||||||
filter_events_for_server(
|
filter_events_for_server(
|
||||||
self._storage_controllers, "other_server", "local_hs", [outlier, evt]
|
self._storage_controllers,
|
||||||
|
"other_server",
|
||||||
|
"local_hs",
|
||||||
|
[outlier, evt],
|
||||||
|
redact=True,
|
||||||
|
filter_out_erased_senders=True,
|
||||||
|
filter_out_remote_partial_state_events=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(filtered[0], outlier)
|
self.assertEqual(filtered[0], outlier)
|
||||||
|
@ -143,7 +167,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||||
# ... and the filtering happens.
|
# ... and the filtering happens.
|
||||||
filtered = self.get_success(
|
filtered = self.get_success(
|
||||||
filter_events_for_server(
|
filter_events_for_server(
|
||||||
self._storage_controllers, "test_server", "local_hs", events_to_filter
|
self._storage_controllers,
|
||||||
|
"test_server",
|
||||||
|
"local_hs",
|
||||||
|
events_to_filter,
|
||||||
|
redact=True,
|
||||||
|
filter_out_erased_senders=True,
|
||||||
|
filter_out_remote_partial_state_events=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue