diff --git a/changelog.d/13031.feature b/changelog.d/13031.feature new file mode 100644 index 0000000000..fee8e9d1ff --- /dev/null +++ b/changelog.d/13031.feature @@ -0,0 +1 @@ +Implement [MSC3827](https://github.com/matrix-org/matrix-spec-proposals/pull/3827): Filtering of /publicRooms by room type. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index e1d31cabed..2653764119 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -259,3 +259,13 @@ class ReceiptTypes: READ: Final = "m.read" READ_PRIVATE: Final = "org.matrix.msc2285.read.private" FULLY_READ: Final = "m.fully_read" + + +class PublicRoomsFilterFields: + """Fields in the search filter for `/publicRooms` that we understand. + + As defined in https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3publicrooms + """ + + GENERIC_SEARCH_TERM: Final = "generic_search_term" + ROOM_TYPES: Final = "org.matrix.msc3827.room_types" diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 0a285dba31..ee443cea00 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -87,3 +87,6 @@ class ExperimentalConfig(Config): # MSC3715: dir param on /relations. self.msc3715_enabled: bool = experimental.get("msc3715_enabled", False) + + # MSC3827: Filtering of /publicRooms by room type + self.msc3827_enabled: bool = experimental.get("msc3827_enabled", False) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 183d4ae3c4..29868eb743 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -25,6 +25,7 @@ from synapse.api.constants import ( GuestAccess, HistoryVisibility, JoinRules, + PublicRoomsFilterFields, ) from synapse.api.errors import ( Codes, @@ -181,6 +182,7 @@ class RoomListHandler: == HistoryVisibility.WORLD_READABLE, "guest_can_join": room["guest_access"] == "can_join", "join_rule": room["join_rules"], + "org.matrix.msc3827.room_type": room["room_type"], } # Filter out Nones – rather omit the field altogether @@ -239,7 +241,9 @@ class RoomListHandler: response["chunk"] = results response["total_room_count_estimate"] = await self.store.count_public_rooms( - network_tuple, ignore_non_federatable=from_federation + network_tuple, + ignore_non_federatable=from_federation, + search_filter=search_filter, ) return response @@ -508,8 +512,21 @@ class RoomListNextBatch: def _matches_room_entry(room_entry: JsonDict, search_filter: dict) -> bool: - if search_filter and search_filter.get("generic_search_term", None): - generic_search_term = search_filter["generic_search_term"].upper() + """Determines whether the given search filter matches a room entry returned over + federation. + + Only used if the remote server does not support MSC2197 remote-filtered search, and + hence does not support MSC3827 filtering of `/publicRooms` by room type either. + + In this case, we cannot apply the `room_type` filter since no `room_type` field is + returned. + """ + if search_filter and search_filter.get( + PublicRoomsFilterFields.GENERIC_SEARCH_TERM, None + ): + generic_search_term = search_filter[ + PublicRoomsFilterFields.GENERIC_SEARCH_TERM + ].upper() if generic_search_term in room_entry.get("name", "").upper(): return True elif generic_search_term in room_entry.get("topic", "").upper(): diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index f45e06eb0e..5c01482acf 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -271,6 +271,9 @@ class StatsHandler: room_state["is_federatable"] = ( event_content.get(EventContentFields.FEDERATE, True) is True ) + room_type = event_content.get(EventContentFields.ROOM_TYPE) + if isinstance(room_type, str): + room_state["room_type"] = room_type elif typ == EventTypes.JoinRules: room_state["join_rules"] = event_content.get("join_rule") elif typ == EventTypes.RoomHistoryVisibility: diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index c1bd775fec..f4f06563dd 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -95,6 +95,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled, # Supports receiving private read receipts as per MSC2285 "org.matrix.msc2285": self.config.experimental.msc2285_enabled, + # Supports filtering of /publicRooms by room type MSC3827 + "org.matrix.msc3827": self.config.experimental.msc3827_enabled, # Adds support for importing historical messages as per MSC2716 "org.matrix.msc2716": self.config.experimental.msc2716_enabled, # Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030 diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 5760d3428e..d8026e3fac 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -32,12 +32,17 @@ from typing import ( import attr -from synapse.api.constants import EventContentFields, EventTypes, JoinRules +from synapse.api.constants import ( + EventContentFields, + EventTypes, + JoinRules, + PublicRoomsFilterFields, +) from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.config.homeserver import HomeServerConfig from synapse.events import EventBase -from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, @@ -199,10 +204,29 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): desc="get_public_room_ids", ) + def _construct_room_type_where_clause( + self, room_types: Union[List[Union[str, None]], None] + ) -> Tuple[Union[str, None], List[str]]: + if not room_types or not self.config.experimental.msc3827_enabled: + return None, [] + else: + # We use None when we want get rooms without a type + is_null_clause = "" + if None in room_types: + is_null_clause = "OR room_type IS NULL" + room_types = [value for value in room_types if value is not None] + + list_clause, args = make_in_list_sql_clause( + self.database_engine, "room_type", room_types + ) + + return f"({list_clause} {is_null_clause})", args + async def count_public_rooms( self, network_tuple: Optional[ThirdPartyInstanceID], ignore_non_federatable: bool, + search_filter: Optional[dict], ) -> int: """Counts the number of public rooms as tracked in the room_stats_current and room_stats_state table. @@ -210,11 +234,20 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): Args: network_tuple ignore_non_federatable: If true filters out non-federatable rooms + search_filter """ def _count_public_rooms_txn(txn: LoggingTransaction) -> int: query_args = [] + room_type_clause, args = self._construct_room_type_where_clause( + search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None) + if search_filter + else None + ) + room_type_clause = f" AND {room_type_clause}" if room_type_clause else "" + query_args += args + if network_tuple: if network_tuple.appservice_id: published_sql = """ @@ -249,6 +282,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): OR join_rules = '{JoinRules.KNOCK_RESTRICTED}' OR history_visibility = 'world_readable' ) + {room_type_clause} AND joined_members > 0 """ @@ -347,8 +381,12 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): if ignore_non_federatable: where_clauses.append("is_federatable") - if search_filter and search_filter.get("generic_search_term", None): - search_term = "%" + search_filter["generic_search_term"] + "%" + if search_filter and search_filter.get( + PublicRoomsFilterFields.GENERIC_SEARCH_TERM, None + ): + search_term = ( + "%" + search_filter[PublicRoomsFilterFields.GENERIC_SEARCH_TERM] + "%" + ) where_clauses.append( """ @@ -365,6 +403,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): search_term.lower(), ] + room_type_clause, args = self._construct_room_type_where_clause( + search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None) + if search_filter + else None + ) + if room_type_clause: + where_clauses.append(room_type_clause) + query_args += args + where_clause = "" if where_clauses: where_clause = " AND " + " AND ".join(where_clauses) @@ -373,7 +420,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): sql = f""" SELECT room_id, name, topic, canonical_alias, joined_members, - avatar, history_visibility, guest_access, join_rules + avatar, history_visibility, guest_access, join_rules, room_type FROM ( {published_sql} ) published @@ -1166,6 +1213,7 @@ class _BackgroundUpdates: POPULATE_ROOM_DEPTH_MIN_DEPTH2 = "populate_room_depth_min_depth2" REPLACE_ROOM_DEPTH_MIN_DEPTH = "replace_room_depth_min_depth" POPULATE_ROOMS_CREATOR_COLUMN = "populate_rooms_creator_column" + ADD_ROOM_TYPE_COLUMN = "add_room_type_column" _REPLACE_ROOM_DEPTH_SQL_COMMANDS = ( @@ -1200,6 +1248,11 @@ class RoomBackgroundUpdateStore(SQLBaseStore): self._background_add_rooms_room_version_column, ) + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN, + self._background_add_room_type_column, + ) + # BG updates to change the type of room_depth.min_depth self.db_pool.updates.register_background_update_handler( _BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2, @@ -1569,6 +1622,69 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return batch_size + async def _background_add_room_type_column( + self, progress: JsonDict, batch_size: int + ) -> int: + """Background update to go and add room_type information to `room_stats_state` + table from `event_json` table. + """ + + last_room_id = progress.get("room_id", "") + + def _background_add_room_type_column_txn( + txn: LoggingTransaction, + ) -> bool: + sql = """ + SELECT state.room_id, json FROM event_json + INNER JOIN current_state_events AS state USING (event_id) + WHERE state.room_id > ? AND type = 'm.room.create' + ORDER BY state.room_id + LIMIT ? + """ + + txn.execute(sql, (last_room_id, batch_size)) + room_id_to_create_event_results = txn.fetchall() + + new_last_room_id = None + for room_id, event_json in room_id_to_create_event_results: + event_dict = db_to_json(event_json) + + room_type = event_dict.get("content", {}).get( + EventContentFields.ROOM_TYPE, None + ) + if isinstance(room_type, str): + self.db_pool.simple_update_txn( + txn, + table="room_stats_state", + keyvalues={"room_id": room_id}, + updatevalues={"room_type": room_type}, + ) + + new_last_room_id = room_id + + if new_last_room_id is None: + return True + + self.db_pool.updates._background_update_progress_txn( + txn, + _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN, + {"room_id": new_last_room_id}, + ) + + return False + + end = await self.db_pool.runInteraction( + "_background_add_room_type_column", + _background_add_room_type_column_txn, + ) + + if end: + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN + ) + + return batch_size + class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): def __init__( diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 82851ffa95..b4c652acf3 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -16,7 +16,7 @@ import logging from enum import Enum from itertools import chain -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast from typing_extensions import Counter @@ -238,6 +238,7 @@ class StatsStore(StateDeltasStore): * avatar * canonical_alias * guest_access + * room_type A is_federatable key can also be included with a boolean value. @@ -263,6 +264,7 @@ class StatsStore(StateDeltasStore): "avatar", "canonical_alias", "guest_access", + "room_type", ): field = fields.get(col, sentinel) if field is not sentinel and (not isinstance(field, str) or "\0" in field): @@ -572,7 +574,7 @@ class StatsStore(StateDeltasStore): state_event_map = await self.get_events(event_ids, get_prev_content=False) # type: ignore[attr-defined] - room_state = { + room_state: Dict[str, Union[None, bool, str]] = { "join_rules": None, "history_visibility": None, "encryption": None, @@ -581,6 +583,7 @@ class StatsStore(StateDeltasStore): "avatar": None, "canonical_alias": None, "is_federatable": True, + "room_type": None, } for event in state_event_map.values(): @@ -604,6 +607,9 @@ class StatsStore(StateDeltasStore): room_state["is_federatable"] = ( event.content.get(EventContentFields.FEDERATE, True) is True ) + room_type = event.content.get(EventContentFields.ROOM_TYPE) + if isinstance(room_type, str): + room_state["room_type"] = room_type await self.update_room_state(room_id, room_state) diff --git a/synapse/storage/schema/main/delta/72/01add_room_type_to_state_stats.sql b/synapse/storage/schema/main/delta/72/01add_room_type_to_state_stats.sql new file mode 100644 index 0000000000..d5e0765471 --- /dev/null +++ b/synapse/storage/schema/main/delta/72/01add_room_type_to_state_stats.sql @@ -0,0 +1,19 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE room_stats_state ADD room_type TEXT; + +INSERT INTO background_updates (update_name, progress_json) + VALUES ('add_room_type_column', '{}'); diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 35c59ee9e0..1ccd96a207 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -18,7 +18,7 @@ """Tests REST events for /rooms paths.""" import json -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from unittest.mock import Mock, call from urllib import parse as urlparse @@ -33,7 +33,9 @@ from synapse.api.constants import ( EventContentFields, EventTypes, Membership, + PublicRoomsFilterFields, RelationTypes, + RoomTypes, ) from synapse.api.errors import Codes, HttpResponseException from synapse.handlers.pagination import PurgeStatus @@ -1858,6 +1860,90 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) +class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + + config = self.default_config() + config["allow_public_rooms_without_auth"] = True + config["experimental_features"] = {"msc3827_enabled": True} + self.hs = self.setup_test_homeserver(config=config) + self.url = b"/_matrix/client/r0/publicRooms" + + return self.hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + user = self.register_user("alice", "pass") + self.token = self.login(user, "pass") + + # Create a room + self.helper.create_room_as( + user, + is_public=True, + extra_content={"visibility": "public"}, + tok=self.token, + ) + # Create a space + self.helper.create_room_as( + user, + is_public=True, + extra_content={ + "visibility": "public", + "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}, + }, + tok=self.token, + ) + + def make_public_rooms_request( + self, room_types: Union[List[Union[str, None]], None] + ) -> Tuple[List[Dict[str, Any]], int]: + channel = self.make_request( + "POST", + self.url, + {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}}, + self.token, + ) + chunk = channel.json_body["chunk"] + count = channel.json_body["total_room_count_estimate"] + + self.assertEqual(len(chunk), count) + + return chunk, count + + def test_returns_both_rooms_and_spaces_if_no_filter(self) -> None: + chunk, count = self.make_public_rooms_request(None) + + self.assertEqual(count, 2) + + def test_returns_only_rooms_based_on_filter(self) -> None: + chunk, count = self.make_public_rooms_request([None]) + + self.assertEqual(count, 1) + self.assertEqual(chunk[0].get("org.matrix.msc3827.room_type", None), None) + + def test_returns_only_space_based_on_filter(self) -> None: + chunk, count = self.make_public_rooms_request(["m.space"]) + + self.assertEqual(count, 1) + self.assertEqual(chunk[0].get("org.matrix.msc3827.room_type", None), "m.space") + + def test_returns_both_rooms_and_space_based_on_filter(self) -> None: + chunk, count = self.make_public_rooms_request(["m.space", None]) + + self.assertEqual(count, 2) + + def test_returns_both_rooms_and_spaces_if_array_is_empty(self) -> None: + chunk, count = self.make_public_rooms_request([]) + + self.assertEqual(count, 2) + + class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): """Test that we correctly fallback to local filtering if a remote server doesn't support search. @@ -1882,7 +1968,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): "Simple test for searching rooms over federation" self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined] - search_filter = {"generic_search_term": "foobar"} + search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"} channel = self.make_request( "POST", @@ -1911,7 +1997,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): make_awaitable({}), ) - search_filter = {"generic_search_term": "foobar"} + search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"} channel = self.make_request( "POST", diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py index 9abd0cb446..1edb619630 100644 --- a/tests/storage/databases/main/test_room.py +++ b/tests/storage/databases/main/test_room.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json + +from synapse.api.constants import RoomTypes from synapse.rest import admin from synapse.rest.client import login, room from synapse.storage.databases.main.room import _BackgroundUpdates @@ -91,3 +94,69 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): ) ) self.assertEqual(room_creator_after, self.user_id) + + def test_background_add_room_type_column(self): + """Test that the background update to populate the `room_type` column in + `room_stats_state` works properly. + """ + + # Create a room without a type + room_id = self._generate_room() + + # Get event_id of the m.room.create event + event_id = self.get_success( + self.store.db_pool.simple_select_one_onecol( + table="current_state_events", + keyvalues={ + "room_id": room_id, + "type": "m.room.create", + }, + retcol="event_id", + ) + ) + + # Fake a room creation event with a room type + event = { + "content": { + "creator": "@user:server.org", + "room_version": "9", + "type": RoomTypes.SPACE, + }, + "type": "m.room.create", + } + self.get_success( + self.store.db_pool.simple_update( + table="event_json", + keyvalues={"event_id": event_id}, + updatevalues={"json": json.dumps(event)}, + desc="test", + ) + ) + + # Insert and run the background update + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + { + "update_name": _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN, + "progress_json": "{}", + }, + ) + ) + + # ... and tell the DataStore that it hasn't finished all updates yet + self.store.db_pool.updates._all_done = False + + # Now let's actually drive the updates to completion + self.wait_for_background_updates() + + # Make sure the background update filled in the room type + room_type_after = self.get_success( + self.store.db_pool.simple_select_one_onecol( + table="room_stats_state", + keyvalues={"room_id": room_id}, + retcol="room_type", + allow_none=True, + ) + ) + self.assertEqual(room_type_after, RoomTypes.SPACE)